aient 1.0.91__py3-none-any.whl → 1.0.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/core/request.py CHANGED
@@ -31,7 +31,10 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
31
31
  model_dict = get_model_dict(provider)
32
32
  original_model = model_dict[request.model]
33
33
 
34
- gemini_stream = "streamGenerateContent"
34
+ if request.stream:
35
+ gemini_stream = "streamGenerateContent"
36
+ else:
37
+ gemini_stream = "generateContent"
35
38
  url = provider['base_url']
36
39
  parsed_url = urllib.parse.urlparse(url)
37
40
  if "/v1beta" in parsed_url.path:
@@ -295,7 +298,10 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
295
298
  if provider.get("project_id"):
296
299
  project_id = provider.get("project_id")
297
300
 
298
- gemini_stream = "streamGenerateContent"
301
+ if request.stream:
302
+ gemini_stream = "streamGenerateContent"
303
+ else:
304
+ gemini_stream = "generateContent"
299
305
  model_dict = get_model_dict(provider)
300
306
  original_model = model_dict[request.model]
301
307
  search_tool = None
@@ -941,10 +947,6 @@ async def get_gpt_payload(request, engine, provider, api_key=None):
941
947
  payload.pop("tools", None)
942
948
  payload.pop("tool_choice", None)
943
949
 
944
- # if "models.inference.ai.azure.com" in url:
945
- # payload["stream"] = False
946
- # payload.pop("stream_options", None)
947
-
948
950
  if "api.x.ai" in url:
949
951
  payload.pop("stream_options", None)
950
952
  payload.pop("presence_penalty", None)
aient/core/response.py CHANGED
@@ -35,6 +35,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
35
35
  promptTokenCount = 0
36
36
  candidatesTokenCount = 0
37
37
  totalTokenCount = 0
38
+ parts_json = ""
38
39
  # line_index = 0
39
40
  # last_text_line = 0
40
41
  # if "thinking" in model:
@@ -63,24 +64,25 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
63
64
  json_data = parse_json_safely( "{" + line + "}")
64
65
  totalTokenCount = json_data.get('totalTokenCount', 0)
65
66
 
66
- # print(line)
67
- if line and '\"text\": \"' in line and is_finish == False:
67
+ if (line and '"parts": [' in line or parts_json != "") and is_finish == False:
68
+ parts_json += line
69
+ if parts_json != "" and line and '],' == line.strip():
70
+ parts_json = "{" + parts_json.strip().rstrip(",} ]}") + "}]}"
68
71
  try:
69
- json_data = json.loads( "{" + line + "}")
70
- content = json_data.get('text', '')
71
- # content = content.replace("\n", "\n\n")
72
- # if last_text_line == 0 and is_thinking:
73
- # content = "> " + content.lstrip()
74
- # if is_thinking:
75
- # content = content.replace("\n", "\n> ")
76
- # if last_text_line == line_index - 3:
77
- # is_thinking = False
78
- # content = "\n\n\n" + content.lstrip()
79
- sse_string = await generate_sse_response(timestamp, model, content=content)
80
- yield sse_string
72
+ json_data = json.loads(parts_json)
73
+
74
+ content = safe_get(json_data, "parts", 0, "text", default="")
75
+
76
+ is_thinking = safe_get(json_data, "parts", 0, "thought", default=False)
77
+ if is_thinking:
78
+ sse_string = await generate_sse_response(timestamp, model, reasoning_content=content)
79
+ yield sse_string
80
+ else:
81
+ sse_string = await generate_sse_response(timestamp, model, content=content)
82
+ yield sse_string
81
83
  except json.JSONDecodeError:
82
- logger.error(f"无法解析JSON: {line}")
83
- # last_text_line = line_index
84
+ logger.error(f"无法解析JSON: {parts_json}")
85
+ parts_json = ""
84
86
 
85
87
  if line and ('\"functionCall\": {' in line or revicing_function_call):
86
88
  revicing_function_call = True
@@ -142,7 +144,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
142
144
 
143
145
  if line and '\"text\": \"' in line and is_finish == False:
144
146
  try:
145
- json_data = json.loads( "{" + line + "}")
147
+ json_data = json.loads( "{" + line.strip().rstrip(",") + "}")
146
148
  content = json_data.get('text', '')
147
149
  sse_string = await generate_sse_response(timestamp, model, content=content)
148
150
  yield sse_string
@@ -525,21 +527,28 @@ async def fetch_response(client, url, headers, payload, engine, model):
525
527
  parsed_data = ast.literal_eval(str(response_json))
526
528
  elif isinstance(response_json, list):
527
529
  parsed_data = response_json
530
+ elif isinstance(response_json, dict):
531
+ parsed_data = [response_json]
528
532
  else:
529
533
  logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
530
534
  parsed_data = response_json
531
535
  # print("parsed_data", json.dumps(parsed_data, indent=4, ensure_ascii=False))
532
536
  content = ""
537
+ reasoning_content = ""
533
538
  for item in parsed_data:
534
539
  chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
540
+ is_think = safe_get(item, "candidates", 0, "content", "parts", 0, "thought", default=False)
535
541
  # logger.info(f"chunk: {repr(chunk)}")
536
542
  if chunk:
537
- content += chunk
543
+ if is_think:
544
+ reasoning_content += chunk
545
+ else:
546
+ content += chunk
538
547
 
539
548
  usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
540
- prompt_tokens = usage_metadata.get("promptTokenCount", 0)
541
- candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
542
- total_tokens = usage_metadata.get("totalTokenCount", 0)
549
+ prompt_tokens = safe_get(usage_metadata, "promptTokenCount", default=0)
550
+ candidates_tokens = safe_get(usage_metadata, "candidatesTokenCount", default=0)
551
+ total_tokens = safe_get(usage_metadata, "totalTokenCount", default=0)
543
552
 
544
553
  role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
545
554
  if role == "model":
@@ -552,7 +561,7 @@ async def fetch_response(client, url, headers, payload, engine, model):
552
561
  function_call_content = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "args", default=None)
553
562
 
554
563
  timestamp = int(datetime.timestamp(datetime.now()))
555
- yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
564
+ yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens, reasoning_content=reasoning_content)
556
565
 
557
566
  elif engine == "claude":
558
567
  response_json = response.json()
@@ -0,0 +1,330 @@
1
+ import httpx
2
+ import base64
3
+ import os # 用于处理 API 密钥
4
+ import asyncio
5
+ import json
6
+ import re
7
+ # python -m core.test.test_geminimask
8
+ from ..utils import get_image_message, safe_get
9
+
10
+ # --- 请替换为您的实际值 ---
11
+ MODEL_NAME = "gemini-2.5-flash-preview-04-17"
12
+ API_KEY = os.environ.get("GEMINI_API_KEY") # 从环境变量读取密钥
13
+ IMAGE_PATH = os.environ.get("IMAGE_PATH")
14
+ # --------------------------
15
+
16
+ # 检查 API 密钥是否存在
17
+ if not API_KEY:
18
+ raise ValueError("请设置 GOOGLE_API_KEY 环境变量或在代码中提供 API 密钥。")
19
+
20
+ # 确定图片的 MIME 类型
21
+ # 您可以根据文件扩展名进行猜测,或者使用更可靠的库如 python-magic
22
+ if IMAGE_PATH.lower().endswith(".png"):
23
+ IMAGE_MIME_TYPE = "image/png"
24
+ elif IMAGE_PATH.lower().endswith(".jpg") or IMAGE_PATH.lower().endswith(".jpeg"):
25
+ IMAGE_MIME_TYPE = "image/jpeg"
26
+ # 添加其他您需要支持的图片类型
27
+ else:
28
+ raise ValueError(f"不支持的图片格式: {IMAGE_PATH}")
29
+
30
+ # 读取图片文件并进行 Base64 编码
31
+ try:
32
+ with open(IMAGE_PATH, "rb") as image_file:
33
+ image_data = image_file.read()
34
+ base64_encoded_image = base64.b64encode(image_data).decode("utf-8")
35
+ # print(base64_encoded_image)
36
+ except FileNotFoundError:
37
+ print(f"错误:找不到图片文件 '{IMAGE_PATH}'")
38
+ exit()
39
+
40
+ # image_message = get_image_message(base64_encoded_image, "gemini")
41
+ image_message = asyncio.run(get_image_message(f"data:{IMAGE_MIME_TYPE};base64," + base64_encoded_image, "gemini"))
42
+
43
+ # 构建请求 URL
44
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_NAME}:generateContent"
45
+
46
+ prompt = "Give the segmentation masks for the Search box. Output a JSON list of segmentation masks where each entry contains the 2D bounding box in \"box_2d\" and the mask in \"mask\"."
47
+ # 定义查询参数 (API Key)
48
+ params = {"key": API_KEY}
49
+
50
+ # 定义请求头
51
+ headers = {"Content-Type": "application/json"}
52
+
53
+ # 定义请求体 (JSON payload)
54
+ payload = {
55
+ "contents": [
56
+ {
57
+ "parts": [
58
+ {
59
+ "text": prompt
60
+ },
61
+ image_message
62
+ ]
63
+ }
64
+ ],
65
+ "generationConfig": {"thinkingConfig": {"thinkingBudget": 0}},
66
+ }
67
+
68
+ # 发送 POST 请求
69
+ try:
70
+ with httpx.Client() as client:
71
+ response = client.post(url, params=params, headers=headers, json=payload, timeout=60.0) # 增加超时时间
72
+ response.raise_for_status() # 如果状态码不是 2xx,则抛出异常
73
+
74
+ # 您可以在这里添加代码来解析 response.json() 并提取分割掩码
75
+ text = safe_get(response.json(), "candidates", 0, "content", "parts", 0, "text")
76
+ # print(text)
77
+ # 例如: segmentation_masks = response.json().get('candidates', [{}])[0].get('content', {}).get('parts', [{}])[0].get('text')
78
+
79
+ except httpx.HTTPStatusError as exc:
80
+ print(f"HTTP 错误发生: {exc.response.status_code} - {exc.response.text}")
81
+ except httpx.RequestError as exc:
82
+ print(f"发送请求时出错: {exc}")
83
+ except Exception as e:
84
+ print(f"发生意外错误: {e}")
85
+
86
+
87
+ regex_pattern = r'(\[\s*\{.*?\}\s*\])' # 匹配包含至少一个对象的数组
88
+
89
+ # 使用 re.search 查找第一个匹配项,re.DOTALL 使点号能匹配换行符
90
+ match = re.search(regex_pattern, text, re.DOTALL)
91
+
92
+ if match:
93
+ # 提取匹配到的整个 JSON 数组字符串 (group 1 因为模式中有括号)
94
+ json_string = match.group(1)
95
+
96
+ try:
97
+ # 使用 json.loads() 解析字符串
98
+ parsed_data = json.loads(json_string)
99
+ # 使用 json.dumps 美化打印输出
100
+ print(json.dumps(parsed_data, indent=2, ensure_ascii=False))
101
+
102
+ # 例如,获取第一个元素的 label
103
+ if isinstance(parsed_data, list) and len(parsed_data) > 0:
104
+ first_item = parsed_data[0]
105
+ if isinstance(first_item, dict):
106
+ label = first_item.get('label')
107
+ print(f"\n第一个元素的 label 是: {label}")
108
+
109
+ except json.JSONDecodeError as e:
110
+ print(f"JSON 解析错误: {e}")
111
+ print(f"出错的字符串是: {json_string}")
112
+ else:
113
+ print("在文本中未找到匹配的 JSON 数组。")
114
+
115
+
116
+ import io
117
+ from PIL import Image, ImageDraw, ImageFont # pip install Pillow
118
+
119
+ def extract_box_and_mask_py(parsed_data):
120
+ """
121
+ 从已解析的 JSON 数据中提取边界框 (box_2d) 和 Base64 编码的掩码 (mask) 数据。
122
+
123
+ Args:
124
+ parsed_data (list): 一个包含字典的列表,每个字典至少包含 'box_2d' 和 'mask' 键。
125
+ 例如: [{'box_2d': [y1, x1, y2, x2], 'mask': 'data:image/png;base64,...'}, ...]
126
+
127
+ Returns:
128
+ list: 一个包含字典的列表,每个字典包含 'box' (坐标列表)
129
+ 和 'mask_base64' (Base64 字符串)。
130
+ 例如: [{'box': [y1, x1, y2, x2], 'mask_base64': '...'}, ...]
131
+ 坐标系假定为 0-1000 范围。
132
+ """
133
+ # 不再需要正则表达式
134
+ results = []
135
+ # 检查 parsed_data 是否为列表
136
+ if not isinstance(parsed_data, list):
137
+ print(f"Error: Input data is not a list. Received type: {type(parsed_data)}")
138
+ return results
139
+
140
+ for item in parsed_data:
141
+ if not isinstance(item, dict):
142
+ print(f"Skipping non-dictionary item in list: {item}")
143
+ continue
144
+
145
+ try:
146
+ box = item.get('box_2d')
147
+ mask_data_uri = item.get('mask')
148
+
149
+ # 检查 'box_2d' 和 'mask' 是否存在且不为 None
150
+ if box is None or mask_data_uri is None:
151
+ print(f"Skipping item due to missing 'box_2d' or 'mask': {item}")
152
+ continue
153
+
154
+ # 从 mask 数据 URI 中提取 Base64 部分
155
+ # 格式: "data:image/[^;]+;base64,..."
156
+ if isinstance(mask_data_uri, str) and mask_data_uri.startswith('data:image/') and ';base64,' in mask_data_uri:
157
+ mask_b64 = mask_data_uri.split(';base64,', 1)[1]
158
+ else:
159
+ print(f"Skipping item due to invalid mask format: {mask_data_uri}")
160
+ continue
161
+
162
+ # 验证 box 数据
163
+ if isinstance(box, list) and len(box) == 4 and all(isinstance(n, int) for n in box):
164
+ results.append({"box": box, "mask_base64": mask_b64})
165
+ else:
166
+ print(f"Skipping invalid box format: {box}")
167
+
168
+ # 捕捉可能的 KeyError 或其他在字典访问/处理中发生的错误
169
+ except Exception as e:
170
+ print(f"Error processing item: {item}. Error: {e}")
171
+
172
+ return results
173
+
174
+ def display_image_with_bounding_boxes_and_masks_py(
175
+ original_image_path,
176
+ box_and_mask_data,
177
+ output_overlay_path="overlay_image.png",
178
+ output_compare_dir="comparison_outputs"
179
+ ):
180
+ """
181
+ 在原始图像上绘制边界框和掩码,并生成裁剪区域与掩码的对比图。
182
+
183
+ Args:
184
+ original_image_path (str): 原始图像的文件路径。
185
+ box_and_mask_data (list): extract_box_and_mask_py 的输出列表。
186
+ output_overlay_path (str): 保存带有叠加效果的图像的路径。
187
+ output_compare_dir (str): 保存对比图像的目录路径。
188
+ """
189
+ try:
190
+ img_original = Image.open(original_image_path).convert("RGBA")
191
+ img_width, img_height = img_original.size
192
+ except FileNotFoundError:
193
+ print(f"Error: Original image not found at {original_image_path}")
194
+ return
195
+ except Exception as e:
196
+ print(f"Error opening original image: {e}")
197
+ return
198
+
199
+ # 创建一个副本用于绘制叠加效果
200
+ img_overlay = img_original.copy()
201
+ draw = ImageDraw.Draw(img_overlay, "RGBA") # 使用 RGBA 模式以支持透明度
202
+
203
+ # 定义颜色列表
204
+ colors_hex = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF']
205
+ # 将十六进制颜色转换为 RGBA 元组 (用于绘制)
206
+ colors_rgba = []
207
+ for hex_color in colors_hex:
208
+ h = hex_color.lstrip('#')
209
+ rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
210
+ colors_rgba.append(rgb + (255,)) # (R, G, B, Alpha) - 边框完全不透明
211
+
212
+ # 创建输出目录(如果不存在)
213
+ import os
214
+ os.makedirs(output_compare_dir, exist_ok=True)
215
+
216
+ print(f"Found {len(box_and_mask_data)} box/mask pairs to process.")
217
+
218
+ for i, data in enumerate(box_and_mask_data):
219
+ box_0_1000 = data['box'] # [ymin, xmin, ymax, xmax] in 0-1000 range
220
+ mask_b64 = data['mask_base64']
221
+ color_index = i % len(colors_rgba)
222
+ outline_color = colors_rgba[color_index]
223
+ # 叠加掩码时使用半透明颜色
224
+ mask_fill_color = outline_color[:3] + (int(255 * 0.7),) # 70% Alpha
225
+
226
+ # --- 1. 坐标转换与验证 ---
227
+ # 将 0-1000 坐标转换为图像像素坐标 (left, top, right, bottom)
228
+ # 假设 box 是 [ymin, xmin, ymax, xmax]
229
+ try:
230
+ ymin_norm, xmin_norm, ymax_norm, xmax_norm = [c / 1000.0 for c in box_0_1000]
231
+
232
+ left = int(xmin_norm * img_width)
233
+ top = int(ymin_norm * img_height)
234
+ right = int(xmax_norm * img_width)
235
+ bottom = int(ymax_norm * img_height)
236
+
237
+ # 确保坐标在图像范围内且有效
238
+ left = max(0, left)
239
+ top = max(0, top)
240
+ right = min(img_width, right)
241
+ bottom = min(img_height, bottom)
242
+
243
+ box_width_px = right - left
244
+ box_height_px = bottom - top
245
+
246
+ if box_width_px <= 0 or box_height_px <= 0:
247
+ print(f"Skipping box {i+1} due to zero or negative dimensions after conversion.")
248
+ continue
249
+
250
+ except Exception as e:
251
+ print(f"Error processing coordinates for box {i+1}: {box_0_1000}. Error: {e}")
252
+ continue
253
+
254
+ print(f"Processing Box {i+1}: Pixels(L,T,R,B)=({left},{top},{right},{bottom}) Color={colors_hex[color_index]}")
255
+
256
+ # --- 2. 在叠加图像上绘制边界框 ---
257
+ try:
258
+ draw.rectangle([left, top, right, bottom], outline=outline_color, width=5)
259
+ except Exception as e:
260
+ print(f"Error drawing rectangle for box {i+1}: {e}")
261
+ continue
262
+
263
+ # --- 3. 处理并绘制掩码 ---
264
+ try:
265
+ # 解码 Base64 掩码数据
266
+ mask_bytes = base64.b64decode(mask_b64)
267
+ mask_img_raw = Image.open(io.BytesIO(mask_bytes)).convert("RGBA")
268
+
269
+ # 将掩码图像缩放到边界框的像素尺寸
270
+ mask_img_resized = mask_img_raw.resize((box_width_px, box_height_px), Image.Resampling.NEAREST)
271
+
272
+ # 创建一个纯色块,应用掩码的 alpha 通道
273
+ color_block = Image.new('RGBA', mask_img_resized.size, mask_fill_color)
274
+
275
+ # 将带有透明度的颜色块粘贴到叠加图像上,使用掩码的 alpha 通道作为粘贴蒙版
276
+ # mask_img_resized.split()[-1] 提取 alpha 通道
277
+ img_overlay.paste(color_block, (left, top), mask=mask_img_resized.split()[-1])
278
+
279
+ except base64.binascii.Error:
280
+ print(f"Error: Invalid Base64 data for mask {i+1}.")
281
+ continue
282
+ except Exception as e:
283
+ print(f"Error processing or drawing mask for box {i+1}: {e}")
284
+ continue
285
+
286
+ # --- 4. 生成对比图 ---
287
+ try:
288
+ # 从原始图像中裁剪出边界框区域
289
+ img_crop = img_original.crop((left, top, right, bottom))
290
+
291
+ # 准备掩码预览图(使用原始解码后的掩码,调整大小以匹配裁剪区域)
292
+ # 这里直接使用缩放后的 mask_img_resized 的 RGB 部分可能更直观
293
+ mask_preview = mask_img_resized.convert("RGB") # 转换为 RGB 以便保存为常见格式
294
+
295
+ # 保存裁剪图和掩码预览图
296
+ crop_filename = os.path.join(output_compare_dir, f"compare_{i+1}_crop.png")
297
+ mask_filename = os.path.join(output_compare_dir, f"compare_{i+1}_mask.png")
298
+ img_crop.save(crop_filename)
299
+ mask_preview.save(mask_filename)
300
+ print(f" - Saved comparison: {crop_filename}, {mask_filename}")
301
+
302
+ except Exception as e:
303
+ print(f"Error creating or saving comparison images for box {i+1}: {e}")
304
+
305
+ # --- 5. 保存最终的叠加图像 ---
306
+ try:
307
+ img_overlay.save(output_overlay_path)
308
+ print(f"\nOverlay image saved to: {output_overlay_path}")
309
+ print(f"Comparison images saved in: {output_compare_dir}")
310
+ except Exception as e:
311
+ print(f"Error saving the final overlay image: {e}")
312
+
313
+
314
+ extracted_data = extract_box_and_mask_py(parsed_data)
315
+
316
+ if extracted_data:
317
+ # 确保原始图像存在
318
+ import os
319
+ if os.path.exists(IMAGE_PATH):
320
+ display_image_with_bounding_boxes_and_masks_py(
321
+ IMAGE_PATH,
322
+ extracted_data,
323
+ output_overlay_path="python_overlay_output.png", # 输出带叠加效果的图片名
324
+ output_compare_dir="python_comparison_outputs" # 输出对比图的文件夹名
325
+ )
326
+ else:
327
+ print(f"Error: Cannot proceed with visualization, image file not found: {IMAGE_PATH}")
328
+ print("Please update the 'IMAGE_PATH' variable in the script.")
329
+ else:
330
+ print("No valid box and mask data found in the response text.")
aient/core/utils.py CHANGED
@@ -70,7 +70,7 @@ def get_engine(provider, endpoint=None, original_model=""):
70
70
  engine = "gemini"
71
71
  elif parsed_url.netloc.rstrip('/').endswith('aiplatform.googleapis.com') or (parsed_url.netloc.rstrip('/').endswith('gateway.ai.cloudflare.com') and "google-vertex-ai" in parsed_url.path):
72
72
  engine = "vertex"
73
- elif parsed_url.netloc.rstrip('/').endswith('openai.azure.com') or parsed_url.netloc.rstrip('/').endswith('services.ai.azure.com'):
73
+ elif parsed_url.netloc.rstrip('/').endswith('azure.com'):
74
74
  engine = "azure"
75
75
  elif parsed_url.netloc == 'api.cloudflare.com':
76
76
  engine = "cloudflare"
@@ -127,6 +127,9 @@ def get_engine(provider, endpoint=None, original_model=""):
127
127
  engine = "tts"
128
128
  stream = False
129
129
 
130
+ if "stream" in safe_get(provider, "preferences", "post_body_parameter_overrides", default={}):
131
+ stream = safe_get(provider, "preferences", "post_body_parameter_overrides", "stream")
132
+
130
133
  return engine, stream
131
134
 
132
135
  from httpx_socks import AsyncProxyTransport
@@ -484,9 +487,17 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
484
487
 
485
488
  return sse_response
486
489
 
487
- async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
490
+ async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0, reasoning_content=None):
488
491
  random.seed(timestamp)
489
492
  random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
493
+ message = {
494
+ "role": role,
495
+ "content": content,
496
+ "refusal": None
497
+ }
498
+ if reasoning_content:
499
+ message["reasoning_content"] = reasoning_content
500
+
490
501
  sample_data = {
491
502
  "id": f"chatcmpl-{random_str}",
492
503
  "object": "chat.completion",
@@ -495,11 +506,7 @@ async def generate_no_stream_response(timestamp, model, content=None, tools_id=N
495
506
  "choices": [
496
507
  {
497
508
  "index": 0,
498
- "message": {
499
- "role": role,
500
- "content": content,
501
- "refusal": None
502
- },
509
+ "message": message,
503
510
  "logprobs": None,
504
511
  "finish_reason": "stop"
505
512
  }
@@ -1,6 +1,7 @@
1
1
  import subprocess
2
2
  from .registry import register_tool
3
3
 
4
+ import re
4
5
  import html
5
6
 
6
7
  def unescape_html(input_string: str) -> str:
@@ -15,39 +16,137 @@ def unescape_html(input_string: str) -> str:
15
16
  """
16
17
  return html.unescape(input_string)
17
18
 
19
+ def get_python_executable(command: str) -> str:
20
+ """
21
+ 获取 Python 可执行文件的路径。
22
+
23
+ Returns:
24
+ str: Python 可执行文件的路径。
25
+ """
26
+ cmd_parts = command.split(None, 1)
27
+ if cmd_parts:
28
+ executable = cmd_parts[0]
29
+ args_str = cmd_parts[1] if len(cmd_parts) > 1 else ""
30
+
31
+ # 检查是否是 python 可执行文件 (如 python, python3, pythonX.Y)
32
+ is_python_exe = False
33
+ if executable == "python" or re.match(r"^python[23]?(\.\d+)?$", executable):
34
+ is_python_exe = True
35
+
36
+ if is_python_exe:
37
+ # 检查参数中是否已经有 -u 选项
38
+ args_list = args_str.split()
39
+ has_u_option = "-u" in args_list
40
+ if not has_u_option:
41
+ if args_str:
42
+ command = f"{executable} -u {args_str}"
43
+ return command
44
+
18
45
  # 执行命令
19
46
  @register_tool()
20
47
  def excute_command(command):
21
48
  """
22
- 执行命令并返回输出结果
49
+ 执行命令并返回输出结果 (标准输出会实时打印到控制台)
23
50
  禁止用于查看pdf,禁止使用 pdftotext 命令
24
51
 
25
52
  参数:
26
53
  command: 要执行的命令,可以克隆仓库,安装依赖,运行代码等
27
54
 
28
55
  返回:
29
- 命令执行的输出结果或错误信息
56
+ 命令执行的最终状态和收集到的输出/错误信息
30
57
  """
31
58
  try:
32
- # 使用subprocess.run捕获命令输出
33
- command = unescape_html(command)
34
- result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
35
- # 返回命令的标准输出
36
- if "pip install" in command:
37
- stdout_log = "\n".join([x for x in result.stdout.split('\n') if '━━' not in x])
38
- else:
39
- stdout_log = result.stdout
40
- return f"执行命令成功:\n{stdout_log}"
41
- except subprocess.CalledProcessError as e:
42
- if "pip install" in command:
43
- stdout_log = "\n".join([x for x in e.stdout.split('\n') if '━━' not in x])
59
+ command = unescape_html(command) # 保留 HTML 解码
60
+
61
+ command = get_python_executable(command)
62
+
63
+
64
+ # 使用 Popen 以便实时处理输出
65
+ # bufsize=1 表示行缓冲, universal_newlines=True 与 text=True 效果类似,用于文本模式
66
+ process = subprocess.Popen(
67
+ command,
68
+ shell=True,
69
+ stdout=subprocess.PIPE,
70
+ stderr=subprocess.PIPE,
71
+ text=True,
72
+ bufsize=1,
73
+ universal_newlines=True
74
+ )
75
+
76
+ stdout_lines = []
77
+
78
+ # 实时打印 stdout
79
+ # print(f"--- 开始执行命令: {command} ---")
80
+ if process.stdout:
81
+ for line in iter(process.stdout.readline, ''):
82
+ # 对 pip install 命令的输出进行过滤,去除进度条相关的行
83
+ if "pip install" in command and '━━' in line:
84
+ continue
85
+ print(line, end='', flush=True) # 实时打印到控制台,并刷新缓冲区
86
+ stdout_lines.append(line) # 收集行以供后续返回
87
+ process.stdout.close()
88
+ # print(f"\n--- 命令实时输出结束 ---")
89
+
90
+ # 等待命令完成
91
+ process.wait()
92
+
93
+ # 获取 stderr (命令完成后一次性读取)
94
+ stderr_output = ""
95
+ if process.stderr:
96
+ stderr_output = process.stderr.read()
97
+ process.stderr.close()
98
+
99
+ # 组合最终的 stdout 日志 (已经过 pip install 过滤)
100
+ final_stdout_log = "".join(stdout_lines)
101
+
102
+ if process.returncode == 0:
103
+ return f"执行命令成功:\n{final_stdout_log}"
44
104
  else:
45
- stdout_log = e.stdout
46
- # 如果命令执行失败,返回错误信息和错误输出
47
- return f"执行命令失败 (退出码 {e.returncode}):\n错误: {e.stderr}\n输出: {stdout_log}"
105
+ return f"执行命令失败 (退出码 {process.returncode}):\n错误: {stderr_output}\n输出: {final_stdout_log}"
106
+
107
+ except FileNotFoundError:
108
+ # 当 shell=True 时,命令未找到通常由 shell 处理,并返回非零退出码。
109
+ # 此处捕获 FileNotFoundError 主要用于 Popen 自身无法启动命令的场景 (例如 shell 本身未找到)。
110
+ return f"执行命令失败: 命令或程序未找到 ({command})"
48
111
  except Exception as e:
112
+ # 其他未知异常
49
113
  return f"执行命令时发生异常: {e}"
50
114
 
51
115
  if __name__ == "__main__":
52
- print(excute_command("ls -l && echo 'Hello, World!'"))
53
- print(excute_command("ls -l &amp;&amp; echo 'Hello, World!'"))
116
+ # print(excute_command("ls -l && echo 'Hello, World!'"))
117
+ # print(excute_command("ls -l &amp;&amp; echo 'Hello, World!'"))
118
+
119
+ # tqdm_script = """
120
+ # import time
121
+ # from tqdm import tqdm
122
+
123
+ # for i in range(10):
124
+ # print(f"TQDM 进度条测试: {i}")
125
+ # time.sleep(1)
126
+ # print('\\n-------TQDM 任务完成.')
127
+ # """
128
+ # processed_tqdm_script = tqdm_script.replace('"', '\\"')
129
+ # tqdm_command = f"python -u -u -c \"{processed_tqdm_script}\""
130
+ # # print(f"执行: {tqdm_command}")
131
+ # print(excute_command(tqdm_command))
132
+
133
+
134
+ # long_running_command_unix = "echo '开始长时间任务...' && for i in 1 2 3; do echo \"正在处理步骤 $i/3...\"; sleep 1; done && echo '长时间任务完成!'"
135
+ # print(f"执行: {long_running_command_unix}")
136
+ # print(excute_command(long_running_command_unix))
137
+
138
+
139
+ # long_running_command_unix = "pip install torch"
140
+ # print(f"执行: {long_running_command_unix}")
141
+ # print(excute_command(long_running_command_unix))
142
+
143
+
144
+ # python_long_task_command = """
145
+ # python -c "import time; print('Python 长时间任务启动...'); [print(f'Python 任务进度: {i+1}/3', flush=True) or time.sleep(1) for i in range(3)]; print('Python 长时间任务完成.')"
146
+ # """
147
+ # python_long_task_command = python_long_task_command.strip() # 移除可能的前后空白
148
+ # print(f"执行: {python_long_task_command}")
149
+ # print(excute_command(python_long_task_command))
150
+
151
+ print(get_python_executable("python -c 'print(123)'"))
152
+ # python -m beswarm.aient.src.aient.plugins.excute_command
aient/prompt/agent.py CHANGED
@@ -131,6 +131,8 @@ instruction_system_prompt = """
131
131
  你的工作目录为:{workspace_path},请在指令中使用绝对路径。所有操作必须基于工作目录。
132
132
  禁止在工作目录之外进行任何操作。你当前运行目录不一定就是工作目录。禁止默认你当前就在工作目录。
133
133
 
134
+ 当前时间:{current_time}
135
+
134
136
  你的输出必须符合以下步骤:
135
137
 
136
138
  1. 首先分析当前对话历史。其中user就是你发送给工作智能体的指令。assistant就是工作智能体的回复。
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aient
3
- Version: 1.0.91
3
+ Version: 1.0.93
4
4
  Summary: Aient: The Awakening of Agent.
5
5
  Description-Content-Type: text/markdown
6
6
  License-File: LICENSE
@@ -4,10 +4,11 @@ aient/core/.gitignore,sha256=5JRRlYYsqt_yt6iFvvzhbqh2FTUQMqwo6WwIuFzlGR8,13
4
4
  aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
5
5
  aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
6
6
  aient/core/models.py,sha256=_1wYZg_n9kb2A3C8xCboyqleH2iHc9scwOvtx9DPeok,7582
7
- aient/core/request.py,sha256=U0SDf_dOE5EEhzyJA14XMQCiCFUtcivYcISveLBsK64,61405
8
- aient/core/response.py,sha256=6fo3GKvTKio8nf4cZdizDIYYq7SnBnFeQ6ROvdAIW9k,30959
9
- aient/core/utils.py,sha256=W-PDhwoJIbmlt4xfyrV9GXHu9TwRk4pivBOcDVXjgsc,26163
7
+ aient/core/request.py,sha256=RChzDuH49gaJE-o5g65h3nCh-OsuHPwLkq8yuyYEcbo,61431
8
+ aient/core/response.py,sha256=8bS1nAoP6QOMDeDvJvZDVAt34kZ1DpWBI3PUGyza0ZU,31447
9
+ aient/core/utils.py,sha256=CAFqWzICaKVysH9GLHBcp-VeOShisLjWGhEsh6-beWo,26365
10
10
  aient/core/test/test_base_api.py,sha256=pWnycRJbuPSXKKU9AQjWrMAX1wiLC_014Qc9hh5C2Pw,524
11
+ aient/core/test/test_geminimask.py,sha256=HFX8jDbNg_FjjgPNxfYaR-0-roUrOO-ND-FVsuxSoiw,13254
11
12
  aient/core/test/test_image.py,sha256=_T4peNGdXKBHHxyQNx12u-NTyFE8TlYI6NvvagsG2LE,319
12
13
  aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhFkw,2755
13
14
  aient/models/__init__.py,sha256=ouNDNvoBBpIFrLsk09Q_sq23HR0GbLAKfGLIFmfEuXE,219
@@ -22,7 +23,7 @@ aient/models/vertex.py,sha256=qVD5l1Q538xXUPulxG4nmDjXE1VoV4yuAkTCpIeJVw0,16795
22
23
  aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
23
24
  aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
24
25
  aient/plugins/config.py,sha256=KnZ5xtb5o41FI2_qvxTEQhssdd3WJc7lIAFNR85INQw,7817
25
- aient/plugins/excute_command.py,sha256=9bCKFSQCRO2OYEYT-C-Vqmme0TRO7tc7bFjOxlT03Fk,1784
26
+ aient/plugins/excute_command.py,sha256=u-JOZ21dDcDx1j3O0KVIHAsa6MNuOxHFBdV3iCnTih0,5413
26
27
  aient/plugins/get_time.py,sha256=Ih5XIW5SDAIhrZ9W4Qe5Hs1k4ieKPUc_LAd6ySNyqZk,654
27
28
  aient/plugins/image.py,sha256=ZElCIaZznE06TN9xW3DrSukS7U3A5_cjk1Jge4NzPxw,2072
28
29
  aient/plugins/list_directory.py,sha256=5ubm-mfrj-tanGSDp4M_Tmb6vQb3dx2-XVfQ2yL2G8A,1394
@@ -32,12 +33,12 @@ aient/plugins/run_python.py,sha256=dgcUwBunMuDkaSKR5bToudVzSdrXVewktDDFUz_iIOQ,4
32
33
  aient/plugins/websearch.py,sha256=yiBzqXK5X220ibR-zko3VDsn4QOnLu1k6E2YOygCeTQ,15185
33
34
  aient/plugins/write_file.py,sha256=qmT6iQ3mDyVAa9Sld1jfJq0KPZj0w2kRIHq0JyjpGeA,1853
34
35
  aient/prompt/__init__.py,sha256=GBtn6-JDT8KHFCcuPpfSNE_aGddg5p4FEyMCy4BfwGs,20
35
- aient/prompt/agent.py,sha256=WLs0KiNZs29FJ5w7N3kQZYLVEeS7K8vxIOUw06LcXVE,23825
36
+ aient/prompt/agent.py,sha256=3VycHGnUq9OdR5pd_RM0AeLESlpAgBcmzrsesfq82X0,23856
36
37
  aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
38
  aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
38
39
  aient/utils/scripts.py,sha256=PPwaJEigPkpciJHUXOag483iq1GjvaLReHDHjkinv6c,26780
39
- aient-1.0.91.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
40
- aient-1.0.91.dist-info/METADATA,sha256=6QuaVz9_62x2b2ynwHT1dQULl-FJCTBMdBYOOh-3MZw,5000
41
- aient-1.0.91.dist-info/WHEEL,sha256=ooBFpIzZCPdw3uqIQsOo4qqbA4ZRPxHnOH7peeONza0,91
42
- aient-1.0.91.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
43
- aient-1.0.91.dist-info/RECORD,,
40
+ aient-1.0.93.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
41
+ aient-1.0.93.dist-info/METADATA,sha256=gY5o1t1r59AE53NnRXuDVGE7hNlSgi9vlWIMfgJceV0,5000
42
+ aient-1.0.93.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
43
+ aient-1.0.93.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
44
+ aient-1.0.93.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.1)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5