aient 1.0.46__py3-none-any.whl → 1.0.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/models/chatgpt.py CHANGED
@@ -12,7 +12,7 @@ from pathlib import Path
12
12
 
13
13
  from .base import BaseLLM
14
14
  from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
15
- from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json
15
+ from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json, convert_functions_to_xml
16
16
  from ..core.request import prepare_request_payload
17
17
  from ..core.response import fetch_response_stream
18
18
 
@@ -148,7 +148,7 @@ class chatgpt(BaseLLM):
148
148
  })
149
149
  self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
150
150
  else:
151
- self.conversation[convo_id].append({"role": "assistant", "content": "I will use tool: " + function_arguments + ". I will get the tool call result in the next user response."})
151
+ self.conversation[convo_id].append({"role": "assistant", "content": convert_functions_to_xml(function_arguments)})
152
152
  self.conversation[convo_id].append({"role": "user", "content": message})
153
153
 
154
154
  else:
@@ -418,6 +418,7 @@ class chatgpt(BaseLLM):
418
418
  # 处理函数调用
419
419
  if need_function_call:
420
420
  if self.print_log:
421
+ print("function_parameter", function_parameter)
421
422
  print("function_full_response", function_full_response)
422
423
 
423
424
  function_response = ""
aient/utils/scripts.py CHANGED
@@ -460,6 +460,7 @@ class XmlMatcher(Generic[R]):
460
460
  def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
461
461
  """
462
462
  解析XML格式的函数调用信息,转换为字典数组格式
463
+ 只解析倒数两层XML标签,忽略更高层级的XML标签
463
464
 
464
465
  参数:
465
466
  xml_content: 包含一个或多个函数调用的XML字符串
@@ -469,6 +470,7 @@ def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
469
470
  """
470
471
  result_functions = []
471
472
 
473
+ # 第一步:识别XML中的顶层标签(可能是函数调用)
472
474
  position = 0
473
475
  while position < len(xml_content):
474
476
  # 寻找下一个开始标签
@@ -482,22 +484,23 @@ def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
482
484
  position = tag_start + 1
483
485
  continue
484
486
 
487
+ # 找到标签的结束位置
485
488
  tag_end = xml_content.find(">", tag_start)
486
489
  if tag_end == -1:
487
490
  break # 标签未正确关闭
488
491
 
489
- # 提取标签名(函数名)
492
+ # 提取标签名
490
493
  tag_content = xml_content[tag_start+1:tag_end].strip()
491
494
  # 处理可能有属性的情况
492
- function_name = tag_content.split()[0] if " " in tag_content else tag_content
495
+ tag_name = tag_content.split()[0] if " " in tag_content else tag_content
493
496
 
494
- if not function_name:
497
+ if not tag_name:
495
498
  position = tag_end + 1
496
499
  continue # 空标签名,跳过
497
500
 
498
- # 查找整个函数调用的起止范围
499
- full_start_tag = f"<{function_name}"
500
- full_end_tag = f"</{function_name}>"
501
+ # 查找整个标签的起止范围
502
+ full_start_tag = f"<{tag_name}"
503
+ full_end_tag = f"</{tag_name}>"
501
504
 
502
505
  # 从当前位置找到开始标签
503
506
  start_pos = xml_content.find(full_start_tag, position)
@@ -512,78 +515,67 @@ def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
512
515
  position = tag_end + 1
513
516
  continue
514
517
 
515
- # 计算整个函数标签内容,包括开始和结束标签
516
- end_pos_complete = end_pos + len(full_end_tag)
517
- full_tag_content = xml_content[start_pos:end_pos_complete]
518
-
519
- # 使用XmlMatcher提取该函数标签内的内容
520
- content_matcher = XmlMatcher[XmlMatcherResult](function_name)
521
- match_results = content_matcher.final(full_tag_content)
522
-
523
- function_content = ""
524
- for result in match_results:
525
- if result.matched:
526
- function_content = result.data
527
- break
528
-
529
- # 解析参数
530
- parameters = {}
531
- if function_content:
532
- lines = function_content.strip().split('\n')
533
- current_param = None
534
- current_value = []
535
-
536
- for line in lines:
537
- line = line.strip()
538
- if line.startswith('<') and '>' in line and not line.startswith('</'):
539
- # 新参数开始
540
- if current_param and current_value:
541
- # 保存之前的参数
542
- parameters[current_param] = '\n'.join(current_value).strip()
543
- current_value = []
544
-
545
- # 提取参数名
546
- param_start = line.find('<') + 1
547
- param_end = line.find('>', param_start)
548
- if param_end != -1:
549
- param = line[param_start:param_end]
550
- # 检查是否是闭合标签
551
- if not param.startswith('/'):
552
- current_param = param
553
- # 检查是否在同一行有值
554
- rest = line[param_end+1:]
555
- if rest and not rest.startswith('</'):
556
- current_value.append(rest)
557
- elif line.startswith('</') and '>' in line:
558
- # 参数结束
559
- if current_param and current_value:
560
- param_end_tag = f"</{current_param}>"
561
- if line.strip() == param_end_tag:
562
- parameters[current_param] = '\n'.join(current_value).strip()
563
- current_param = None
564
- current_value = []
565
- elif current_param:
566
- # 继续收集当前参数的值
567
- current_value.append(line)
568
-
569
- # 处理最后一个参数
570
- if current_param and current_value:
571
- parameters[current_param] = '\n'.join(current_value).strip()
572
-
573
- # 清理参数值中可能的结束标签
574
- for param, value in parameters.items():
575
- end_tag = f'</{param}>'
576
- if value.endswith(end_tag):
577
- parameters[param] = value[:-len(end_tag)].strip()
578
-
579
- # 将解析的函数添加到结果数组
580
- result_functions.append({
581
- 'function_name': function_name,
582
- 'parameter': parameters
583
- })
584
-
585
- # 更新位置到当前标签之后,继续查找下一个函数
586
- position = end_pos_complete
518
+ # 标签的内容(不包括开始和结束标签)
519
+ tag_inner_content = xml_content[tag_end+1:end_pos]
520
+
521
+ # 如果是普通辅助标签(如tool_call),则在其内部寻找函数调用
522
+ if tag_name in ["tool_call", "function_call", "tool", "function"]:
523
+ # 递归处理内部内容
524
+ nested_functions = parse_function_xml(tag_inner_content)
525
+ result_functions.extend(nested_functions)
526
+ else:
527
+ # 将当前标签作为函数名,解析其内部标签作为参数
528
+ parameters = {}
529
+
530
+ # 解析内部标签作为参数
531
+ param_position = 0
532
+ while param_position < len(tag_inner_content):
533
+ param_tag_start = tag_inner_content.find("<", param_position)
534
+ if param_tag_start == -1:
535
+ break
536
+
537
+ # 跳过闭合标签
538
+ if param_tag_start + 1 < len(tag_inner_content) and tag_inner_content[param_tag_start + 1] == '/':
539
+ param_position = param_tag_start + 1
540
+ continue
541
+
542
+ param_tag_end = tag_inner_content.find(">", param_tag_start)
543
+ if param_tag_end == -1:
544
+ break
545
+
546
+ # 提取参数名
547
+ param_name = tag_inner_content[param_tag_start+1:param_tag_end].strip()
548
+ if " " in param_name: # 处理有属性的情况
549
+ param_name = param_name.split()[0]
550
+
551
+ if not param_name:
552
+ param_position = param_tag_end + 1
553
+ continue
554
+
555
+ # 查找参数标签的结束位置
556
+ param_end_tag = f"</{param_name}>"
557
+ param_end_pos = tag_inner_content.find(param_end_tag, param_tag_end)
558
+
559
+ if param_end_pos == -1:
560
+ # 参数标签未闭合
561
+ param_position = param_tag_end + 1
562
+ continue
563
+
564
+ # 提取参数值
565
+ param_value = tag_inner_content[param_tag_end+1:param_end_pos].strip()
566
+ parameters[param_name] = param_value
567
+
568
+ # 更新位置到当前参数标签之后
569
+ param_position = param_end_pos + len(param_end_tag)
570
+
571
+ # 添加解析结果
572
+ result_functions.append({
573
+ 'function_name': tag_name,
574
+ 'parameter': parameters
575
+ })
576
+
577
+ # 更新位置到当前标签之后
578
+ position = end_pos + len(full_end_tag)
587
579
 
588
580
  return result_functions
589
581
 
@@ -657,5 +649,71 @@ def parse_continuous_json(json_str: str, function_name: str = "") -> List[Dict[s
657
649
 
658
650
  return result
659
651
 
652
+ def convert_functions_to_xml(functions_list):
653
+ """
654
+ 将函数调用列表转换为XML格式的字符串
655
+
656
+ 参数:
657
+ functions_list: 函数调用列表,每个元素是包含function_name和parameter的字典
658
+
659
+ 返回:
660
+ XML格式的字符串
661
+ """
662
+ xml_result = ""
663
+
664
+ if isinstance(functions_list, str):
665
+ try:
666
+ # 提取并解析JSON字符串
667
+ functions_list = json.loads(functions_list)
668
+ # 确保解析结果是列表
669
+ if not isinstance(functions_list, list):
670
+ print(f"提取的工具调用不是列表格式: {functions_list}")
671
+ except json.JSONDecodeError as e:
672
+ print(f"从文本中提取的工具调用JSON解析失败: {e}")
673
+
674
+ for func in functions_list:
675
+ # 获取函数名和参数
676
+ function_name = func.get('function_name', '')
677
+ parameters = func.get('parameter', {})
678
+
679
+ # 开始函数标签
680
+ xml_result += f"<{function_name}>\n"
681
+
682
+ # 添加所有参数
683
+ for param_name, param_value in parameters.items():
684
+ xml_result += f"<{param_name}>{param_value}</{param_name}>\n"
685
+
686
+ # 结束函数标签
687
+ xml_result += f"</{function_name}>\n"
688
+
689
+ return xml_result
690
+
660
691
  if __name__ == "__main__":
661
- os.system("clear")
692
+
693
+ # 运行本文件:python -m aient.utils.scripts
694
+ os.system("clear")
695
+ test_xml = """
696
+ ✅ 好的,我现在读取 `README.md` 文件。
697
+ <tool_call>
698
+ <read_file>
699
+ <file_path>/Users/yanyuming/Downloads/GitHub/llama3_interpretability_sae/README.md</file_path>
700
+ </read_file>
701
+ </tool_call>好的,我现在读取 `README.md` 文件。
702
+ """
703
+ test_xml = """
704
+ ✅ 好的,我现在读取 `README.md` 文件。
705
+ <read_file>
706
+ <file_path>README.md</file_path>
707
+ </read_file>
708
+ <read_file>
709
+ <file_path>README.md</file_path>
710
+ </read_file>
711
+
712
+ <tool_call>
713
+ <read_file>
714
+ <file_path>README.md</file_path>
715
+ </read_file>
716
+ </tool_call>
717
+ 好的,我现在读取 `README.md` 文件。
718
+ """
719
+ print(parse_function_xml(test_xml))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aient
3
- Version: 1.0.46
3
+ Version: 1.0.47
4
4
  Summary: Aient: The Awakening of Agent.
5
5
  Description-Content-Type: text/markdown
6
6
  License-File: LICENSE
@@ -12,7 +12,7 @@ aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhF
12
12
  aient/models/__init__.py,sha256=ouNDNvoBBpIFrLsk09Q_sq23HR0GbLAKfGLIFmfEuXE,219
13
13
  aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
14
14
  aient/models/base.py,sha256=Loyt2F2WrDMBbK-sdmTtgkLVtdUXxK5tg4qoI6nc0Xo,7527
15
- aient/models/chatgpt.py,sha256=QGMx2szrYlK-uqe18Vbem3ou37nrQFhS7vonpLxHrUo,42173
15
+ aient/models/chatgpt.py,sha256=r56r0Q7sTSU1WQcrFDfVfb3LOtHuHLJpoqhy-pAE1TA,42202
16
16
  aient/models/claude.py,sha256=thK9P8qkaaoUN3OOJ9Shw4KDs-pAGKPoX4FOPGFXva8,28597
17
17
  aient/models/duckduckgo.py,sha256=1l7vYCs9SG5SWPCbcl7q6pCcB5AUF_r-a4l9frz3Ogo,8115
18
18
  aient/models/gemini.py,sha256=chGLc-8G_DAOxr10HPoOhvVFW1RvMgHd6mt--VyAW98,14730
@@ -28,9 +28,9 @@ aient/plugins/today.py,sha256=btnXJNqWorJDKPvH9PBTdHaExpVI1YPuSAeRrq-fg9A,667
28
28
  aient/plugins/websearch.py,sha256=yiBzqXK5X220ibR-zko3VDsn4QOnLu1k6E2YOygCeTQ,15185
29
29
  aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
31
- aient/utils/scripts.py,sha256=obrf5oxzFQPCu1A5MYDDiZv_LM6l9C1QSkgWIqcu28k,25690
32
- aient-1.0.46.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
33
- aient-1.0.46.dist-info/METADATA,sha256=nYfiefitlFshZCNddR3PTfypDm1mrCtJhjboAJmoNOQ,4986
34
- aient-1.0.46.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
35
- aient-1.0.46.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
36
- aient-1.0.46.dist-info/RECORD,,
31
+ aient/utils/scripts.py,sha256=XCXMRdpWRJb34Znk4t9JkFnvzDzGHVA5Vv5WpUgP2_0,27152
32
+ aient-1.0.47.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
33
+ aient-1.0.47.dist-info/METADATA,sha256=2qqnAF8-z1wWVwel2w3v7OL31j69wpY_fr224GFnJwQ,4986
34
+ aient-1.0.47.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
35
+ aient-1.0.47.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
36
+ aient-1.0.47.dist-info/RECORD,,
File without changes