autocoder-nano 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,19 +2,18 @@ import argparse
2
2
  import glob
3
3
  import hashlib
4
4
  import os
5
- import re
6
5
  import json
7
6
  import shutil
8
7
  import subprocess
9
- import tempfile
10
8
  import textwrap
11
9
  import time
12
- import traceback
13
10
  import uuid
14
- from difflib import SequenceMatcher
15
11
 
16
- from autocoder_nano.agent.new.auto_new_project import BuildNewProject
12
+ from autocoder_nano.edit import Dispacher
17
13
  from autocoder_nano.helper import show_help
14
+ from autocoder_nano.index.entry import build_index_and_filter_files
15
+ from autocoder_nano.index.index_manager import IndexManager
16
+ from autocoder_nano.index.symbols_utils import extract_symbols
18
17
  from autocoder_nano.llm_client import AutoLLM
19
18
  from autocoder_nano.version import __version__
20
19
  from autocoder_nano.llm_types import *
@@ -61,8 +60,10 @@ memory = {
61
60
  "current_files": {"files": [], "groups": {}},
62
61
  "conf": {
63
62
  "auto_merge": "editblock",
64
- "current_chat_model": "",
65
- "current_code_model": ""
63
+ # "current_chat_model": "",
64
+ # "current_code_model": "",
65
+ "chat_model": "",
66
+ "code_model": "",
66
67
  },
67
68
  "exclude_dirs": [],
68
69
  "mode": "normal", # 新增mode字段,默认为normal模式
@@ -73,29 +74,6 @@ memory = {
73
74
  args: AutoCoderArgs = AutoCoderArgs()
74
75
 
75
76
 
76
- def extract_symbols(text: str) -> SymbolsInfo:
77
- patterns = {
78
- "usage": r"用途:(.+)",
79
- "functions": r"函数:(.+)",
80
- "variables": r"变量:(.+)",
81
- "classes": r"类:(.+)",
82
- "import_statements": r"导入语句:(.+)",
83
- }
84
-
85
- info = SymbolsInfo()
86
- for field, pattern in patterns.items():
87
- match = re.search(pattern, text)
88
- if match:
89
- value = match.group(1).strip()
90
- if field == "import_statements":
91
- value = [v.strip() for v in value.split("^^")]
92
- elif field == "functions" or field == "variables" or field == "classes":
93
- value = [v.strip() for v in value.split(",")]
94
- setattr(info, field, value)
95
-
96
- return info
97
-
98
-
99
77
  def get_all_file_names_in_project() -> List[str]:
100
78
  file_names = []
101
79
  final_exclude_dirs = default_exclude_dirs + memory.get("exclude_dirs", [])
@@ -794,449 +772,8 @@ def load_memory():
794
772
  completer.update_current_files(memory["current_files"]["files"])
795
773
 
796
774
 
797
- def symbols_info_to_str(info: SymbolsInfo, symbol_types: List[SymbolType]) -> str:
798
- result = []
799
- for symbol_type in symbol_types:
800
- value = getattr(info, symbol_type.value)
801
- if value:
802
- if symbol_type == SymbolType.IMPORT_STATEMENTS:
803
- value_str = "^^".join(value)
804
- elif symbol_type in [SymbolType.FUNCTIONS, SymbolType.VARIABLES, SymbolType.CLASSES,]:
805
- value_str = ",".join(value)
806
- else:
807
- value_str = value
808
- result.append(f"{symbol_type.value}:{value_str}")
809
-
810
- return "\n".join(result)
811
-
812
-
813
- class IndexManager:
814
- def __init__(self, source_codes: List[SourceCode], llm: AutoLLM = None):
815
- self.args = args
816
- self.sources = source_codes
817
- self.source_dir = args.source_dir
818
- self.index_dir = os.path.join(self.source_dir, ".auto-coder")
819
- self.index_file = os.path.join(self.index_dir, "index.json")
820
- self.llm = llm
821
- self.llm.setup_default_model_name(memory["conf"]["current_chat_model"])
822
- self.max_input_length = args.model_max_input_length # 模型输入最大长度
823
- # 使用 time.sleep(self.anti_quota_limit) 防止超过 API 频率限制
824
- self.anti_quota_limit = args.anti_quota_limit
825
- # 如果索引目录不存在,则创建它
826
- if not os.path.exists(self.index_dir):
827
- os.makedirs(self.index_dir)
828
-
829
- def build_index(self):
830
- """ 构建或更新索引,使用多线程处理多个文件,并将更新后的索引数据写入文件 """
831
- if os.path.exists(self.index_file):
832
- with open(self.index_file, "r") as file: # 读缓存
833
- index_data = json.load(file)
834
- else: # 首次 build index
835
- logger.info("首次生成索引.")
836
- index_data = {}
837
-
838
- @prompt()
839
- def error_message(source_dir: str, file_path: str):
840
- """
841
- The source_dir is different from the path in index file (e.g. file_path:{{ file_path }} source_dir:{{
842
- source_dir }}). You may need to replace the prefix with the source_dir in the index file or Just delete
843
- the index file to rebuild it.
844
- """
845
-
846
- for item in index_data.keys():
847
- if not item.startswith(self.source_dir):
848
- logger.warning(error_message(source_dir=self.source_dir, file_path=item))
849
- break
850
-
851
- updated_sources = []
852
- wait_to_build_files = []
853
- for source in self.sources:
854
- source_code = source.source_code
855
- md5 = hashlib.md5(source_code.encode("utf-8")).hexdigest()
856
- if source.module_name not in index_data or index_data[source.module_name]["md5"] != md5:
857
- wait_to_build_files.append(source)
858
- counter = 0
859
- num_files = len(wait_to_build_files)
860
- total_files = len(self.sources)
861
- logger.info(f"总文件数: {total_files}, 需要索引文件数: {num_files}")
862
-
863
- for source in wait_to_build_files:
864
- build_result = self.build_index_for_single_source(source)
865
- if build_result is not None:
866
- counter += 1
867
- logger.info(f"正在构建索引:{counter}/{num_files}...")
868
- module_name = build_result["module_name"]
869
- index_data[module_name] = build_result
870
- updated_sources.append(module_name)
871
- if updated_sources:
872
- with open(self.index_file, "w") as fp:
873
- json_str = json.dumps(index_data, indent=2, ensure_ascii=False)
874
- fp.write(json_str)
875
- return index_data
876
-
877
- def split_text_into_chunks(self, text):
878
- """ 文本分块,将大文本分割成适合 LLM 处理的小块 """
879
- lines = text.split("\n")
880
- chunks = []
881
- current_chunk = []
882
- current_length = 0
883
- for line in lines:
884
- if current_length + len(line) + 1 <= self.max_input_length:
885
- current_chunk.append(line)
886
- current_length += len(line) + 1
887
- else:
888
- chunks.append("\n".join(current_chunk))
889
- current_chunk = [line]
890
- current_length = len(line) + 1
891
- if current_chunk:
892
- chunks.append("\n".join(current_chunk))
893
- return chunks
894
-
895
- @prompt()
896
- def get_all_file_symbols(self, path: str, code: str) -> str:
897
- """
898
- 你的目标是从给定的代码中获取代码里的符号,需要获取的符号类型包括:
899
-
900
- 1. 函数
901
- 2. 类
902
- 3. 变量
903
- 4. 所有导入语句
904
-
905
- 如果没有任何符号,返回空字符串就行。
906
- 如果有符号,按如下格式返回:
907
-
908
- ```
909
- {符号类型}: {符号名称}, {符号名称}, ...
910
- ```
911
-
912
- 注意:
913
- 1. 直接输出结果,不要尝试使用任何代码
914
- 2. 不要分析代码的内容和目的
915
- 3. 用途的长度不能超过100字符
916
- 4. 导入语句的分隔符为^^
917
-
918
- 下面是一段示例:
919
-
920
- ## 输入
921
- 下列是文件 /test.py 的源码:
922
-
923
- import os
924
- import time
925
- from loguru import logger
926
- import byzerllm
927
-
928
- a = ""
929
-
930
- @byzerllm.prompt(render="jinja")
931
- def auto_implement_function_template(instruction:str, content:str)->str:
932
-
933
- ## 输出
934
- 用途:主要用于提供自动实现函数模板的功能。
935
- 函数:auto_implement_function_template
936
- 变量:a
937
- 类:
938
- 导入语句:import os^^import time^^from loguru import logger^^import byzerllm
939
-
940
- 现在,让我们开始一个新的任务:
941
-
942
- ## 输入
943
- 下列是文件 {{ path }} 的源码:
944
-
945
- {{ code }}
946
-
947
- ## 输出
948
- """
949
-
950
- def build_index_for_single_source(self, source: SourceCode):
951
- """ 处理单个源文件,提取符号信息并存储元数据 """
952
- file_path = source.module_name
953
- if not os.path.exists(file_path): # 过滤不存在的文件
954
- return None
955
-
956
- ext = os.path.splitext(file_path)[1].lower()
957
- if ext in [".md", ".html", ".txt", ".doc", ".pdf"]: # 过滤文档文件
958
- return None
959
-
960
- if source.source_code.strip() == "":
961
- return None
962
-
963
- md5 = hashlib.md5(source.source_code.encode("utf-8")).hexdigest()
964
-
965
- try:
966
- start_time = time.monotonic()
967
- source_code = source.source_code
968
- if len(source.source_code) > self.max_input_length:
969
- logger.warning(
970
- f"警告[构建索引]: 源代码({source.module_name})长度过长 "
971
- f"({len(source.source_code)}) > 模型最大输入长度({self.max_input_length}),"
972
- f"正在分割为多个块..."
973
- )
974
- chunks = self.split_text_into_chunks(source_code)
975
- symbols_list = []
976
- for chunk in chunks:
977
- chunk_symbols = self.get_all_file_symbols.with_llm(self.llm).run(source.module_name, chunk)
978
- time.sleep(self.anti_quota_limit)
979
- symbols_list.append(chunk_symbols.output)
980
- symbols = "\n".join(symbols_list)
981
- else:
982
- single_symbols = self.get_all_file_symbols.with_llm(self.llm).run(source.module_name, source_code)
983
- symbols = single_symbols.output
984
- time.sleep(self.anti_quota_limit)
985
-
986
- logger.info(f"解析并更新索引:文件 {file_path}(MD5: {md5}),耗时 {time.monotonic() - start_time:.2f} 秒")
987
- except Exception as e:
988
- logger.warning(f"源文件 {file_path} 处理失败: {e}")
989
- return None
990
-
991
- return {
992
- "module_name": source.module_name,
993
- "symbols": symbols,
994
- "last_modified": os.path.getmtime(file_path),
995
- "md5": md5,
996
- }
997
-
998
- @prompt()
999
- def _get_target_files_by_query(self, indices: str, query: str) -> str:
1000
- """
1001
- 下面是已知文件以及对应的符号信息:
1002
-
1003
- {{ indices }}
1004
-
1005
- 用户的问题是:
1006
-
1007
- {{ query }}
1008
-
1009
- 现在,请根据用户的问题以及前面的文件和符号信息,寻找相关文件路径。返回结果按如下格式:
1010
-
1011
- ```json
1012
- {
1013
- "file_list": [
1014
- {
1015
- "file_path": "path/to/file.py",
1016
- "reason": "The reason why the file is the target file"
1017
- },
1018
- {
1019
- "file_path": "path/to/file.py",
1020
- "reason": "The reason why the file is the target file"
1021
- }
1022
- ]
1023
- }
1024
- ```
1025
-
1026
- 如果没有找到,返回如下 json 即可:
1027
-
1028
- ```json
1029
- {"file_list": []}
1030
- ```
1031
-
1032
- 请严格遵循以下步骤:
1033
-
1034
- 1. 识别特殊标记:
1035
- - 查找query中的 `@` 符号,它后面的内容是用户关注的文件路径。
1036
- - 查找query中的 `@@` 符号,它后面的内容是用户关注的符号(如函数名、类名、变量名)。
1037
-
1038
- 2. 匹配文件路径:
1039
- - 对于 `@` 标记,在indices中查找包含该路径的所有文件。
1040
- - 路径匹配应该是部分匹配,因为用户可能只提供了路径的一部分。
1041
-
1042
- 3. 匹配符号:
1043
- - 对于 `@@` 标记,在indices中所有文件的符号信息中查找该符号。
1044
- - 检查函数、类、变量等所有符号类型。
1045
-
1046
- 4. 分析依赖关系:
1047
- - 利用 "导入语句" 信息确定文件间的依赖关系。
1048
- - 如果找到了相关文件,也包括与之直接相关的依赖文件。
1049
-
1050
- 5. 考虑文件用途:
1051
- - 使用每个文件的 "用途" 信息来判断其与查询的相关性。
1052
-
1053
- 6. 请严格按格式要求返回结果,无需额外的说明
1054
-
1055
- 请确保结果的准确性和完整性,包括所有可能相关的文件。
1056
- """
1057
-
1058
- def read_index(self) -> List[IndexItem]:
1059
- """ 读取并解析索引文件,将其转换为 IndexItem 对象列表 """
1060
- if not os.path.exists(self.index_file):
1061
- return []
1062
-
1063
- with open(self.index_file, "r") as file:
1064
- index_data = json.load(file)
1065
-
1066
- index_items = []
1067
- for module_name, data in index_data.items():
1068
- index_item = IndexItem(
1069
- module_name=module_name,
1070
- symbols=data["symbols"],
1071
- last_modified=data["last_modified"],
1072
- md5=data["md5"]
1073
- )
1074
- index_items.append(index_item)
1075
-
1076
- return index_items
1077
-
1078
- def _get_meta_str(self, includes: Optional[List[SymbolType]] = None):
1079
- index_items = self.read_index()
1080
- current_chunk = []
1081
- for item in index_items:
1082
- symbols_str = item.symbols
1083
- if includes:
1084
- symbol_info = extract_symbols(symbols_str)
1085
- symbols_str = symbols_info_to_str(symbol_info, includes)
1086
-
1087
- item_str = f"##{item.module_name}\n{symbols_str}\n\n"
1088
- if len(current_chunk) > self.args.filter_batch_size:
1089
- yield "".join(current_chunk)
1090
- current_chunk = [item_str]
1091
- else:
1092
- current_chunk.append(item_str)
1093
- if current_chunk:
1094
- yield "".join(current_chunk)
1095
-
1096
- def get_target_files_by_query(self, query: str):
1097
- """ 根据查询条件查找相关文件,考虑不同过滤级别 """
1098
- all_results = []
1099
- completed = 0
1100
- total = 0
1101
-
1102
- includes = None
1103
- if self.args.index_filter_level == 0:
1104
- includes = [SymbolType.USAGE]
1105
- if self.args.index_filter_level >= 1:
1106
- includes = None
1107
-
1108
- for chunk in self._get_meta_str(includes=includes):
1109
- result = self._get_target_files_by_query.with_llm(self.llm).with_return_type(FileList).run(chunk, query)
1110
- if result is not None:
1111
- all_results.extend(result.file_list)
1112
- completed += 1
1113
- else:
1114
- logger.warning(f"无法找到分块的目标文件。原因可能是模型响应未返回 JSON 格式数据,或返回的 JSON 为空。")
1115
- total += 1
1116
- time.sleep(self.anti_quota_limit)
1117
-
1118
- logger.info(f"已完成 {completed}/{total} 个分块(基于查询条件)")
1119
- all_results = list({file.file_path: file for file in all_results}.values())
1120
- if self.args.index_filter_file_num > 0:
1121
- limited_results = all_results[: self.args.index_filter_file_num]
1122
- return FileList(file_list=limited_results)
1123
- return FileList(file_list=all_results)
1124
-
1125
- @prompt()
1126
- def _get_related_files(self, indices: str, file_paths: str) -> str:
1127
- """
1128
- 下面是所有文件以及对应的符号信息:
1129
-
1130
- {{ indices }}
1131
-
1132
- 请参考上面的信息,找到被下列文件使用或者引用到的文件列表:
1133
-
1134
- {{ file_paths }}
1135
-
1136
- 请按如下格式进行输出:
1137
-
1138
- ```json
1139
- {
1140
- "file_list": [
1141
- {
1142
- "file_path": "path/to/file.py",
1143
- "reason": "The reason why the file is the target file"
1144
- },
1145
- {
1146
- "file_path": "path/to/file.py",
1147
- "reason": "The reason why the file is the target file"
1148
- }
1149
- ]
1150
- }
1151
- ```
1152
-
1153
- 如果没有相关的文件,输出如下 json 即可:
1154
-
1155
- ```json
1156
- {"file_list": []}
1157
- ```
1158
-
1159
- 注意,
1160
- 1. 找到的文件名必须出现在上面的文件列表中
1161
- 2. 原因控制在20字以内, 且使用中文
1162
- 3. 请严格按格式要求返回结果,无需额外的说明
1163
- """
1164
-
1165
- def get_related_files(self, file_paths: List[str]):
1166
- """ 根据文件路径查询相关文件 """
1167
- all_results = []
1168
-
1169
- completed = 0
1170
- total = 0
1171
-
1172
- for chunk in self._get_meta_str():
1173
- result = self._get_related_files.with_llm(self.llm).with_return_type(
1174
- FileList).run(chunk, "\n".join(file_paths))
1175
- if result is not None:
1176
- all_results.extend(result.file_list)
1177
- completed += 1
1178
- else:
1179
- logger.warning(f"无法找到与分块相关的文件。原因可能是模型限制或查询条件与文件不匹配。")
1180
- total += 1
1181
- time.sleep(self.anti_quota_limit)
1182
- logger.info(f"已完成 {completed}/{total} 个分块(基于相关文件)")
1183
- all_results = list({file.file_path: file for file in all_results}.values())
1184
- return FileList(file_list=all_results)
1185
-
1186
- @prompt()
1187
- def verify_file_relevance(self, file_content: str, query: str) -> str:
1188
- """
1189
- 请验证下面的文件内容是否与用户问题相关:
1190
-
1191
- 文件内容:
1192
- {{ file_content }}
1193
-
1194
- 用户问题:
1195
- {{ query }}
1196
-
1197
- 相关是指,需要依赖这个文件提供上下文,或者需要修改这个文件才能解决用户的问题。
1198
- 请给出相应的可能性分数:0-10,并结合用户问题,理由控制在50字以内,并且使用中文。
1199
- 请严格按格式要求返回结果。
1200
- 格式如下:
1201
-
1202
- ```json
1203
- {
1204
- "relevant_score": 0-10,
1205
- "reason": "这是相关的原因..."
1206
- }
1207
- ```
1208
- """
1209
-
1210
-
1211
775
  def index_command(llm):
1212
- conf = memory.get("conf", {})
1213
- # 默认 chat 配置
1214
- yaml_config = {
1215
- "include_file": ["./base/base.yml"],
1216
- "include_project_structure": conf.get("include_project_structure", "true") in ["true", "True"],
1217
- "human_as_model": conf.get("human_as_model", "false") == "true",
1218
- "skip_build_index": conf.get("skip_build_index", "true") == "true",
1219
- "skip_confirm": conf.get("skip_confirm", "true") == "true",
1220
- "silence": conf.get("silence", "true") == "true",
1221
- "query": ""
1222
- }
1223
- current_files = memory["current_files"]["files"] # get_llm_friendly_package_docs
1224
- yaml_config["urls"] = current_files
1225
- yaml_config["query"] = ""
1226
-
1227
- # 如果 conf 中有设置, 则以 conf 配置为主
1228
- for key, value in conf.items():
1229
- converted_value = convert_config_value(key, value)
1230
- if converted_value is not None:
1231
- yaml_config[key] = converted_value
1232
-
1233
- yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
1234
- execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
1235
-
1236
- with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
1237
- f.write(yaml_content)
1238
-
1239
- convert_yaml_to_config(execute_file) # 更新到args
776
+ update_config_to_args(query="", delete_execute_file=True)
1240
777
 
1241
778
  source_dir = os.path.abspath(args.source_dir)
1242
779
  logger.info(f"开始对目录 {source_dir} 中的源代码进行索引")
@@ -1246,7 +783,7 @@ def index_command(llm):
1246
783
  pp = SuffixProject(llm=llm, args=args)
1247
784
  pp.run()
1248
785
  _sources = pp.sources
1249
- index_manager = IndexManager(source_codes=_sources, llm=llm)
786
+ index_manager = IndexManager(args=args, source_codes=_sources, llm=llm)
1250
787
  index_manager.build_index()
1251
788
 
1252
789
 
@@ -1332,34 +869,7 @@ def wrap_text_in_table(data, max_width=60):
1332
869
 
1333
870
 
1334
871
  def index_query_command(query: str, llm: AutoLLM):
1335
- conf = memory.get("conf", {})
1336
- # 默认 chat 配置
1337
- yaml_config = {
1338
- "include_file": ["./base/base.yml"],
1339
- "include_project_structure": conf.get("include_project_structure", "true") in ["true", "True"],
1340
- "human_as_model": conf.get("human_as_model", "false") == "true",
1341
- "skip_build_index": conf.get("skip_build_index", "true") == "true",
1342
- "skip_confirm": conf.get("skip_confirm", "true") == "true",
1343
- "silence": conf.get("silence", "true") == "true",
1344
- "query": query
1345
- }
1346
- current_files = memory["current_files"]["files"] # get_llm_friendly_package_docs
1347
- yaml_config["urls"] = current_files
1348
- yaml_config["query"] = query
1349
-
1350
- # 如果 conf 中有设置, 则以 conf 配置为主
1351
- for key, value in conf.items():
1352
- converted_value = convert_config_value(key, value)
1353
- if converted_value is not None:
1354
- yaml_config[key] = converted_value
1355
-
1356
- yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
1357
- execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
1358
-
1359
- with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
1360
- f.write(yaml_content)
1361
-
1362
- convert_yaml_to_config(execute_file) # 更新到args
872
+ update_config_to_args(query=query, delete_execute_file=True)
1363
873
 
1364
874
  # args.query = query
1365
875
  if args.project_type == "py":
@@ -1370,7 +880,7 @@ def index_query_command(query: str, llm: AutoLLM):
1370
880
  _sources = pp.sources
1371
881
 
1372
882
  final_files = []
1373
- index_manager = IndexManager(source_codes=_sources, llm=llm)
883
+ index_manager = IndexManager(args=args, source_codes=_sources, llm=llm)
1374
884
  target_files = index_manager.get_target_files_by_query(query)
1375
885
 
1376
886
  if target_files:
@@ -1397,159 +907,6 @@ def index_query_command(query: str, llm: AutoLLM):
1397
907
  return
1398
908
 
1399
909
 
1400
- def build_index_and_filter_files(llm, sources: List[SourceCode]) -> str:
1401
- def get_file_path(_file_path):
1402
- if _file_path.startswith("##"):
1403
- return _file_path.strip()[2:]
1404
- return _file_path
1405
-
1406
- final_files: Dict[str, TargetFile] = {}
1407
- logger.info("第一阶段:处理 REST/RAG/Search 资源...")
1408
- for source in sources:
1409
- if source.tag in ["REST", "RAG", "SEARCH"]:
1410
- final_files[get_file_path(source.module_name)] = TargetFile(
1411
- file_path=source.module_name, reason="Rest/Rag/Search"
1412
- )
1413
-
1414
- if not args.skip_build_index and llm:
1415
- logger.info("第二阶段:为所有文件构建索引...")
1416
- index_manager = IndexManager(llm=llm, source_codes=sources)
1417
- index_data = index_manager.build_index()
1418
- indexed_files_count = len(index_data) if index_data else 0
1419
- logger.info(f"总索引文件数: {indexed_files_count}")
1420
-
1421
- if not args.skip_filter_index and args.index_filter_level >= 1:
1422
- logger.info("第三阶段:执行 Level 1 过滤(基于查询) ...")
1423
- target_files = index_manager.get_target_files_by_query(args.query)
1424
- if target_files:
1425
- for file in target_files.file_list:
1426
- file_path = file.file_path.strip()
1427
- final_files[get_file_path(file_path)] = file
1428
-
1429
- if target_files is not None and args.index_filter_level >= 2:
1430
- logger.info("第四阶段:执行 Level 2 过滤(基于相关文件)...")
1431
- related_files = index_manager.get_related_files(
1432
- [file.file_path for file in target_files.file_list]
1433
- )
1434
- if related_files is not None:
1435
- for file in related_files.file_list:
1436
- file_path = file.file_path.strip()
1437
- final_files[get_file_path(file_path)] = file
1438
-
1439
- # 如果 Level 1 filtering 和 Level 2 filtering 都未获取路径,则使用全部文件
1440
- if not final_files:
1441
- logger.warning("Level 1, Level 2 过滤未找到相关文件, 将使用所有文件 ...")
1442
- for source in sources:
1443
- final_files[get_file_path(source.module_name)] = TargetFile(
1444
- file_path=source.module_name,
1445
- reason="No related files found, use all files",
1446
- )
1447
-
1448
- logger.info("第五阶段:执行相关性验证 ...")
1449
- verified_files = {}
1450
- temp_files = list(final_files.values())
1451
- verification_results = []
1452
-
1453
- def _print_verification_results(results):
1454
- table = Table(title="文件相关性验证结果", expand=True, show_lines=True)
1455
- table.add_column("文件路径", style="cyan", no_wrap=True)
1456
- table.add_column("得分", justify="right", style="green")
1457
- table.add_column("状态", style="yellow")
1458
- table.add_column("原因/错误")
1459
- if result:
1460
- for _file_path, _score, _status, _reason in results:
1461
- table.add_row(_file_path,
1462
- str(_score) if _score is not None else "N/A", _status, _reason)
1463
- console.print(table)
1464
-
1465
- def _verify_single_file(single_file: TargetFile):
1466
- for _source in sources:
1467
- if _source.module_name == single_file.file_path:
1468
- file_content = _source.source_code
1469
- try:
1470
- _result = index_manager.verify_file_relevance.with_llm(llm).with_return_type(
1471
- VerifyFileRelevance).run(
1472
- file_content=file_content,
1473
- query=args.query
1474
- )
1475
- if _result.relevant_score >= args.verify_file_relevance_score:
1476
- verified_files[single_file.file_path] = TargetFile(
1477
- file_path=single_file.file_path,
1478
- reason=f"Score:{_result.relevant_score}, {_result.reason}"
1479
- )
1480
- return single_file.file_path, _result.relevant_score, "PASS", _result.reason
1481
- else:
1482
- return single_file.file_path, _result.relevant_score, "FAIL", _result.reason
1483
- except Exception as e:
1484
- error_msg = str(e)
1485
- verified_files[single_file.file_path] = TargetFile(
1486
- file_path=single_file.file_path,
1487
- reason=f"Verification failed: {error_msg}"
1488
- )
1489
- return single_file.file_path, None, "ERROR", error_msg
1490
- return
1491
-
1492
- for pending_verify_file in temp_files:
1493
- result = _verify_single_file(pending_verify_file)
1494
- if result:
1495
- verification_results.append(result)
1496
- time.sleep(args.anti_quota_limit)
1497
-
1498
- _print_verification_results(verification_results)
1499
- # Keep all files, not just verified ones
1500
- final_files = verified_files
1501
-
1502
- logger.info("第六阶段:筛选文件并应用限制条件 ...")
1503
- if args.index_filter_file_num > 0:
1504
- logger.info(f"从 {len(final_files)} 个文件中获取前 {args.index_filter_file_num} 个文件(Limit)")
1505
- final_filenames = [file.file_path for file in final_files.values()]
1506
- if not final_filenames:
1507
- logger.warning("未找到目标文件,你可能需要重新编写查询并重试.")
1508
- if args.index_filter_file_num > 0:
1509
- final_filenames = final_filenames[: args.index_filter_file_num]
1510
-
1511
- def _shorten_path(path: str, keep_levels: int = 3) -> str:
1512
- """
1513
- 优化长路径显示,保留最后指定层级
1514
- 示例:/a/b/c/d/e/f.py -> .../c/d/e/f.py
1515
- """
1516
- parts = path.split(os.sep)
1517
- if len(parts) > keep_levels:
1518
- return ".../" + os.sep.join(parts[-keep_levels:])
1519
- return path
1520
-
1521
- def _print_selected(data):
1522
- table = Table(title="代码上下文文件", expand=True, show_lines=True)
1523
- table.add_column("文件路径", style="cyan")
1524
- table.add_column("原因", style="cyan")
1525
- for _file, _reason in data:
1526
- # 路径截取优化:保留最后 3 级路径
1527
- _processed_path = _shorten_path(_file, keep_levels=3)
1528
- table.add_row(_processed_path, _reason)
1529
- console.print(table)
1530
-
1531
- logger.info("第七阶段:准备最终输出 ...")
1532
- _print_selected(
1533
- [
1534
- (file.file_path, file.reason)
1535
- for file in final_files.values()
1536
- if file.file_path in final_filenames
1537
- ]
1538
- )
1539
- result_source_code = ""
1540
- depulicated_sources = set()
1541
-
1542
- for file in sources:
1543
- if file.module_name in final_filenames:
1544
- if file.module_name in depulicated_sources:
1545
- continue
1546
- depulicated_sources.add(file.module_name)
1547
- result_source_code += f"##File: {file.module_name}\n"
1548
- result_source_code += f"{file.source_code}\n\n"
1549
-
1550
- return result_source_code
1551
-
1552
-
1553
910
  def convert_yaml_config_to_str(yaml_config):
1554
911
  yaml_content = yaml.safe_dump(
1555
912
  yaml_config,
@@ -1598,6 +955,41 @@ def convert_config_value(key, value):
1598
955
  return None
1599
956
 
1600
957
 
958
+ def update_config_to_args(query, delete_execute_file: bool = False):
959
+ conf = memory.get("conf", {})
960
+
961
+ # 默认 chat 配置
962
+ yaml_config = {
963
+ "include_file": ["./base/base.yml"],
964
+ "skip_build_index": conf.get("skip_build_index", "true") == "true",
965
+ "skip_confirm": conf.get("skip_confirm", "true") == "true",
966
+ "chat_model": conf.get("chat_model", ""),
967
+ "code_model": conf.get("code_model", ""),
968
+ "auto_merge": conf.get("auto_merge", "editblock")
969
+ }
970
+ current_files = memory["current_files"]["files"]
971
+ yaml_config["urls"] = current_files
972
+ yaml_config["query"] = query
973
+
974
+ # 如果 conf 中有设置, 则以 conf 配置为主
975
+ for key, value in conf.items():
976
+ converted_value = convert_config_value(key, value)
977
+ if converted_value is not None:
978
+ yaml_config[key] = converted_value
979
+
980
+ yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
981
+ execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
982
+
983
+ with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
984
+ f.write(yaml_content)
985
+
986
+ convert_yaml_to_config(execute_file) # 更新到args
987
+
988
+ if delete_execute_file:
989
+ if os.path.exists(execute_file):
990
+ os.remove(execute_file)
991
+
992
+
1601
993
  def print_chat_history(history, max_entries=5):
1602
994
  recent_history = history[-max_entries:]
1603
995
  table = Table(show_header=False, padding=(0, 1), expand=True, show_lines=True)
@@ -1638,6 +1030,8 @@ def code_review(query: str) -> str:
1638
1030
 
1639
1031
 
1640
1032
  def chat(query: str, llm: AutoLLM):
1033
+ update_config_to_args(query)
1034
+
1641
1035
  is_history = query.strip().startswith("/history")
1642
1036
  is_new = "/new" in query
1643
1037
  if is_new:
@@ -1652,36 +1046,6 @@ def chat(query: str, llm: AutoLLM):
1652
1046
  query = query.replace("/review", "", 1).strip()
1653
1047
  query = code_review.prompt(query)
1654
1048
 
1655
- conf = memory.get("conf", {})
1656
- # 默认 chat 配置
1657
- yaml_config = {
1658
- "include_file": ["./base/base.yml"],
1659
- "include_project_structure": conf.get("include_project_structure", "true") in ["true", "True"],
1660
- "human_as_model": conf.get("human_as_model", "false") == "true",
1661
- "skip_build_index": conf.get("skip_build_index", "true") == "true",
1662
- "skip_confirm": conf.get("skip_confirm", "true") == "true",
1663
- "silence": conf.get("silence", "true") == "true",
1664
- "query": query
1665
- }
1666
- current_files = memory["current_files"]["files"] # get_llm_friendly_package_docs
1667
- yaml_config["urls"] = current_files
1668
-
1669
- yaml_config["query"] = query
1670
-
1671
- # 如果 conf 中有设置, 则以 conf 配置为主
1672
- for key, value in conf.items():
1673
- converted_value = convert_config_value(key, value)
1674
- if converted_value is not None:
1675
- yaml_config[key] = converted_value
1676
-
1677
- yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
1678
- execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
1679
-
1680
- with open(os.path.join(execute_file), "w") as f: # 保存此次查询的细节
1681
- f.write(yaml_content)
1682
-
1683
- convert_yaml_to_config(execute_file) # 更新到args
1684
-
1685
1049
  memory_dir = os.path.join(args.source_dir, ".auto-coder", "memory")
1686
1050
  os.makedirs(memory_dir, exist_ok=True)
1687
1051
  memory_file = os.path.join(memory_dir, "chat_history.json")
@@ -1745,7 +1109,7 @@ def chat(query: str, llm: AutoLLM):
1745
1109
  pp = SuffixProject(llm=llm, args=args)
1746
1110
  pp.run()
1747
1111
  _sources = pp.sources
1748
- s = build_index_and_filter_files(llm=llm, sources=_sources)
1112
+ s = build_index_and_filter_files(args=args, llm=llm, sources=_sources)
1749
1113
  if s:
1750
1114
  pre_conversations.append(
1751
1115
  {
@@ -1760,7 +1124,7 @@ def chat(query: str, llm: AutoLLM):
1760
1124
 
1761
1125
  loaded_conversations = pre_conversations + chat_history["ask_conversation"]
1762
1126
 
1763
- v = chat_llm.stream_chat_ai(conversations=loaded_conversations, model=memory["conf"]["current_chat_model"])
1127
+ v = chat_llm.stream_chat_ai(conversations=loaded_conversations, model=args.chat_model)
1764
1128
 
1765
1129
  MAX_HISTORY_LINES = 15 # 最大保留历史行数
1766
1130
  lines_buffer = []
@@ -1814,30 +1178,6 @@ def chat(query: str, llm: AutoLLM):
1814
1178
  return
1815
1179
 
1816
1180
 
1817
- def git_print_commit_info(commit_result: CommitResult):
1818
- table = Table(
1819
- title="Commit Information (Use /revert to revert this commit)", show_header=True, header_style="bold magenta"
1820
- )
1821
- table.add_column("Attribute", style="cyan", no_wrap=True)
1822
- table.add_column("Value", style="green")
1823
-
1824
- table.add_row("Commit Hash", commit_result.commit_hash)
1825
- table.add_row("Commit Message", commit_result.commit_message)
1826
- table.add_row("Changed Files", "\n".join(commit_result.changed_files))
1827
-
1828
- console.print(
1829
- Panel(table, expand=False, border_style="green", title="Git Commit Summary")
1830
- )
1831
-
1832
- if commit_result.diffs:
1833
- for file, diff in commit_result.diffs.items():
1834
- console.print(f"\n[bold blue]File: {file}[/bold blue]")
1835
- syntax = Syntax(diff, "diff", theme="monokai", line_numbers=True)
1836
- console.print(
1837
- Panel(syntax, expand=False, border_style="yellow", title="File Diff")
1838
- )
1839
-
1840
-
1841
1181
  def init_project():
1842
1182
  if not args.project_type:
1843
1183
  logger.error(
@@ -1910,1040 +1250,6 @@ def load_include_files(config, base_path, max_depth=10, current_depth=0):
1910
1250
  return config
1911
1251
 
1912
1252
 
1913
- class CodeAutoGenerateEditBlock:
1914
- def __init__(self, llm: AutoLLM, action=None, fence_0: str = "```", fence_1: str = "```"):
1915
- self.llm = llm
1916
- self.llm.setup_default_model_name(memory["conf"]["current_code_model"])
1917
- self.args = args
1918
- self.action = action
1919
- self.fence_0 = fence_0
1920
- self.fence_1 = fence_1
1921
- if not self.llm:
1922
- raise ValueError("Please provide a valid model instance to use for code generation.")
1923
- self.llms = [self.llm]
1924
-
1925
- @prompt()
1926
- def single_round_instruction(self, instruction: str, content: str, context: str = ""):
1927
- """
1928
- 如果你需要生成代码,对于每个需要更改的文件,你需要按 *SEARCH/REPLACE block* 的格式进行生成。
1929
-
1930
- # *SEARCH/REPLACE block* Rules:
1931
-
1932
- Every *SEARCH/REPLACE block* must use this format:
1933
- 1. The opening fence and code language, eg: {{ fence_0 }}python
1934
- 2. The file path alone on a line, starting with "##File:" and verbatim. No bold asterisks, no quotes around it,
1935
- no escaping of characters, etc.
1936
- 3. The start of search block: <<<<<<< SEARCH
1937
- 4. A contiguous chunk of lines to search for in the existing source code
1938
- 5. The dividing line: =======
1939
- 6. The lines to replace into the source code
1940
- 7. The end of the replacement block: >>>>>>> REPLACE
1941
- 8. The closing fence: {{ fence_1 }}
1942
-
1943
- Every *SEARCH* section must *EXACTLY MATCH* the existing source code, character for character,
1944
- including all comments, docstrings, etc.
1945
-
1946
- *SEARCH/REPLACE* blocks will replace *all* matching occurrences.
1947
- Include enough lines to make the SEARCH blocks unique.
1948
-
1949
- Include *ALL* the code being searched and replaced!
1950
-
1951
- To move code within a file, use 2 *SEARCH/REPLACE* blocks: 1 to delete it from its current location,
1952
- 1 to insert it in the new location.
1953
-
1954
- If you want to put code in a new file, use a *SEARCH/REPLACE block* with:
1955
- - A new file path, including dir name if needed
1956
- - An empty `SEARCH` section
1957
- - The new file's contents in the `REPLACE` section
1958
-
1959
- ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
1960
-
1961
- 下面我们来看一个例子:
1962
-
1963
- 当前项目目录结构:
1964
- 1. 项目根目录: /tmp/projects/mathweb
1965
- 2. 项目子目录/文件列表(类似tree 命令输出)
1966
- flask/
1967
- app.py
1968
- templates/
1969
- index.html
1970
- static/
1971
- style.css
1972
-
1973
- 用户需求: Change get_factorial() to use math.factorial
1974
-
1975
- 回答: To make this change we need to modify `/tmp/projects/mathweb/flask/app.py` to:
1976
-
1977
- 1. Import the math package.
1978
- 2. Remove the existing factorial() function.
1979
- 3. Update get_factorial() to call math.factorial instead.
1980
-
1981
- Here are the *SEARCH/REPLACE* blocks:
1982
-
1983
- ```python
1984
- ##File: /tmp/projects/mathweb/flask/app.py
1985
- <<<<<<< SEARCH
1986
- from flask import Flask
1987
- =======
1988
- import math
1989
- from flask import Flask
1990
- >>>>>>> REPLACE
1991
- ```
1992
-
1993
- ```python
1994
- ##File: /tmp/projects/mathweb/flask/app.py
1995
- <<<<<<< SEARCH
1996
- def factorial(n):
1997
- "compute factorial"
1998
-
1999
- if n == 0:
2000
- return 1
2001
- else:
2002
- return n * factorial(n-1)
2003
-
2004
- =======
2005
- >>>>>>> REPLACE
2006
- ```
2007
-
2008
- ```python
2009
- ##File: /tmp/projects/mathweb/flask/app.py
2010
- <<<<<<< SEARCH
2011
- return str(factorial(n))
2012
- =======
2013
- return str(math.factorial(n))
2014
- >>>>>>> REPLACE
2015
- ```
2016
-
2017
- 用户需求: Refactor hello() into its own file.
2018
-
2019
- 回答:To make this change we need to modify `main.py` and make a new file `hello.py`:
2020
-
2021
- 1. Make a new hello.py file with hello() in it.
2022
- 2. Remove hello() from main.py and replace it with an import.
2023
-
2024
- Here are the *SEARCH/REPLACE* blocks:
2025
-
2026
- ```python
2027
- ##File: /tmp/projects/mathweb/hello.py
2028
- <<<<<<< SEARCH
2029
- =======
2030
- def hello():
2031
- "print a greeting"
2032
-
2033
- print("hello")
2034
- >>>>>>> REPLACE
2035
- ```
2036
-
2037
- ```python
2038
- ##File: /tmp/projects/mathweb/main.py
2039
- <<<<<<< SEARCH
2040
- def hello():
2041
- "print a greeting"
2042
-
2043
- print("hello")
2044
- =======
2045
- from hello import hello
2046
- >>>>>>> REPLACE
2047
- ```
2048
-
2049
- 现在让我们开始一个新的任务:
2050
-
2051
- {%- if structure %}
2052
- {{ structure }}
2053
- {%- endif %}
2054
-
2055
- {%- if content %}
2056
- 下面是一些文件路径以及每个文件对应的源码:
2057
- <files>
2058
- {{ content }}
2059
- </files>
2060
- {%- endif %}
2061
-
2062
- {%- if context %}
2063
- <extra_context>
2064
- {{ context }}
2065
- </extra_context>
2066
- {%- endif %}
2067
-
2068
- 下面是用户的需求:
2069
-
2070
- {{ instruction }}
2071
-
2072
- """
2073
-
2074
- @prompt()
2075
- def auto_implement_function(self, instruction: str, content: str) -> str:
2076
- """
2077
- 下面是一些文件路径以及每个文件对应的源码:
2078
-
2079
- {{ content }}
2080
-
2081
- 请参考上面的内容,重新实现所有文件下方法体标记了如下内容的方法:
2082
-
2083
- ```python
2084
- raise NotImplementedError("This function should be implemented by the model.")
2085
- ```
2086
-
2087
- {{ instruction }}
2088
-
2089
- """
2090
-
2091
- def single_round_run(self, query: str, source_content: str) -> CodeGenerateResult:
2092
- init_prompt = ''
2093
- if self.args.template == "common":
2094
- init_prompt = self.single_round_instruction.prompt(
2095
- instruction=query, content=source_content, context=self.args.context
2096
- )
2097
- elif self.args.template == "auto_implement":
2098
- init_prompt = self.auto_implement_function.prompt(
2099
- instruction=query, content=source_content
2100
- )
2101
-
2102
- with open(self.args.target_file, "w") as file:
2103
- file.write(init_prompt)
2104
-
2105
- conversations = [{"role": "user", "content": init_prompt}]
2106
-
2107
- conversations_list = []
2108
- results = []
2109
-
2110
- for llm in self.llms:
2111
- v = llm.chat_ai(conversations=conversations)
2112
- results.append(v.output)
2113
- for result in results:
2114
- conversations_list.append(conversations + [{"role": "assistant", "content": result}])
2115
-
2116
- return CodeGenerateResult(contents=results, conversations=conversations_list)
2117
-
2118
- @prompt()
2119
- def multi_round_instruction(self, instruction: str, content: str, context: str = "") -> str:
2120
- """
2121
- 如果你需要生成代码,对于每个需要更改的文件,你需要按 *SEARCH/REPLACE block* 的格式进行生成。
2122
-
2123
- # *SEARCH/REPLACE block* Rules:
2124
-
2125
- Every *SEARCH/REPLACE block* must use this format:
2126
- 1. The opening fence and code language, eg: {{ fence_0 }}python
2127
- 2. The file path alone on a line, starting with "##File:" and verbatim. No bold asterisks, no quotes around it,
2128
- no escaping of characters, etc.
2129
- 3. The start of search block: <<<<<<< SEARCH
2130
- 4. A contiguous chunk of lines to search for in the existing source code
2131
- 5. The dividing line: =======
2132
- 6. The lines to replace into the source code
2133
- 7. The end of the replacement block: >>>>>>> REPLACE
2134
- 8. The closing fence: {{ fence_1 }}
2135
-
2136
- Every *SEARCH* section must *EXACTLY MATCH* the existing source code, character for character,
2137
- including all comments, docstrings, etc.
2138
-
2139
- *SEARCH/REPLACE* blocks will replace *all* matching occurrences.
2140
- Include enough lines to make the SEARCH blocks unique.
2141
-
2142
- Include *ALL* the code being searched and replaced!
2143
-
2144
- To move code within a file, use 2 *SEARCH/REPLACE* blocks: 1 to delete it from its current location,
2145
- 1 to insert it in the new location.
2146
-
2147
- If you want to put code in a new file, use a *SEARCH/REPLACE block* with:
2148
- - A new file path, including dir name if needed
2149
- - An empty `SEARCH` section
2150
- - The new file's contents in the `REPLACE` section
2151
-
2152
- ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
2153
-
2154
- 下面我们来看一个例子:
2155
-
2156
- 当前项目目录结构:
2157
- 1. 项目根目录: /tmp/projects/mathweb
2158
- 2. 项目子目录/文件列表(类似tree 命令输出)
2159
- flask/
2160
- app.py
2161
- templates/
2162
- index.html
2163
- static/
2164
- style.css
2165
-
2166
- 用户需求: Change get_factorial() to use math.factorial
2167
-
2168
- 回答: To make this change we need to modify `/tmp/projects/mathweb/flask/app.py` to:
2169
-
2170
- 1. Import the math package.
2171
- 2. Remove the existing factorial() function.
2172
- 3. Update get_factorial() to call math.factorial instead.
2173
-
2174
- Here are the *SEARCH/REPLACE* blocks:
2175
-
2176
- {{ fence_0 }}python
2177
- ##File: /tmp/projects/mathweb/flask/app.py
2178
- <<<<<<< SEARCH
2179
- from flask import Flask
2180
- =======
2181
- import math
2182
- from flask import Flask
2183
- >>>>>>> REPLACE
2184
- {{ fence_1 }}
2185
-
2186
- {{ fence_0 }}python
2187
- ##File: /tmp/projects/mathweb/flask/app.py
2188
- <<<<<<< SEARCH
2189
- def factorial(n):
2190
- "compute factorial"
2191
-
2192
- if n == 0:
2193
- return 1
2194
- else:
2195
- return n * factorial(n-1)
2196
-
2197
- =======
2198
- >>>>>>> REPLACE
2199
- {{ fence_1 }}
2200
-
2201
- {{ fence_0 }}python
2202
- ##File: /tmp/projects/mathweb/flask/app.py
2203
- <<<<<<< SEARCH
2204
- return str(factorial(n))
2205
- =======
2206
- return str(math.factorial(n))
2207
- >>>>>>> REPLACE
2208
- {{ fence_1 }}
2209
-
2210
- 用户需求: Refactor hello() into its own file.
2211
-
2212
- 回答:To make this change we need to modify `main.py` and make a new file `hello.py`:
2213
-
2214
- 1. Make a new hello.py file with hello() in it.
2215
- 2. Remove hello() from main.py and replace it with an import.
2216
-
2217
- Here are the *SEARCH/REPLACE* blocks:
2218
-
2219
-
2220
- {{ fence_0 }}python
2221
- ##File: /tmp/projects/mathweb/hello.py
2222
- <<<<<<< SEARCH
2223
- =======
2224
- def hello():
2225
- "print a greeting"
2226
-
2227
- print("hello")
2228
- >>>>>>> REPLACE
2229
- {{ fence_1 }}
2230
-
2231
- {{ fence_0 }}python
2232
- ##File: /tmp/projects/mathweb/main.py
2233
- <<<<<<< SEARCH
2234
- def hello():
2235
- "print a greeting"
2236
-
2237
- print("hello")
2238
- =======
2239
- from hello import hello
2240
- >>>>>>> REPLACE
2241
- {{ fence_1 }}
2242
-
2243
- 现在让我们开始一个新的任务:
2244
-
2245
- {%- if structure %}
2246
- {{ structure }}
2247
- {%- endif %}
2248
-
2249
- {%- if content %}
2250
- 下面是一些文件路径以及每个文件对应的源码:
2251
- <files>
2252
- {{ content }}
2253
- </files>
2254
- {%- endif %}
2255
-
2256
- {%- if context %}
2257
- <extra_context>
2258
- {{ context }}
2259
- </extra_context>
2260
- {%- endif %}
2261
-
2262
- 下面是用户的需求:
2263
-
2264
- {{ instruction }}
2265
-
2266
- 每次生成一个文件的*SEARCH/REPLACE* blocks,然后询问我是否继续,当我回复继续,
2267
- 继续生成下一个文件的*SEARCH/REPLACE* blocks。当没有后续任务时,请回复 "__完成__" 或者 "__EOF__"。
2268
- """
2269
-
2270
- def multi_round_run(self, query: str, source_content: str, max_steps: int = 3) -> CodeGenerateResult:
2271
- init_prompt = ''
2272
- if self.args.template == "common":
2273
- init_prompt = self.multi_round_instruction.prompt(
2274
- instruction=query, content=source_content, context=self.args.context
2275
- )
2276
- elif self.args.template == "auto_implement":
2277
- init_prompt = self.auto_implement_function.prompt(
2278
- instruction=query, content=source_content
2279
- )
2280
-
2281
- with open(self.args.target_file, "w") as file:
2282
- file.write(init_prompt)
2283
-
2284
- results = []
2285
- conversations = [{"role": "user", "content": init_prompt}]
2286
-
2287
- code_llm = self.llms[0]
2288
- v = code_llm.chat_ai(conversations=conversations)
2289
- results.append(v.output)
2290
-
2291
- conversations.append({"role": "assistant", "content": v.output})
2292
-
2293
- if "__完成__" in v.output or "/done" in v.output or "__EOF__" in v.output:
2294
- return CodeGenerateResult(contents=["\n\n".join(results)], conversations=[conversations])
2295
-
2296
- current_step = 0
2297
-
2298
- while current_step < max_steps:
2299
- conversations.append({"role": "user", "content": "继续"})
2300
-
2301
- with open(self.args.target_file, "w") as file:
2302
- file.write("继续")
2303
-
2304
- t = code_llm.chat_ai(conversations=conversations)
2305
-
2306
- results.append(t.output)
2307
- conversations.append({"role": "assistant", "content": t.output})
2308
- current_step += 1
2309
-
2310
- if "__完成__" in t.output or "/done" in t.output or "__EOF__" in t.output:
2311
- return CodeGenerateResult(contents=["\n\n".join(results)], conversations=[conversations])
2312
-
2313
- return CodeGenerateResult(contents=["\n\n".join(results)], conversations=[conversations])
2314
-
2315
-
2316
- class CodeModificationRanker:
2317
- def __init__(self, llm: AutoLLM):
2318
- self.llm = llm
2319
- self.llm.setup_default_model_name(memory["conf"]["current_code_model"])
2320
- self.args = args
2321
- self.llms = [self.llm]
2322
-
2323
- @prompt()
2324
- def _rank_modifications(self, s: CodeGenerateResult) -> str:
2325
- """
2326
- 对一组代码修改进行质量评估并排序。
2327
-
2328
- 下面是修改需求:
2329
-
2330
- <edit_requirement>
2331
- {{ s.conversations[0][-2]["content"] }}
2332
- </edit_requirement>
2333
-
2334
- 下面是相应的代码修改:
2335
- {% for content in s.contents %}
2336
- <edit_block id="{{ loop.index0 }}">
2337
- {{content}}
2338
- </edit_block>
2339
- {% endfor %}
2340
-
2341
- 请输出如下格式的评估结果,只包含 JSON 数据:
2342
-
2343
- ```json
2344
- {
2345
- "rank_result": [id1, id2, id3] // id 为 edit_block 的 id,按质量从高到低排序
2346
- }
2347
- ```
2348
-
2349
- 注意:
2350
- 1. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
2351
- """
2352
-
2353
- def rank_modifications(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
2354
- import time
2355
- from collections import defaultdict
2356
-
2357
- start_time = time.time()
2358
- logger.info(f"开始对 {len(generate_result.contents)} 个候选结果进行排序")
2359
-
2360
- try:
2361
- results = []
2362
- for llm in self.llms:
2363
- v = self._rank_modifications.with_llm(llm).with_return_type(RankResult).run(generate_result)
2364
- results.append(v.rank_result)
2365
-
2366
- if not results:
2367
- raise Exception("All ranking requests failed")
2368
-
2369
- # 计算每个候选人的分数
2370
- candidate_scores = defaultdict(float)
2371
- for rank_result in results:
2372
- for idx, candidate_id in enumerate(rank_result):
2373
- # Score is 1/(position + 1) since position starts from 0
2374
- candidate_scores[candidate_id] += 1.0 / (idx + 1)
2375
- # 按分数降序对候选人进行排序
2376
- sorted_candidates = sorted(candidate_scores.keys(),
2377
- key=lambda x: candidate_scores[x],
2378
- reverse=True)
2379
-
2380
- elapsed = time.time() - start_time
2381
- score_details = ", ".join([f"candidate {i}: {candidate_scores[i]:.2f}" for i in sorted_candidates])
2382
- logger.info(
2383
- f"排序完成,耗时 {elapsed:.2f} 秒,最佳候选索引: {sorted_candidates[0]},评分详情: {score_details}"
2384
- )
2385
-
2386
- rerank_contents = [generate_result.contents[i] for i in sorted_candidates]
2387
- rerank_conversations = [generate_result.conversations[i] for i in sorted_candidates]
2388
-
2389
- return CodeGenerateResult(contents=rerank_contents, conversations=rerank_conversations)
2390
-
2391
- except Exception as e:
2392
- logger.error(f"排序过程失败: {str(e)}")
2393
- logger.debug(traceback.format_exc())
2394
- elapsed = time.time() - start_time
2395
- logger.warning(f"排序失败,耗时 {elapsed:.2f} 秒,将使用原始顺序")
2396
- return generate_result
2397
-
2398
-
2399
- class TextSimilarity:
2400
- """
2401
- 找到 text_b 中与 text_a 最相似的部分(滑动窗口)
2402
- 返回相似度分数和最相似的文本片段
2403
- """
2404
-
2405
- def __init__(self, text_a, text_b):
2406
- self.text_a = text_a
2407
- self.text_b = text_b
2408
- self.lines_a = self._split_into_lines(text_a)
2409
- self.lines_b = self._split_into_lines(text_b)
2410
- self.m = len(self.lines_a)
2411
- self.n = len(self.lines_b)
2412
-
2413
- @staticmethod
2414
- def _split_into_lines(text):
2415
- return text.splitlines()
2416
-
2417
- @staticmethod
2418
- def _levenshtein_ratio(s1, s2):
2419
- return SequenceMatcher(None, s1, s2).ratio()
2420
-
2421
- def get_best_matching_window(self):
2422
- best_similarity = 0
2423
- best_window = []
2424
-
2425
- for i in range(self.n - self.m + 1): # 滑动窗口
2426
- window_b = self.lines_b[i:i + self.m]
2427
- similarity = self._levenshtein_ratio("\n".join(self.lines_a), "\n".join(window_b))
2428
-
2429
- if similarity > best_similarity:
2430
- best_similarity = similarity
2431
- best_window = window_b
2432
-
2433
- return best_similarity, "\n".join(best_window)
2434
-
2435
-
2436
- class CodeAutoMergeEditBlock:
2437
- def __init__(self, llm: AutoLLM, fence_0: str = "```", fence_1: str = "```"):
2438
- self.llm = llm
2439
- self.llm.setup_default_model_name(memory["conf"]["current_code_model"])
2440
- self.args = args
2441
- self.fence_0 = fence_0
2442
- self.fence_1 = fence_1
2443
-
2444
- @staticmethod
2445
- def run_pylint(code: str) -> tuple[bool, str]:
2446
- """
2447
- --disable=all 禁用所有 Pylint 的检查规则
2448
- --enable=E0001,W0311,W0312 启用指定的 Pylint 检查规则,
2449
- E0001:语法错误(Syntax Error),
2450
- W0311:代码缩进使用了 Tab 而不是空格(Bad indentation)
2451
- W0312:代码缩进不一致(Mixed indentation)
2452
- :param code:
2453
- :return:
2454
- """
2455
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file:
2456
- temp_file.write(code)
2457
- temp_file_path = temp_file.name
2458
-
2459
- try:
2460
- result = subprocess.run(
2461
- ["pylint", "--disable=all", "--enable=E0001,W0311,W0312", temp_file_path,],
2462
- capture_output=True,
2463
- text=True,
2464
- check=False,
2465
- )
2466
- os.unlink(temp_file_path)
2467
- if result.returncode != 0:
2468
- error_message = result.stdout.strip() or result.stderr.strip()
2469
- logger.warning(f"Pylint 检查代码失败: {error_message}")
2470
- return False, error_message
2471
- return True, ""
2472
- except subprocess.CalledProcessError as e:
2473
- error_message = f"运行 Pylint 时发生错误: {str(e)}"
2474
- logger.error(error_message)
2475
- os.unlink(temp_file_path)
2476
- return False, error_message
2477
-
2478
- def parse_whole_text(self, text: str) -> List[PathAndCode]:
2479
- """
2480
- 从文本中抽取如下格式代码(two_line_mode):
2481
-
2482
- ```python
2483
- ##File: /project/path/src/autocoder/index/index.py
2484
- <<<<<<< SEARCH
2485
- =======
2486
- >>>>>>> REPLACE
2487
- ```
2488
-
2489
- 或者 (one_line_mode)
2490
-
2491
- ```python:/project/path/src/autocoder/index/index.py
2492
- <<<<<<< SEARCH
2493
- =======
2494
- >>>>>>> REPLACE
2495
- ```
2496
- """
2497
- HEAD = "<<<<<<< SEARCH"
2498
- DIVIDER = "======="
2499
- UPDATED = ">>>>>>> REPLACE"
2500
- lines = text.split("\n")
2501
- lines_len = len(lines)
2502
- start_marker_count = 0
2503
- block = []
2504
- path_and_code_list = []
2505
- # two_line_mode or one_line_mode
2506
- current_editblock_mode = "two_line_mode"
2507
- current_editblock_path = None
2508
-
2509
- def guard(_index):
2510
- return _index + 1 < lines_len
2511
-
2512
- def start_marker(_line, _index):
2513
- nonlocal current_editblock_mode
2514
- nonlocal current_editblock_path
2515
- if _line.startswith(self.fence_0) and guard(_index) and ":" in _line and lines[_index + 1].startswith(HEAD):
2516
- current_editblock_mode = "one_line_mode"
2517
- current_editblock_path = _line.split(":", 1)[1].strip()
2518
- return True
2519
- if _line.startswith(self.fence_0) and guard(_index) and lines[_index + 1].startswith("##File:"):
2520
- current_editblock_mode = "two_line_mode"
2521
- current_editblock_path = None
2522
- return True
2523
- return False
2524
-
2525
- def end_marker(_line, _index):
2526
- return _line.startswith(self.fence_1) and UPDATED in lines[_index - 1]
2527
-
2528
- for index, line in enumerate(lines):
2529
- if start_marker(line, index) and start_marker_count == 0:
2530
- start_marker_count += 1
2531
- elif end_marker(line, index) and start_marker_count == 1:
2532
- start_marker_count -= 1
2533
- if block:
2534
- if current_editblock_mode == "two_line_mode":
2535
- path = block[0].split(":", 1)[1].strip()
2536
- content = "\n".join(block[1:])
2537
- else:
2538
- path = current_editblock_path
2539
- content = "\n".join(block)
2540
- block = []
2541
- path_and_code_list.append(PathAndCode(path=path, content=content))
2542
- elif start_marker_count > 0:
2543
- block.append(line)
2544
-
2545
- return path_and_code_list
2546
-
2547
- def get_edits(self, content: str):
2548
- edits = self.parse_whole_text(content)
2549
- HEAD = "<<<<<<< SEARCH"
2550
- DIVIDER = "======="
2551
- UPDATED = ">>>>>>> REPLACE"
2552
- result = []
2553
- for edit in edits:
2554
- heads = []
2555
- updates = []
2556
- c = edit.content
2557
- in_head = False
2558
- in_updated = False
2559
- for line in c.splitlines():
2560
- if line.strip() == HEAD:
2561
- in_head = True
2562
- continue
2563
- if line.strip() == DIVIDER:
2564
- in_head = False
2565
- in_updated = True
2566
- continue
2567
- if line.strip() == UPDATED:
2568
- in_head = False
2569
- in_updated = False
2570
- continue
2571
- if in_head:
2572
- heads.append(line)
2573
- if in_updated:
2574
- updates.append(line)
2575
- result.append((edit.path, "\n".join(heads), "\n".join(updates)))
2576
- return result
2577
-
2578
- @prompt()
2579
- def git_require_msg(self, source_dir: str, error: str) -> str:
2580
- """
2581
- auto_merge only works for git repositories.
2582
-
2583
- Try to use git init in the source directory.
2584
-
2585
- ```shell
2586
- cd {{ source_dir }}
2587
- git init .
2588
- ```
2589
-
2590
- Then try to run auto-coder again.
2591
- Error: {{ error }}
2592
- """
2593
-
2594
- def _merge_code_without_effect(self, content: str) -> MergeCodeWithoutEffect:
2595
- """
2596
- 合并代码时不会产生任何副作用,例如 Git 操作、代码检查或文件写入。
2597
- 返回一个元组,包含:
2598
- - 成功合并的代码块的列表,每个元素是一个 (file_path, new_content) 元组,
2599
- 其中 file_path 是文件路径,new_content 是合并后的新内容。
2600
- - 合并失败的代码块的列表,每个元素是一个 (file_path, head, update) 元组,
2601
- 其中:file_path 是文件路径,head 是原始内容,update 是尝试合并的内容。
2602
- """
2603
- codes = self.get_edits(content)
2604
- file_content_mapping = {}
2605
- failed_blocks = []
2606
-
2607
- for block in codes:
2608
- file_path, head, update = block
2609
- if not os.path.exists(file_path):
2610
- file_content_mapping[file_path] = update
2611
- else:
2612
- if file_path not in file_content_mapping:
2613
- with open(file_path, "r") as f:
2614
- temp = f.read()
2615
- file_content_mapping[file_path] = temp
2616
- existing_content = file_content_mapping[file_path]
2617
-
2618
- # First try exact match
2619
- new_content = (
2620
- existing_content.replace(head, update, 1)
2621
- if head
2622
- else existing_content + "\n" + update
2623
- )
2624
-
2625
- # If exact match fails, try similarity match
2626
- if new_content == existing_content and head:
2627
- similarity, best_window = TextSimilarity(
2628
- head, existing_content
2629
- ).get_best_matching_window()
2630
- if similarity > self.args.editblock_similarity:
2631
- new_content = existing_content.replace(
2632
- best_window, update, 1
2633
- )
2634
-
2635
- if new_content != existing_content:
2636
- file_content_mapping[file_path] = new_content
2637
- else:
2638
- failed_blocks.append((file_path, head, update))
2639
- return MergeCodeWithoutEffect(
2640
- success_blocks=[(path, content) for path, content in file_content_mapping.items()],
2641
- failed_blocks=failed_blocks
2642
- )
2643
-
2644
- def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
2645
- """ 选择最佳代码 """
2646
- if len(generate_result.contents) == 1: # 仅一份代码立即返回
2647
- logger.info("仅有一个候选结果,跳过排序")
2648
- return generate_result
2649
-
2650
- ranker = CodeModificationRanker(self.llm)
2651
- ranked_result = ranker.rank_modifications(generate_result)
2652
- # 过滤掉包含失败块的内容
2653
- for content, conversations in zip(ranked_result.contents, ranked_result.conversations):
2654
- merge_result = self._merge_code_without_effect(content)
2655
- if not merge_result.failed_blocks:
2656
- return CodeGenerateResult(contents=[content], conversations=[conversations])
2657
- # 如果所有内容都包含失败块,则返回第一个
2658
- return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
2659
-
2660
- def _merge_code(self, content: str, force_skip_git: bool = False):
2661
- file_content = open(self.args.file).read()
2662
- md5 = hashlib.md5(file_content.encode("utf-8")).hexdigest()
2663
- file_name = os.path.basename(self.args.file)
2664
-
2665
- codes = self.get_edits(content)
2666
- changes_to_make = []
2667
- changes_made = False
2668
- unmerged_blocks = []
2669
- merged_blocks = []
2670
-
2671
- # First, check if there are any changes to be made
2672
- file_content_mapping = {}
2673
- for block in codes:
2674
- file_path, head, update = block
2675
- if not os.path.exists(file_path):
2676
- changes_to_make.append((file_path, None, update))
2677
- file_content_mapping[file_path] = update
2678
- merged_blocks.append((file_path, "", update, 1))
2679
- changes_made = True
2680
- else:
2681
- if file_path not in file_content_mapping:
2682
- with open(file_path, "r") as f:
2683
- temp = f.read()
2684
- file_content_mapping[file_path] = temp
2685
- existing_content = file_content_mapping[file_path]
2686
- new_content = (
2687
- existing_content.replace(head, update, 1)
2688
- if head
2689
- else existing_content + "\n" + update
2690
- )
2691
- if new_content != existing_content:
2692
- changes_to_make.append(
2693
- (file_path, existing_content, new_content))
2694
- file_content_mapping[file_path] = new_content
2695
- merged_blocks.append((file_path, head, update, 1))
2696
- changes_made = True
2697
- else:
2698
- # If the SEARCH BLOCK is not found exactly, then try to use
2699
- # the similarity ratio to find the best matching block
2700
- similarity, best_window = TextSimilarity(head, existing_content).get_best_matching_window()
2701
- if similarity > self.args.editblock_similarity: # 相似性比较
2702
- new_content = existing_content.replace(
2703
- best_window, update, 1)
2704
- if new_content != existing_content:
2705
- changes_to_make.append(
2706
- (file_path, existing_content, new_content)
2707
- )
2708
- file_content_mapping[file_path] = new_content
2709
- merged_blocks.append(
2710
- (file_path, head, update, similarity))
2711
- changes_made = True
2712
- else:
2713
- unmerged_blocks.append((file_path, head, update, similarity))
2714
-
2715
- if unmerged_blocks:
2716
- if self.args.request_id and not self.args.skip_events:
2717
- # collect unmerged blocks
2718
- event_data = []
2719
- for file_path, head, update, similarity in unmerged_blocks:
2720
- event_data.append(
2721
- {
2722
- "file_path": file_path,
2723
- "head": head,
2724
- "update": update,
2725
- "similarity": similarity,
2726
- }
2727
- )
2728
- return
2729
- logger.warning(f"发现 {len(unmerged_blocks)} 个未合并的代码块,更改将不会应用,请手动检查这些代码块后重试。")
2730
- self._print_unmerged_blocks(unmerged_blocks)
2731
- return
2732
-
2733
- # lint check
2734
- for file_path, new_content in file_content_mapping.items():
2735
- if file_path.endswith(".py"):
2736
- pylint_passed, error_message = self.run_pylint(new_content)
2737
- if not pylint_passed:
2738
- logger.warning(f"代码文件 {file_path} 的 Pylint 检查未通过,本次更改未应用。错误信息: {error_message}")
2739
-
2740
- if changes_made and not force_skip_git and not self.args.skip_commit:
2741
- try:
2742
- commit_changes(self.args.source_dir, f"auto_coder_pre_{file_name}_{md5}")
2743
- except Exception as e:
2744
- logger.error(
2745
- self.git_require_msg(
2746
- source_dir=self.args.source_dir, error=str(e))
2747
- )
2748
- return
2749
- # Now, apply the changes
2750
- for file_path, new_content in file_content_mapping.items():
2751
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
2752
- with open(file_path, "w") as f:
2753
- f.write(new_content)
2754
-
2755
- if self.args.request_id and not self.args.skip_events:
2756
- # collect modified files
2757
- event_data = []
2758
- for code in merged_blocks:
2759
- file_path, head, update, similarity = code
2760
- event_data.append(
2761
- {
2762
- "file_path": file_path,
2763
- "head": head,
2764
- "update": update,
2765
- "similarity": similarity,
2766
- }
2767
- )
2768
-
2769
- if changes_made:
2770
- if not force_skip_git and not self.args.skip_commit:
2771
- try:
2772
- commit_result = commit_changes(self.args.source_dir, f"auto_coder_{file_name}_{md5}")
2773
- git_print_commit_info(commit_result=commit_result)
2774
- except Exception as e:
2775
- logger.error(
2776
- self.git_require_msg(
2777
- source_dir=self.args.source_dir, error=str(e)
2778
- )
2779
- )
2780
- logger.info(
2781
- f"已在 {len(file_content_mapping.keys())} 个文件中合并更改,"
2782
- f"完成 {len(changes_to_make)}/{len(codes)} 个代码块。"
2783
- )
2784
- else:
2785
- logger.warning("未对任何文件进行更改。")
2786
-
2787
- def merge_code(self, generate_result: CodeGenerateResult, force_skip_git: bool = False):
2788
- result = self.choose_best_choice(generate_result)
2789
- self._merge_code(result.contents[0], force_skip_git)
2790
- return result
2791
-
2792
- @staticmethod
2793
- def _print_unmerged_blocks(unmerged_blocks: List[tuple]):
2794
- console.print(f"\n[bold red]未合并的代码块:[/bold red]")
2795
- for file_path, head, update, similarity in unmerged_blocks:
2796
- console.print(f"\n[bold blue]文件:[/bold blue] {file_path}")
2797
- console.print(
2798
- f"\n[bold green]搜索代码块(相似度:{similarity}):[/bold green]")
2799
- syntax = Syntax(head, "python", theme="monokai", line_numbers=True)
2800
- console.print(Panel(syntax, expand=False))
2801
- console.print("\n[bold yellow]替换代码块:[/bold yellow]")
2802
- syntax = Syntax(update, "python", theme="monokai",
2803
- line_numbers=True)
2804
- console.print(Panel(syntax, expand=False))
2805
- console.print(f"\n[bold red]未合并的代码块总数: {len(unmerged_blocks)}[/bold red]")
2806
-
2807
-
2808
- class BaseAction:
2809
- @staticmethod
2810
- def _get_content_length(content: str) -> int:
2811
- return len(content)
2812
-
2813
-
2814
- class ActionPyProject(BaseAction):
2815
- def __init__(self, llm: Optional[AutoLLM] = None) -> None:
2816
- self.args = args
2817
- self.llm = llm
2818
- self.pp = None
2819
-
2820
- def run(self):
2821
- if self.args.project_type != "py":
2822
- return False
2823
- pp = PyProject(llm=self.llm, args=args)
2824
- self.pp = pp
2825
- pp.run()
2826
- source_code = pp.output()
2827
- if self.llm:
2828
- source_code = build_index_and_filter_files(llm=self.llm, sources=pp.sources)
2829
- self.process_content(source_code)
2830
- return True
2831
-
2832
- def process_content(self, content: str):
2833
- # args = self.args
2834
- if self.args.execute and self.llm:
2835
- content_length = self._get_content_length(content)
2836
- if content_length > self.args.model_max_input_length:
2837
- logger.warning(
2838
- f"发送给模型的内容长度为 {content_length} 个 token(可能收集了过多文件),"
2839
- f"已超过最大输入长度限制 {self.args.model_max_input_length}。"
2840
- )
2841
-
2842
- if args.execute:
2843
- logger.info("正在自动生成代码...")
2844
- start_time = time.time()
2845
- # diff, strict_diff, editblock 是代码自动生成或合并的不同策略, 通常用于处理代码的变更或生成
2846
- # diff 模式,基于差异生成代码,生成最小的变更集,适用于局部优化,代码重构
2847
- # strict_diff 模式,严格验证差异,确保生成的代码符合规则,适用于代码审查,自动化测试
2848
- # editblock 模式,基于编辑块生成代码,支持较大范围的修改,适用于代码重构,功能扩展
2849
- if args.auto_merge == "editblock":
2850
- generate = CodeAutoGenerateEditBlock(llm=self.llm, action=self)
2851
- else:
2852
- generate = None
2853
-
2854
- if self.args.enable_multi_round_generate:
2855
- generate_result = generate.multi_round_run(query=args.query, source_content=content)
2856
- else:
2857
- generate_result = generate.single_round_run(query=args.query, source_content=content)
2858
- logger.info(f"代码生成完成,耗时 {time.time() - start_time:.2f} 秒")
2859
-
2860
- if args.auto_merge:
2861
- logger.info("正在自动合并代码...")
2862
- if args.auto_merge == "editblock":
2863
- code_merge = CodeAutoMergeEditBlock(llm=self.llm)
2864
- merge_result = code_merge.merge_code(generate_result=generate_result)
2865
- else:
2866
- merge_result = None
2867
-
2868
- content = merge_result.contents[0]
2869
- else:
2870
- content = generate_result.contents[0]
2871
- with open(args.target_file, "w") as file:
2872
- file.write(content)
2873
-
2874
-
2875
- class ActionSuffixProject(BaseAction):
2876
- def __init__(self, llm: Optional[AutoLLM] = None) -> None:
2877
- self.args = args
2878
- self.llm = llm
2879
- self.pp = None
2880
-
2881
- def run(self):
2882
- pp = SuffixProject(llm=self.llm, args=args)
2883
- self.pp = pp
2884
- pp.run()
2885
- source_code = pp.output()
2886
- if self.llm:
2887
- source_code = build_index_and_filter_files(llm=self.llm, sources=pp.sources)
2888
- self.process_content(source_code)
2889
-
2890
- def process_content(self, content: str):
2891
- if self.args.execute and self.llm:
2892
- content_length = self._get_content_length(content)
2893
- if content_length > self.args.model_max_input_length:
2894
- logger.warning(
2895
- f"发送给模型的内容长度为 {content_length} 个 token(可能收集了过多文件),"
2896
- f"已超过最大输入长度限制 {self.args.model_max_input_length}。"
2897
- )
2898
-
2899
- if args.execute:
2900
- logger.info("正在自动生成代码...")
2901
- start_time = time.time()
2902
- # diff, strict_diff, editblock 是代码自动生成或合并的不同策略, 通常用于处理代码的变更或生成
2903
- # diff 模式,基于差异生成代码,生成最小的变更集,适用于局部优化,代码重构
2904
- # strict_diff 模式,严格验证差异,确保生成的代码符合规则,适用于代码审查,自动化测试
2905
- # editblock 模式,基于编辑块生成代码,支持较大范围的修改,适用于代码重构,功能扩展
2906
- if args.auto_merge == "editblock":
2907
- generate = CodeAutoGenerateEditBlock(llm=self.llm, action=self)
2908
- else:
2909
- generate = None
2910
-
2911
- if self.args.enable_multi_round_generate:
2912
- generate_result = generate.multi_round_run(query=args.query, source_content=content)
2913
- else:
2914
- generate_result = generate.single_round_run(query=args.query, source_content=content)
2915
- logger.info(f"代码生成完成,耗时 {time.time() - start_time:.2f} 秒")
2916
-
2917
- if args.auto_merge:
2918
- logger.info("正在自动合并代码...")
2919
- if args.auto_merge == "editblock":
2920
- code_merge = CodeAutoMergeEditBlock(llm=self.llm)
2921
- merge_result = code_merge.merge_code(generate_result=generate_result)
2922
- else:
2923
- merge_result = None
2924
-
2925
- content = merge_result.contents[0]
2926
- else:
2927
- content = generate_result.contents[0]
2928
- with open(args.target_file, "w") as file:
2929
- file.write(content)
2930
-
2931
-
2932
- class Dispacher:
2933
- def __init__(self, llm: Optional[AutoLLM] = None):
2934
- self.args = args
2935
- self.llm = llm
2936
-
2937
- def dispach(self):
2938
- actions = [
2939
- ActionPyProject(llm=self.llm),
2940
- ActionSuffixProject(llm=self.llm)
2941
- ]
2942
- for action in actions:
2943
- if action.run():
2944
- return
2945
-
2946
-
2947
1253
  def prepare_chat_yaml():
2948
1254
  # auto_coder_main(["next", "chat_action"]) 准备聊天 yaml 文件
2949
1255
  actions_dir = os.path.join(args.source_dir, "actions")
@@ -3003,12 +1309,11 @@ def coding(query: str, llm: AutoLLM):
3003
1309
  if latest_yaml_file:
3004
1310
  yaml_config = {
3005
1311
  "include_file": ["./base/base.yml"],
3006
- "auto_merge": conf.get("auto_merge", "editblock"),
3007
- "human_as_model": conf.get("human_as_model", "false") == "true",
3008
1312
  "skip_build_index": conf.get("skip_build_index", "true") == "true",
3009
1313
  "skip_confirm": conf.get("skip_confirm", "true") == "true",
3010
- "silence": conf.get("silence", "true") == "true",
3011
- "include_project_structure": conf.get("include_project_structure", "true") == "true",
1314
+ "chat_model": conf.get("chat_model", ""),
1315
+ "code_model": conf.get("code_model", ""),
1316
+ "auto_merge": conf.get("auto_merge", "editblock"),
3012
1317
  "context": ""
3013
1318
  }
3014
1319
 
@@ -3071,7 +1376,7 @@ def coding(query: str, llm: AutoLLM):
3071
1376
  f.write(yaml_content)
3072
1377
  convert_yaml_to_config(execute_file)
3073
1378
 
3074
- dispacher = Dispacher(llm)
1379
+ dispacher = Dispacher(args=args, llm=llm)
3075
1380
  dispacher.dispach()
3076
1381
  else:
3077
1382
  logger.warning("创建新的 YAML 文件失败。")
@@ -3133,28 +1438,26 @@ def commit_info(query: str, llm: AutoLLM):
3133
1438
  prepare_chat_yaml() # 复制上一个序号的 yaml 文件, 生成一个新的聊天 yaml 文件
3134
1439
 
3135
1440
  latest_yaml_file = get_last_yaml_file(os.path.join(args.source_dir, "actions"))
3136
-
3137
- conf = memory.get("conf", {})
3138
- current_files = memory["current_files"]["files"]
3139
1441
  execute_file = None
3140
1442
 
3141
1443
  if latest_yaml_file:
3142
1444
  try:
3143
1445
  execute_file = os.path.join(args.source_dir, "actions", latest_yaml_file)
1446
+ conf = memory.get("conf", {})
3144
1447
  yaml_config = {
3145
1448
  "include_file": ["./base/base.yml"],
3146
- "auto_merge": conf.get("auto_merge", "editblock"),
3147
- "human_as_model": conf.get("human_as_model", "false") == "true",
3148
1449
  "skip_build_index": conf.get("skip_build_index", "true") == "true",
3149
1450
  "skip_confirm": conf.get("skip_confirm", "true") == "true",
3150
- "silence": conf.get("silence", "true") == "true",
3151
- "include_project_structure": conf.get("include_project_structure", "true") == "true",
1451
+ "chat_model": conf.get("chat_model", ""),
1452
+ "code_model": conf.get("code_model", ""),
1453
+ "auto_merge": conf.get("auto_merge", "editblock")
3152
1454
  }
3153
1455
  for key, value in conf.items():
3154
1456
  converted_value = convert_config_value(key, value)
3155
1457
  if converted_value is not None:
3156
1458
  yaml_config[key] = converted_value
3157
1459
 
1460
+ current_files = memory["current_files"]["files"]
3158
1461
  yaml_config["urls"] = current_files
3159
1462
 
3160
1463
  # 临时保存yaml文件,然后读取yaml文件,更新args
@@ -3169,7 +1472,7 @@ def commit_info(query: str, llm: AutoLLM):
3169
1472
 
3170
1473
  # commit_message = ""
3171
1474
  commit_llm = llm
3172
- commit_llm.setup_default_model_name(memory["conf"]["current_chat_model"])
1475
+ commit_llm.setup_default_model_name(args.chat_model)
3173
1476
  console.print(f"Commit 信息生成中...", style="yellow")
3174
1477
 
3175
1478
  try:
@@ -3239,20 +1542,7 @@ def _generate_shell_script(user_input: str) -> str:
3239
1542
 
3240
1543
 
3241
1544
  def generate_shell_command(input_text: str, llm: AutoLLM) -> str | None:
3242
- conf = memory.get("conf", {})
3243
- yaml_config = {
3244
- "include_file": ["./base/base.yml"],
3245
- }
3246
- if "model" in conf:
3247
- yaml_config["model"] = conf["model"]
3248
- yaml_config["query"] = input_text
3249
-
3250
- yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
3251
-
3252
- execute_file = os.path.join(args.source_dir, "actions", f"{uuid.uuid4()}.yml")
3253
-
3254
- with open(os.path.join(execute_file), "w") as f:
3255
- f.write(yaml_content)
1545
+ update_config_to_args(query=input_text, delete_execute_file=True)
3256
1546
 
3257
1547
  try:
3258
1548
  console.print(
@@ -3262,7 +1552,7 @@ def generate_shell_command(input_text: str, llm: AutoLLM) -> str | None:
3262
1552
  border_style="green",
3263
1553
  )
3264
1554
  )
3265
- llm.setup_default_model_name(memory["conf"]["current_code_model"])
1555
+ llm.setup_default_model_name(args.code_model)
3266
1556
  result = _generate_shell_script.with_llm(llm).run(user_input=input_text)
3267
1557
  shell_script = extract_code(result.output)[0][1]
3268
1558
  console.print(
@@ -3274,7 +1564,8 @@ def generate_shell_command(input_text: str, llm: AutoLLM) -> str | None:
3274
1564
  )
3275
1565
  return shell_script
3276
1566
  finally:
3277
- os.remove(execute_file)
1567
+ pass
1568
+ # os.remove(execute_file)
3278
1569
 
3279
1570
 
3280
1571
  def execute_shell_command(command: str):
@@ -3884,10 +2175,10 @@ def manage_models(models_args, models_data, llm: AutoLLM):
3884
2175
  logger.info(f"正在卸载 {remove_model_name} 模型")
3885
2176
  if llm.get_sub_client(remove_model_name):
3886
2177
  llm.remove_sub_client(remove_model_name)
3887
- if remove_model_name == memory["conf"]["current_chat_model"]:
3888
- logger.warning(f"当前首选 Chat 模型 {remove_model_name} 已被删除, 请立即 /conf current_chat_model: 调整 !!!")
3889
- if remove_model_name == memory["conf"]["current_code_model"]:
3890
- logger.warning(f"当前首选 Code 模型 {remove_model_name} 已被删除, 请立即 /conf current_code_model: 调整 !!!")
2178
+ if remove_model_name == memory["conf"]["chat_model"]:
2179
+ logger.warning(f"当前首选 Chat 模型 {remove_model_name} 已被删除, 请立即 /conf chat_model: 调整 !!!")
2180
+ if remove_model_name == memory["conf"]["code_model"]:
2181
+ logger.warning(f"当前首选 Code 模型 {remove_model_name} 已被删除, 请立即 /conf code_model: 调整 !!!")
3891
2182
 
3892
2183
 
3893
2184
  def configure_project_model():
@@ -3958,58 +2249,71 @@ def configure_project_model():
3958
2249
  )
3959
2250
 
3960
2251
 
3961
- def new_project(query, llm):
3962
- console.print(f"正在基于你的需求 {query} 构建项目 ...", style="bold green")
3963
- env_info = detect_env()
3964
- project = BuildNewProject(args=args, llm=llm,
3965
- chat_model=memory["conf"]["current_chat_model"],
3966
- code_model=memory["conf"]["current_code_model"])
3967
-
3968
- console.print(f"正在完善项目需求 ...", style="bold green")
3969
-
3970
- information = project.build_project_information(query, env_info, args.project_type)
3971
- if not information:
3972
- raise Exception(f"项目需求未正常生成 .")
3973
-
3974
- table = Table(title=f"{query}")
3975
- table.add_column("需求说明", style="cyan")
3976
- table.add_row(f"{information[:50]}...")
3977
- console.print(table)
3978
-
3979
- console.print(f"正在完善项目架构 ...", style="bold green")
3980
- architecture = project.build_project_architecture(query, env_info, args.project_type, information)
3981
-
3982
- console.print(f"正在构建项目索引 ...", style="bold green")
3983
- index_file_list = project.build_project_index(query, env_info, args.project_type, information, architecture)
3984
-
3985
- table = Table(title=f"索引列表")
3986
- table.add_column("路径", style="cyan")
3987
- table.add_column("用途", style="cyan")
3988
- for index_file in index_file_list.file_list:
3989
- table.add_row(index_file.file_path, index_file.purpose)
3990
- console.print(table)
3991
-
3992
- for index_file in index_file_list.file_list:
3993
- full_path = os.path.join(args.source_dir, index_file.file_path)
3994
-
3995
- # 获取目录路径
3996
- full_dir_path = os.path.dirname(full_path)
3997
- if not os.path.exists(full_dir_path):
3998
- os.makedirs(full_dir_path)
3999
-
4000
- console.print(f"正在编码: {full_path} ...", style="bold green")
4001
- code = project.build_single_code(query, env_info, args.project_type, information, architecture, index_file)
4002
-
4003
- with open(full_path, "w") as fp:
4004
- fp.write(code)
4005
-
4006
- # 生成 readme
4007
- readme_context = information + architecture
4008
- readme_path = os.path.join(args.source_dir, "README.md")
4009
- with open(readme_path, "w") as fp:
4010
- fp.write(readme_context)
4011
-
4012
- console.print(f"项目构建完成", style="bold green")
2252
+ # def new_project(query, llm):
2253
+ # console.print(f"正在基于你的需求 {query} 构建项目 ...", style="bold green")
2254
+ # env_info = detect_env()
2255
+ # project = BuildNewProject(args=args, llm=llm,
2256
+ # chat_model=memory["conf"]["chat_model"],
2257
+ # code_model=memory["conf"]["code_model"])
2258
+ #
2259
+ # console.print(f"正在完善项目需求 ...", style="bold green")
2260
+ #
2261
+ # information = project.build_project_information(query, env_info, args.project_type)
2262
+ # if not information:
2263
+ # raise Exception(f"项目需求未正常生成 .")
2264
+ #
2265
+ # table = Table(title=f"{query}")
2266
+ # table.add_column("需求说明", style="cyan")
2267
+ # table.add_row(f"{information[:50]}...")
2268
+ # console.print(table)
2269
+ #
2270
+ # console.print(f"正在完善项目架构 ...", style="bold green")
2271
+ # architecture = project.build_project_architecture(query, env_info, args.project_type, information)
2272
+ #
2273
+ # console.print(f"正在构建项目索引 ...", style="bold green")
2274
+ # index_file_list = project.build_project_index(query, env_info, args.project_type, information, architecture)
2275
+ #
2276
+ # table = Table(title=f"索引列表")
2277
+ # table.add_column("路径", style="cyan")
2278
+ # table.add_column("用途", style="cyan")
2279
+ # for index_file in index_file_list.file_list:
2280
+ # table.add_row(index_file.file_path, index_file.purpose)
2281
+ # console.print(table)
2282
+ #
2283
+ # for index_file in index_file_list.file_list:
2284
+ # full_path = os.path.join(args.source_dir, index_file.file_path)
2285
+ #
2286
+ # # 获取目录路径
2287
+ # full_dir_path = os.path.dirname(full_path)
2288
+ # if not os.path.exists(full_dir_path):
2289
+ # os.makedirs(full_dir_path)
2290
+ #
2291
+ # console.print(f"正在编码: {full_path} ...", style="bold green")
2292
+ # code = project.build_single_code(query, env_info, args.project_type, information, architecture, index_file)
2293
+ #
2294
+ # with open(full_path, "w") as fp:
2295
+ # fp.write(code)
2296
+ #
2297
+ # # 生成 readme
2298
+ # readme_context = information + architecture
2299
+ # readme_path = os.path.join(args.source_dir, "README.md")
2300
+ # with open(readme_path, "w") as fp:
2301
+ # fp.write(readme_context)
2302
+ #
2303
+ # console.print(f"项目构建完成", style="bold green")
2304
+
2305
+
2306
+ def is_old_version():
2307
+ """
2308
+ __version__ = "0.1.26" 开始使用兼容 AutoCoder 的 chat_model, code_model 参数
2309
+ 不再使用 current_chat_model 和 current_chat_model
2310
+ """
2311
+ if 'current_chat_model' in memory['conf'] and 'current_code_model' in memory['conf']:
2312
+ logger.warning(f"您当前版本使用的版本偏低, 正在进行配置兼容性处理")
2313
+ memory['conf']['chat_model'] = memory['conf']['current_chat_model']
2314
+ memory['conf']['code_model'] = memory['conf']['current_code_model']
2315
+ del memory['conf']['current_chat_model']
2316
+ del memory['conf']['current_code_model']
4013
2317
 
4014
2318
 
4015
2319
  def main():
@@ -4021,14 +2325,15 @@ def main():
4021
2325
  initialize_system()
4022
2326
 
4023
2327
  load_memory()
2328
+ is_old_version()
4024
2329
 
4025
2330
  if len(memory["models"]) == 0:
4026
2331
  _model_pass = input(f" 是否跳过模型配置(y/n): ").strip().lower()
4027
2332
  if _model_pass == "n":
4028
2333
  m1, m2, m3, m4 = configure_project_model()
4029
2334
  print_status(f"正在更新缓存...", "warning")
4030
- memory["conf"]["current_chat_model"] = m1
4031
- memory["conf"]["current_code_model"] = m1
2335
+ memory["conf"]["chat_model"] = m1
2336
+ memory["conf"]["code_model"] = m1
4032
2337
  memory["models"][m1] = {"base_url": m3, "api_key": m4, "model": m2}
4033
2338
  print_status(f"供应商配置已成功完成!后续你可以使用 /models 命令, 查看, 新增和修改所有模型", "success")
4034
2339
  else:
@@ -4046,10 +2351,10 @@ def main():
4046
2351
 
4047
2352
  print_status("初始化完成。", "success")
4048
2353
 
4049
- if memory["conf"]["current_chat_model"] not in memory["models"].keys():
4050
- print_status("首选 Chat 模型与部署模型不一致, 请使用 /conf current_chat_model:xxx 设置", "error")
4051
- if memory["conf"]["current_code_model"] not in memory["models"].keys():
4052
- print_status("首选 Code 模型与部署模型不一致, 请使用 /conf current_code_model:xxx 设置", "error")
2354
+ if memory["conf"]["chat_model"] not in memory["models"].keys():
2355
+ print_status("首选 Chat 模型与部署模型不一致, 请使用 /conf chat_model:xxx 设置", "error")
2356
+ if memory["conf"]["code_model"] not in memory["models"].keys():
2357
+ print_status("首选 Code 模型与部署模型不一致, 请使用 /conf code_model:xxx 设置", "error")
4053
2358
 
4054
2359
  MODES = {
4055
2360
  "normal": "正常模式",
@@ -4180,12 +2485,12 @@ def main():
4180
2485
  print("\033[91mPlease enter your request.\033[0m")
4181
2486
  continue
4182
2487
  coding(query=query, llm=auto_llm)
4183
- elif user_input.startswith("/new"):
4184
- query = user_input[len("/new"):].strip()
4185
- if not query:
4186
- print("\033[91mPlease enter your request.\033[0m")
4187
- continue
4188
- new_project(query=query, llm=auto_llm)
2488
+ # elif user_input.startswith("/new"):
2489
+ # query = user_input[len("/new"):].strip()
2490
+ # if not query:
2491
+ # print("\033[91mPlease enter your request.\033[0m")
2492
+ # continue
2493
+ # new_project(query=query, llm=auto_llm)
4189
2494
  elif user_input.startswith("/chat"):
4190
2495
  query = user_input[len("/chat"):].strip()
4191
2496
  if not query: