aient 1.0.45__py3-none-any.whl → 1.0.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/core/request.py CHANGED
@@ -96,7 +96,8 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
96
96
  content[0]["text"] = re.sub(r"_+", "_", content[0]["text"])
97
97
  systemInstruction = {"parts": content}
98
98
 
99
- if "gemini-2.0-flash-exp" in original_model or "gemini-1.5" in original_model:
99
+ off_models = ["gemini-2.0-flash-exp", "gemini-1.5", "gemini-2.5-pro"]
100
+ if any(off_model in original_model for off_model in off_models):
100
101
  safety_settings = "OFF"
101
102
  else:
102
103
  safety_settings = "BLOCK_NONE"
@@ -119,7 +120,11 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
119
120
  {
120
121
  "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
121
122
  "threshold": safety_settings
122
- }
123
+ },
124
+ {
125
+ "category": "HARM_CATEGORY_CIVIC_INTEGRITY",
126
+ "threshold": "BLOCK_NONE"
127
+ },
123
128
  ]
124
129
  }
125
130
 
aient/core/utils.py CHANGED
@@ -155,7 +155,8 @@ def update_initial_model(provider):
155
155
  proxy = safe_get(provider, "preferences", "proxy", default=None)
156
156
  client_config = get_proxy(proxy)
157
157
  if engine == "gemini":
158
- url = "https://generativelanguage.googleapis.com/v1beta/models"
158
+ before_v1 = api_url.split("/v1beta")[0]
159
+ url = before_v1 + "/v1beta/models"
159
160
  params = {"key": api}
160
161
  with httpx.Client(**client_config) as client:
161
162
  response = client.get(url, params=params)
@@ -288,7 +289,7 @@ class ThreadSafeCircularList:
288
289
  # self.requests[item] = []
289
290
  logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
290
291
 
291
- async def is_rate_limited(self, item, model: str = None) -> bool:
292
+ async def is_rate_limited(self, item, model: str = None, is_check: bool = False) -> bool:
292
293
  now = time()
293
294
  # 检查是否在冷却中
294
295
  if now < self.cooling_until[item]:
@@ -321,7 +322,8 @@ class ThreadSafeCircularList:
321
322
  # 使用特定模型的请求记录进行计算
322
323
  recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period)
323
324
  if recent_requests >= limit_count:
324
- logger.warning(f"API key {item} 对模型 {model_key} 已达到速率限制 ({limit_count}/{limit_period}秒)")
325
+ if not is_check:
326
+ logger.warning(f"API key {item}: model: {model_key} has been rate limited ({limit_count}/{limit_period} seconds)")
325
327
  return True
326
328
 
327
329
  # 清理太旧的请求记录
@@ -329,7 +331,9 @@ class ThreadSafeCircularList:
329
331
  self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period]
330
332
 
331
333
  # 记录新的请求
332
- self.requests[item][model_key].append(now)
334
+ if not is_check:
335
+ self.requests[item][model_key].append(now)
336
+
333
337
  return False
334
338
 
335
339
  async def next(self, model: str = None):
@@ -349,6 +353,30 @@ class ThreadSafeCircularList:
349
353
  logger.warning(f"All API keys are rate limited!")
350
354
  raise HTTPException(status_code=429, detail="Too many requests")
351
355
 
356
+ async def is_all_rate_limited(self, model: str = None) -> bool:
357
+ """检查是否所有的items都被速率限制
358
+
359
+ 与next方法不同,此方法不会改变任何内部状态(如self.index),
360
+ 仅返回一个布尔值表示是否所有的key都被限制。
361
+
362
+ Args:
363
+ model: 要检查的模型名称,默认为None
364
+
365
+ Returns:
366
+ bool: 如果所有items都被速率限制返回True,否则返回False
367
+ """
368
+ if len(self.items) == 0:
369
+ return False
370
+
371
+ async with self.lock:
372
+ for item in self.items:
373
+ if not await self.is_rate_limited(item, model, is_check=True):
374
+ return False
375
+
376
+ # 如果遍历完所有items都被限制,返回True
377
+ # logger.debug(f"Check result: all items are rate limited!")
378
+ return True
379
+
352
380
  async def after_next_current(self):
353
381
  # 返回当前取出的 API,因为已经调用了 next,所以当前API应该是上一个
354
382
  if len(self.items) == 0:
aient/models/chatgpt.py CHANGED
@@ -12,7 +12,7 @@ from pathlib import Path
12
12
 
13
13
  from .base import BaseLLM
14
14
  from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
15
- from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json
15
+ from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json, convert_functions_to_xml
16
16
  from ..core.request import prepare_request_payload
17
17
  from ..core.response import fetch_response_stream
18
18
 
@@ -148,7 +148,7 @@ class chatgpt(BaseLLM):
148
148
  })
149
149
  self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
150
150
  else:
151
- self.conversation[convo_id].append({"role": "assistant", "content": "I will use tool: " + function_arguments + ". I will get the tool call result in the next user response."})
151
+ self.conversation[convo_id].append({"role": "assistant", "content": convert_functions_to_xml(function_arguments)})
152
152
  self.conversation[convo_id].append({"role": "user", "content": message})
153
153
 
154
154
  else:
@@ -159,7 +159,9 @@ class chatgpt(BaseLLM):
159
159
 
160
160
  conversation_len = len(self.conversation[convo_id]) - 1
161
161
  message_index = 0
162
- # print(json.dumps(self.conversation[convo_id], indent=4, ensure_ascii=False))
162
+ # if self.print_log:
163
+ # replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(self.conversation[convo_id])))
164
+ # print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
163
165
  while message_index < conversation_len:
164
166
  if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
165
167
  if self.conversation[convo_id][message_index].get("content") and self.conversation[convo_id][message_index + 1].get("content"):
@@ -180,6 +182,9 @@ class chatgpt(BaseLLM):
180
182
  and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
181
183
  self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
182
184
  self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
185
+ if type(self.conversation[convo_id][message_index]["content"]) == list \
186
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
187
+ self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
183
188
  self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
184
189
  self.conversation[convo_id].pop(message_index + 1)
185
190
  conversation_len = conversation_len - 1
@@ -413,6 +418,7 @@ class chatgpt(BaseLLM):
413
418
  # 处理函数调用
414
419
  if need_function_call:
415
420
  if self.print_log:
421
+ print("function_parameter", function_parameter)
416
422
  print("function_full_response", function_full_response)
417
423
 
418
424
  function_response = ""
aient/utils/scripts.py CHANGED
@@ -460,6 +460,7 @@ class XmlMatcher(Generic[R]):
460
460
  def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
461
461
  """
462
462
  解析XML格式的函数调用信息,转换为字典数组格式
463
+ 只解析倒数两层XML标签,忽略更高层级的XML标签
463
464
 
464
465
  参数:
465
466
  xml_content: 包含一个或多个函数调用的XML字符串
@@ -469,6 +470,7 @@ def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
469
470
  """
470
471
  result_functions = []
471
472
 
473
+ # 第一步:识别XML中的顶层标签(可能是函数调用)
472
474
  position = 0
473
475
  while position < len(xml_content):
474
476
  # 寻找下一个开始标签
@@ -482,22 +484,23 @@ def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
482
484
  position = tag_start + 1
483
485
  continue
484
486
 
487
+ # 找到标签的结束位置
485
488
  tag_end = xml_content.find(">", tag_start)
486
489
  if tag_end == -1:
487
490
  break # 标签未正确关闭
488
491
 
489
- # 提取标签名(函数名)
492
+ # 提取标签名
490
493
  tag_content = xml_content[tag_start+1:tag_end].strip()
491
494
  # 处理可能有属性的情况
492
- function_name = tag_content.split()[0] if " " in tag_content else tag_content
495
+ tag_name = tag_content.split()[0] if " " in tag_content else tag_content
493
496
 
494
- if not function_name:
497
+ if not tag_name:
495
498
  position = tag_end + 1
496
499
  continue # 空标签名,跳过
497
500
 
498
- # 查找整个函数调用的起止范围
499
- full_start_tag = f"<{function_name}"
500
- full_end_tag = f"</{function_name}>"
501
+ # 查找整个标签的起止范围
502
+ full_start_tag = f"<{tag_name}"
503
+ full_end_tag = f"</{tag_name}>"
501
504
 
502
505
  # 从当前位置找到开始标签
503
506
  start_pos = xml_content.find(full_start_tag, position)
@@ -512,78 +515,67 @@ def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
512
515
  position = tag_end + 1
513
516
  continue
514
517
 
515
- # 计算整个函数标签内容,包括开始和结束标签
516
- end_pos_complete = end_pos + len(full_end_tag)
517
- full_tag_content = xml_content[start_pos:end_pos_complete]
518
-
519
- # 使用XmlMatcher提取该函数标签内的内容
520
- content_matcher = XmlMatcher[XmlMatcherResult](function_name)
521
- match_results = content_matcher.final(full_tag_content)
522
-
523
- function_content = ""
524
- for result in match_results:
525
- if result.matched:
526
- function_content = result.data
527
- break
528
-
529
- # 解析参数
530
- parameters = {}
531
- if function_content:
532
- lines = function_content.strip().split('\n')
533
- current_param = None
534
- current_value = []
535
-
536
- for line in lines:
537
- line = line.strip()
538
- if line.startswith('<') and '>' in line and not line.startswith('</'):
539
- # 新参数开始
540
- if current_param and current_value:
541
- # 保存之前的参数
542
- parameters[current_param] = '\n'.join(current_value).strip()
543
- current_value = []
544
-
545
- # 提取参数名
546
- param_start = line.find('<') + 1
547
- param_end = line.find('>', param_start)
548
- if param_end != -1:
549
- param = line[param_start:param_end]
550
- # 检查是否是闭合标签
551
- if not param.startswith('/'):
552
- current_param = param
553
- # 检查是否在同一行有值
554
- rest = line[param_end+1:]
555
- if rest and not rest.startswith('</'):
556
- current_value.append(rest)
557
- elif line.startswith('</') and '>' in line:
558
- # 参数结束
559
- if current_param and current_value:
560
- param_end_tag = f"</{current_param}>"
561
- if line.strip() == param_end_tag:
562
- parameters[current_param] = '\n'.join(current_value).strip()
563
- current_param = None
564
- current_value = []
565
- elif current_param:
566
- # 继续收集当前参数的值
567
- current_value.append(line)
568
-
569
- # 处理最后一个参数
570
- if current_param and current_value:
571
- parameters[current_param] = '\n'.join(current_value).strip()
572
-
573
- # 清理参数值中可能的结束标签
574
- for param, value in parameters.items():
575
- end_tag = f'</{param}>'
576
- if value.endswith(end_tag):
577
- parameters[param] = value[:-len(end_tag)].strip()
578
-
579
- # 将解析的函数添加到结果数组
580
- result_functions.append({
581
- 'function_name': function_name,
582
- 'parameter': parameters
583
- })
584
-
585
- # 更新位置到当前标签之后,继续查找下一个函数
586
- position = end_pos_complete
518
+ # 标签的内容(不包括开始和结束标签)
519
+ tag_inner_content = xml_content[tag_end+1:end_pos]
520
+
521
+ # 如果是普通辅助标签(如tool_call),则在其内部寻找函数调用
522
+ if tag_name in ["tool_call", "function_call", "tool", "function"]:
523
+ # 递归处理内部内容
524
+ nested_functions = parse_function_xml(tag_inner_content)
525
+ result_functions.extend(nested_functions)
526
+ else:
527
+ # 将当前标签作为函数名,解析其内部标签作为参数
528
+ parameters = {}
529
+
530
+ # 解析内部标签作为参数
531
+ param_position = 0
532
+ while param_position < len(tag_inner_content):
533
+ param_tag_start = tag_inner_content.find("<", param_position)
534
+ if param_tag_start == -1:
535
+ break
536
+
537
+ # 跳过闭合标签
538
+ if param_tag_start + 1 < len(tag_inner_content) and tag_inner_content[param_tag_start + 1] == '/':
539
+ param_position = param_tag_start + 1
540
+ continue
541
+
542
+ param_tag_end = tag_inner_content.find(">", param_tag_start)
543
+ if param_tag_end == -1:
544
+ break
545
+
546
+ # 提取参数名
547
+ param_name = tag_inner_content[param_tag_start+1:param_tag_end].strip()
548
+ if " " in param_name: # 处理有属性的情况
549
+ param_name = param_name.split()[0]
550
+
551
+ if not param_name:
552
+ param_position = param_tag_end + 1
553
+ continue
554
+
555
+ # 查找参数标签的结束位置
556
+ param_end_tag = f"</{param_name}>"
557
+ param_end_pos = tag_inner_content.find(param_end_tag, param_tag_end)
558
+
559
+ if param_end_pos == -1:
560
+ # 参数标签未闭合
561
+ param_position = param_tag_end + 1
562
+ continue
563
+
564
+ # 提取参数值
565
+ param_value = tag_inner_content[param_tag_end+1:param_end_pos].strip()
566
+ parameters[param_name] = param_value
567
+
568
+ # 更新位置到当前参数标签之后
569
+ param_position = param_end_pos + len(param_end_tag)
570
+
571
+ # 添加解析结果
572
+ result_functions.append({
573
+ 'function_name': tag_name,
574
+ 'parameter': parameters
575
+ })
576
+
577
+ # 更新位置到当前标签之后
578
+ position = end_pos + len(full_end_tag)
587
579
 
588
580
  return result_functions
589
581
 
@@ -657,5 +649,71 @@ def parse_continuous_json(json_str: str, function_name: str = "") -> List[Dict[s
657
649
 
658
650
  return result
659
651
 
652
+ def convert_functions_to_xml(functions_list):
653
+ """
654
+ 将函数调用列表转换为XML格式的字符串
655
+
656
+ 参数:
657
+ functions_list: 函数调用列表,每个元素是包含function_name和parameter的字典
658
+
659
+ 返回:
660
+ XML格式的字符串
661
+ """
662
+ xml_result = ""
663
+
664
+ if isinstance(functions_list, str):
665
+ try:
666
+ # 提取并解析JSON字符串
667
+ functions_list = json.loads(functions_list)
668
+ # 确保解析结果是列表
669
+ if not isinstance(functions_list, list):
670
+ print(f"提取的工具调用不是列表格式: {functions_list}")
671
+ except json.JSONDecodeError as e:
672
+ print(f"从文本中提取的工具调用JSON解析失败: {e}")
673
+
674
+ for func in functions_list:
675
+ # 获取函数名和参数
676
+ function_name = func.get('function_name', '')
677
+ parameters = func.get('parameter', {})
678
+
679
+ # 开始函数标签
680
+ xml_result += f"<{function_name}>\n"
681
+
682
+ # 添加所有参数
683
+ for param_name, param_value in parameters.items():
684
+ xml_result += f"<{param_name}>{param_value}</{param_name}>\n"
685
+
686
+ # 结束函数标签
687
+ xml_result += f"</{function_name}>\n"
688
+
689
+ return xml_result
690
+
660
691
  if __name__ == "__main__":
661
- os.system("clear")
692
+
693
+ # 运行本文件:python -m aient.utils.scripts
694
+ os.system("clear")
695
+ test_xml = """
696
+ ✅ 好的,我现在读取 `README.md` 文件。
697
+ <tool_call>
698
+ <read_file>
699
+ <file_path>/Users/yanyuming/Downloads/GitHub/llama3_interpretability_sae/README.md</file_path>
700
+ </read_file>
701
+ </tool_call>好的,我现在读取 `README.md` 文件。
702
+ """
703
+ test_xml = """
704
+ ✅ 好的,我现在读取 `README.md` 文件。
705
+ <read_file>
706
+ <file_path>README.md</file_path>
707
+ </read_file>
708
+ <read_file>
709
+ <file_path>README.md</file_path>
710
+ </read_file>
711
+
712
+ <tool_call>
713
+ <read_file>
714
+ <file_path>README.md</file_path>
715
+ </read_file>
716
+ </tool_call>
717
+ 好的,我现在读取 `README.md` 文件。
718
+ """
719
+ print(parse_function_xml(test_xml))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aient
3
- Version: 1.0.45
3
+ Version: 1.0.47
4
4
  Summary: Aient: The Awakening of Agent.
5
5
  Description-Content-Type: text/markdown
6
6
  License-File: LICENSE
@@ -3,16 +3,16 @@ aient/core/.git,sha256=lrAcW1SxzRBUcUiuKL5tS9ykDmmTXxyLP3YYU-Y-Q-I,45
3
3
  aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
4
4
  aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
5
5
  aient/core/models.py,sha256=H3_XuWA7aS25MWZPK1c-5RBiiuxWJbTfE3RAk0Pkc9A,7504
6
- aient/core/request.py,sha256=6c9drOddcvfeuLoUmDUWxP0gekW-ov839wiYETsNiZ0,48895
6
+ aient/core/request.py,sha256=OlMkjGMcFAH-ItA1PgPuf2HT-RbI-Ca4JXncWApc3gM,49088
7
7
  aient/core/response.py,sha256=7RVSFfGHisejv2SlsHvp0t-N_8OpTS4edQU_NOi5BGU,25822
8
- aient/core/utils.py,sha256=i9ZwyywBLIhRM0fNmFSD3jF3dBL5QqVMOtSlG_ddv-I,24101
8
+ aient/core/utils.py,sha256=I0u3WLWaMd4j1ShqKg_tz67m-1wr_uXlWgxGeUjIIiE,25098
9
9
  aient/core/test/test_base_api.py,sha256=CjfFzMG26r8C4xCPoVkKb3Ac6pp9gy5NUCbZJHoSSsM,393
10
10
  aient/core/test/test_image.py,sha256=_T4peNGdXKBHHxyQNx12u-NTyFE8TlYI6NvvagsG2LE,319
11
11
  aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhFkw,2755
12
12
  aient/models/__init__.py,sha256=ouNDNvoBBpIFrLsk09Q_sq23HR0GbLAKfGLIFmfEuXE,219
13
13
  aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
14
14
  aient/models/base.py,sha256=Loyt2F2WrDMBbK-sdmTtgkLVtdUXxK5tg4qoI6nc0Xo,7527
15
- aient/models/chatgpt.py,sha256=rF95RmO4C3h4PKRqE3Qk6fKoR0yIf-3zp8t7KBF_kjA,41685
15
+ aient/models/chatgpt.py,sha256=r56r0Q7sTSU1WQcrFDfVfb3LOtHuHLJpoqhy-pAE1TA,42202
16
16
  aient/models/claude.py,sha256=thK9P8qkaaoUN3OOJ9Shw4KDs-pAGKPoX4FOPGFXva8,28597
17
17
  aient/models/duckduckgo.py,sha256=1l7vYCs9SG5SWPCbcl7q6pCcB5AUF_r-a4l9frz3Ogo,8115
18
18
  aient/models/gemini.py,sha256=chGLc-8G_DAOxr10HPoOhvVFW1RvMgHd6mt--VyAW98,14730
@@ -28,9 +28,9 @@ aient/plugins/today.py,sha256=btnXJNqWorJDKPvH9PBTdHaExpVI1YPuSAeRrq-fg9A,667
28
28
  aient/plugins/websearch.py,sha256=yiBzqXK5X220ibR-zko3VDsn4QOnLu1k6E2YOygCeTQ,15185
29
29
  aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
31
- aient/utils/scripts.py,sha256=obrf5oxzFQPCu1A5MYDDiZv_LM6l9C1QSkgWIqcu28k,25690
32
- aient-1.0.45.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
33
- aient-1.0.45.dist-info/METADATA,sha256=Wt-dsD5uQjMdMfWGA052WOC7ITd8UQy84MiGPkABO8A,4986
34
- aient-1.0.45.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
35
- aient-1.0.45.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
36
- aient-1.0.45.dist-info/RECORD,,
31
+ aient/utils/scripts.py,sha256=XCXMRdpWRJb34Znk4t9JkFnvzDzGHVA5Vv5WpUgP2_0,27152
32
+ aient-1.0.47.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
33
+ aient-1.0.47.dist-info/METADATA,sha256=2qqnAF8-z1wWVwel2w3v7OL31j69wpY_fr224GFnJwQ,4986
34
+ aient-1.0.47.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
35
+ aient-1.0.47.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
36
+ aient-1.0.47.dist-info/RECORD,,
File without changes