autocoder-nano 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autocoder_nano/agent/agent_base.py +376 -63
- autocoder_nano/auto_coder_nano.py +147 -1842
- autocoder_nano/edit/__init__.py +20 -0
- autocoder_nano/edit/actions.py +136 -0
- autocoder_nano/edit/code/__init__.py +0 -0
- autocoder_nano/edit/code/generate_editblock.py +403 -0
- autocoder_nano/edit/code/merge_editblock.py +418 -0
- autocoder_nano/edit/code/modification_ranker.py +90 -0
- autocoder_nano/edit/text.py +38 -0
- autocoder_nano/index/__init__.py +0 -0
- autocoder_nano/index/entry.py +166 -0
- autocoder_nano/index/index_manager.py +410 -0
- autocoder_nano/index/symbols_utils.py +43 -0
- autocoder_nano/llm_types.py +12 -8
- autocoder_nano/version.py +1 -1
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/METADATA +1 -1
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/RECORD +21 -10
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/LICENSE +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/WHEEL +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/entry_points.txt +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/top_level.txt +0 -0
@@ -2,19 +2,18 @@ import argparse
|
|
2
2
|
import glob
|
3
3
|
import hashlib
|
4
4
|
import os
|
5
|
-
import re
|
6
5
|
import json
|
7
6
|
import shutil
|
8
7
|
import subprocess
|
9
|
-
import tempfile
|
10
8
|
import textwrap
|
11
9
|
import time
|
12
|
-
import traceback
|
13
10
|
import uuid
|
14
|
-
from difflib import SequenceMatcher
|
15
11
|
|
16
|
-
from autocoder_nano.
|
12
|
+
from autocoder_nano.edit import Dispacher
|
17
13
|
from autocoder_nano.helper import show_help
|
14
|
+
from autocoder_nano.index.entry import build_index_and_filter_files
|
15
|
+
from autocoder_nano.index.index_manager import IndexManager
|
16
|
+
from autocoder_nano.index.symbols_utils import extract_symbols
|
18
17
|
from autocoder_nano.llm_client import AutoLLM
|
19
18
|
from autocoder_nano.version import __version__
|
20
19
|
from autocoder_nano.llm_types import *
|
@@ -61,8 +60,10 @@ memory = {
|
|
61
60
|
"current_files": {"files": [], "groups": {}},
|
62
61
|
"conf": {
|
63
62
|
"auto_merge": "editblock",
|
64
|
-
"current_chat_model": "",
|
65
|
-
"current_code_model": ""
|
63
|
+
# "current_chat_model": "",
|
64
|
+
# "current_code_model": "",
|
65
|
+
"chat_model": "",
|
66
|
+
"code_model": "",
|
66
67
|
},
|
67
68
|
"exclude_dirs": [],
|
68
69
|
"mode": "normal", # 新增mode字段,默认为normal模式
|
@@ -73,29 +74,6 @@ memory = {
|
|
73
74
|
args: AutoCoderArgs = AutoCoderArgs()
|
74
75
|
|
75
76
|
|
76
|
-
def extract_symbols(text: str) -> SymbolsInfo:
|
77
|
-
patterns = {
|
78
|
-
"usage": r"用途:(.+)",
|
79
|
-
"functions": r"函数:(.+)",
|
80
|
-
"variables": r"变量:(.+)",
|
81
|
-
"classes": r"类:(.+)",
|
82
|
-
"import_statements": r"导入语句:(.+)",
|
83
|
-
}
|
84
|
-
|
85
|
-
info = SymbolsInfo()
|
86
|
-
for field, pattern in patterns.items():
|
87
|
-
match = re.search(pattern, text)
|
88
|
-
if match:
|
89
|
-
value = match.group(1).strip()
|
90
|
-
if field == "import_statements":
|
91
|
-
value = [v.strip() for v in value.split("^^")]
|
92
|
-
elif field == "functions" or field == "variables" or field == "classes":
|
93
|
-
value = [v.strip() for v in value.split(",")]
|
94
|
-
setattr(info, field, value)
|
95
|
-
|
96
|
-
return info
|
97
|
-
|
98
|
-
|
99
77
|
def get_all_file_names_in_project() -> List[str]:
|
100
78
|
file_names = []
|
101
79
|
final_exclude_dirs = default_exclude_dirs + memory.get("exclude_dirs", [])
|
@@ -794,449 +772,8 @@ def load_memory():
|
|
794
772
|
completer.update_current_files(memory["current_files"]["files"])
|
795
773
|
|
796
774
|
|
797
|
-
def symbols_info_to_str(info: SymbolsInfo, symbol_types: List[SymbolType]) -> str:
|
798
|
-
result = []
|
799
|
-
for symbol_type in symbol_types:
|
800
|
-
value = getattr(info, symbol_type.value)
|
801
|
-
if value:
|
802
|
-
if symbol_type == SymbolType.IMPORT_STATEMENTS:
|
803
|
-
value_str = "^^".join(value)
|
804
|
-
elif symbol_type in [SymbolType.FUNCTIONS, SymbolType.VARIABLES, SymbolType.CLASSES,]:
|
805
|
-
value_str = ",".join(value)
|
806
|
-
else:
|
807
|
-
value_str = value
|
808
|
-
result.append(f"{symbol_type.value}:{value_str}")
|
809
|
-
|
810
|
-
return "\n".join(result)
|
811
|
-
|
812
|
-
|
813
|
-
class IndexManager:
|
814
|
-
def __init__(self, source_codes: List[SourceCode], llm: AutoLLM = None):
|
815
|
-
self.args = args
|
816
|
-
self.sources = source_codes
|
817
|
-
self.source_dir = args.source_dir
|
818
|
-
self.index_dir = os.path.join(self.source_dir, ".auto-coder")
|
819
|
-
self.index_file = os.path.join(self.index_dir, "index.json")
|
820
|
-
self.llm = llm
|
821
|
-
self.llm.setup_default_model_name(memory["conf"]["current_chat_model"])
|
822
|
-
self.max_input_length = args.model_max_input_length # 模型输入最大长度
|
823
|
-
# 使用 time.sleep(self.anti_quota_limit) 防止超过 API 频率限制
|
824
|
-
self.anti_quota_limit = args.anti_quota_limit
|
825
|
-
# 如果索引目录不存在,则创建它
|
826
|
-
if not os.path.exists(self.index_dir):
|
827
|
-
os.makedirs(self.index_dir)
|
828
|
-
|
829
|
-
def build_index(self):
|
830
|
-
""" 构建或更新索引,使用多线程处理多个文件,并将更新后的索引数据写入文件 """
|
831
|
-
if os.path.exists(self.index_file):
|
832
|
-
with open(self.index_file, "r") as file: # 读缓存
|
833
|
-
index_data = json.load(file)
|
834
|
-
else: # 首次 build index
|
835
|
-
logger.info("首次生成索引.")
|
836
|
-
index_data = {}
|
837
|
-
|
838
|
-
@prompt()
|
839
|
-
def error_message(source_dir: str, file_path: str):
|
840
|
-
"""
|
841
|
-
The source_dir is different from the path in index file (e.g. file_path:{{ file_path }} source_dir:{{
|
842
|
-
source_dir }}). You may need to replace the prefix with the source_dir in the index file or Just delete
|
843
|
-
the index file to rebuild it.
|
844
|
-
"""
|
845
|
-
|
846
|
-
for item in index_data.keys():
|
847
|
-
if not item.startswith(self.source_dir):
|
848
|
-
logger.warning(error_message(source_dir=self.source_dir, file_path=item))
|
849
|
-
break
|
850
|
-
|
851
|
-
updated_sources = []
|
852
|
-
wait_to_build_files = []
|
853
|
-
for source in self.sources:
|
854
|
-
source_code = source.source_code
|
855
|
-
md5 = hashlib.md5(source_code.encode("utf-8")).hexdigest()
|
856
|
-
if source.module_name not in index_data or index_data[source.module_name]["md5"] != md5:
|
857
|
-
wait_to_build_files.append(source)
|
858
|
-
counter = 0
|
859
|
-
num_files = len(wait_to_build_files)
|
860
|
-
total_files = len(self.sources)
|
861
|
-
logger.info(f"总文件数: {total_files}, 需要索引文件数: {num_files}")
|
862
|
-
|
863
|
-
for source in wait_to_build_files:
|
864
|
-
build_result = self.build_index_for_single_source(source)
|
865
|
-
if build_result is not None:
|
866
|
-
counter += 1
|
867
|
-
logger.info(f"正在构建索引:{counter}/{num_files}...")
|
868
|
-
module_name = build_result["module_name"]
|
869
|
-
index_data[module_name] = build_result
|
870
|
-
updated_sources.append(module_name)
|
871
|
-
if updated_sources:
|
872
|
-
with open(self.index_file, "w") as fp:
|
873
|
-
json_str = json.dumps(index_data, indent=2, ensure_ascii=False)
|
874
|
-
fp.write(json_str)
|
875
|
-
return index_data
|
876
|
-
|
877
|
-
def split_text_into_chunks(self, text):
|
878
|
-
""" 文本分块,将大文本分割成适合 LLM 处理的小块 """
|
879
|
-
lines = text.split("\n")
|
880
|
-
chunks = []
|
881
|
-
current_chunk = []
|
882
|
-
current_length = 0
|
883
|
-
for line in lines:
|
884
|
-
if current_length + len(line) + 1 <= self.max_input_length:
|
885
|
-
current_chunk.append(line)
|
886
|
-
current_length += len(line) + 1
|
887
|
-
else:
|
888
|
-
chunks.append("\n".join(current_chunk))
|
889
|
-
current_chunk = [line]
|
890
|
-
current_length = len(line) + 1
|
891
|
-
if current_chunk:
|
892
|
-
chunks.append("\n".join(current_chunk))
|
893
|
-
return chunks
|
894
|
-
|
895
|
-
@prompt()
|
896
|
-
def get_all_file_symbols(self, path: str, code: str) -> str:
|
897
|
-
"""
|
898
|
-
你的目标是从给定的代码中获取代码里的符号,需要获取的符号类型包括:
|
899
|
-
|
900
|
-
1. 函数
|
901
|
-
2. 类
|
902
|
-
3. 变量
|
903
|
-
4. 所有导入语句
|
904
|
-
|
905
|
-
如果没有任何符号,返回空字符串就行。
|
906
|
-
如果有符号,按如下格式返回:
|
907
|
-
|
908
|
-
```
|
909
|
-
{符号类型}: {符号名称}, {符号名称}, ...
|
910
|
-
```
|
911
|
-
|
912
|
-
注意:
|
913
|
-
1. 直接输出结果,不要尝试使用任何代码
|
914
|
-
2. 不要分析代码的内容和目的
|
915
|
-
3. 用途的长度不能超过100字符
|
916
|
-
4. 导入语句的分隔符为^^
|
917
|
-
|
918
|
-
下面是一段示例:
|
919
|
-
|
920
|
-
## 输入
|
921
|
-
下列是文件 /test.py 的源码:
|
922
|
-
|
923
|
-
import os
|
924
|
-
import time
|
925
|
-
from loguru import logger
|
926
|
-
import byzerllm
|
927
|
-
|
928
|
-
a = ""
|
929
|
-
|
930
|
-
@byzerllm.prompt(render="jinja")
|
931
|
-
def auto_implement_function_template(instruction:str, content:str)->str:
|
932
|
-
|
933
|
-
## 输出
|
934
|
-
用途:主要用于提供自动实现函数模板的功能。
|
935
|
-
函数:auto_implement_function_template
|
936
|
-
变量:a
|
937
|
-
类:
|
938
|
-
导入语句:import os^^import time^^from loguru import logger^^import byzerllm
|
939
|
-
|
940
|
-
现在,让我们开始一个新的任务:
|
941
|
-
|
942
|
-
## 输入
|
943
|
-
下列是文件 {{ path }} 的源码:
|
944
|
-
|
945
|
-
{{ code }}
|
946
|
-
|
947
|
-
## 输出
|
948
|
-
"""
|
949
|
-
|
950
|
-
def build_index_for_single_source(self, source: SourceCode):
|
951
|
-
""" 处理单个源文件,提取符号信息并存储元数据 """
|
952
|
-
file_path = source.module_name
|
953
|
-
if not os.path.exists(file_path): # 过滤不存在的文件
|
954
|
-
return None
|
955
|
-
|
956
|
-
ext = os.path.splitext(file_path)[1].lower()
|
957
|
-
if ext in [".md", ".html", ".txt", ".doc", ".pdf"]: # 过滤文档文件
|
958
|
-
return None
|
959
|
-
|
960
|
-
if source.source_code.strip() == "":
|
961
|
-
return None
|
962
|
-
|
963
|
-
md5 = hashlib.md5(source.source_code.encode("utf-8")).hexdigest()
|
964
|
-
|
965
|
-
try:
|
966
|
-
start_time = time.monotonic()
|
967
|
-
source_code = source.source_code
|
968
|
-
if len(source.source_code) > self.max_input_length:
|
969
|
-
logger.warning(
|
970
|
-
f"警告[构建索引]: 源代码({source.module_name})长度过长 "
|
971
|
-
f"({len(source.source_code)}) > 模型最大输入长度({self.max_input_length}),"
|
972
|
-
f"正在分割为多个块..."
|
973
|
-
)
|
974
|
-
chunks = self.split_text_into_chunks(source_code)
|
975
|
-
symbols_list = []
|
976
|
-
for chunk in chunks:
|
977
|
-
chunk_symbols = self.get_all_file_symbols.with_llm(self.llm).run(source.module_name, chunk)
|
978
|
-
time.sleep(self.anti_quota_limit)
|
979
|
-
symbols_list.append(chunk_symbols.output)
|
980
|
-
symbols = "\n".join(symbols_list)
|
981
|
-
else:
|
982
|
-
single_symbols = self.get_all_file_symbols.with_llm(self.llm).run(source.module_name, source_code)
|
983
|
-
symbols = single_symbols.output
|
984
|
-
time.sleep(self.anti_quota_limit)
|
985
|
-
|
986
|
-
logger.info(f"解析并更新索引:文件 {file_path}(MD5: {md5}),耗时 {time.monotonic() - start_time:.2f} 秒")
|
987
|
-
except Exception as e:
|
988
|
-
logger.warning(f"源文件 {file_path} 处理失败: {e}")
|
989
|
-
return None
|
990
|
-
|
991
|
-
return {
|
992
|
-
"module_name": source.module_name,
|
993
|
-
"symbols": symbols,
|
994
|
-
"last_modified": os.path.getmtime(file_path),
|
995
|
-
"md5": md5,
|
996
|
-
}
|
997
|
-
|
998
|
-
@prompt()
|
999
|
-
def _get_target_files_by_query(self, indices: str, query: str) -> str:
|
1000
|
-
"""
|
1001
|
-
下面是已知文件以及对应的符号信息:
|
1002
|
-
|
1003
|
-
{{ indices }}
|
1004
|
-
|
1005
|
-
用户的问题是:
|
1006
|
-
|
1007
|
-
{{ query }}
|
1008
|
-
|
1009
|
-
现在,请根据用户的问题以及前面的文件和符号信息,寻找相关文件路径。返回结果按如下格式:
|
1010
|
-
|
1011
|
-
```json
|
1012
|
-
{
|
1013
|
-
"file_list": [
|
1014
|
-
{
|
1015
|
-
"file_path": "path/to/file.py",
|
1016
|
-
"reason": "The reason why the file is the target file"
|
1017
|
-
},
|
1018
|
-
{
|
1019
|
-
"file_path": "path/to/file.py",
|
1020
|
-
"reason": "The reason why the file is the target file"
|
1021
|
-
}
|
1022
|
-
]
|
1023
|
-
}
|
1024
|
-
```
|
1025
|
-
|
1026
|
-
如果没有找到,返回如下 json 即可:
|
1027
|
-
|
1028
|
-
```json
|
1029
|
-
{"file_list": []}
|
1030
|
-
```
|
1031
|
-
|
1032
|
-
请严格遵循以下步骤:
|
1033
|
-
|
1034
|
-
1. 识别特殊标记:
|
1035
|
-
- 查找query中的 `@` 符号,它后面的内容是用户关注的文件路径。
|
1036
|
-
- 查找query中的 `@@` 符号,它后面的内容是用户关注的符号(如函数名、类名、变量名)。
|
1037
|
-
|
1038
|
-
2. 匹配文件路径:
|
1039
|
-
- 对于 `@` 标记,在indices中查找包含该路径的所有文件。
|
1040
|
-
- 路径匹配应该是部分匹配,因为用户可能只提供了路径的一部分。
|
1041
|
-
|
1042
|
-
3. 匹配符号:
|
1043
|
-
- 对于 `@@` 标记,在indices中所有文件的符号信息中查找该符号。
|
1044
|
-
- 检查函数、类、变量等所有符号类型。
|
1045
|
-
|
1046
|
-
4. 分析依赖关系:
|
1047
|
-
- 利用 "导入语句" 信息确定文件间的依赖关系。
|
1048
|
-
- 如果找到了相关文件,也包括与之直接相关的依赖文件。
|
1049
|
-
|
1050
|
-
5. 考虑文件用途:
|
1051
|
-
- 使用每个文件的 "用途" 信息来判断其与查询的相关性。
|
1052
|
-
|
1053
|
-
6. 请严格按格式要求返回结果,无需额外的说明
|
1054
|
-
|
1055
|
-
请确保结果的准确性和完整性,包括所有可能相关的文件。
|
1056
|
-
"""
|
1057
|
-
|
1058
|
-
def read_index(self) -> List[IndexItem]:
|
1059
|
-
""" 读取并解析索引文件,将其转换为 IndexItem 对象列表 """
|
1060
|
-
if not os.path.exists(self.index_file):
|
1061
|
-
return []
|
1062
|
-
|
1063
|
-
with open(self.index_file, "r") as file:
|
1064
|
-
index_data = json.load(file)
|
1065
|
-
|
1066
|
-
index_items = []
|
1067
|
-
for module_name, data in index_data.items():
|
1068
|
-
index_item = IndexItem(
|
1069
|
-
module_name=module_name,
|
1070
|
-
symbols=data["symbols"],
|
1071
|
-
last_modified=data["last_modified"],
|
1072
|
-
md5=data["md5"]
|
1073
|
-
)
|
1074
|
-
index_items.append(index_item)
|
1075
|
-
|
1076
|
-
return index_items
|
1077
|
-
|
1078
|
-
def _get_meta_str(self, includes: Optional[List[SymbolType]] = None):
|
1079
|
-
index_items = self.read_index()
|
1080
|
-
current_chunk = []
|
1081
|
-
for item in index_items:
|
1082
|
-
symbols_str = item.symbols
|
1083
|
-
if includes:
|
1084
|
-
symbol_info = extract_symbols(symbols_str)
|
1085
|
-
symbols_str = symbols_info_to_str(symbol_info, includes)
|
1086
|
-
|
1087
|
-
item_str = f"##{item.module_name}\n{symbols_str}\n\n"
|
1088
|
-
if len(current_chunk) > self.args.filter_batch_size:
|
1089
|
-
yield "".join(current_chunk)
|
1090
|
-
current_chunk = [item_str]
|
1091
|
-
else:
|
1092
|
-
current_chunk.append(item_str)
|
1093
|
-
if current_chunk:
|
1094
|
-
yield "".join(current_chunk)
|
1095
|
-
|
1096
|
-
def get_target_files_by_query(self, query: str):
|
1097
|
-
""" 根据查询条件查找相关文件,考虑不同过滤级别 """
|
1098
|
-
all_results = []
|
1099
|
-
completed = 0
|
1100
|
-
total = 0
|
1101
|
-
|
1102
|
-
includes = None
|
1103
|
-
if self.args.index_filter_level == 0:
|
1104
|
-
includes = [SymbolType.USAGE]
|
1105
|
-
if self.args.index_filter_level >= 1:
|
1106
|
-
includes = None
|
1107
|
-
|
1108
|
-
for chunk in self._get_meta_str(includes=includes):
|
1109
|
-
result = self._get_target_files_by_query.with_llm(self.llm).with_return_type(FileList).run(chunk, query)
|
1110
|
-
if result is not None:
|
1111
|
-
all_results.extend(result.file_list)
|
1112
|
-
completed += 1
|
1113
|
-
else:
|
1114
|
-
logger.warning(f"无法找到分块的目标文件。原因可能是模型响应未返回 JSON 格式数据,或返回的 JSON 为空。")
|
1115
|
-
total += 1
|
1116
|
-
time.sleep(self.anti_quota_limit)
|
1117
|
-
|
1118
|
-
logger.info(f"已完成 {completed}/{total} 个分块(基于查询条件)")
|
1119
|
-
all_results = list({file.file_path: file for file in all_results}.values())
|
1120
|
-
if self.args.index_filter_file_num > 0:
|
1121
|
-
limited_results = all_results[: self.args.index_filter_file_num]
|
1122
|
-
return FileList(file_list=limited_results)
|
1123
|
-
return FileList(file_list=all_results)
|
1124
|
-
|
1125
|
-
@prompt()
|
1126
|
-
def _get_related_files(self, indices: str, file_paths: str) -> str:
|
1127
|
-
"""
|
1128
|
-
下面是所有文件以及对应的符号信息:
|
1129
|
-
|
1130
|
-
{{ indices }}
|
1131
|
-
|
1132
|
-
请参考上面的信息,找到被下列文件使用或者引用到的文件列表:
|
1133
|
-
|
1134
|
-
{{ file_paths }}
|
1135
|
-
|
1136
|
-
请按如下格式进行输出:
|
1137
|
-
|
1138
|
-
```json
|
1139
|
-
{
|
1140
|
-
"file_list": [
|
1141
|
-
{
|
1142
|
-
"file_path": "path/to/file.py",
|
1143
|
-
"reason": "The reason why the file is the target file"
|
1144
|
-
},
|
1145
|
-
{
|
1146
|
-
"file_path": "path/to/file.py",
|
1147
|
-
"reason": "The reason why the file is the target file"
|
1148
|
-
}
|
1149
|
-
]
|
1150
|
-
}
|
1151
|
-
```
|
1152
|
-
|
1153
|
-
如果没有相关的文件,输出如下 json 即可:
|
1154
|
-
|
1155
|
-
```json
|
1156
|
-
{"file_list": []}
|
1157
|
-
```
|
1158
|
-
|
1159
|
-
注意,
|
1160
|
-
1. 找到的文件名必须出现在上面的文件列表中
|
1161
|
-
2. 原因控制在20字以内, 且使用中文
|
1162
|
-
3. 请严格按格式要求返回结果,无需额外的说明
|
1163
|
-
"""
|
1164
|
-
|
1165
|
-
def get_related_files(self, file_paths: List[str]):
|
1166
|
-
""" 根据文件路径查询相关文件 """
|
1167
|
-
all_results = []
|
1168
|
-
|
1169
|
-
completed = 0
|
1170
|
-
total = 0
|
1171
|
-
|
1172
|
-
for chunk in self._get_meta_str():
|
1173
|
-
result = self._get_related_files.with_llm(self.llm).with_return_type(
|
1174
|
-
FileList).run(chunk, "\n".join(file_paths))
|
1175
|
-
if result is not None:
|
1176
|
-
all_results.extend(result.file_list)
|
1177
|
-
completed += 1
|
1178
|
-
else:
|
1179
|
-
logger.warning(f"无法找到与分块相关的文件。原因可能是模型限制或查询条件与文件不匹配。")
|
1180
|
-
total += 1
|
1181
|
-
time.sleep(self.anti_quota_limit)
|
1182
|
-
logger.info(f"已完成 {completed}/{total} 个分块(基于相关文件)")
|
1183
|
-
all_results = list({file.file_path: file for file in all_results}.values())
|
1184
|
-
return FileList(file_list=all_results)
|
1185
|
-
|
1186
|
-
@prompt()
|
1187
|
-
def verify_file_relevance(self, file_content: str, query: str) -> str:
|
1188
|
-
"""
|
1189
|
-
请验证下面的文件内容是否与用户问题相关:
|
1190
|
-
|
1191
|
-
文件内容:
|
1192
|
-
{{ file_content }}
|
1193
|
-
|
1194
|
-
用户问题:
|
1195
|
-
{{ query }}
|
1196
|
-
|
1197
|
-
相关是指,需要依赖这个文件提供上下文,或者需要修改这个文件才能解决用户的问题。
|
1198
|
-
请给出相应的可能性分数:0-10,并结合用户问题,理由控制在50字以内,并且使用中文。
|
1199
|
-
请严格按格式要求返回结果。
|
1200
|
-
格式如下:
|
1201
|
-
|
1202
|
-
```json
|
1203
|
-
{
|
1204
|
-
"relevant_score": 0-10,
|
1205
|
-
"reason": "这是相关的原因..."
|
1206
|
-
}
|
1207
|
-
```
|
1208
|
-
"""
|
1209
|
-
|
1210
|
-
|
1211
775
|
def index_command(llm):
|
1212
|
-
|
1213
|
-
# 默认 chat 配置
|
1214
|
-
yaml_config = {
|
1215
|
-
"include_file": ["./base/base.yml"],
|
1216
|
-
"include_project_structure": conf.get("include_project_structure", "true") in ["true", "True"],
|
1217
|
-
"human_as_model": conf.get("human_as_model", "false") == "true",
|
1218
|
-
"skip_build_index": conf.get("skip_build_index", "true") == "true",
|
1219
|
-
"skip_confirm": conf.get("skip_confirm", "true") == "true",
|
1220
|
-
"silence": conf.get("silence", "true") == "true",
|
1221
|
-
"query": ""
|
1222
|
-
}
|
1223
|
-
current_files = memory["current_files"]["files"] # get_llm_friendly_package_docs
|
1224
|
-
yaml_config["urls"] = current_files
|
1225
|
-
yaml_config["query"] = ""
|
1226
|
-
|
1227
|
-
# 如果 conf 中有设置, 则以 conf 配置为主
|
1228
|
-
for key, value in conf.items():
|
1229
|
-
converted_value = convert_config_value(key, value)
|
1230
|
-
if converted_value is not None:
|
1231
|
-
yaml_config[key] = converted_value
|
1232
|
-
|
1233
|
-
yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
|
1234
|
-
execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
|
1235
|
-
|
1236
|
-
with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
|
1237
|
-
f.write(yaml_content)
|
1238
|
-
|
1239
|
-
convert_yaml_to_config(execute_file) # 更新到args
|
776
|
+
update_config_to_args(query="", delete_execute_file=True)
|
1240
777
|
|
1241
778
|
source_dir = os.path.abspath(args.source_dir)
|
1242
779
|
logger.info(f"开始对目录 {source_dir} 中的源代码进行索引")
|
@@ -1246,7 +783,7 @@ def index_command(llm):
|
|
1246
783
|
pp = SuffixProject(llm=llm, args=args)
|
1247
784
|
pp.run()
|
1248
785
|
_sources = pp.sources
|
1249
|
-
index_manager = IndexManager(source_codes=_sources, llm=llm)
|
786
|
+
index_manager = IndexManager(args=args, source_codes=_sources, llm=llm)
|
1250
787
|
index_manager.build_index()
|
1251
788
|
|
1252
789
|
|
@@ -1332,34 +869,7 @@ def wrap_text_in_table(data, max_width=60):
|
|
1332
869
|
|
1333
870
|
|
1334
871
|
def index_query_command(query: str, llm: AutoLLM):
|
1335
|
-
|
1336
|
-
# 默认 chat 配置
|
1337
|
-
yaml_config = {
|
1338
|
-
"include_file": ["./base/base.yml"],
|
1339
|
-
"include_project_structure": conf.get("include_project_structure", "true") in ["true", "True"],
|
1340
|
-
"human_as_model": conf.get("human_as_model", "false") == "true",
|
1341
|
-
"skip_build_index": conf.get("skip_build_index", "true") == "true",
|
1342
|
-
"skip_confirm": conf.get("skip_confirm", "true") == "true",
|
1343
|
-
"silence": conf.get("silence", "true") == "true",
|
1344
|
-
"query": query
|
1345
|
-
}
|
1346
|
-
current_files = memory["current_files"]["files"] # get_llm_friendly_package_docs
|
1347
|
-
yaml_config["urls"] = current_files
|
1348
|
-
yaml_config["query"] = query
|
1349
|
-
|
1350
|
-
# 如果 conf 中有设置, 则以 conf 配置为主
|
1351
|
-
for key, value in conf.items():
|
1352
|
-
converted_value = convert_config_value(key, value)
|
1353
|
-
if converted_value is not None:
|
1354
|
-
yaml_config[key] = converted_value
|
1355
|
-
|
1356
|
-
yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
|
1357
|
-
execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
|
1358
|
-
|
1359
|
-
with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
|
1360
|
-
f.write(yaml_content)
|
1361
|
-
|
1362
|
-
convert_yaml_to_config(execute_file) # 更新到args
|
872
|
+
update_config_to_args(query=query, delete_execute_file=True)
|
1363
873
|
|
1364
874
|
# args.query = query
|
1365
875
|
if args.project_type == "py":
|
@@ -1370,7 +880,7 @@ def index_query_command(query: str, llm: AutoLLM):
|
|
1370
880
|
_sources = pp.sources
|
1371
881
|
|
1372
882
|
final_files = []
|
1373
|
-
index_manager = IndexManager(source_codes=_sources, llm=llm)
|
883
|
+
index_manager = IndexManager(args=args, source_codes=_sources, llm=llm)
|
1374
884
|
target_files = index_manager.get_target_files_by_query(query)
|
1375
885
|
|
1376
886
|
if target_files:
|
@@ -1397,159 +907,6 @@ def index_query_command(query: str, llm: AutoLLM):
|
|
1397
907
|
return
|
1398
908
|
|
1399
909
|
|
1400
|
-
def build_index_and_filter_files(llm, sources: List[SourceCode]) -> str:
|
1401
|
-
def get_file_path(_file_path):
|
1402
|
-
if _file_path.startswith("##"):
|
1403
|
-
return _file_path.strip()[2:]
|
1404
|
-
return _file_path
|
1405
|
-
|
1406
|
-
final_files: Dict[str, TargetFile] = {}
|
1407
|
-
logger.info("第一阶段:处理 REST/RAG/Search 资源...")
|
1408
|
-
for source in sources:
|
1409
|
-
if source.tag in ["REST", "RAG", "SEARCH"]:
|
1410
|
-
final_files[get_file_path(source.module_name)] = TargetFile(
|
1411
|
-
file_path=source.module_name, reason="Rest/Rag/Search"
|
1412
|
-
)
|
1413
|
-
|
1414
|
-
if not args.skip_build_index and llm:
|
1415
|
-
logger.info("第二阶段:为所有文件构建索引...")
|
1416
|
-
index_manager = IndexManager(llm=llm, source_codes=sources)
|
1417
|
-
index_data = index_manager.build_index()
|
1418
|
-
indexed_files_count = len(index_data) if index_data else 0
|
1419
|
-
logger.info(f"总索引文件数: {indexed_files_count}")
|
1420
|
-
|
1421
|
-
if not args.skip_filter_index and args.index_filter_level >= 1:
|
1422
|
-
logger.info("第三阶段:执行 Level 1 过滤(基于查询) ...")
|
1423
|
-
target_files = index_manager.get_target_files_by_query(args.query)
|
1424
|
-
if target_files:
|
1425
|
-
for file in target_files.file_list:
|
1426
|
-
file_path = file.file_path.strip()
|
1427
|
-
final_files[get_file_path(file_path)] = file
|
1428
|
-
|
1429
|
-
if target_files is not None and args.index_filter_level >= 2:
|
1430
|
-
logger.info("第四阶段:执行 Level 2 过滤(基于相关文件)...")
|
1431
|
-
related_files = index_manager.get_related_files(
|
1432
|
-
[file.file_path for file in target_files.file_list]
|
1433
|
-
)
|
1434
|
-
if related_files is not None:
|
1435
|
-
for file in related_files.file_list:
|
1436
|
-
file_path = file.file_path.strip()
|
1437
|
-
final_files[get_file_path(file_path)] = file
|
1438
|
-
|
1439
|
-
# 如果 Level 1 filtering 和 Level 2 filtering 都未获取路径,则使用全部文件
|
1440
|
-
if not final_files:
|
1441
|
-
logger.warning("Level 1, Level 2 过滤未找到相关文件, 将使用所有文件 ...")
|
1442
|
-
for source in sources:
|
1443
|
-
final_files[get_file_path(source.module_name)] = TargetFile(
|
1444
|
-
file_path=source.module_name,
|
1445
|
-
reason="No related files found, use all files",
|
1446
|
-
)
|
1447
|
-
|
1448
|
-
logger.info("第五阶段:执行相关性验证 ...")
|
1449
|
-
verified_files = {}
|
1450
|
-
temp_files = list(final_files.values())
|
1451
|
-
verification_results = []
|
1452
|
-
|
1453
|
-
def _print_verification_results(results):
|
1454
|
-
table = Table(title="文件相关性验证结果", expand=True, show_lines=True)
|
1455
|
-
table.add_column("文件路径", style="cyan", no_wrap=True)
|
1456
|
-
table.add_column("得分", justify="right", style="green")
|
1457
|
-
table.add_column("状态", style="yellow")
|
1458
|
-
table.add_column("原因/错误")
|
1459
|
-
if result:
|
1460
|
-
for _file_path, _score, _status, _reason in results:
|
1461
|
-
table.add_row(_file_path,
|
1462
|
-
str(_score) if _score is not None else "N/A", _status, _reason)
|
1463
|
-
console.print(table)
|
1464
|
-
|
1465
|
-
def _verify_single_file(single_file: TargetFile):
|
1466
|
-
for _source in sources:
|
1467
|
-
if _source.module_name == single_file.file_path:
|
1468
|
-
file_content = _source.source_code
|
1469
|
-
try:
|
1470
|
-
_result = index_manager.verify_file_relevance.with_llm(llm).with_return_type(
|
1471
|
-
VerifyFileRelevance).run(
|
1472
|
-
file_content=file_content,
|
1473
|
-
query=args.query
|
1474
|
-
)
|
1475
|
-
if _result.relevant_score >= args.verify_file_relevance_score:
|
1476
|
-
verified_files[single_file.file_path] = TargetFile(
|
1477
|
-
file_path=single_file.file_path,
|
1478
|
-
reason=f"Score:{_result.relevant_score}, {_result.reason}"
|
1479
|
-
)
|
1480
|
-
return single_file.file_path, _result.relevant_score, "PASS", _result.reason
|
1481
|
-
else:
|
1482
|
-
return single_file.file_path, _result.relevant_score, "FAIL", _result.reason
|
1483
|
-
except Exception as e:
|
1484
|
-
error_msg = str(e)
|
1485
|
-
verified_files[single_file.file_path] = TargetFile(
|
1486
|
-
file_path=single_file.file_path,
|
1487
|
-
reason=f"Verification failed: {error_msg}"
|
1488
|
-
)
|
1489
|
-
return single_file.file_path, None, "ERROR", error_msg
|
1490
|
-
return
|
1491
|
-
|
1492
|
-
for pending_verify_file in temp_files:
|
1493
|
-
result = _verify_single_file(pending_verify_file)
|
1494
|
-
if result:
|
1495
|
-
verification_results.append(result)
|
1496
|
-
time.sleep(args.anti_quota_limit)
|
1497
|
-
|
1498
|
-
_print_verification_results(verification_results)
|
1499
|
-
# Keep all files, not just verified ones
|
1500
|
-
final_files = verified_files
|
1501
|
-
|
1502
|
-
logger.info("第六阶段:筛选文件并应用限制条件 ...")
|
1503
|
-
if args.index_filter_file_num > 0:
|
1504
|
-
logger.info(f"从 {len(final_files)} 个文件中获取前 {args.index_filter_file_num} 个文件(Limit)")
|
1505
|
-
final_filenames = [file.file_path for file in final_files.values()]
|
1506
|
-
if not final_filenames:
|
1507
|
-
logger.warning("未找到目标文件,你可能需要重新编写查询并重试.")
|
1508
|
-
if args.index_filter_file_num > 0:
|
1509
|
-
final_filenames = final_filenames[: args.index_filter_file_num]
|
1510
|
-
|
1511
|
-
def _shorten_path(path: str, keep_levels: int = 3) -> str:
|
1512
|
-
"""
|
1513
|
-
优化长路径显示,保留最后指定层级
|
1514
|
-
示例:/a/b/c/d/e/f.py -> .../c/d/e/f.py
|
1515
|
-
"""
|
1516
|
-
parts = path.split(os.sep)
|
1517
|
-
if len(parts) > keep_levels:
|
1518
|
-
return ".../" + os.sep.join(parts[-keep_levels:])
|
1519
|
-
return path
|
1520
|
-
|
1521
|
-
def _print_selected(data):
|
1522
|
-
table = Table(title="代码上下文文件", expand=True, show_lines=True)
|
1523
|
-
table.add_column("文件路径", style="cyan")
|
1524
|
-
table.add_column("原因", style="cyan")
|
1525
|
-
for _file, _reason in data:
|
1526
|
-
# 路径截取优化:保留最后 3 级路径
|
1527
|
-
_processed_path = _shorten_path(_file, keep_levels=3)
|
1528
|
-
table.add_row(_processed_path, _reason)
|
1529
|
-
console.print(table)
|
1530
|
-
|
1531
|
-
logger.info("第七阶段:准备最终输出 ...")
|
1532
|
-
_print_selected(
|
1533
|
-
[
|
1534
|
-
(file.file_path, file.reason)
|
1535
|
-
for file in final_files.values()
|
1536
|
-
if file.file_path in final_filenames
|
1537
|
-
]
|
1538
|
-
)
|
1539
|
-
result_source_code = ""
|
1540
|
-
depulicated_sources = set()
|
1541
|
-
|
1542
|
-
for file in sources:
|
1543
|
-
if file.module_name in final_filenames:
|
1544
|
-
if file.module_name in depulicated_sources:
|
1545
|
-
continue
|
1546
|
-
depulicated_sources.add(file.module_name)
|
1547
|
-
result_source_code += f"##File: {file.module_name}\n"
|
1548
|
-
result_source_code += f"{file.source_code}\n\n"
|
1549
|
-
|
1550
|
-
return result_source_code
|
1551
|
-
|
1552
|
-
|
1553
910
|
def convert_yaml_config_to_str(yaml_config):
|
1554
911
|
yaml_content = yaml.safe_dump(
|
1555
912
|
yaml_config,
|
@@ -1598,6 +955,41 @@ def convert_config_value(key, value):
|
|
1598
955
|
return None
|
1599
956
|
|
1600
957
|
|
958
|
+
def update_config_to_args(query, delete_execute_file: bool = False):
|
959
|
+
conf = memory.get("conf", {})
|
960
|
+
|
961
|
+
# 默认 chat 配置
|
962
|
+
yaml_config = {
|
963
|
+
"include_file": ["./base/base.yml"],
|
964
|
+
"skip_build_index": conf.get("skip_build_index", "true") == "true",
|
965
|
+
"skip_confirm": conf.get("skip_confirm", "true") == "true",
|
966
|
+
"chat_model": conf.get("chat_model", ""),
|
967
|
+
"code_model": conf.get("code_model", ""),
|
968
|
+
"auto_merge": conf.get("auto_merge", "editblock")
|
969
|
+
}
|
970
|
+
current_files = memory["current_files"]["files"]
|
971
|
+
yaml_config["urls"] = current_files
|
972
|
+
yaml_config["query"] = query
|
973
|
+
|
974
|
+
# 如果 conf 中有设置, 则以 conf 配置为主
|
975
|
+
for key, value in conf.items():
|
976
|
+
converted_value = convert_config_value(key, value)
|
977
|
+
if converted_value is not None:
|
978
|
+
yaml_config[key] = converted_value
|
979
|
+
|
980
|
+
yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
|
981
|
+
execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
|
982
|
+
|
983
|
+
with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
|
984
|
+
f.write(yaml_content)
|
985
|
+
|
986
|
+
convert_yaml_to_config(execute_file) # 更新到args
|
987
|
+
|
988
|
+
if delete_execute_file:
|
989
|
+
if os.path.exists(execute_file):
|
990
|
+
os.remove(execute_file)
|
991
|
+
|
992
|
+
|
1601
993
|
def print_chat_history(history, max_entries=5):
|
1602
994
|
recent_history = history[-max_entries:]
|
1603
995
|
table = Table(show_header=False, padding=(0, 1), expand=True, show_lines=True)
|
@@ -1638,6 +1030,8 @@ def code_review(query: str) -> str:
|
|
1638
1030
|
|
1639
1031
|
|
1640
1032
|
def chat(query: str, llm: AutoLLM):
|
1033
|
+
update_config_to_args(query)
|
1034
|
+
|
1641
1035
|
is_history = query.strip().startswith("/history")
|
1642
1036
|
is_new = "/new" in query
|
1643
1037
|
if is_new:
|
@@ -1652,36 +1046,6 @@ def chat(query: str, llm: AutoLLM):
|
|
1652
1046
|
query = query.replace("/review", "", 1).strip()
|
1653
1047
|
query = code_review.prompt(query)
|
1654
1048
|
|
1655
|
-
conf = memory.get("conf", {})
|
1656
|
-
# 默认 chat 配置
|
1657
|
-
yaml_config = {
|
1658
|
-
"include_file": ["./base/base.yml"],
|
1659
|
-
"include_project_structure": conf.get("include_project_structure", "true") in ["true", "True"],
|
1660
|
-
"human_as_model": conf.get("human_as_model", "false") == "true",
|
1661
|
-
"skip_build_index": conf.get("skip_build_index", "true") == "true",
|
1662
|
-
"skip_confirm": conf.get("skip_confirm", "true") == "true",
|
1663
|
-
"silence": conf.get("silence", "true") == "true",
|
1664
|
-
"query": query
|
1665
|
-
}
|
1666
|
-
current_files = memory["current_files"]["files"] # get_llm_friendly_package_docs
|
1667
|
-
yaml_config["urls"] = current_files
|
1668
|
-
|
1669
|
-
yaml_config["query"] = query
|
1670
|
-
|
1671
|
-
# 如果 conf 中有设置, 则以 conf 配置为主
|
1672
|
-
for key, value in conf.items():
|
1673
|
-
converted_value = convert_config_value(key, value)
|
1674
|
-
if converted_value is not None:
|
1675
|
-
yaml_config[key] = converted_value
|
1676
|
-
|
1677
|
-
yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
|
1678
|
-
execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
|
1679
|
-
|
1680
|
-
with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
|
1681
|
-
f.write(yaml_content)
|
1682
|
-
|
1683
|
-
convert_yaml_to_config(execute_file) # 更新到args
|
1684
|
-
|
1685
1049
|
memory_dir = os.path.join(args.source_dir, ".auto-coder", "memory")
|
1686
1050
|
os.makedirs(memory_dir, exist_ok=True)
|
1687
1051
|
memory_file = os.path.join(memory_dir, "chat_history.json")
|
@@ -1745,7 +1109,7 @@ def chat(query: str, llm: AutoLLM):
|
|
1745
1109
|
pp = SuffixProject(llm=llm, args=args)
|
1746
1110
|
pp.run()
|
1747
1111
|
_sources = pp.sources
|
1748
|
-
s = build_index_and_filter_files(llm=llm, sources=_sources)
|
1112
|
+
s = build_index_and_filter_files(args=args, llm=llm, sources=_sources)
|
1749
1113
|
if s:
|
1750
1114
|
pre_conversations.append(
|
1751
1115
|
{
|
@@ -1760,7 +1124,7 @@ def chat(query: str, llm: AutoLLM):
|
|
1760
1124
|
|
1761
1125
|
loaded_conversations = pre_conversations + chat_history["ask_conversation"]
|
1762
1126
|
|
1763
|
-
v = chat_llm.stream_chat_ai(conversations=loaded_conversations, model=
|
1127
|
+
v = chat_llm.stream_chat_ai(conversations=loaded_conversations, model=args.chat_model)
|
1764
1128
|
|
1765
1129
|
MAX_HISTORY_LINES = 15 # 最大保留历史行数
|
1766
1130
|
lines_buffer = []
|
@@ -1814,30 +1178,6 @@ def chat(query: str, llm: AutoLLM):
|
|
1814
1178
|
return
|
1815
1179
|
|
1816
1180
|
|
1817
|
-
def git_print_commit_info(commit_result: CommitResult):
|
1818
|
-
table = Table(
|
1819
|
-
title="Commit Information (Use /revert to revert this commit)", show_header=True, header_style="bold magenta"
|
1820
|
-
)
|
1821
|
-
table.add_column("Attribute", style="cyan", no_wrap=True)
|
1822
|
-
table.add_column("Value", style="green")
|
1823
|
-
|
1824
|
-
table.add_row("Commit Hash", commit_result.commit_hash)
|
1825
|
-
table.add_row("Commit Message", commit_result.commit_message)
|
1826
|
-
table.add_row("Changed Files", "\n".join(commit_result.changed_files))
|
1827
|
-
|
1828
|
-
console.print(
|
1829
|
-
Panel(table, expand=False, border_style="green", title="Git Commit Summary")
|
1830
|
-
)
|
1831
|
-
|
1832
|
-
if commit_result.diffs:
|
1833
|
-
for file, diff in commit_result.diffs.items():
|
1834
|
-
console.print(f"\n[bold blue]File: {file}[/bold blue]")
|
1835
|
-
syntax = Syntax(diff, "diff", theme="monokai", line_numbers=True)
|
1836
|
-
console.print(
|
1837
|
-
Panel(syntax, expand=False, border_style="yellow", title="File Diff")
|
1838
|
-
)
|
1839
|
-
|
1840
|
-
|
1841
1181
|
def init_project():
|
1842
1182
|
if not args.project_type:
|
1843
1183
|
logger.error(
|
@@ -1910,1040 +1250,6 @@ def load_include_files(config, base_path, max_depth=10, current_depth=0):
|
|
1910
1250
|
return config
|
1911
1251
|
|
1912
1252
|
|
1913
|
-
class CodeAutoGenerateEditBlock:
|
1914
|
-
def __init__(self, llm: AutoLLM, action=None, fence_0: str = "```", fence_1: str = "```"):
|
1915
|
-
self.llm = llm
|
1916
|
-
self.llm.setup_default_model_name(memory["conf"]["current_code_model"])
|
1917
|
-
self.args = args
|
1918
|
-
self.action = action
|
1919
|
-
self.fence_0 = fence_0
|
1920
|
-
self.fence_1 = fence_1
|
1921
|
-
if not self.llm:
|
1922
|
-
raise ValueError("Please provide a valid model instance to use for code generation.")
|
1923
|
-
self.llms = [self.llm]
|
1924
|
-
|
1925
|
-
@prompt()
|
1926
|
-
def single_round_instruction(self, instruction: str, content: str, context: str = ""):
|
1927
|
-
"""
|
1928
|
-
如果你需要生成代码,对于每个需要更改的文件,你需要按 *SEARCH/REPLACE block* 的格式进行生成。
|
1929
|
-
|
1930
|
-
# *SEARCH/REPLACE block* Rules:
|
1931
|
-
|
1932
|
-
Every *SEARCH/REPLACE block* must use this format:
|
1933
|
-
1. The opening fence and code language, eg: {{ fence_0 }}python
|
1934
|
-
2. The file path alone on a line, starting with "##File:" and verbatim. No bold asterisks, no quotes around it,
|
1935
|
-
no escaping of characters, etc.
|
1936
|
-
3. The start of search block: <<<<<<< SEARCH
|
1937
|
-
4. A contiguous chunk of lines to search for in the existing source code
|
1938
|
-
5. The dividing line: =======
|
1939
|
-
6. The lines to replace into the source code
|
1940
|
-
7. The end of the replacement block: >>>>>>> REPLACE
|
1941
|
-
8. The closing fence: {{ fence_1 }}
|
1942
|
-
|
1943
|
-
Every *SEARCH* section must *EXACTLY MATCH* the existing source code, character for character,
|
1944
|
-
including all comments, docstrings, etc.
|
1945
|
-
|
1946
|
-
*SEARCH/REPLACE* blocks will replace *all* matching occurrences.
|
1947
|
-
Include enough lines to make the SEARCH blocks unique.
|
1948
|
-
|
1949
|
-
Include *ALL* the code being searched and replaced!
|
1950
|
-
|
1951
|
-
To move code within a file, use 2 *SEARCH/REPLACE* blocks: 1 to delete it from its current location,
|
1952
|
-
1 to insert it in the new location.
|
1953
|
-
|
1954
|
-
If you want to put code in a new file, use a *SEARCH/REPLACE block* with:
|
1955
|
-
- A new file path, including dir name if needed
|
1956
|
-
- An empty `SEARCH` section
|
1957
|
-
- The new file's contents in the `REPLACE` section
|
1958
|
-
|
1959
|
-
ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
|
1960
|
-
|
1961
|
-
下面我们来看一个例子:
|
1962
|
-
|
1963
|
-
当前项目目录结构:
|
1964
|
-
1. 项目根目录: /tmp/projects/mathweb
|
1965
|
-
2. 项目子目录/文件列表(类似tree 命令输出)
|
1966
|
-
flask/
|
1967
|
-
app.py
|
1968
|
-
templates/
|
1969
|
-
index.html
|
1970
|
-
static/
|
1971
|
-
style.css
|
1972
|
-
|
1973
|
-
用户需求: Change get_factorial() to use math.factorial
|
1974
|
-
|
1975
|
-
回答: To make this change we need to modify `/tmp/projects/mathweb/flask/app.py` to:
|
1976
|
-
|
1977
|
-
1. Import the math package.
|
1978
|
-
2. Remove the existing factorial() function.
|
1979
|
-
3. Update get_factorial() to call math.factorial instead.
|
1980
|
-
|
1981
|
-
Here are the *SEARCH/REPLACE* blocks:
|
1982
|
-
|
1983
|
-
```python
|
1984
|
-
##File: /tmp/projects/mathweb/flask/app.py
|
1985
|
-
<<<<<<< SEARCH
|
1986
|
-
from flask import Flask
|
1987
|
-
=======
|
1988
|
-
import math
|
1989
|
-
from flask import Flask
|
1990
|
-
>>>>>>> REPLACE
|
1991
|
-
```
|
1992
|
-
|
1993
|
-
```python
|
1994
|
-
##File: /tmp/projects/mathweb/flask/app.py
|
1995
|
-
<<<<<<< SEARCH
|
1996
|
-
def factorial(n):
|
1997
|
-
"compute factorial"
|
1998
|
-
|
1999
|
-
if n == 0:
|
2000
|
-
return 1
|
2001
|
-
else:
|
2002
|
-
return n * factorial(n-1)
|
2003
|
-
|
2004
|
-
=======
|
2005
|
-
>>>>>>> REPLACE
|
2006
|
-
```
|
2007
|
-
|
2008
|
-
```python
|
2009
|
-
##File: /tmp/projects/mathweb/flask/app.py
|
2010
|
-
<<<<<<< SEARCH
|
2011
|
-
return str(factorial(n))
|
2012
|
-
=======
|
2013
|
-
return str(math.factorial(n))
|
2014
|
-
>>>>>>> REPLACE
|
2015
|
-
```
|
2016
|
-
|
2017
|
-
用户需求: Refactor hello() into its own file.
|
2018
|
-
|
2019
|
-
回答:To make this change we need to modify `main.py` and make a new file `hello.py`:
|
2020
|
-
|
2021
|
-
1. Make a new hello.py file with hello() in it.
|
2022
|
-
2. Remove hello() from main.py and replace it with an import.
|
2023
|
-
|
2024
|
-
Here are the *SEARCH/REPLACE* blocks:
|
2025
|
-
|
2026
|
-
```python
|
2027
|
-
##File: /tmp/projects/mathweb/hello.py
|
2028
|
-
<<<<<<< SEARCH
|
2029
|
-
=======
|
2030
|
-
def hello():
|
2031
|
-
"print a greeting"
|
2032
|
-
|
2033
|
-
print("hello")
|
2034
|
-
>>>>>>> REPLACE
|
2035
|
-
```
|
2036
|
-
|
2037
|
-
```python
|
2038
|
-
##File: /tmp/projects/mathweb/main.py
|
2039
|
-
<<<<<<< SEARCH
|
2040
|
-
def hello():
|
2041
|
-
"print a greeting"
|
2042
|
-
|
2043
|
-
print("hello")
|
2044
|
-
=======
|
2045
|
-
from hello import hello
|
2046
|
-
>>>>>>> REPLACE
|
2047
|
-
```
|
2048
|
-
|
2049
|
-
现在让我们开始一个新的任务:
|
2050
|
-
|
2051
|
-
{%- if structure %}
|
2052
|
-
{{ structure }}
|
2053
|
-
{%- endif %}
|
2054
|
-
|
2055
|
-
{%- if content %}
|
2056
|
-
下面是一些文件路径以及每个文件对应的源码:
|
2057
|
-
<files>
|
2058
|
-
{{ content }}
|
2059
|
-
</files>
|
2060
|
-
{%- endif %}
|
2061
|
-
|
2062
|
-
{%- if context %}
|
2063
|
-
<extra_context>
|
2064
|
-
{{ context }}
|
2065
|
-
</extra_context>
|
2066
|
-
{%- endif %}
|
2067
|
-
|
2068
|
-
下面是用户的需求:
|
2069
|
-
|
2070
|
-
{{ instruction }}
|
2071
|
-
|
2072
|
-
"""
|
2073
|
-
|
2074
|
-
@prompt()
|
2075
|
-
def auto_implement_function(self, instruction: str, content: str) -> str:
|
2076
|
-
"""
|
2077
|
-
下面是一些文件路径以及每个文件对应的源码:
|
2078
|
-
|
2079
|
-
{{ content }}
|
2080
|
-
|
2081
|
-
请参考上面的内容,重新实现所有文件下方法体标记了如下内容的方法:
|
2082
|
-
|
2083
|
-
```python
|
2084
|
-
raise NotImplementedError("This function should be implemented by the model.")
|
2085
|
-
```
|
2086
|
-
|
2087
|
-
{{ instruction }}
|
2088
|
-
|
2089
|
-
"""
|
2090
|
-
|
2091
|
-
def single_round_run(self, query: str, source_content: str) -> CodeGenerateResult:
|
2092
|
-
init_prompt = ''
|
2093
|
-
if self.args.template == "common":
|
2094
|
-
init_prompt = self.single_round_instruction.prompt(
|
2095
|
-
instruction=query, content=source_content, context=self.args.context
|
2096
|
-
)
|
2097
|
-
elif self.args.template == "auto_implement":
|
2098
|
-
init_prompt = self.auto_implement_function.prompt(
|
2099
|
-
instruction=query, content=source_content
|
2100
|
-
)
|
2101
|
-
|
2102
|
-
with open(self.args.target_file, "w") as file:
|
2103
|
-
file.write(init_prompt)
|
2104
|
-
|
2105
|
-
conversations = [{"role": "user", "content": init_prompt}]
|
2106
|
-
|
2107
|
-
conversations_list = []
|
2108
|
-
results = []
|
2109
|
-
|
2110
|
-
for llm in self.llms:
|
2111
|
-
v = llm.chat_ai(conversations=conversations)
|
2112
|
-
results.append(v.output)
|
2113
|
-
for result in results:
|
2114
|
-
conversations_list.append(conversations + [{"role": "assistant", "content": result}])
|
2115
|
-
|
2116
|
-
return CodeGenerateResult(contents=results, conversations=conversations_list)
|
2117
|
-
|
2118
|
-
@prompt()
|
2119
|
-
def multi_round_instruction(self, instruction: str, content: str, context: str = "") -> str:
|
2120
|
-
"""
|
2121
|
-
如果你需要生成代码,对于每个需要更改的文件,你需要按 *SEARCH/REPLACE block* 的格式进行生成。
|
2122
|
-
|
2123
|
-
# *SEARCH/REPLACE block* Rules:
|
2124
|
-
|
2125
|
-
Every *SEARCH/REPLACE block* must use this format:
|
2126
|
-
1. The opening fence and code language, eg: {{ fence_0 }}python
|
2127
|
-
2. The file path alone on a line, starting with "##File:" and verbatim. No bold asterisks, no quotes around it,
|
2128
|
-
no escaping of characters, etc.
|
2129
|
-
3. The start of search block: <<<<<<< SEARCH
|
2130
|
-
4. A contiguous chunk of lines to search for in the existing source code
|
2131
|
-
5. The dividing line: =======
|
2132
|
-
6. The lines to replace into the source code
|
2133
|
-
7. The end of the replacement block: >>>>>>> REPLACE
|
2134
|
-
8. The closing fence: {{ fence_1 }}
|
2135
|
-
|
2136
|
-
Every *SEARCH* section must *EXACTLY MATCH* the existing source code, character for character,
|
2137
|
-
including all comments, docstrings, etc.
|
2138
|
-
|
2139
|
-
*SEARCH/REPLACE* blocks will replace *all* matching occurrences.
|
2140
|
-
Include enough lines to make the SEARCH blocks unique.
|
2141
|
-
|
2142
|
-
Include *ALL* the code being searched and replaced!
|
2143
|
-
|
2144
|
-
To move code within a file, use 2 *SEARCH/REPLACE* blocks: 1 to delete it from its current location,
|
2145
|
-
1 to insert it in the new location.
|
2146
|
-
|
2147
|
-
If you want to put code in a new file, use a *SEARCH/REPLACE block* with:
|
2148
|
-
- A new file path, including dir name if needed
|
2149
|
-
- An empty `SEARCH` section
|
2150
|
-
- The new file's contents in the `REPLACE` section
|
2151
|
-
|
2152
|
-
ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
|
2153
|
-
|
2154
|
-
下面我们来看一个例子:
|
2155
|
-
|
2156
|
-
当前项目目录结构:
|
2157
|
-
1. 项目根目录: /tmp/projects/mathweb
|
2158
|
-
2. 项目子目录/文件列表(类似tree 命令输出)
|
2159
|
-
flask/
|
2160
|
-
app.py
|
2161
|
-
templates/
|
2162
|
-
index.html
|
2163
|
-
static/
|
2164
|
-
style.css
|
2165
|
-
|
2166
|
-
用户需求: Change get_factorial() to use math.factorial
|
2167
|
-
|
2168
|
-
回答: To make this change we need to modify `/tmp/projects/mathweb/flask/app.py` to:
|
2169
|
-
|
2170
|
-
1. Import the math package.
|
2171
|
-
2. Remove the existing factorial() function.
|
2172
|
-
3. Update get_factorial() to call math.factorial instead.
|
2173
|
-
|
2174
|
-
Here are the *SEARCH/REPLACE* blocks:
|
2175
|
-
|
2176
|
-
{{ fence_0 }}python
|
2177
|
-
##File: /tmp/projects/mathweb/flask/app.py
|
2178
|
-
<<<<<<< SEARCH
|
2179
|
-
from flask import Flask
|
2180
|
-
=======
|
2181
|
-
import math
|
2182
|
-
from flask import Flask
|
2183
|
-
>>>>>>> REPLACE
|
2184
|
-
{{ fence_1 }}
|
2185
|
-
|
2186
|
-
{{ fence_0 }}python
|
2187
|
-
##File: /tmp/projects/mathweb/flask/app.py
|
2188
|
-
<<<<<<< SEARCH
|
2189
|
-
def factorial(n):
|
2190
|
-
"compute factorial"
|
2191
|
-
|
2192
|
-
if n == 0:
|
2193
|
-
return 1
|
2194
|
-
else:
|
2195
|
-
return n * factorial(n-1)
|
2196
|
-
|
2197
|
-
=======
|
2198
|
-
>>>>>>> REPLACE
|
2199
|
-
{{ fence_1 }}
|
2200
|
-
|
2201
|
-
{{ fence_0 }}python
|
2202
|
-
##File: /tmp/projects/mathweb/flask/app.py
|
2203
|
-
<<<<<<< SEARCH
|
2204
|
-
return str(factorial(n))
|
2205
|
-
=======
|
2206
|
-
return str(math.factorial(n))
|
2207
|
-
>>>>>>> REPLACE
|
2208
|
-
{{ fence_1 }}
|
2209
|
-
|
2210
|
-
用户需求: Refactor hello() into its own file.
|
2211
|
-
|
2212
|
-
回答:To make this change we need to modify `main.py` and make a new file `hello.py`:
|
2213
|
-
|
2214
|
-
1. Make a new hello.py file with hello() in it.
|
2215
|
-
2. Remove hello() from main.py and replace it with an import.
|
2216
|
-
|
2217
|
-
Here are the *SEARCH/REPLACE* blocks:
|
2218
|
-
|
2219
|
-
|
2220
|
-
{{ fence_0 }}python
|
2221
|
-
##File: /tmp/projects/mathweb/hello.py
|
2222
|
-
<<<<<<< SEARCH
|
2223
|
-
=======
|
2224
|
-
def hello():
|
2225
|
-
"print a greeting"
|
2226
|
-
|
2227
|
-
print("hello")
|
2228
|
-
>>>>>>> REPLACE
|
2229
|
-
{{ fence_1 }}
|
2230
|
-
|
2231
|
-
{{ fence_0 }}python
|
2232
|
-
##File: /tmp/projects/mathweb/main.py
|
2233
|
-
<<<<<<< SEARCH
|
2234
|
-
def hello():
|
2235
|
-
"print a greeting"
|
2236
|
-
|
2237
|
-
print("hello")
|
2238
|
-
=======
|
2239
|
-
from hello import hello
|
2240
|
-
>>>>>>> REPLACE
|
2241
|
-
{{ fence_1 }}
|
2242
|
-
|
2243
|
-
现在让我们开始一个新的任务:
|
2244
|
-
|
2245
|
-
{%- if structure %}
|
2246
|
-
{{ structure }}
|
2247
|
-
{%- endif %}
|
2248
|
-
|
2249
|
-
{%- if content %}
|
2250
|
-
下面是一些文件路径以及每个文件对应的源码:
|
2251
|
-
<files>
|
2252
|
-
{{ content }}
|
2253
|
-
</files>
|
2254
|
-
{%- endif %}
|
2255
|
-
|
2256
|
-
{%- if context %}
|
2257
|
-
<extra_context>
|
2258
|
-
{{ context }}
|
2259
|
-
</extra_context>
|
2260
|
-
{%- endif %}
|
2261
|
-
|
2262
|
-
下面是用户的需求:
|
2263
|
-
|
2264
|
-
{{ instruction }}
|
2265
|
-
|
2266
|
-
每次生成一个文件的*SEARCH/REPLACE* blocks,然后询问我是否继续,当我回复继续,
|
2267
|
-
继续生成下一个文件的*SEARCH/REPLACE* blocks。当没有后续任务时,请回复 "__完成__" 或者 "__EOF__"。
|
2268
|
-
"""
|
2269
|
-
|
2270
|
-
def multi_round_run(self, query: str, source_content: str, max_steps: int = 3) -> CodeGenerateResult:
|
2271
|
-
init_prompt = ''
|
2272
|
-
if self.args.template == "common":
|
2273
|
-
init_prompt = self.multi_round_instruction.prompt(
|
2274
|
-
instruction=query, content=source_content, context=self.args.context
|
2275
|
-
)
|
2276
|
-
elif self.args.template == "auto_implement":
|
2277
|
-
init_prompt = self.auto_implement_function.prompt(
|
2278
|
-
instruction=query, content=source_content
|
2279
|
-
)
|
2280
|
-
|
2281
|
-
with open(self.args.target_file, "w") as file:
|
2282
|
-
file.write(init_prompt)
|
2283
|
-
|
2284
|
-
results = []
|
2285
|
-
conversations = [{"role": "user", "content": init_prompt}]
|
2286
|
-
|
2287
|
-
code_llm = self.llms[0]
|
2288
|
-
v = code_llm.chat_ai(conversations=conversations)
|
2289
|
-
results.append(v.output)
|
2290
|
-
|
2291
|
-
conversations.append({"role": "assistant", "content": v.output})
|
2292
|
-
|
2293
|
-
if "__完成__" in v.output or "/done" in v.output or "__EOF__" in v.output:
|
2294
|
-
return CodeGenerateResult(contents=["\n\n".join(results)], conversations=[conversations])
|
2295
|
-
|
2296
|
-
current_step = 0
|
2297
|
-
|
2298
|
-
while current_step < max_steps:
|
2299
|
-
conversations.append({"role": "user", "content": "继续"})
|
2300
|
-
|
2301
|
-
with open(self.args.target_file, "w") as file:
|
2302
|
-
file.write("继续")
|
2303
|
-
|
2304
|
-
t = code_llm.chat_ai(conversations=conversations)
|
2305
|
-
|
2306
|
-
results.append(t.output)
|
2307
|
-
conversations.append({"role": "assistant", "content": t.output})
|
2308
|
-
current_step += 1
|
2309
|
-
|
2310
|
-
if "__完成__" in t.output or "/done" in t.output or "__EOF__" in t.output:
|
2311
|
-
return CodeGenerateResult(contents=["\n\n".join(results)], conversations=[conversations])
|
2312
|
-
|
2313
|
-
return CodeGenerateResult(contents=["\n\n".join(results)], conversations=[conversations])
|
2314
|
-
|
2315
|
-
|
2316
|
-
class CodeModificationRanker:
|
2317
|
-
def __init__(self, llm: AutoLLM):
|
2318
|
-
self.llm = llm
|
2319
|
-
self.llm.setup_default_model_name(memory["conf"]["current_code_model"])
|
2320
|
-
self.args = args
|
2321
|
-
self.llms = [self.llm]
|
2322
|
-
|
2323
|
-
@prompt()
|
2324
|
-
def _rank_modifications(self, s: CodeGenerateResult) -> str:
|
2325
|
-
"""
|
2326
|
-
对一组代码修改进行质量评估并排序。
|
2327
|
-
|
2328
|
-
下面是修改需求:
|
2329
|
-
|
2330
|
-
<edit_requirement>
|
2331
|
-
{{ s.conversations[0][-2]["content"] }}
|
2332
|
-
</edit_requirement>
|
2333
|
-
|
2334
|
-
下面是相应的代码修改:
|
2335
|
-
{% for content in s.contents %}
|
2336
|
-
<edit_block id="{{ loop.index0 }}">
|
2337
|
-
{{content}}
|
2338
|
-
</edit_block>
|
2339
|
-
{% endfor %}
|
2340
|
-
|
2341
|
-
请输出如下格式的评估结果,只包含 JSON 数据:
|
2342
|
-
|
2343
|
-
```json
|
2344
|
-
{
|
2345
|
-
"rank_result": [id1, id2, id3] // id 为 edit_block 的 id,按质量从高到低排序
|
2346
|
-
}
|
2347
|
-
```
|
2348
|
-
|
2349
|
-
注意:
|
2350
|
-
1. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
|
2351
|
-
"""
|
2352
|
-
|
2353
|
-
def rank_modifications(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
|
2354
|
-
import time
|
2355
|
-
from collections import defaultdict
|
2356
|
-
|
2357
|
-
start_time = time.time()
|
2358
|
-
logger.info(f"开始对 {len(generate_result.contents)} 个候选结果进行排序")
|
2359
|
-
|
2360
|
-
try:
|
2361
|
-
results = []
|
2362
|
-
for llm in self.llms:
|
2363
|
-
v = self._rank_modifications.with_llm(llm).with_return_type(RankResult).run(generate_result)
|
2364
|
-
results.append(v.rank_result)
|
2365
|
-
|
2366
|
-
if not results:
|
2367
|
-
raise Exception("All ranking requests failed")
|
2368
|
-
|
2369
|
-
# 计算每个候选人的分数
|
2370
|
-
candidate_scores = defaultdict(float)
|
2371
|
-
for rank_result in results:
|
2372
|
-
for idx, candidate_id in enumerate(rank_result):
|
2373
|
-
# Score is 1/(position + 1) since position starts from 0
|
2374
|
-
candidate_scores[candidate_id] += 1.0 / (idx + 1)
|
2375
|
-
# 按分数降序对候选人进行排序
|
2376
|
-
sorted_candidates = sorted(candidate_scores.keys(),
|
2377
|
-
key=lambda x: candidate_scores[x],
|
2378
|
-
reverse=True)
|
2379
|
-
|
2380
|
-
elapsed = time.time() - start_time
|
2381
|
-
score_details = ", ".join([f"candidate {i}: {candidate_scores[i]:.2f}" for i in sorted_candidates])
|
2382
|
-
logger.info(
|
2383
|
-
f"排序完成,耗时 {elapsed:.2f} 秒,最佳候选索引: {sorted_candidates[0]},评分详情: {score_details}"
|
2384
|
-
)
|
2385
|
-
|
2386
|
-
rerank_contents = [generate_result.contents[i] for i in sorted_candidates]
|
2387
|
-
rerank_conversations = [generate_result.conversations[i] for i in sorted_candidates]
|
2388
|
-
|
2389
|
-
return CodeGenerateResult(contents=rerank_contents, conversations=rerank_conversations)
|
2390
|
-
|
2391
|
-
except Exception as e:
|
2392
|
-
logger.error(f"排序过程失败: {str(e)}")
|
2393
|
-
logger.debug(traceback.format_exc())
|
2394
|
-
elapsed = time.time() - start_time
|
2395
|
-
logger.warning(f"排序失败,耗时 {elapsed:.2f} 秒,将使用原始顺序")
|
2396
|
-
return generate_result
|
2397
|
-
|
2398
|
-
|
2399
|
-
class TextSimilarity:
|
2400
|
-
"""
|
2401
|
-
找到 text_b 中与 text_a 最相似的部分(滑动窗口)
|
2402
|
-
返回相似度分数和最相似的文本片段
|
2403
|
-
"""
|
2404
|
-
|
2405
|
-
def __init__(self, text_a, text_b):
|
2406
|
-
self.text_a = text_a
|
2407
|
-
self.text_b = text_b
|
2408
|
-
self.lines_a = self._split_into_lines(text_a)
|
2409
|
-
self.lines_b = self._split_into_lines(text_b)
|
2410
|
-
self.m = len(self.lines_a)
|
2411
|
-
self.n = len(self.lines_b)
|
2412
|
-
|
2413
|
-
@staticmethod
|
2414
|
-
def _split_into_lines(text):
|
2415
|
-
return text.splitlines()
|
2416
|
-
|
2417
|
-
@staticmethod
|
2418
|
-
def _levenshtein_ratio(s1, s2):
|
2419
|
-
return SequenceMatcher(None, s1, s2).ratio()
|
2420
|
-
|
2421
|
-
def get_best_matching_window(self):
|
2422
|
-
best_similarity = 0
|
2423
|
-
best_window = []
|
2424
|
-
|
2425
|
-
for i in range(self.n - self.m + 1): # 滑动窗口
|
2426
|
-
window_b = self.lines_b[i:i + self.m]
|
2427
|
-
similarity = self._levenshtein_ratio("\n".join(self.lines_a), "\n".join(window_b))
|
2428
|
-
|
2429
|
-
if similarity > best_similarity:
|
2430
|
-
best_similarity = similarity
|
2431
|
-
best_window = window_b
|
2432
|
-
|
2433
|
-
return best_similarity, "\n".join(best_window)
|
2434
|
-
|
2435
|
-
|
2436
|
-
class CodeAutoMergeEditBlock:
|
2437
|
-
def __init__(self, llm: AutoLLM, fence_0: str = "```", fence_1: str = "```"):
|
2438
|
-
self.llm = llm
|
2439
|
-
self.llm.setup_default_model_name(memory["conf"]["current_code_model"])
|
2440
|
-
self.args = args
|
2441
|
-
self.fence_0 = fence_0
|
2442
|
-
self.fence_1 = fence_1
|
2443
|
-
|
2444
|
-
@staticmethod
|
2445
|
-
def run_pylint(code: str) -> tuple[bool, str]:
|
2446
|
-
"""
|
2447
|
-
--disable=all 禁用所有 Pylint 的检查规则
|
2448
|
-
--enable=E0001,W0311,W0312 启用指定的 Pylint 检查规则,
|
2449
|
-
E0001:语法错误(Syntax Error),
|
2450
|
-
W0311:代码缩进使用了 Tab 而不是空格(Bad indentation)
|
2451
|
-
W0312:代码缩进不一致(Mixed indentation)
|
2452
|
-
:param code:
|
2453
|
-
:return:
|
2454
|
-
"""
|
2455
|
-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file:
|
2456
|
-
temp_file.write(code)
|
2457
|
-
temp_file_path = temp_file.name
|
2458
|
-
|
2459
|
-
try:
|
2460
|
-
result = subprocess.run(
|
2461
|
-
["pylint", "--disable=all", "--enable=E0001,W0311,W0312", temp_file_path,],
|
2462
|
-
capture_output=True,
|
2463
|
-
text=True,
|
2464
|
-
check=False,
|
2465
|
-
)
|
2466
|
-
os.unlink(temp_file_path)
|
2467
|
-
if result.returncode != 0:
|
2468
|
-
error_message = result.stdout.strip() or result.stderr.strip()
|
2469
|
-
logger.warning(f"Pylint 检查代码失败: {error_message}")
|
2470
|
-
return False, error_message
|
2471
|
-
return True, ""
|
2472
|
-
except subprocess.CalledProcessError as e:
|
2473
|
-
error_message = f"运行 Pylint 时发生错误: {str(e)}"
|
2474
|
-
logger.error(error_message)
|
2475
|
-
os.unlink(temp_file_path)
|
2476
|
-
return False, error_message
|
2477
|
-
|
2478
|
-
def parse_whole_text(self, text: str) -> List[PathAndCode]:
|
2479
|
-
"""
|
2480
|
-
从文本中抽取如下格式代码(two_line_mode):
|
2481
|
-
|
2482
|
-
```python
|
2483
|
-
##File: /project/path/src/autocoder/index/index.py
|
2484
|
-
<<<<<<< SEARCH
|
2485
|
-
=======
|
2486
|
-
>>>>>>> REPLACE
|
2487
|
-
```
|
2488
|
-
|
2489
|
-
或者 (one_line_mode)
|
2490
|
-
|
2491
|
-
```python:/project/path/src/autocoder/index/index.py
|
2492
|
-
<<<<<<< SEARCH
|
2493
|
-
=======
|
2494
|
-
>>>>>>> REPLACE
|
2495
|
-
```
|
2496
|
-
"""
|
2497
|
-
HEAD = "<<<<<<< SEARCH"
|
2498
|
-
DIVIDER = "======="
|
2499
|
-
UPDATED = ">>>>>>> REPLACE"
|
2500
|
-
lines = text.split("\n")
|
2501
|
-
lines_len = len(lines)
|
2502
|
-
start_marker_count = 0
|
2503
|
-
block = []
|
2504
|
-
path_and_code_list = []
|
2505
|
-
# two_line_mode or one_line_mode
|
2506
|
-
current_editblock_mode = "two_line_mode"
|
2507
|
-
current_editblock_path = None
|
2508
|
-
|
2509
|
-
def guard(_index):
|
2510
|
-
return _index + 1 < lines_len
|
2511
|
-
|
2512
|
-
def start_marker(_line, _index):
|
2513
|
-
nonlocal current_editblock_mode
|
2514
|
-
nonlocal current_editblock_path
|
2515
|
-
if _line.startswith(self.fence_0) and guard(_index) and ":" in _line and lines[_index + 1].startswith(HEAD):
|
2516
|
-
current_editblock_mode = "one_line_mode"
|
2517
|
-
current_editblock_path = _line.split(":", 1)[1].strip()
|
2518
|
-
return True
|
2519
|
-
if _line.startswith(self.fence_0) and guard(_index) and lines[_index + 1].startswith("##File:"):
|
2520
|
-
current_editblock_mode = "two_line_mode"
|
2521
|
-
current_editblock_path = None
|
2522
|
-
return True
|
2523
|
-
return False
|
2524
|
-
|
2525
|
-
def end_marker(_line, _index):
|
2526
|
-
return _line.startswith(self.fence_1) and UPDATED in lines[_index - 1]
|
2527
|
-
|
2528
|
-
for index, line in enumerate(lines):
|
2529
|
-
if start_marker(line, index) and start_marker_count == 0:
|
2530
|
-
start_marker_count += 1
|
2531
|
-
elif end_marker(line, index) and start_marker_count == 1:
|
2532
|
-
start_marker_count -= 1
|
2533
|
-
if block:
|
2534
|
-
if current_editblock_mode == "two_line_mode":
|
2535
|
-
path = block[0].split(":", 1)[1].strip()
|
2536
|
-
content = "\n".join(block[1:])
|
2537
|
-
else:
|
2538
|
-
path = current_editblock_path
|
2539
|
-
content = "\n".join(block)
|
2540
|
-
block = []
|
2541
|
-
path_and_code_list.append(PathAndCode(path=path, content=content))
|
2542
|
-
elif start_marker_count > 0:
|
2543
|
-
block.append(line)
|
2544
|
-
|
2545
|
-
return path_and_code_list
|
2546
|
-
|
2547
|
-
def get_edits(self, content: str):
|
2548
|
-
edits = self.parse_whole_text(content)
|
2549
|
-
HEAD = "<<<<<<< SEARCH"
|
2550
|
-
DIVIDER = "======="
|
2551
|
-
UPDATED = ">>>>>>> REPLACE"
|
2552
|
-
result = []
|
2553
|
-
for edit in edits:
|
2554
|
-
heads = []
|
2555
|
-
updates = []
|
2556
|
-
c = edit.content
|
2557
|
-
in_head = False
|
2558
|
-
in_updated = False
|
2559
|
-
for line in c.splitlines():
|
2560
|
-
if line.strip() == HEAD:
|
2561
|
-
in_head = True
|
2562
|
-
continue
|
2563
|
-
if line.strip() == DIVIDER:
|
2564
|
-
in_head = False
|
2565
|
-
in_updated = True
|
2566
|
-
continue
|
2567
|
-
if line.strip() == UPDATED:
|
2568
|
-
in_head = False
|
2569
|
-
in_updated = False
|
2570
|
-
continue
|
2571
|
-
if in_head:
|
2572
|
-
heads.append(line)
|
2573
|
-
if in_updated:
|
2574
|
-
updates.append(line)
|
2575
|
-
result.append((edit.path, "\n".join(heads), "\n".join(updates)))
|
2576
|
-
return result
|
2577
|
-
|
2578
|
-
@prompt()
|
2579
|
-
def git_require_msg(self, source_dir: str, error: str) -> str:
|
2580
|
-
"""
|
2581
|
-
auto_merge only works for git repositories.
|
2582
|
-
|
2583
|
-
Try to use git init in the source directory.
|
2584
|
-
|
2585
|
-
```shell
|
2586
|
-
cd {{ source_dir }}
|
2587
|
-
git init .
|
2588
|
-
```
|
2589
|
-
|
2590
|
-
Then try to run auto-coder again.
|
2591
|
-
Error: {{ error }}
|
2592
|
-
"""
|
2593
|
-
|
2594
|
-
def _merge_code_without_effect(self, content: str) -> MergeCodeWithoutEffect:
|
2595
|
-
"""
|
2596
|
-
合并代码时不会产生任何副作用,例如 Git 操作、代码检查或文件写入。
|
2597
|
-
返回一个元组,包含:
|
2598
|
-
- 成功合并的代码块的列表,每个元素是一个 (file_path, new_content) 元组,
|
2599
|
-
其中 file_path 是文件路径,new_content 是合并后的新内容。
|
2600
|
-
- 合并失败的代码块的列表,每个元素是一个 (file_path, head, update) 元组,
|
2601
|
-
其中:file_path 是文件路径,head 是原始内容,update 是尝试合并的内容。
|
2602
|
-
"""
|
2603
|
-
codes = self.get_edits(content)
|
2604
|
-
file_content_mapping = {}
|
2605
|
-
failed_blocks = []
|
2606
|
-
|
2607
|
-
for block in codes:
|
2608
|
-
file_path, head, update = block
|
2609
|
-
if not os.path.exists(file_path):
|
2610
|
-
file_content_mapping[file_path] = update
|
2611
|
-
else:
|
2612
|
-
if file_path not in file_content_mapping:
|
2613
|
-
with open(file_path, "r") as f:
|
2614
|
-
temp = f.read()
|
2615
|
-
file_content_mapping[file_path] = temp
|
2616
|
-
existing_content = file_content_mapping[file_path]
|
2617
|
-
|
2618
|
-
# First try exact match
|
2619
|
-
new_content = (
|
2620
|
-
existing_content.replace(head, update, 1)
|
2621
|
-
if head
|
2622
|
-
else existing_content + "\n" + update
|
2623
|
-
)
|
2624
|
-
|
2625
|
-
# If exact match fails, try similarity match
|
2626
|
-
if new_content == existing_content and head:
|
2627
|
-
similarity, best_window = TextSimilarity(
|
2628
|
-
head, existing_content
|
2629
|
-
).get_best_matching_window()
|
2630
|
-
if similarity > self.args.editblock_similarity:
|
2631
|
-
new_content = existing_content.replace(
|
2632
|
-
best_window, update, 1
|
2633
|
-
)
|
2634
|
-
|
2635
|
-
if new_content != existing_content:
|
2636
|
-
file_content_mapping[file_path] = new_content
|
2637
|
-
else:
|
2638
|
-
failed_blocks.append((file_path, head, update))
|
2639
|
-
return MergeCodeWithoutEffect(
|
2640
|
-
success_blocks=[(path, content) for path, content in file_content_mapping.items()],
|
2641
|
-
failed_blocks=failed_blocks
|
2642
|
-
)
|
2643
|
-
|
2644
|
-
def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
|
2645
|
-
""" 选择最佳代码 """
|
2646
|
-
if len(generate_result.contents) == 1: # 仅一份代码立即返回
|
2647
|
-
logger.info("仅有一个候选结果,跳过排序")
|
2648
|
-
return generate_result
|
2649
|
-
|
2650
|
-
ranker = CodeModificationRanker(self.llm)
|
2651
|
-
ranked_result = ranker.rank_modifications(generate_result)
|
2652
|
-
# 过滤掉包含失败块的内容
|
2653
|
-
for content, conversations in zip(ranked_result.contents, ranked_result.conversations):
|
2654
|
-
merge_result = self._merge_code_without_effect(content)
|
2655
|
-
if not merge_result.failed_blocks:
|
2656
|
-
return CodeGenerateResult(contents=[content], conversations=[conversations])
|
2657
|
-
# 如果所有内容都包含失败块,则返回第一个
|
2658
|
-
return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
|
2659
|
-
|
2660
|
-
def _merge_code(self, content: str, force_skip_git: bool = False):
|
2661
|
-
file_content = open(self.args.file).read()
|
2662
|
-
md5 = hashlib.md5(file_content.encode("utf-8")).hexdigest()
|
2663
|
-
file_name = os.path.basename(self.args.file)
|
2664
|
-
|
2665
|
-
codes = self.get_edits(content)
|
2666
|
-
changes_to_make = []
|
2667
|
-
changes_made = False
|
2668
|
-
unmerged_blocks = []
|
2669
|
-
merged_blocks = []
|
2670
|
-
|
2671
|
-
# First, check if there are any changes to be made
|
2672
|
-
file_content_mapping = {}
|
2673
|
-
for block in codes:
|
2674
|
-
file_path, head, update = block
|
2675
|
-
if not os.path.exists(file_path):
|
2676
|
-
changes_to_make.append((file_path, None, update))
|
2677
|
-
file_content_mapping[file_path] = update
|
2678
|
-
merged_blocks.append((file_path, "", update, 1))
|
2679
|
-
changes_made = True
|
2680
|
-
else:
|
2681
|
-
if file_path not in file_content_mapping:
|
2682
|
-
with open(file_path, "r") as f:
|
2683
|
-
temp = f.read()
|
2684
|
-
file_content_mapping[file_path] = temp
|
2685
|
-
existing_content = file_content_mapping[file_path]
|
2686
|
-
new_content = (
|
2687
|
-
existing_content.replace(head, update, 1)
|
2688
|
-
if head
|
2689
|
-
else existing_content + "\n" + update
|
2690
|
-
)
|
2691
|
-
if new_content != existing_content:
|
2692
|
-
changes_to_make.append(
|
2693
|
-
(file_path, existing_content, new_content))
|
2694
|
-
file_content_mapping[file_path] = new_content
|
2695
|
-
merged_blocks.append((file_path, head, update, 1))
|
2696
|
-
changes_made = True
|
2697
|
-
else:
|
2698
|
-
# If the SEARCH BLOCK is not found exactly, then try to use
|
2699
|
-
# the similarity ratio to find the best matching block
|
2700
|
-
similarity, best_window = TextSimilarity(head, existing_content).get_best_matching_window()
|
2701
|
-
if similarity > self.args.editblock_similarity: # 相似性比较
|
2702
|
-
new_content = existing_content.replace(
|
2703
|
-
best_window, update, 1)
|
2704
|
-
if new_content != existing_content:
|
2705
|
-
changes_to_make.append(
|
2706
|
-
(file_path, existing_content, new_content)
|
2707
|
-
)
|
2708
|
-
file_content_mapping[file_path] = new_content
|
2709
|
-
merged_blocks.append(
|
2710
|
-
(file_path, head, update, similarity))
|
2711
|
-
changes_made = True
|
2712
|
-
else:
|
2713
|
-
unmerged_blocks.append((file_path, head, update, similarity))
|
2714
|
-
|
2715
|
-
if unmerged_blocks:
|
2716
|
-
if self.args.request_id and not self.args.skip_events:
|
2717
|
-
# collect unmerged blocks
|
2718
|
-
event_data = []
|
2719
|
-
for file_path, head, update, similarity in unmerged_blocks:
|
2720
|
-
event_data.append(
|
2721
|
-
{
|
2722
|
-
"file_path": file_path,
|
2723
|
-
"head": head,
|
2724
|
-
"update": update,
|
2725
|
-
"similarity": similarity,
|
2726
|
-
}
|
2727
|
-
)
|
2728
|
-
return
|
2729
|
-
logger.warning(f"发现 {len(unmerged_blocks)} 个未合并的代码块,更改将不会应用,请手动检查这些代码块后重试。")
|
2730
|
-
self._print_unmerged_blocks(unmerged_blocks)
|
2731
|
-
return
|
2732
|
-
|
2733
|
-
# lint check
|
2734
|
-
for file_path, new_content in file_content_mapping.items():
|
2735
|
-
if file_path.endswith(".py"):
|
2736
|
-
pylint_passed, error_message = self.run_pylint(new_content)
|
2737
|
-
if not pylint_passed:
|
2738
|
-
logger.warning(f"代码文件 {file_path} 的 Pylint 检查未通过,本次更改未应用。错误信息: {error_message}")
|
2739
|
-
|
2740
|
-
if changes_made and not force_skip_git and not self.args.skip_commit:
|
2741
|
-
try:
|
2742
|
-
commit_changes(self.args.source_dir, f"auto_coder_pre_{file_name}_{md5}")
|
2743
|
-
except Exception as e:
|
2744
|
-
logger.error(
|
2745
|
-
self.git_require_msg(
|
2746
|
-
source_dir=self.args.source_dir, error=str(e))
|
2747
|
-
)
|
2748
|
-
return
|
2749
|
-
# Now, apply the changes
|
2750
|
-
for file_path, new_content in file_content_mapping.items():
|
2751
|
-
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
2752
|
-
with open(file_path, "w") as f:
|
2753
|
-
f.write(new_content)
|
2754
|
-
|
2755
|
-
if self.args.request_id and not self.args.skip_events:
|
2756
|
-
# collect modified files
|
2757
|
-
event_data = []
|
2758
|
-
for code in merged_blocks:
|
2759
|
-
file_path, head, update, similarity = code
|
2760
|
-
event_data.append(
|
2761
|
-
{
|
2762
|
-
"file_path": file_path,
|
2763
|
-
"head": head,
|
2764
|
-
"update": update,
|
2765
|
-
"similarity": similarity,
|
2766
|
-
}
|
2767
|
-
)
|
2768
|
-
|
2769
|
-
if changes_made:
|
2770
|
-
if not force_skip_git and not self.args.skip_commit:
|
2771
|
-
try:
|
2772
|
-
commit_result = commit_changes(self.args.source_dir, f"auto_coder_{file_name}_{md5}")
|
2773
|
-
git_print_commit_info(commit_result=commit_result)
|
2774
|
-
except Exception as e:
|
2775
|
-
logger.error(
|
2776
|
-
self.git_require_msg(
|
2777
|
-
source_dir=self.args.source_dir, error=str(e)
|
2778
|
-
)
|
2779
|
-
)
|
2780
|
-
logger.info(
|
2781
|
-
f"已在 {len(file_content_mapping.keys())} 个文件中合并更改,"
|
2782
|
-
f"完成 {len(changes_to_make)}/{len(codes)} 个代码块。"
|
2783
|
-
)
|
2784
|
-
else:
|
2785
|
-
logger.warning("未对任何文件进行更改。")
|
2786
|
-
|
2787
|
-
def merge_code(self, generate_result: CodeGenerateResult, force_skip_git: bool = False):
|
2788
|
-
result = self.choose_best_choice(generate_result)
|
2789
|
-
self._merge_code(result.contents[0], force_skip_git)
|
2790
|
-
return result
|
2791
|
-
|
2792
|
-
@staticmethod
|
2793
|
-
def _print_unmerged_blocks(unmerged_blocks: List[tuple]):
|
2794
|
-
console.print(f"\n[bold red]未合并的代码块:[/bold red]")
|
2795
|
-
for file_path, head, update, similarity in unmerged_blocks:
|
2796
|
-
console.print(f"\n[bold blue]文件:[/bold blue] {file_path}")
|
2797
|
-
console.print(
|
2798
|
-
f"\n[bold green]搜索代码块(相似度:{similarity}):[/bold green]")
|
2799
|
-
syntax = Syntax(head, "python", theme="monokai", line_numbers=True)
|
2800
|
-
console.print(Panel(syntax, expand=False))
|
2801
|
-
console.print("\n[bold yellow]替换代码块:[/bold yellow]")
|
2802
|
-
syntax = Syntax(update, "python", theme="monokai",
|
2803
|
-
line_numbers=True)
|
2804
|
-
console.print(Panel(syntax, expand=False))
|
2805
|
-
console.print(f"\n[bold red]未合并的代码块总数: {len(unmerged_blocks)}[/bold red]")
|
2806
|
-
|
2807
|
-
|
2808
|
-
class BaseAction:
|
2809
|
-
@staticmethod
|
2810
|
-
def _get_content_length(content: str) -> int:
|
2811
|
-
return len(content)
|
2812
|
-
|
2813
|
-
|
2814
|
-
class ActionPyProject(BaseAction):
|
2815
|
-
def __init__(self, llm: Optional[AutoLLM] = None) -> None:
|
2816
|
-
self.args = args
|
2817
|
-
self.llm = llm
|
2818
|
-
self.pp = None
|
2819
|
-
|
2820
|
-
def run(self):
|
2821
|
-
if self.args.project_type != "py":
|
2822
|
-
return False
|
2823
|
-
pp = PyProject(llm=self.llm, args=args)
|
2824
|
-
self.pp = pp
|
2825
|
-
pp.run()
|
2826
|
-
source_code = pp.output()
|
2827
|
-
if self.llm:
|
2828
|
-
source_code = build_index_and_filter_files(llm=self.llm, sources=pp.sources)
|
2829
|
-
self.process_content(source_code)
|
2830
|
-
return True
|
2831
|
-
|
2832
|
-
def process_content(self, content: str):
|
2833
|
-
# args = self.args
|
2834
|
-
if self.args.execute and self.llm:
|
2835
|
-
content_length = self._get_content_length(content)
|
2836
|
-
if content_length > self.args.model_max_input_length:
|
2837
|
-
logger.warning(
|
2838
|
-
f"发送给模型的内容长度为 {content_length} 个 token(可能收集了过多文件),"
|
2839
|
-
f"已超过最大输入长度限制 {self.args.model_max_input_length}。"
|
2840
|
-
)
|
2841
|
-
|
2842
|
-
if args.execute:
|
2843
|
-
logger.info("正在自动生成代码...")
|
2844
|
-
start_time = time.time()
|
2845
|
-
# diff, strict_diff, editblock 是代码自动生成或合并的不同策略, 通常用于处理代码的变更或生成
|
2846
|
-
# diff 模式,基于差异生成代码,生成最小的变更集,适用于局部优化,代码重构
|
2847
|
-
# strict_diff 模式,严格验证差异,确保生成的代码符合规则,适用于代码审查,自动化测试
|
2848
|
-
# editblock 模式,基于编辑块生成代码,支持较大范围的修改,适用于代码重构,功能扩展
|
2849
|
-
if args.auto_merge == "editblock":
|
2850
|
-
generate = CodeAutoGenerateEditBlock(llm=self.llm, action=self)
|
2851
|
-
else:
|
2852
|
-
generate = None
|
2853
|
-
|
2854
|
-
if self.args.enable_multi_round_generate:
|
2855
|
-
generate_result = generate.multi_round_run(query=args.query, source_content=content)
|
2856
|
-
else:
|
2857
|
-
generate_result = generate.single_round_run(query=args.query, source_content=content)
|
2858
|
-
logger.info(f"代码生成完成,耗时 {time.time() - start_time:.2f} 秒")
|
2859
|
-
|
2860
|
-
if args.auto_merge:
|
2861
|
-
logger.info("正在自动合并代码...")
|
2862
|
-
if args.auto_merge == "editblock":
|
2863
|
-
code_merge = CodeAutoMergeEditBlock(llm=self.llm)
|
2864
|
-
merge_result = code_merge.merge_code(generate_result=generate_result)
|
2865
|
-
else:
|
2866
|
-
merge_result = None
|
2867
|
-
|
2868
|
-
content = merge_result.contents[0]
|
2869
|
-
else:
|
2870
|
-
content = generate_result.contents[0]
|
2871
|
-
with open(args.target_file, "w") as file:
|
2872
|
-
file.write(content)
|
2873
|
-
|
2874
|
-
|
2875
|
-
class ActionSuffixProject(BaseAction):
|
2876
|
-
def __init__(self, llm: Optional[AutoLLM] = None) -> None:
|
2877
|
-
self.args = args
|
2878
|
-
self.llm = llm
|
2879
|
-
self.pp = None
|
2880
|
-
|
2881
|
-
def run(self):
|
2882
|
-
pp = SuffixProject(llm=self.llm, args=args)
|
2883
|
-
self.pp = pp
|
2884
|
-
pp.run()
|
2885
|
-
source_code = pp.output()
|
2886
|
-
if self.llm:
|
2887
|
-
source_code = build_index_and_filter_files(llm=self.llm, sources=pp.sources)
|
2888
|
-
self.process_content(source_code)
|
2889
|
-
|
2890
|
-
def process_content(self, content: str):
|
2891
|
-
if self.args.execute and self.llm:
|
2892
|
-
content_length = self._get_content_length(content)
|
2893
|
-
if content_length > self.args.model_max_input_length:
|
2894
|
-
logger.warning(
|
2895
|
-
f"发送给模型的内容长度为 {content_length} 个 token(可能收集了过多文件),"
|
2896
|
-
f"已超过最大输入长度限制 {self.args.model_max_input_length}。"
|
2897
|
-
)
|
2898
|
-
|
2899
|
-
if args.execute:
|
2900
|
-
logger.info("正在自动生成代码...")
|
2901
|
-
start_time = time.time()
|
2902
|
-
# diff, strict_diff, editblock 是代码自动生成或合并的不同策略, 通常用于处理代码的变更或生成
|
2903
|
-
# diff 模式,基于差异生成代码,生成最小的变更集,适用于局部优化,代码重构
|
2904
|
-
# strict_diff 模式,严格验证差异,确保生成的代码符合规则,适用于代码审查,自动化测试
|
2905
|
-
# editblock 模式,基于编辑块生成代码,支持较大范围的修改,适用于代码重构,功能扩展
|
2906
|
-
if args.auto_merge == "editblock":
|
2907
|
-
generate = CodeAutoGenerateEditBlock(llm=self.llm, action=self)
|
2908
|
-
else:
|
2909
|
-
generate = None
|
2910
|
-
|
2911
|
-
if self.args.enable_multi_round_generate:
|
2912
|
-
generate_result = generate.multi_round_run(query=args.query, source_content=content)
|
2913
|
-
else:
|
2914
|
-
generate_result = generate.single_round_run(query=args.query, source_content=content)
|
2915
|
-
logger.info(f"代码生成完成,耗时 {time.time() - start_time:.2f} 秒")
|
2916
|
-
|
2917
|
-
if args.auto_merge:
|
2918
|
-
logger.info("正在自动合并代码...")
|
2919
|
-
if args.auto_merge == "editblock":
|
2920
|
-
code_merge = CodeAutoMergeEditBlock(llm=self.llm)
|
2921
|
-
merge_result = code_merge.merge_code(generate_result=generate_result)
|
2922
|
-
else:
|
2923
|
-
merge_result = None
|
2924
|
-
|
2925
|
-
content = merge_result.contents[0]
|
2926
|
-
else:
|
2927
|
-
content = generate_result.contents[0]
|
2928
|
-
with open(args.target_file, "w") as file:
|
2929
|
-
file.write(content)
|
2930
|
-
|
2931
|
-
|
2932
|
-
class Dispacher:
|
2933
|
-
def __init__(self, llm: Optional[AutoLLM] = None):
|
2934
|
-
self.args = args
|
2935
|
-
self.llm = llm
|
2936
|
-
|
2937
|
-
def dispach(self):
|
2938
|
-
actions = [
|
2939
|
-
ActionPyProject(llm=self.llm),
|
2940
|
-
ActionSuffixProject(llm=self.llm)
|
2941
|
-
]
|
2942
|
-
for action in actions:
|
2943
|
-
if action.run():
|
2944
|
-
return
|
2945
|
-
|
2946
|
-
|
2947
1253
|
def prepare_chat_yaml():
|
2948
1254
|
# auto_coder_main(["next", "chat_action"]) 准备聊天 yaml 文件
|
2949
1255
|
actions_dir = os.path.join(args.source_dir, "actions")
|
@@ -3003,12 +1309,11 @@ def coding(query: str, llm: AutoLLM):
|
|
3003
1309
|
if latest_yaml_file:
|
3004
1310
|
yaml_config = {
|
3005
1311
|
"include_file": ["./base/base.yml"],
|
3006
|
-
"auto_merge": conf.get("auto_merge", "editblock"),
|
3007
|
-
"human_as_model": conf.get("human_as_model", "false") == "true",
|
3008
1312
|
"skip_build_index": conf.get("skip_build_index", "true") == "true",
|
3009
1313
|
"skip_confirm": conf.get("skip_confirm", "true") == "true",
|
3010
|
-
"
|
3011
|
-
"
|
1314
|
+
"chat_model": conf.get("chat_model", ""),
|
1315
|
+
"code_model": conf.get("code_model", ""),
|
1316
|
+
"auto_merge": conf.get("auto_merge", "editblock"),
|
3012
1317
|
"context": ""
|
3013
1318
|
}
|
3014
1319
|
|
@@ -3071,7 +1376,7 @@ def coding(query: str, llm: AutoLLM):
|
|
3071
1376
|
f.write(yaml_content)
|
3072
1377
|
convert_yaml_to_config(execute_file)
|
3073
1378
|
|
3074
|
-
dispacher = Dispacher(llm)
|
1379
|
+
dispacher = Dispacher(args=args, llm=llm)
|
3075
1380
|
dispacher.dispach()
|
3076
1381
|
else:
|
3077
1382
|
logger.warning("创建新的 YAML 文件失败。")
|
@@ -3133,28 +1438,26 @@ def commit_info(query: str, llm: AutoLLM):
|
|
3133
1438
|
prepare_chat_yaml() # 复制上一个序号的 yaml 文件, 生成一个新的聊天 yaml 文件
|
3134
1439
|
|
3135
1440
|
latest_yaml_file = get_last_yaml_file(os.path.join(args.source_dir, "actions"))
|
3136
|
-
|
3137
|
-
conf = memory.get("conf", {})
|
3138
|
-
current_files = memory["current_files"]["files"]
|
3139
1441
|
execute_file = None
|
3140
1442
|
|
3141
1443
|
if latest_yaml_file:
|
3142
1444
|
try:
|
3143
1445
|
execute_file = os.path.join(args.source_dir, "actions", latest_yaml_file)
|
1446
|
+
conf = memory.get("conf", {})
|
3144
1447
|
yaml_config = {
|
3145
1448
|
"include_file": ["./base/base.yml"],
|
3146
|
-
"auto_merge": conf.get("auto_merge", "editblock"),
|
3147
|
-
"human_as_model": conf.get("human_as_model", "false") == "true",
|
3148
1449
|
"skip_build_index": conf.get("skip_build_index", "true") == "true",
|
3149
1450
|
"skip_confirm": conf.get("skip_confirm", "true") == "true",
|
3150
|
-
"
|
3151
|
-
"
|
1451
|
+
"chat_model": conf.get("chat_model", ""),
|
1452
|
+
"code_model": conf.get("code_model", ""),
|
1453
|
+
"auto_merge": conf.get("auto_merge", "editblock")
|
3152
1454
|
}
|
3153
1455
|
for key, value in conf.items():
|
3154
1456
|
converted_value = convert_config_value(key, value)
|
3155
1457
|
if converted_value is not None:
|
3156
1458
|
yaml_config[key] = converted_value
|
3157
1459
|
|
1460
|
+
current_files = memory["current_files"]["files"]
|
3158
1461
|
yaml_config["urls"] = current_files
|
3159
1462
|
|
3160
1463
|
# 临时保存yaml文件,然后读取yaml文件,更新args
|
@@ -3169,7 +1472,7 @@ def commit_info(query: str, llm: AutoLLM):
|
|
3169
1472
|
|
3170
1473
|
# commit_message = ""
|
3171
1474
|
commit_llm = llm
|
3172
|
-
commit_llm.setup_default_model_name(
|
1475
|
+
commit_llm.setup_default_model_name(args.chat_model)
|
3173
1476
|
console.print(f"Commit 信息生成中...", style="yellow")
|
3174
1477
|
|
3175
1478
|
try:
|
@@ -3239,20 +1542,7 @@ def _generate_shell_script(user_input: str) -> str:
|
|
3239
1542
|
|
3240
1543
|
|
3241
1544
|
def generate_shell_command(input_text: str, llm: AutoLLM) -> str | None:
|
3242
|
-
|
3243
|
-
yaml_config = {
|
3244
|
-
"include_file": ["./base/base.yml"],
|
3245
|
-
}
|
3246
|
-
if "model" in conf:
|
3247
|
-
yaml_config["model"] = conf["model"]
|
3248
|
-
yaml_config["query"] = input_text
|
3249
|
-
|
3250
|
-
yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
|
3251
|
-
|
3252
|
-
execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
|
3253
|
-
|
3254
|
-
with open(os.path.join(execute_file), "w") as f:
|
3255
|
-
f.write(yaml_content)
|
1545
|
+
update_config_to_args(query=input_text, delete_execute_file=True)
|
3256
1546
|
|
3257
1547
|
try:
|
3258
1548
|
console.print(
|
@@ -3262,7 +1552,7 @@ def generate_shell_command(input_text: str, llm: AutoLLM) -> str | None:
|
|
3262
1552
|
border_style="green",
|
3263
1553
|
)
|
3264
1554
|
)
|
3265
|
-
llm.setup_default_model_name(
|
1555
|
+
llm.setup_default_model_name(args.code_model)
|
3266
1556
|
result = _generate_shell_script.with_llm(llm).run(user_input=input_text)
|
3267
1557
|
shell_script = extract_code(result.output)[0][1]
|
3268
1558
|
console.print(
|
@@ -3274,7 +1564,8 @@ def generate_shell_command(input_text: str, llm: AutoLLM) -> str | None:
|
|
3274
1564
|
)
|
3275
1565
|
return shell_script
|
3276
1566
|
finally:
|
3277
|
-
|
1567
|
+
pass
|
1568
|
+
# os.remove(execute_file)
|
3278
1569
|
|
3279
1570
|
|
3280
1571
|
def execute_shell_command(command: str):
|
@@ -3884,10 +2175,10 @@ def manage_models(models_args, models_data, llm: AutoLLM):
|
|
3884
2175
|
logger.info(f"正在卸载 {remove_model_name} 模型")
|
3885
2176
|
if llm.get_sub_client(remove_model_name):
|
3886
2177
|
llm.remove_sub_client(remove_model_name)
|
3887
|
-
if remove_model_name == memory["conf"]["
|
3888
|
-
logger.warning(f"当前首选 Chat 模型 {remove_model_name} 已被删除, 请立即 /conf
|
3889
|
-
if remove_model_name == memory["conf"]["
|
3890
|
-
logger.warning(f"当前首选 Code 模型 {remove_model_name} 已被删除, 请立即 /conf
|
2178
|
+
if remove_model_name == memory["conf"]["chat_model"]:
|
2179
|
+
logger.warning(f"当前首选 Chat 模型 {remove_model_name} 已被删除, 请立即 /conf chat_model: 调整 !!!")
|
2180
|
+
if remove_model_name == memory["conf"]["code_model"]:
|
2181
|
+
logger.warning(f"当前首选 Code 模型 {remove_model_name} 已被删除, 请立即 /conf code_model: 调整 !!!")
|
3891
2182
|
|
3892
2183
|
|
3893
2184
|
def configure_project_model():
|
@@ -3958,58 +2249,71 @@ def configure_project_model():
|
|
3958
2249
|
)
|
3959
2250
|
|
3960
2251
|
|
3961
|
-
def new_project(query, llm):
|
3962
|
-
|
3963
|
-
|
3964
|
-
|
3965
|
-
|
3966
|
-
|
3967
|
-
|
3968
|
-
|
3969
|
-
|
3970
|
-
|
3971
|
-
|
3972
|
-
|
3973
|
-
|
3974
|
-
|
3975
|
-
|
3976
|
-
|
3977
|
-
|
3978
|
-
|
3979
|
-
|
3980
|
-
|
3981
|
-
|
3982
|
-
|
3983
|
-
|
3984
|
-
|
3985
|
-
|
3986
|
-
|
3987
|
-
|
3988
|
-
|
3989
|
-
|
3990
|
-
|
3991
|
-
|
3992
|
-
|
3993
|
-
|
3994
|
-
|
3995
|
-
|
3996
|
-
|
3997
|
-
|
3998
|
-
|
3999
|
-
|
4000
|
-
|
4001
|
-
|
4002
|
-
|
4003
|
-
|
4004
|
-
|
4005
|
-
|
4006
|
-
|
4007
|
-
|
4008
|
-
|
4009
|
-
|
4010
|
-
|
4011
|
-
|
4012
|
-
|
2252
|
+
# def new_project(query, llm):
|
2253
|
+
# console.print(f"正在基于你的需求 {query} 构建项目 ...", style="bold green")
|
2254
|
+
# env_info = detect_env()
|
2255
|
+
# project = BuildNewProject(args=args, llm=llm,
|
2256
|
+
# chat_model=memory["conf"]["chat_model"],
|
2257
|
+
# code_model=memory["conf"]["code_model"])
|
2258
|
+
#
|
2259
|
+
# console.print(f"正在完善项目需求 ...", style="bold green")
|
2260
|
+
#
|
2261
|
+
# information = project.build_project_information(query, env_info, args.project_type)
|
2262
|
+
# if not information:
|
2263
|
+
# raise Exception(f"项目需求未正常生成 .")
|
2264
|
+
#
|
2265
|
+
# table = Table(title=f"{query}")
|
2266
|
+
# table.add_column("需求说明", style="cyan")
|
2267
|
+
# table.add_row(f"{information[:50]}...")
|
2268
|
+
# console.print(table)
|
2269
|
+
#
|
2270
|
+
# console.print(f"正在完善项目架构 ...", style="bold green")
|
2271
|
+
# architecture = project.build_project_architecture(query, env_info, args.project_type, information)
|
2272
|
+
#
|
2273
|
+
# console.print(f"正在构建项目索引 ...", style="bold green")
|
2274
|
+
# index_file_list = project.build_project_index(query, env_info, args.project_type, information, architecture)
|
2275
|
+
#
|
2276
|
+
# table = Table(title=f"索引列表")
|
2277
|
+
# table.add_column("路径", style="cyan")
|
2278
|
+
# table.add_column("用途", style="cyan")
|
2279
|
+
# for index_file in index_file_list.file_list:
|
2280
|
+
# table.add_row(index_file.file_path, index_file.purpose)
|
2281
|
+
# console.print(table)
|
2282
|
+
#
|
2283
|
+
# for index_file in index_file_list.file_list:
|
2284
|
+
# full_path = os.path.join(args.source_dir, index_file.file_path)
|
2285
|
+
#
|
2286
|
+
# # 获取目录路径
|
2287
|
+
# full_dir_path = os.path.dirname(full_path)
|
2288
|
+
# if not os.path.exists(full_dir_path):
|
2289
|
+
# os.makedirs(full_dir_path)
|
2290
|
+
#
|
2291
|
+
# console.print(f"正在编码: {full_path} ...", style="bold green")
|
2292
|
+
# code = project.build_single_code(query, env_info, args.project_type, information, architecture, index_file)
|
2293
|
+
#
|
2294
|
+
# with open(full_path, "w") as fp:
|
2295
|
+
# fp.write(code)
|
2296
|
+
#
|
2297
|
+
# # 生成 readme
|
2298
|
+
# readme_context = information + architecture
|
2299
|
+
# readme_path = os.path.join(args.source_dir, "README.md")
|
2300
|
+
# with open(readme_path, "w") as fp:
|
2301
|
+
# fp.write(readme_context)
|
2302
|
+
#
|
2303
|
+
# console.print(f"项目构建完成", style="bold green")
|
2304
|
+
|
2305
|
+
|
2306
|
+
def is_old_version():
|
2307
|
+
"""
|
2308
|
+
__version__ = "0.1.26" 开始使用兼容 AutoCoder 的 chat_model, code_model 参数
|
2309
|
+
不再使用 current_chat_model 和 current_chat_model
|
2310
|
+
"""
|
2311
|
+
if 'current_chat_model' in memory['conf'] and 'current_code_model' in memory['conf']:
|
2312
|
+
logger.warning(f"您当前版本使用的版本偏低, 正在进行配置兼容性处理")
|
2313
|
+
memory['conf']['chat_model'] = memory['conf']['current_chat_model']
|
2314
|
+
memory['conf']['code_model'] = memory['conf']['current_code_model']
|
2315
|
+
del memory['conf']['current_chat_model']
|
2316
|
+
del memory['conf']['current_code_model']
|
4013
2317
|
|
4014
2318
|
|
4015
2319
|
def main():
|
@@ -4021,14 +2325,15 @@ def main():
|
|
4021
2325
|
initialize_system()
|
4022
2326
|
|
4023
2327
|
load_memory()
|
2328
|
+
is_old_version()
|
4024
2329
|
|
4025
2330
|
if len(memory["models"]) == 0:
|
4026
2331
|
_model_pass = input(f" 是否跳过模型配置(y/n): ").strip().lower()
|
4027
2332
|
if _model_pass == "n":
|
4028
2333
|
m1, m2, m3, m4 = configure_project_model()
|
4029
2334
|
print_status(f"正在更新缓存...", "warning")
|
4030
|
-
memory["conf"]["
|
4031
|
-
memory["conf"]["
|
2335
|
+
memory["conf"]["chat_model"] = m1
|
2336
|
+
memory["conf"]["code_model"] = m1
|
4032
2337
|
memory["models"][m1] = {"base_url": m3, "api_key": m4, "model": m2}
|
4033
2338
|
print_status(f"供应商配置已成功完成!后续你可以使用 /models 命令, 查看, 新增和修改所有模型", "success")
|
4034
2339
|
else:
|
@@ -4046,10 +2351,10 @@ def main():
|
|
4046
2351
|
|
4047
2352
|
print_status("初始化完成。", "success")
|
4048
2353
|
|
4049
|
-
if memory["conf"]["
|
4050
|
-
print_status("首选 Chat 模型与部署模型不一致, 请使用 /conf
|
4051
|
-
if memory["conf"]["
|
4052
|
-
print_status("首选 Code 模型与部署模型不一致, 请使用 /conf
|
2354
|
+
if memory["conf"]["chat_model"] not in memory["models"].keys():
|
2355
|
+
print_status("首选 Chat 模型与部署模型不一致, 请使用 /conf chat_model:xxx 设置", "error")
|
2356
|
+
if memory["conf"]["code_model"] not in memory["models"].keys():
|
2357
|
+
print_status("首选 Code 模型与部署模型不一致, 请使用 /conf code_model:xxx 设置", "error")
|
4053
2358
|
|
4054
2359
|
MODES = {
|
4055
2360
|
"normal": "正常模式",
|
@@ -4180,12 +2485,12 @@ def main():
|
|
4180
2485
|
print("\033[91mPlease enter your request.\033[0m")
|
4181
2486
|
continue
|
4182
2487
|
coding(query=query, llm=auto_llm)
|
4183
|
-
elif user_input.startswith("/new"):
|
4184
|
-
|
4185
|
-
|
4186
|
-
|
4187
|
-
|
4188
|
-
|
2488
|
+
# elif user_input.startswith("/new"):
|
2489
|
+
# query = user_input[len("/new"):].strip()
|
2490
|
+
# if not query:
|
2491
|
+
# print("\033[91mPlease enter your request.\033[0m")
|
2492
|
+
# continue
|
2493
|
+
# new_project(query=query, llm=auto_llm)
|
4189
2494
|
elif user_input.startswith("/chat"):
|
4190
2495
|
query = user_input[len("/chat"):].strip()
|
4191
2496
|
if not query:
|