beswarm 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/aient/src/aient/core/request.py +8 -6
- beswarm/aient/src/aient/core/response.py +31 -22
- beswarm/aient/src/aient/core/test/test_geminimask.py +330 -0
- beswarm/aient/src/aient/core/utils.py +14 -7
- beswarm/aient/src/aient/plugins/excute_command.py +88 -19
- beswarm/tools/UIworker.py +145 -0
- beswarm/tools/__init__.py +10 -0
- beswarm/tools/click.py +456 -0
- {beswarm-0.1.33.dist-info → beswarm-0.1.35.dist-info}/METADATA +21 -4
- {beswarm-0.1.33.dist-info → beswarm-0.1.35.dist-info}/RECORD +12 -9
- {beswarm-0.1.33.dist-info → beswarm-0.1.35.dist-info}/WHEEL +1 -1
- {beswarm-0.1.33.dist-info → beswarm-0.1.35.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,10 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
|
|
31
31
|
model_dict = get_model_dict(provider)
|
32
32
|
original_model = model_dict[request.model]
|
33
33
|
|
34
|
-
|
34
|
+
if request.stream:
|
35
|
+
gemini_stream = "streamGenerateContent"
|
36
|
+
else:
|
37
|
+
gemini_stream = "generateContent"
|
35
38
|
url = provider['base_url']
|
36
39
|
parsed_url = urllib.parse.urlparse(url)
|
37
40
|
if "/v1beta" in parsed_url.path:
|
@@ -295,7 +298,10 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
295
298
|
if provider.get("project_id"):
|
296
299
|
project_id = provider.get("project_id")
|
297
300
|
|
298
|
-
|
301
|
+
if request.stream:
|
302
|
+
gemini_stream = "streamGenerateContent"
|
303
|
+
else:
|
304
|
+
gemini_stream = "generateContent"
|
299
305
|
model_dict = get_model_dict(provider)
|
300
306
|
original_model = model_dict[request.model]
|
301
307
|
search_tool = None
|
@@ -941,10 +947,6 @@ async def get_gpt_payload(request, engine, provider, api_key=None):
|
|
941
947
|
payload.pop("tools", None)
|
942
948
|
payload.pop("tool_choice", None)
|
943
949
|
|
944
|
-
# if "models.inference.ai.azure.com" in url:
|
945
|
-
# payload["stream"] = False
|
946
|
-
# payload.pop("stream_options", None)
|
947
|
-
|
948
950
|
if "api.x.ai" in url:
|
949
951
|
payload.pop("stream_options", None)
|
950
952
|
payload.pop("presence_penalty", None)
|
@@ -35,6 +35,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
35
35
|
promptTokenCount = 0
|
36
36
|
candidatesTokenCount = 0
|
37
37
|
totalTokenCount = 0
|
38
|
+
parts_json = ""
|
38
39
|
# line_index = 0
|
39
40
|
# last_text_line = 0
|
40
41
|
# if "thinking" in model:
|
@@ -63,24 +64,25 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
63
64
|
json_data = parse_json_safely( "{" + line + "}")
|
64
65
|
totalTokenCount = json_data.get('totalTokenCount', 0)
|
65
66
|
|
66
|
-
|
67
|
-
|
67
|
+
if (line and '"parts": [' in line or parts_json != "") and is_finish == False:
|
68
|
+
parts_json += line
|
69
|
+
if parts_json != "" and line and '],' == line.strip():
|
70
|
+
parts_json = "{" + parts_json.strip().rstrip(",} ]}") + "}]}"
|
68
71
|
try:
|
69
|
-
json_data = json.loads(
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
yield sse_string
|
72
|
+
json_data = json.loads(parts_json)
|
73
|
+
|
74
|
+
content = safe_get(json_data, "parts", 0, "text", default="")
|
75
|
+
|
76
|
+
is_thinking = safe_get(json_data, "parts", 0, "thought", default=False)
|
77
|
+
if is_thinking:
|
78
|
+
sse_string = await generate_sse_response(timestamp, model, reasoning_content=content)
|
79
|
+
yield sse_string
|
80
|
+
else:
|
81
|
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
82
|
+
yield sse_string
|
81
83
|
except json.JSONDecodeError:
|
82
|
-
logger.error(f"无法解析JSON: {
|
83
|
-
|
84
|
+
logger.error(f"无法解析JSON: {parts_json}")
|
85
|
+
parts_json = ""
|
84
86
|
|
85
87
|
if line and ('\"functionCall\": {' in line or revicing_function_call):
|
86
88
|
revicing_function_call = True
|
@@ -142,7 +144,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
142
144
|
|
143
145
|
if line and '\"text\": \"' in line and is_finish == False:
|
144
146
|
try:
|
145
|
-
json_data = json.loads( "{" + line + "}")
|
147
|
+
json_data = json.loads( "{" + line.strip().rstrip(",") + "}")
|
146
148
|
content = json_data.get('text', '')
|
147
149
|
sse_string = await generate_sse_response(timestamp, model, content=content)
|
148
150
|
yield sse_string
|
@@ -525,21 +527,28 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
525
527
|
parsed_data = ast.literal_eval(str(response_json))
|
526
528
|
elif isinstance(response_json, list):
|
527
529
|
parsed_data = response_json
|
530
|
+
elif isinstance(response_json, dict):
|
531
|
+
parsed_data = [response_json]
|
528
532
|
else:
|
529
533
|
logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
|
530
534
|
parsed_data = response_json
|
531
535
|
# print("parsed_data", json.dumps(parsed_data, indent=4, ensure_ascii=False))
|
532
536
|
content = ""
|
537
|
+
reasoning_content = ""
|
533
538
|
for item in parsed_data:
|
534
539
|
chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
|
540
|
+
is_think = safe_get(item, "candidates", 0, "content", "parts", 0, "thought", default=False)
|
535
541
|
# logger.info(f"chunk: {repr(chunk)}")
|
536
542
|
if chunk:
|
537
|
-
|
543
|
+
if is_think:
|
544
|
+
reasoning_content += chunk
|
545
|
+
else:
|
546
|
+
content += chunk
|
538
547
|
|
539
548
|
usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
|
540
|
-
prompt_tokens = usage_metadata
|
541
|
-
candidates_tokens = usage_metadata
|
542
|
-
total_tokens = usage_metadata
|
549
|
+
prompt_tokens = safe_get(usage_metadata, "promptTokenCount", default=0)
|
550
|
+
candidates_tokens = safe_get(usage_metadata, "candidatesTokenCount", default=0)
|
551
|
+
total_tokens = safe_get(usage_metadata, "totalTokenCount", default=0)
|
543
552
|
|
544
553
|
role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
|
545
554
|
if role == "model":
|
@@ -552,7 +561,7 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
552
561
|
function_call_content = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "args", default=None)
|
553
562
|
|
554
563
|
timestamp = int(datetime.timestamp(datetime.now()))
|
555
|
-
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
|
564
|
+
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens, reasoning_content=reasoning_content)
|
556
565
|
|
557
566
|
elif engine == "claude":
|
558
567
|
response_json = response.json()
|
@@ -0,0 +1,330 @@
|
|
1
|
+
import httpx
|
2
|
+
import base64
|
3
|
+
import os # 用于处理 API 密钥
|
4
|
+
import asyncio
|
5
|
+
import json
|
6
|
+
import re
|
7
|
+
# python -m core.test.test_geminimask
|
8
|
+
from ..utils import get_image_message, safe_get
|
9
|
+
|
10
|
+
# --- 请替换为您的实际值 ---
|
11
|
+
MODEL_NAME = "gemini-2.5-flash-preview-04-17"
|
12
|
+
API_KEY = os.environ.get("GEMINI_API_KEY") # 从环境变量读取密钥
|
13
|
+
IMAGE_PATH = os.environ.get("IMAGE_PATH")
|
14
|
+
# --------------------------
|
15
|
+
|
16
|
+
# 检查 API 密钥是否存在
|
17
|
+
if not API_KEY:
|
18
|
+
raise ValueError("请设置 GOOGLE_API_KEY 环境变量或在代码中提供 API 密钥。")
|
19
|
+
|
20
|
+
# 确定图片的 MIME 类型
|
21
|
+
# 您可以根据文件扩展名进行猜测,或者使用更可靠的库如 python-magic
|
22
|
+
if IMAGE_PATH.lower().endswith(".png"):
|
23
|
+
IMAGE_MIME_TYPE = "image/png"
|
24
|
+
elif IMAGE_PATH.lower().endswith(".jpg") or IMAGE_PATH.lower().endswith(".jpeg"):
|
25
|
+
IMAGE_MIME_TYPE = "image/jpeg"
|
26
|
+
# 添加其他您需要支持的图片类型
|
27
|
+
else:
|
28
|
+
raise ValueError(f"不支持的图片格式: {IMAGE_PATH}")
|
29
|
+
|
30
|
+
# 读取图片文件并进行 Base64 编码
|
31
|
+
try:
|
32
|
+
with open(IMAGE_PATH, "rb") as image_file:
|
33
|
+
image_data = image_file.read()
|
34
|
+
base64_encoded_image = base64.b64encode(image_data).decode("utf-8")
|
35
|
+
# print(base64_encoded_image)
|
36
|
+
except FileNotFoundError:
|
37
|
+
print(f"错误:找不到图片文件 '{IMAGE_PATH}'")
|
38
|
+
exit()
|
39
|
+
|
40
|
+
# image_message = get_image_message(base64_encoded_image, "gemini")
|
41
|
+
image_message = asyncio.run(get_image_message(f"data:{IMAGE_MIME_TYPE};base64," + base64_encoded_image, "gemini"))
|
42
|
+
|
43
|
+
# 构建请求 URL
|
44
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_NAME}:generateContent"
|
45
|
+
|
46
|
+
prompt = "Give the segmentation masks for the Search box. Output a JSON list of segmentation masks where each entry contains the 2D bounding box in \"box_2d\" and the mask in \"mask\"."
|
47
|
+
# 定义查询参数 (API Key)
|
48
|
+
params = {"key": API_KEY}
|
49
|
+
|
50
|
+
# 定义请求头
|
51
|
+
headers = {"Content-Type": "application/json"}
|
52
|
+
|
53
|
+
# 定义请求体 (JSON payload)
|
54
|
+
payload = {
|
55
|
+
"contents": [
|
56
|
+
{
|
57
|
+
"parts": [
|
58
|
+
{
|
59
|
+
"text": prompt
|
60
|
+
},
|
61
|
+
image_message
|
62
|
+
]
|
63
|
+
}
|
64
|
+
],
|
65
|
+
"generationConfig": {"thinkingConfig": {"thinkingBudget": 0}},
|
66
|
+
}
|
67
|
+
|
68
|
+
# 发送 POST 请求
|
69
|
+
try:
|
70
|
+
with httpx.Client() as client:
|
71
|
+
response = client.post(url, params=params, headers=headers, json=payload, timeout=60.0) # 增加超时时间
|
72
|
+
response.raise_for_status() # 如果状态码不是 2xx,则抛出异常
|
73
|
+
|
74
|
+
# 您可以在这里添加代码来解析 response.json() 并提取分割掩码
|
75
|
+
text = safe_get(response.json(), "candidates", 0, "content", "parts", 0, "text")
|
76
|
+
# print(text)
|
77
|
+
# 例如: segmentation_masks = response.json().get('candidates', [{}])[0].get('content', {}).get('parts', [{}])[0].get('text')
|
78
|
+
|
79
|
+
except httpx.HTTPStatusError as exc:
|
80
|
+
print(f"HTTP 错误发生: {exc.response.status_code} - {exc.response.text}")
|
81
|
+
except httpx.RequestError as exc:
|
82
|
+
print(f"发送请求时出错: {exc}")
|
83
|
+
except Exception as e:
|
84
|
+
print(f"发生意外错误: {e}")
|
85
|
+
|
86
|
+
|
87
|
+
regex_pattern = r'(\[\s*\{.*?\}\s*\])' # 匹配包含至少一个对象的数组
|
88
|
+
|
89
|
+
# 使用 re.search 查找第一个匹配项,re.DOTALL 使点号能匹配换行符
|
90
|
+
match = re.search(regex_pattern, text, re.DOTALL)
|
91
|
+
|
92
|
+
if match:
|
93
|
+
# 提取匹配到的整个 JSON 数组字符串 (group 1 因为模式中有括号)
|
94
|
+
json_string = match.group(1)
|
95
|
+
|
96
|
+
try:
|
97
|
+
# 使用 json.loads() 解析字符串
|
98
|
+
parsed_data = json.loads(json_string)
|
99
|
+
# 使用 json.dumps 美化打印输出
|
100
|
+
print(json.dumps(parsed_data, indent=2, ensure_ascii=False))
|
101
|
+
|
102
|
+
# 例如,获取第一个元素的 label
|
103
|
+
if isinstance(parsed_data, list) and len(parsed_data) > 0:
|
104
|
+
first_item = parsed_data[0]
|
105
|
+
if isinstance(first_item, dict):
|
106
|
+
label = first_item.get('label')
|
107
|
+
print(f"\n第一个元素的 label 是: {label}")
|
108
|
+
|
109
|
+
except json.JSONDecodeError as e:
|
110
|
+
print(f"JSON 解析错误: {e}")
|
111
|
+
print(f"出错的字符串是: {json_string}")
|
112
|
+
else:
|
113
|
+
print("在文本中未找到匹配的 JSON 数组。")
|
114
|
+
|
115
|
+
|
116
|
+
import io
|
117
|
+
from PIL import Image, ImageDraw, ImageFont # pip install Pillow
|
118
|
+
|
119
|
+
def extract_box_and_mask_py(parsed_data):
|
120
|
+
"""
|
121
|
+
从已解析的 JSON 数据中提取边界框 (box_2d) 和 Base64 编码的掩码 (mask) 数据。
|
122
|
+
|
123
|
+
Args:
|
124
|
+
parsed_data (list): 一个包含字典的列表,每个字典至少包含 'box_2d' 和 'mask' 键。
|
125
|
+
例如: [{'box_2d': [y1, x1, y2, x2], 'mask': 'data:image/png;base64,...'}, ...]
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
list: 一个包含字典的列表,每个字典包含 'box' (坐标列表)
|
129
|
+
和 'mask_base64' (Base64 字符串)。
|
130
|
+
例如: [{'box': [y1, x1, y2, x2], 'mask_base64': '...'}, ...]
|
131
|
+
坐标系假定为 0-1000 范围。
|
132
|
+
"""
|
133
|
+
# 不再需要正则表达式
|
134
|
+
results = []
|
135
|
+
# 检查 parsed_data 是否为列表
|
136
|
+
if not isinstance(parsed_data, list):
|
137
|
+
print(f"Error: Input data is not a list. Received type: {type(parsed_data)}")
|
138
|
+
return results
|
139
|
+
|
140
|
+
for item in parsed_data:
|
141
|
+
if not isinstance(item, dict):
|
142
|
+
print(f"Skipping non-dictionary item in list: {item}")
|
143
|
+
continue
|
144
|
+
|
145
|
+
try:
|
146
|
+
box = item.get('box_2d')
|
147
|
+
mask_data_uri = item.get('mask')
|
148
|
+
|
149
|
+
# 检查 'box_2d' 和 'mask' 是否存在且不为 None
|
150
|
+
if box is None or mask_data_uri is None:
|
151
|
+
print(f"Skipping item due to missing 'box_2d' or 'mask': {item}")
|
152
|
+
continue
|
153
|
+
|
154
|
+
# 从 mask 数据 URI 中提取 Base64 部分
|
155
|
+
# 格式: "data:image/[^;]+;base64,..."
|
156
|
+
if isinstance(mask_data_uri, str) and mask_data_uri.startswith('data:image/') and ';base64,' in mask_data_uri:
|
157
|
+
mask_b64 = mask_data_uri.split(';base64,', 1)[1]
|
158
|
+
else:
|
159
|
+
print(f"Skipping item due to invalid mask format: {mask_data_uri}")
|
160
|
+
continue
|
161
|
+
|
162
|
+
# 验证 box 数据
|
163
|
+
if isinstance(box, list) and len(box) == 4 and all(isinstance(n, int) for n in box):
|
164
|
+
results.append({"box": box, "mask_base64": mask_b64})
|
165
|
+
else:
|
166
|
+
print(f"Skipping invalid box format: {box}")
|
167
|
+
|
168
|
+
# 捕捉可能的 KeyError 或其他在字典访问/处理中发生的错误
|
169
|
+
except Exception as e:
|
170
|
+
print(f"Error processing item: {item}. Error: {e}")
|
171
|
+
|
172
|
+
return results
|
173
|
+
|
174
|
+
def display_image_with_bounding_boxes_and_masks_py(
|
175
|
+
original_image_path,
|
176
|
+
box_and_mask_data,
|
177
|
+
output_overlay_path="overlay_image.png",
|
178
|
+
output_compare_dir="comparison_outputs"
|
179
|
+
):
|
180
|
+
"""
|
181
|
+
在原始图像上绘制边界框和掩码,并生成裁剪区域与掩码的对比图。
|
182
|
+
|
183
|
+
Args:
|
184
|
+
original_image_path (str): 原始图像的文件路径。
|
185
|
+
box_and_mask_data (list): extract_box_and_mask_py 的输出列表。
|
186
|
+
output_overlay_path (str): 保存带有叠加效果的图像的路径。
|
187
|
+
output_compare_dir (str): 保存对比图像的目录路径。
|
188
|
+
"""
|
189
|
+
try:
|
190
|
+
img_original = Image.open(original_image_path).convert("RGBA")
|
191
|
+
img_width, img_height = img_original.size
|
192
|
+
except FileNotFoundError:
|
193
|
+
print(f"Error: Original image not found at {original_image_path}")
|
194
|
+
return
|
195
|
+
except Exception as e:
|
196
|
+
print(f"Error opening original image: {e}")
|
197
|
+
return
|
198
|
+
|
199
|
+
# 创建一个副本用于绘制叠加效果
|
200
|
+
img_overlay = img_original.copy()
|
201
|
+
draw = ImageDraw.Draw(img_overlay, "RGBA") # 使用 RGBA 模式以支持透明度
|
202
|
+
|
203
|
+
# 定义颜色列表
|
204
|
+
colors_hex = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF']
|
205
|
+
# 将十六进制颜色转换为 RGBA 元组 (用于绘制)
|
206
|
+
colors_rgba = []
|
207
|
+
for hex_color in colors_hex:
|
208
|
+
h = hex_color.lstrip('#')
|
209
|
+
rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
|
210
|
+
colors_rgba.append(rgb + (255,)) # (R, G, B, Alpha) - 边框完全不透明
|
211
|
+
|
212
|
+
# 创建输出目录(如果不存在)
|
213
|
+
import os
|
214
|
+
os.makedirs(output_compare_dir, exist_ok=True)
|
215
|
+
|
216
|
+
print(f"Found {len(box_and_mask_data)} box/mask pairs to process.")
|
217
|
+
|
218
|
+
for i, data in enumerate(box_and_mask_data):
|
219
|
+
box_0_1000 = data['box'] # [ymin, xmin, ymax, xmax] in 0-1000 range
|
220
|
+
mask_b64 = data['mask_base64']
|
221
|
+
color_index = i % len(colors_rgba)
|
222
|
+
outline_color = colors_rgba[color_index]
|
223
|
+
# 叠加掩码时使用半透明颜色
|
224
|
+
mask_fill_color = outline_color[:3] + (int(255 * 0.7),) # 70% Alpha
|
225
|
+
|
226
|
+
# --- 1. 坐标转换与验证 ---
|
227
|
+
# 将 0-1000 坐标转换为图像像素坐标 (left, top, right, bottom)
|
228
|
+
# 假设 box 是 [ymin, xmin, ymax, xmax]
|
229
|
+
try:
|
230
|
+
ymin_norm, xmin_norm, ymax_norm, xmax_norm = [c / 1000.0 for c in box_0_1000]
|
231
|
+
|
232
|
+
left = int(xmin_norm * img_width)
|
233
|
+
top = int(ymin_norm * img_height)
|
234
|
+
right = int(xmax_norm * img_width)
|
235
|
+
bottom = int(ymax_norm * img_height)
|
236
|
+
|
237
|
+
# 确保坐标在图像范围内且有效
|
238
|
+
left = max(0, left)
|
239
|
+
top = max(0, top)
|
240
|
+
right = min(img_width, right)
|
241
|
+
bottom = min(img_height, bottom)
|
242
|
+
|
243
|
+
box_width_px = right - left
|
244
|
+
box_height_px = bottom - top
|
245
|
+
|
246
|
+
if box_width_px <= 0 or box_height_px <= 0:
|
247
|
+
print(f"Skipping box {i+1} due to zero or negative dimensions after conversion.")
|
248
|
+
continue
|
249
|
+
|
250
|
+
except Exception as e:
|
251
|
+
print(f"Error processing coordinates for box {i+1}: {box_0_1000}. Error: {e}")
|
252
|
+
continue
|
253
|
+
|
254
|
+
print(f"Processing Box {i+1}: Pixels(L,T,R,B)=({left},{top},{right},{bottom}) Color={colors_hex[color_index]}")
|
255
|
+
|
256
|
+
# --- 2. 在叠加图像上绘制边界框 ---
|
257
|
+
try:
|
258
|
+
draw.rectangle([left, top, right, bottom], outline=outline_color, width=5)
|
259
|
+
except Exception as e:
|
260
|
+
print(f"Error drawing rectangle for box {i+1}: {e}")
|
261
|
+
continue
|
262
|
+
|
263
|
+
# --- 3. 处理并绘制掩码 ---
|
264
|
+
try:
|
265
|
+
# 解码 Base64 掩码数据
|
266
|
+
mask_bytes = base64.b64decode(mask_b64)
|
267
|
+
mask_img_raw = Image.open(io.BytesIO(mask_bytes)).convert("RGBA")
|
268
|
+
|
269
|
+
# 将掩码图像缩放到边界框的像素尺寸
|
270
|
+
mask_img_resized = mask_img_raw.resize((box_width_px, box_height_px), Image.Resampling.NEAREST)
|
271
|
+
|
272
|
+
# 创建一个纯色块,应用掩码的 alpha 通道
|
273
|
+
color_block = Image.new('RGBA', mask_img_resized.size, mask_fill_color)
|
274
|
+
|
275
|
+
# 将带有透明度的颜色块粘贴到叠加图像上,使用掩码的 alpha 通道作为粘贴蒙版
|
276
|
+
# mask_img_resized.split()[-1] 提取 alpha 通道
|
277
|
+
img_overlay.paste(color_block, (left, top), mask=mask_img_resized.split()[-1])
|
278
|
+
|
279
|
+
except base64.binascii.Error:
|
280
|
+
print(f"Error: Invalid Base64 data for mask {i+1}.")
|
281
|
+
continue
|
282
|
+
except Exception as e:
|
283
|
+
print(f"Error processing or drawing mask for box {i+1}: {e}")
|
284
|
+
continue
|
285
|
+
|
286
|
+
# --- 4. 生成对比图 ---
|
287
|
+
try:
|
288
|
+
# 从原始图像中裁剪出边界框区域
|
289
|
+
img_crop = img_original.crop((left, top, right, bottom))
|
290
|
+
|
291
|
+
# 准备掩码预览图(使用原始解码后的掩码,调整大小以匹配裁剪区域)
|
292
|
+
# 这里直接使用缩放后的 mask_img_resized 的 RGB 部分可能更直观
|
293
|
+
mask_preview = mask_img_resized.convert("RGB") # 转换为 RGB 以便保存为常见格式
|
294
|
+
|
295
|
+
# 保存裁剪图和掩码预览图
|
296
|
+
crop_filename = os.path.join(output_compare_dir, f"compare_{i+1}_crop.png")
|
297
|
+
mask_filename = os.path.join(output_compare_dir, f"compare_{i+1}_mask.png")
|
298
|
+
img_crop.save(crop_filename)
|
299
|
+
mask_preview.save(mask_filename)
|
300
|
+
print(f" - Saved comparison: {crop_filename}, {mask_filename}")
|
301
|
+
|
302
|
+
except Exception as e:
|
303
|
+
print(f"Error creating or saving comparison images for box {i+1}: {e}")
|
304
|
+
|
305
|
+
# --- 5. 保存最终的叠加图像 ---
|
306
|
+
try:
|
307
|
+
img_overlay.save(output_overlay_path)
|
308
|
+
print(f"\nOverlay image saved to: {output_overlay_path}")
|
309
|
+
print(f"Comparison images saved in: {output_compare_dir}")
|
310
|
+
except Exception as e:
|
311
|
+
print(f"Error saving the final overlay image: {e}")
|
312
|
+
|
313
|
+
|
314
|
+
extracted_data = extract_box_and_mask_py(parsed_data)
|
315
|
+
|
316
|
+
if extracted_data:
|
317
|
+
# 确保原始图像存在
|
318
|
+
import os
|
319
|
+
if os.path.exists(IMAGE_PATH):
|
320
|
+
display_image_with_bounding_boxes_and_masks_py(
|
321
|
+
IMAGE_PATH,
|
322
|
+
extracted_data,
|
323
|
+
output_overlay_path="python_overlay_output.png", # 输出带叠加效果的图片名
|
324
|
+
output_compare_dir="python_comparison_outputs" # 输出对比图的文件夹名
|
325
|
+
)
|
326
|
+
else:
|
327
|
+
print(f"Error: Cannot proceed with visualization, image file not found: {IMAGE_PATH}")
|
328
|
+
print("Please update the 'IMAGE_PATH' variable in the script.")
|
329
|
+
else:
|
330
|
+
print("No valid box and mask data found in the response text.")
|
@@ -70,7 +70,7 @@ def get_engine(provider, endpoint=None, original_model=""):
|
|
70
70
|
engine = "gemini"
|
71
71
|
elif parsed_url.netloc.rstrip('/').endswith('aiplatform.googleapis.com') or (parsed_url.netloc.rstrip('/').endswith('gateway.ai.cloudflare.com') and "google-vertex-ai" in parsed_url.path):
|
72
72
|
engine = "vertex"
|
73
|
-
elif parsed_url.netloc.rstrip('/').endswith('
|
73
|
+
elif parsed_url.netloc.rstrip('/').endswith('azure.com'):
|
74
74
|
engine = "azure"
|
75
75
|
elif parsed_url.netloc == 'api.cloudflare.com':
|
76
76
|
engine = "cloudflare"
|
@@ -127,6 +127,9 @@ def get_engine(provider, endpoint=None, original_model=""):
|
|
127
127
|
engine = "tts"
|
128
128
|
stream = False
|
129
129
|
|
130
|
+
if "stream" in safe_get(provider, "preferences", "post_body_parameter_overrides", default={}):
|
131
|
+
stream = safe_get(provider, "preferences", "post_body_parameter_overrides", "stream")
|
132
|
+
|
130
133
|
return engine, stream
|
131
134
|
|
132
135
|
from httpx_socks import AsyncProxyTransport
|
@@ -484,9 +487,17 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
484
487
|
|
485
488
|
return sse_response
|
486
489
|
|
487
|
-
async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
|
490
|
+
async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0, reasoning_content=None):
|
488
491
|
random.seed(timestamp)
|
489
492
|
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
|
493
|
+
message = {
|
494
|
+
"role": role,
|
495
|
+
"content": content,
|
496
|
+
"refusal": None
|
497
|
+
}
|
498
|
+
if reasoning_content:
|
499
|
+
message["reasoning_content"] = reasoning_content
|
500
|
+
|
490
501
|
sample_data = {
|
491
502
|
"id": f"chatcmpl-{random_str}",
|
492
503
|
"object": "chat.completion",
|
@@ -495,11 +506,7 @@ async def generate_no_stream_response(timestamp, model, content=None, tools_id=N
|
|
495
506
|
"choices": [
|
496
507
|
{
|
497
508
|
"index": 0,
|
498
|
-
"message":
|
499
|
-
"role": role,
|
500
|
-
"content": content,
|
501
|
-
"refusal": None
|
502
|
-
},
|
509
|
+
"message": message,
|
503
510
|
"logprobs": None,
|
504
511
|
"finish_reason": "stop"
|
505
512
|
}
|
@@ -19,35 +19,104 @@ def unescape_html(input_string: str) -> str:
|
|
19
19
|
@register_tool()
|
20
20
|
def excute_command(command):
|
21
21
|
"""
|
22
|
-
执行命令并返回输出结果
|
22
|
+
执行命令并返回输出结果 (标准输出会实时打印到控制台)
|
23
23
|
禁止用于查看pdf,禁止使用 pdftotext 命令
|
24
24
|
|
25
25
|
参数:
|
26
26
|
command: 要执行的命令,可以克隆仓库,安装依赖,运行代码等
|
27
27
|
|
28
28
|
返回:
|
29
|
-
|
29
|
+
命令执行的最终状态和收集到的输出/错误信息
|
30
30
|
"""
|
31
31
|
try:
|
32
|
-
#
|
33
|
-
|
34
|
-
|
35
|
-
#
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
32
|
+
command = unescape_html(command) # 保留 HTML 解码
|
33
|
+
|
34
|
+
# 使用 Popen 以便实时处理输出
|
35
|
+
# bufsize=1 表示行缓冲, universal_newlines=True 与 text=True 效果类似,用于文本模式
|
36
|
+
process = subprocess.Popen(
|
37
|
+
command,
|
38
|
+
shell=True,
|
39
|
+
stdout=subprocess.PIPE,
|
40
|
+
stderr=subprocess.PIPE,
|
41
|
+
text=True,
|
42
|
+
bufsize=1,
|
43
|
+
universal_newlines=True
|
44
|
+
)
|
45
|
+
|
46
|
+
stdout_lines = []
|
47
|
+
|
48
|
+
# 实时打印 stdout
|
49
|
+
# print(f"--- 开始执行命令: {command} ---")
|
50
|
+
if process.stdout:
|
51
|
+
for line in iter(process.stdout.readline, ''):
|
52
|
+
# 对 pip install 命令的输出进行过滤,去除进度条相关的行
|
53
|
+
if "pip install" in command and '━━' in line:
|
54
|
+
continue
|
55
|
+
print(line, end='', flush=True) # 实时打印到控制台,并刷新缓冲区
|
56
|
+
stdout_lines.append(line) # 收集行以供后续返回
|
57
|
+
process.stdout.close()
|
58
|
+
# print(f"\n--- 命令实时输出结束 ---")
|
59
|
+
|
60
|
+
# 等待命令完成
|
61
|
+
process.wait()
|
62
|
+
|
63
|
+
# 获取 stderr (命令完成后一次性读取)
|
64
|
+
stderr_output = ""
|
65
|
+
if process.stderr:
|
66
|
+
stderr_output = process.stderr.read()
|
67
|
+
process.stderr.close()
|
68
|
+
|
69
|
+
# 组合最终的 stdout 日志 (已经过 pip install 过滤)
|
70
|
+
final_stdout_log = "".join(stdout_lines)
|
71
|
+
|
72
|
+
if process.returncode == 0:
|
73
|
+
return f"执行命令成功:\n{final_stdout_log}"
|
44
74
|
else:
|
45
|
-
|
46
|
-
|
47
|
-
|
75
|
+
return f"执行命令失败 (退出码 {process.returncode}):\n错误: {stderr_output}\n输出: {final_stdout_log}"
|
76
|
+
|
77
|
+
except FileNotFoundError:
|
78
|
+
# 当 shell=True 时,命令未找到通常由 shell 处理,并返回非零退出码。
|
79
|
+
# 此处捕获 FileNotFoundError 主要用于 Popen 自身无法启动命令的场景 (例如 shell 本身未找到)。
|
80
|
+
return f"执行命令失败: 命令或程序未找到 ({command})"
|
48
81
|
except Exception as e:
|
82
|
+
# 其他未知异常
|
49
83
|
return f"执行命令时发生异常: {e}"
|
50
84
|
|
51
85
|
if __name__ == "__main__":
|
52
|
-
print(excute_command("ls -l && echo 'Hello, World!'"))
|
53
|
-
print(excute_command("ls -l && echo 'Hello, World!'"))
|
86
|
+
# print(excute_command("ls -l && echo 'Hello, World!'"))
|
87
|
+
# print(excute_command("ls -l && echo 'Hello, World!'"))
|
88
|
+
|
89
|
+
# tqdm_script = """
|
90
|
+
# import time
|
91
|
+
# from tqdm import tqdm
|
92
|
+
|
93
|
+
# for i in range(10):
|
94
|
+
# print(f"TQDM 进度条测试: {i}")
|
95
|
+
# time.sleep(1)
|
96
|
+
# print('\\n-------TQDM 任务完成.')
|
97
|
+
# """
|
98
|
+
# processed_tqdm_script = tqdm_script.replace('"', '\\"')
|
99
|
+
# tqdm_command = f"python -u -u -c \"{processed_tqdm_script}\""
|
100
|
+
# # print(f"执行: {tqdm_command}")
|
101
|
+
# print(excute_command(tqdm_command))
|
102
|
+
|
103
|
+
|
104
|
+
# long_running_command_unix = "echo '开始长时间任务...' && for i in 1 2 3; do echo \"正在处理步骤 $i/3...\"; sleep 1; done && echo '长时间任务完成!'"
|
105
|
+
# print(f"执行: {long_running_command_unix}")
|
106
|
+
# print(excute_command(long_running_command_unix))
|
107
|
+
|
108
|
+
|
109
|
+
long_running_command_unix = "pip install torch"
|
110
|
+
print(f"执行: {long_running_command_unix}")
|
111
|
+
print(excute_command(long_running_command_unix))
|
112
|
+
|
113
|
+
|
114
|
+
# python_long_task_command = """
|
115
|
+
# python -c "import time; print('Python 长时间任务启动...'); [print(f'Python 任务进度: {i+1}/3', flush=True) or time.sleep(1) for i in range(3)]; print('Python 长时间任务完成.')"
|
116
|
+
# """
|
117
|
+
# python_long_task_command = python_long_task_command.strip() # 移除可能的前后空白
|
118
|
+
# print(f"执行: {python_long_task_command}")
|
119
|
+
# print(excute_command(python_long_task_command))
|
120
|
+
|
121
|
+
|
122
|
+
# python -m beswarm.aient.src.aient.plugins.excute_command
|