autocoder-nano 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autocoder_nano/agent/agent_base.py +376 -63
- autocoder_nano/auto_coder_nano.py +147 -1842
- autocoder_nano/edit/__init__.py +20 -0
- autocoder_nano/edit/actions.py +136 -0
- autocoder_nano/edit/code/__init__.py +0 -0
- autocoder_nano/edit/code/generate_editblock.py +403 -0
- autocoder_nano/edit/code/merge_editblock.py +418 -0
- autocoder_nano/edit/code/modification_ranker.py +90 -0
- autocoder_nano/edit/text.py +38 -0
- autocoder_nano/index/__init__.py +0 -0
- autocoder_nano/index/entry.py +166 -0
- autocoder_nano/index/index_manager.py +410 -0
- autocoder_nano/index/symbols_utils.py +43 -0
- autocoder_nano/llm_types.py +12 -8
- autocoder_nano/version.py +1 -1
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/METADATA +1 -1
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/RECORD +21 -10
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/LICENSE +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/WHEEL +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/entry_points.txt +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from typing import List, Dict
|
4
|
+
|
5
|
+
from loguru import logger
|
6
|
+
from rich.console import Console
|
7
|
+
from rich.table import Table
|
8
|
+
|
9
|
+
from autocoder_nano.index.index_manager import IndexManager
|
10
|
+
from autocoder_nano.llm_types import SourceCode, TargetFile, VerifyFileRelevance, AutoCoderArgs
|
11
|
+
from autocoder_nano.llm_client import AutoLLM
|
12
|
+
|
13
|
+
console = Console()
|
14
|
+
|
15
|
+
|
16
|
+
def build_index_and_filter_files(args: AutoCoderArgs, llm: AutoLLM, sources: List[SourceCode]) -> str:
|
17
|
+
def get_file_path(_file_path):
|
18
|
+
if _file_path.startswith("##"):
|
19
|
+
return _file_path.strip()[2:]
|
20
|
+
return _file_path
|
21
|
+
|
22
|
+
final_files: Dict[str, TargetFile] = {}
|
23
|
+
logger.info("第一阶段:处理 REST/RAG/Search 资源...")
|
24
|
+
for source in sources:
|
25
|
+
if source.tag in ["REST", "RAG", "SEARCH"]:
|
26
|
+
final_files[get_file_path(source.module_name)] = TargetFile(
|
27
|
+
file_path=source.module_name, reason="Rest/Rag/Search"
|
28
|
+
)
|
29
|
+
|
30
|
+
if not args.skip_build_index and llm:
|
31
|
+
logger.info("第二阶段:为所有文件构建索引...")
|
32
|
+
index_manager = IndexManager(args=args, llm=llm, source_codes=sources)
|
33
|
+
index_data = index_manager.build_index()
|
34
|
+
indexed_files_count = len(index_data) if index_data else 0
|
35
|
+
logger.info(f"总索引文件数: {indexed_files_count}")
|
36
|
+
|
37
|
+
if not args.skip_filter_index and args.index_filter_level >= 1:
|
38
|
+
logger.info("第三阶段:执行 Level 1 过滤(基于查询) ...")
|
39
|
+
target_files = index_manager.get_target_files_by_query(args.query)
|
40
|
+
if target_files:
|
41
|
+
for file in target_files.file_list:
|
42
|
+
file_path = file.file_path.strip()
|
43
|
+
final_files[get_file_path(file_path)] = file
|
44
|
+
|
45
|
+
if target_files is not None and args.index_filter_level >= 2:
|
46
|
+
logger.info("第四阶段:执行 Level 2 过滤(基于相关文件)...")
|
47
|
+
related_files = index_manager.get_related_files(
|
48
|
+
[file.file_path for file in target_files.file_list]
|
49
|
+
)
|
50
|
+
if related_files is not None:
|
51
|
+
for file in related_files.file_list:
|
52
|
+
file_path = file.file_path.strip()
|
53
|
+
final_files[get_file_path(file_path)] = file
|
54
|
+
|
55
|
+
# 如果 Level 1 filtering 和 Level 2 filtering 都未获取路径,则使用全部文件
|
56
|
+
if not final_files:
|
57
|
+
logger.warning("Level 1, Level 2 过滤未找到相关文件, 将使用所有文件 ...")
|
58
|
+
for source in sources:
|
59
|
+
final_files[get_file_path(source.module_name)] = TargetFile(
|
60
|
+
file_path=source.module_name,
|
61
|
+
reason="No related files found, use all files",
|
62
|
+
)
|
63
|
+
|
64
|
+
logger.info("第五阶段:执行相关性验证 ...")
|
65
|
+
verified_files = {}
|
66
|
+
temp_files = list(final_files.values())
|
67
|
+
verification_results = []
|
68
|
+
|
69
|
+
def _print_verification_results(results):
|
70
|
+
table = Table(title="文件相关性验证结果", expand=True, show_lines=True)
|
71
|
+
table.add_column("文件路径", style="cyan", no_wrap=True)
|
72
|
+
table.add_column("得分", justify="right", style="green")
|
73
|
+
table.add_column("状态", style="yellow")
|
74
|
+
table.add_column("原因/错误")
|
75
|
+
if result:
|
76
|
+
for _file_path, _score, _status, _reason in results:
|
77
|
+
table.add_row(_file_path,
|
78
|
+
str(_score) if _score is not None else "N/A", _status, _reason)
|
79
|
+
console.print(table)
|
80
|
+
|
81
|
+
def _verify_single_file(single_file: TargetFile):
|
82
|
+
for _source in sources:
|
83
|
+
if _source.module_name == single_file.file_path:
|
84
|
+
file_content = _source.source_code
|
85
|
+
try:
|
86
|
+
_result = index_manager.verify_file_relevance.with_llm(llm).with_return_type(
|
87
|
+
VerifyFileRelevance).run(
|
88
|
+
file_content=file_content,
|
89
|
+
query=args.query
|
90
|
+
)
|
91
|
+
if _result.relevant_score >= args.verify_file_relevance_score:
|
92
|
+
verified_files[single_file.file_path] = TargetFile(
|
93
|
+
file_path=single_file.file_path,
|
94
|
+
reason=f"Score:{_result.relevant_score}, {_result.reason}"
|
95
|
+
)
|
96
|
+
return single_file.file_path, _result.relevant_score, "PASS", _result.reason
|
97
|
+
else:
|
98
|
+
return single_file.file_path, _result.relevant_score, "FAIL", _result.reason
|
99
|
+
except Exception as e:
|
100
|
+
error_msg = str(e)
|
101
|
+
verified_files[single_file.file_path] = TargetFile(
|
102
|
+
file_path=single_file.file_path,
|
103
|
+
reason=f"Verification failed: {error_msg}"
|
104
|
+
)
|
105
|
+
return single_file.file_path, None, "ERROR", error_msg
|
106
|
+
return
|
107
|
+
|
108
|
+
for pending_verify_file in temp_files:
|
109
|
+
result = _verify_single_file(pending_verify_file)
|
110
|
+
if result:
|
111
|
+
verification_results.append(result)
|
112
|
+
time.sleep(args.anti_quota_limit)
|
113
|
+
|
114
|
+
_print_verification_results(verification_results)
|
115
|
+
# Keep all files, not just verified ones
|
116
|
+
final_files = verified_files
|
117
|
+
|
118
|
+
logger.info("第六阶段:筛选文件并应用限制条件 ...")
|
119
|
+
if args.index_filter_file_num > 0:
|
120
|
+
logger.info(f"从 {len(final_files)} 个文件中获取前 {args.index_filter_file_num} 个文件(Limit)")
|
121
|
+
final_filenames = [file.file_path for file in final_files.values()]
|
122
|
+
if not final_filenames:
|
123
|
+
logger.warning("未找到目标文件,你可能需要重新编写查询并重试.")
|
124
|
+
if args.index_filter_file_num > 0:
|
125
|
+
final_filenames = final_filenames[: args.index_filter_file_num]
|
126
|
+
|
127
|
+
def _shorten_path(path: str, keep_levels: int = 3) -> str:
|
128
|
+
"""
|
129
|
+
优化长路径显示,保留最后指定层级
|
130
|
+
示例:/a/b/c/d/e/f.py -> .../c/d/e/f.py
|
131
|
+
"""
|
132
|
+
parts = path.split(os.sep)
|
133
|
+
if len(parts) > keep_levels:
|
134
|
+
return ".../" + os.sep.join(parts[-keep_levels:])
|
135
|
+
return path
|
136
|
+
|
137
|
+
def _print_selected(data):
|
138
|
+
table = Table(title="代码上下文文件", expand=True, show_lines=True)
|
139
|
+
table.add_column("文件路径", style="cyan")
|
140
|
+
table.add_column("原因", style="cyan")
|
141
|
+
for _file, _reason in data:
|
142
|
+
# 路径截取优化:保留最后 3 级路径
|
143
|
+
_processed_path = _shorten_path(_file, keep_levels=3)
|
144
|
+
table.add_row(_processed_path, _reason)
|
145
|
+
console.print(table)
|
146
|
+
|
147
|
+
logger.info("第七阶段:准备最终输出 ...")
|
148
|
+
_print_selected(
|
149
|
+
[
|
150
|
+
(file.file_path, file.reason)
|
151
|
+
for file in final_files.values()
|
152
|
+
if file.file_path in final_filenames
|
153
|
+
]
|
154
|
+
)
|
155
|
+
result_source_code = ""
|
156
|
+
depulicated_sources = set()
|
157
|
+
|
158
|
+
for file in sources:
|
159
|
+
if file.module_name in final_filenames:
|
160
|
+
if file.module_name in depulicated_sources:
|
161
|
+
continue
|
162
|
+
depulicated_sources.add(file.module_name)
|
163
|
+
result_source_code += f"##File: {file.module_name}\n"
|
164
|
+
result_source_code += f"{file.source_code}\n\n"
|
165
|
+
|
166
|
+
return result_source_code
|
@@ -0,0 +1,410 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from typing import List, Optional
|
6
|
+
|
7
|
+
from loguru import logger
|
8
|
+
|
9
|
+
from autocoder_nano.index.symbols_utils import extract_symbols, symbols_info_to_str
|
10
|
+
from autocoder_nano.llm_client import AutoLLM
|
11
|
+
from autocoder_nano.llm_prompt import prompt
|
12
|
+
from autocoder_nano.llm_types import SourceCode, AutoCoderArgs, IndexItem, SymbolType, FileList
|
13
|
+
|
14
|
+
|
15
|
+
class IndexManager:
|
16
|
+
def __init__(self, args: AutoCoderArgs, source_codes: List[SourceCode], llm: AutoLLM = None):
|
17
|
+
self.args = args
|
18
|
+
self.sources = source_codes
|
19
|
+
self.source_dir = args.source_dir
|
20
|
+
self.index_dir = os.path.join(self.source_dir, ".auto-coder")
|
21
|
+
self.index_file = os.path.join(self.index_dir, "index.json")
|
22
|
+
self.llm = llm
|
23
|
+
self.llm.setup_default_model_name(args.chat_model)
|
24
|
+
self.max_input_length = args.model_max_input_length # 模型输入最大长度
|
25
|
+
# 使用 time.sleep(self.anti_quota_limit) 防止超过 API 频率限制
|
26
|
+
self.anti_quota_limit = args.anti_quota_limit
|
27
|
+
# 如果索引目录不存在,则创建它
|
28
|
+
if not os.path.exists(self.index_dir):
|
29
|
+
os.makedirs(self.index_dir)
|
30
|
+
|
31
|
+
def build_index(self):
|
32
|
+
""" 构建或更新索引,使用多线程处理多个文件,并将更新后的索引数据写入文件 """
|
33
|
+
if os.path.exists(self.index_file):
|
34
|
+
with open(self.index_file, "r") as file: # 读缓存
|
35
|
+
index_data = json.load(file)
|
36
|
+
else: # 首次 build index
|
37
|
+
logger.info("首次生成索引.")
|
38
|
+
index_data = {}
|
39
|
+
|
40
|
+
@prompt()
|
41
|
+
def error_message(source_dir: str, file_path: str):
|
42
|
+
"""
|
43
|
+
The source_dir is different from the path in index file (e.g. file_path:{{ file_path }} source_dir:{{
|
44
|
+
source_dir }}). You may need to replace the prefix with the source_dir in the index file or Just delete
|
45
|
+
the index file to rebuild it.
|
46
|
+
"""
|
47
|
+
|
48
|
+
for item in index_data.keys():
|
49
|
+
if not item.startswith(self.source_dir):
|
50
|
+
logger.warning(error_message(source_dir=self.source_dir, file_path=item))
|
51
|
+
break
|
52
|
+
|
53
|
+
updated_sources = []
|
54
|
+
wait_to_build_files = []
|
55
|
+
for source in self.sources:
|
56
|
+
source_code = source.source_code
|
57
|
+
md5 = hashlib.md5(source_code.encode("utf-8")).hexdigest()
|
58
|
+
if source.module_name not in index_data or index_data[source.module_name]["md5"] != md5:
|
59
|
+
wait_to_build_files.append(source)
|
60
|
+
counter = 0
|
61
|
+
num_files = len(wait_to_build_files)
|
62
|
+
total_files = len(self.sources)
|
63
|
+
logger.info(f"总文件数: {total_files}, 需要索引文件数: {num_files}")
|
64
|
+
|
65
|
+
for source in wait_to_build_files:
|
66
|
+
build_result = self.build_index_for_single_source(source)
|
67
|
+
if build_result is not None:
|
68
|
+
counter += 1
|
69
|
+
logger.info(f"正在构建索引:{counter}/{num_files}...")
|
70
|
+
module_name = build_result["module_name"]
|
71
|
+
index_data[module_name] = build_result
|
72
|
+
updated_sources.append(module_name)
|
73
|
+
if updated_sources:
|
74
|
+
with open(self.index_file, "w") as fp:
|
75
|
+
json_str = json.dumps(index_data, indent=2, ensure_ascii=False)
|
76
|
+
fp.write(json_str)
|
77
|
+
return index_data
|
78
|
+
|
79
|
+
def split_text_into_chunks(self, text):
|
80
|
+
""" 文本分块,将大文本分割成适合 LLM 处理的小块 """
|
81
|
+
lines = text.split("\n")
|
82
|
+
chunks = []
|
83
|
+
current_chunk = []
|
84
|
+
current_length = 0
|
85
|
+
for line in lines:
|
86
|
+
if current_length + len(line) + 1 <= self.max_input_length:
|
87
|
+
current_chunk.append(line)
|
88
|
+
current_length += len(line) + 1
|
89
|
+
else:
|
90
|
+
chunks.append("\n".join(current_chunk))
|
91
|
+
current_chunk = [line]
|
92
|
+
current_length = len(line) + 1
|
93
|
+
if current_chunk:
|
94
|
+
chunks.append("\n".join(current_chunk))
|
95
|
+
return chunks
|
96
|
+
|
97
|
+
@prompt()
|
98
|
+
def get_all_file_symbols(self, path: str, code: str) -> str:
|
99
|
+
"""
|
100
|
+
你的目标是从给定的代码中获取代码里的符号,需要获取的符号类型包括:
|
101
|
+
|
102
|
+
1. 函数
|
103
|
+
2. 类
|
104
|
+
3. 变量
|
105
|
+
4. 所有导入语句
|
106
|
+
|
107
|
+
如果没有任何符号,返回空字符串就行。
|
108
|
+
如果有符号,按如下格式返回:
|
109
|
+
|
110
|
+
```
|
111
|
+
{符号类型}: {符号名称}, {符号名称}, ...
|
112
|
+
```
|
113
|
+
|
114
|
+
注意:
|
115
|
+
1. 直接输出结果,不要尝试使用任何代码
|
116
|
+
2. 不要分析代码的内容和目的
|
117
|
+
3. 用途的长度不能超过100字符
|
118
|
+
4. 导入语句的分隔符为^^
|
119
|
+
|
120
|
+
下面是一段示例:
|
121
|
+
|
122
|
+
## 输入
|
123
|
+
下列是文件 /test.py 的源码:
|
124
|
+
|
125
|
+
import os
|
126
|
+
import time
|
127
|
+
from loguru import logger
|
128
|
+
import byzerllm
|
129
|
+
|
130
|
+
a = ""
|
131
|
+
|
132
|
+
@byzerllm.prompt(render="jinja")
|
133
|
+
def auto_implement_function_template(instruction:str, content:str)->str:
|
134
|
+
|
135
|
+
## 输出
|
136
|
+
用途:主要用于提供自动实现函数模板的功能。
|
137
|
+
函数:auto_implement_function_template
|
138
|
+
变量:a
|
139
|
+
类:
|
140
|
+
导入语句:import os^^import time^^from loguru import logger^^import byzerllm
|
141
|
+
|
142
|
+
现在,让我们开始一个新的任务:
|
143
|
+
|
144
|
+
## 输入
|
145
|
+
下列是文件 {{ path }} 的源码:
|
146
|
+
|
147
|
+
{{ code }}
|
148
|
+
|
149
|
+
## 输出
|
150
|
+
"""
|
151
|
+
|
152
|
+
def build_index_for_single_source(self, source: SourceCode):
|
153
|
+
""" 处理单个源文件,提取符号信息并存储元数据 """
|
154
|
+
file_path = source.module_name
|
155
|
+
if not os.path.exists(file_path): # 过滤不存在的文件
|
156
|
+
return None
|
157
|
+
|
158
|
+
ext = os.path.splitext(file_path)[1].lower()
|
159
|
+
if ext in [".md", ".html", ".txt", ".doc", ".pdf"]: # 过滤文档文件
|
160
|
+
return None
|
161
|
+
|
162
|
+
if source.source_code.strip() == "":
|
163
|
+
return None
|
164
|
+
|
165
|
+
md5 = hashlib.md5(source.source_code.encode("utf-8")).hexdigest()
|
166
|
+
|
167
|
+
try:
|
168
|
+
start_time = time.monotonic()
|
169
|
+
source_code = source.source_code
|
170
|
+
if len(source.source_code) > self.max_input_length:
|
171
|
+
logger.warning(
|
172
|
+
f"警告[构建索引]: 源代码({source.module_name})长度过长 "
|
173
|
+
f"({len(source.source_code)}) > 模型最大输入长度({self.max_input_length}),"
|
174
|
+
f"正在分割为多个块..."
|
175
|
+
)
|
176
|
+
chunks = self.split_text_into_chunks(source_code)
|
177
|
+
symbols_list = []
|
178
|
+
for chunk in chunks:
|
179
|
+
chunk_symbols = self.get_all_file_symbols.with_llm(self.llm).run(source.module_name, chunk)
|
180
|
+
time.sleep(self.anti_quota_limit)
|
181
|
+
symbols_list.append(chunk_symbols.output)
|
182
|
+
symbols = "\n".join(symbols_list)
|
183
|
+
else:
|
184
|
+
single_symbols = self.get_all_file_symbols.with_llm(self.llm).run(source.module_name, source_code)
|
185
|
+
symbols = single_symbols.output
|
186
|
+
time.sleep(self.anti_quota_limit)
|
187
|
+
|
188
|
+
logger.info(f"解析并更新索引:文件 {file_path}(MD5: {md5}),耗时 {time.monotonic() - start_time:.2f} 秒")
|
189
|
+
except Exception as e:
|
190
|
+
logger.warning(f"源文件 {file_path} 处理失败: {e}")
|
191
|
+
return None
|
192
|
+
|
193
|
+
return {
|
194
|
+
"module_name": source.module_name,
|
195
|
+
"symbols": symbols,
|
196
|
+
"last_modified": os.path.getmtime(file_path),
|
197
|
+
"md5": md5,
|
198
|
+
}
|
199
|
+
|
200
|
+
@prompt()
|
201
|
+
def _get_target_files_by_query(self, indices: str, query: str) -> str:
|
202
|
+
"""
|
203
|
+
下面是已知文件以及对应的符号信息:
|
204
|
+
|
205
|
+
{{ indices }}
|
206
|
+
|
207
|
+
用户的问题是:
|
208
|
+
|
209
|
+
{{ query }}
|
210
|
+
|
211
|
+
现在,请根据用户的问题以及前面的文件和符号信息,寻找相关文件路径。返回结果按如下格式:
|
212
|
+
|
213
|
+
```json
|
214
|
+
{
|
215
|
+
"file_list": [
|
216
|
+
{
|
217
|
+
"file_path": "path/to/file.py",
|
218
|
+
"reason": "The reason why the file is the target file"
|
219
|
+
},
|
220
|
+
{
|
221
|
+
"file_path": "path/to/file.py",
|
222
|
+
"reason": "The reason why the file is the target file"
|
223
|
+
}
|
224
|
+
]
|
225
|
+
}
|
226
|
+
```
|
227
|
+
|
228
|
+
如果没有找到,返回如下 json 即可:
|
229
|
+
|
230
|
+
```json
|
231
|
+
{"file_list": []}
|
232
|
+
```
|
233
|
+
|
234
|
+
请严格遵循以下步骤:
|
235
|
+
|
236
|
+
1. 识别特殊标记:
|
237
|
+
- 查找query中的 `@` 符号,它后面的内容是用户关注的文件路径。
|
238
|
+
- 查找query中的 `@@` 符号,它后面的内容是用户关注的符号(如函数名、类名、变量名)。
|
239
|
+
|
240
|
+
2. 匹配文件路径:
|
241
|
+
- 对于 `@` 标记,在indices中查找包含该路径的所有文件。
|
242
|
+
- 路径匹配应该是部分匹配,因为用户可能只提供了路径的一部分。
|
243
|
+
|
244
|
+
3. 匹配符号:
|
245
|
+
- 对于 `@@` 标记,在indices中所有文件的符号信息中查找该符号。
|
246
|
+
- 检查函数、类、变量等所有符号类型。
|
247
|
+
|
248
|
+
4. 分析依赖关系:
|
249
|
+
- 利用 "导入语句" 信息确定文件间的依赖关系。
|
250
|
+
- 如果找到了相关文件,也包括与之直接相关的依赖文件。
|
251
|
+
|
252
|
+
5. 考虑文件用途:
|
253
|
+
- 使用每个文件的 "用途" 信息来判断其与查询的相关性。
|
254
|
+
|
255
|
+
6. 请严格按格式要求返回结果,无需额外的说明
|
256
|
+
|
257
|
+
请确保结果的准确性和完整性,包括所有可能相关的文件。
|
258
|
+
"""
|
259
|
+
|
260
|
+
def read_index(self) -> List[IndexItem]:
|
261
|
+
""" 读取并解析索引文件,将其转换为 IndexItem 对象列表 """
|
262
|
+
if not os.path.exists(self.index_file):
|
263
|
+
return []
|
264
|
+
|
265
|
+
with open(self.index_file, "r") as file:
|
266
|
+
index_data = json.load(file)
|
267
|
+
|
268
|
+
index_items = []
|
269
|
+
for module_name, data in index_data.items():
|
270
|
+
index_item = IndexItem(
|
271
|
+
module_name=module_name,
|
272
|
+
symbols=data["symbols"],
|
273
|
+
last_modified=data["last_modified"],
|
274
|
+
md5=data["md5"]
|
275
|
+
)
|
276
|
+
index_items.append(index_item)
|
277
|
+
|
278
|
+
return index_items
|
279
|
+
|
280
|
+
def _get_meta_str(self, includes: Optional[List[SymbolType]] = None):
|
281
|
+
index_items = self.read_index()
|
282
|
+
current_chunk = []
|
283
|
+
for item in index_items:
|
284
|
+
symbols_str = item.symbols
|
285
|
+
if includes:
|
286
|
+
symbol_info = extract_symbols(symbols_str)
|
287
|
+
symbols_str = symbols_info_to_str(symbol_info, includes)
|
288
|
+
|
289
|
+
item_str = f"##{item.module_name}\n{symbols_str}\n\n"
|
290
|
+
if len(current_chunk) > self.args.filter_batch_size:
|
291
|
+
yield "".join(current_chunk)
|
292
|
+
current_chunk = [item_str]
|
293
|
+
else:
|
294
|
+
current_chunk.append(item_str)
|
295
|
+
if current_chunk:
|
296
|
+
yield "".join(current_chunk)
|
297
|
+
|
298
|
+
def get_target_files_by_query(self, query: str):
|
299
|
+
""" 根据查询条件查找相关文件,考虑不同过滤级别 """
|
300
|
+
all_results = []
|
301
|
+
completed = 0
|
302
|
+
total = 0
|
303
|
+
|
304
|
+
includes = None
|
305
|
+
if self.args.index_filter_level == 0:
|
306
|
+
includes = [SymbolType.USAGE]
|
307
|
+
if self.args.index_filter_level >= 1:
|
308
|
+
includes = None
|
309
|
+
|
310
|
+
for chunk in self._get_meta_str(includes=includes):
|
311
|
+
result = self._get_target_files_by_query.with_llm(self.llm).with_return_type(FileList).run(chunk, query)
|
312
|
+
if result is not None:
|
313
|
+
all_results.extend(result.file_list)
|
314
|
+
completed += 1
|
315
|
+
else:
|
316
|
+
logger.warning(f"无法找到分块的目标文件。原因可能是模型响应未返回 JSON 格式数据,或返回的 JSON 为空。")
|
317
|
+
total += 1
|
318
|
+
time.sleep(self.anti_quota_limit)
|
319
|
+
|
320
|
+
logger.info(f"已完成 {completed}/{total} 个分块(基于查询条件)")
|
321
|
+
all_results = list({file.file_path: file for file in all_results}.values())
|
322
|
+
if self.args.index_filter_file_num > 0:
|
323
|
+
limited_results = all_results[: self.args.index_filter_file_num]
|
324
|
+
return FileList(file_list=limited_results)
|
325
|
+
return FileList(file_list=all_results)
|
326
|
+
|
327
|
+
@prompt()
|
328
|
+
def _get_related_files(self, indices: str, file_paths: str) -> str:
|
329
|
+
"""
|
330
|
+
下面是所有文件以及对应的符号信息:
|
331
|
+
|
332
|
+
{{ indices }}
|
333
|
+
|
334
|
+
请参考上面的信息,找到被下列文件使用或者引用到的文件列表:
|
335
|
+
|
336
|
+
{{ file_paths }}
|
337
|
+
|
338
|
+
请按如下格式进行输出:
|
339
|
+
|
340
|
+
```json
|
341
|
+
{
|
342
|
+
"file_list": [
|
343
|
+
{
|
344
|
+
"file_path": "path/to/file.py",
|
345
|
+
"reason": "The reason why the file is the target file"
|
346
|
+
},
|
347
|
+
{
|
348
|
+
"file_path": "path/to/file.py",
|
349
|
+
"reason": "The reason why the file is the target file"
|
350
|
+
}
|
351
|
+
]
|
352
|
+
}
|
353
|
+
```
|
354
|
+
|
355
|
+
如果没有相关的文件,输出如下 json 即可:
|
356
|
+
|
357
|
+
```json
|
358
|
+
{"file_list": []}
|
359
|
+
```
|
360
|
+
|
361
|
+
注意,
|
362
|
+
1. 找到的文件名必须出现在上面的文件列表中
|
363
|
+
2. 原因控制在20字以内, 且使用中文
|
364
|
+
3. 请严格按格式要求返回结果,无需额外的说明
|
365
|
+
"""
|
366
|
+
|
367
|
+
def get_related_files(self, file_paths: List[str]):
|
368
|
+
""" 根据文件路径查询相关文件 """
|
369
|
+
all_results = []
|
370
|
+
|
371
|
+
completed = 0
|
372
|
+
total = 0
|
373
|
+
|
374
|
+
for chunk in self._get_meta_str():
|
375
|
+
result = self._get_related_files.with_llm(self.llm).with_return_type(
|
376
|
+
FileList).run(chunk, "\n".join(file_paths))
|
377
|
+
if result is not None:
|
378
|
+
all_results.extend(result.file_list)
|
379
|
+
completed += 1
|
380
|
+
else:
|
381
|
+
logger.warning(f"无法找到与分块相关的文件。原因可能是模型限制或查询条件与文件不匹配。")
|
382
|
+
total += 1
|
383
|
+
time.sleep(self.anti_quota_limit)
|
384
|
+
logger.info(f"已完成 {completed}/{total} 个分块(基于相关文件)")
|
385
|
+
all_results = list({file.file_path: file for file in all_results}.values())
|
386
|
+
return FileList(file_list=all_results)
|
387
|
+
|
388
|
+
@prompt()
|
389
|
+
def verify_file_relevance(self, file_content: str, query: str) -> str:
|
390
|
+
"""
|
391
|
+
请验证下面的文件内容是否与用户问题相关:
|
392
|
+
|
393
|
+
文件内容:
|
394
|
+
{{ file_content }}
|
395
|
+
|
396
|
+
用户问题:
|
397
|
+
{{ query }}
|
398
|
+
|
399
|
+
相关是指,需要依赖这个文件提供上下文,或者需要修改这个文件才能解决用户的问题。
|
400
|
+
请给出相应的可能性分数:0-10,并结合用户问题,理由控制在50字以内,并且使用中文。
|
401
|
+
请严格按格式要求返回结果。
|
402
|
+
格式如下:
|
403
|
+
|
404
|
+
```json
|
405
|
+
{
|
406
|
+
"relevant_score": 0-10,
|
407
|
+
"reason": "这是相关的原因..."
|
408
|
+
}
|
409
|
+
```
|
410
|
+
"""
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import re
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from autocoder_nano.llm_types import SymbolsInfo, SymbolType
|
5
|
+
|
6
|
+
|
7
|
+
def extract_symbols(text: str) -> SymbolsInfo:
|
8
|
+
patterns = {
|
9
|
+
"usage": r"用途:(.+)",
|
10
|
+
"functions": r"函数:(.+)",
|
11
|
+
"variables": r"变量:(.+)",
|
12
|
+
"classes": r"类:(.+)",
|
13
|
+
"import_statements": r"导入语句:(.+)",
|
14
|
+
}
|
15
|
+
|
16
|
+
info = SymbolsInfo()
|
17
|
+
for field, pattern in patterns.items():
|
18
|
+
match = re.search(pattern, text)
|
19
|
+
if match:
|
20
|
+
value = match.group(1).strip()
|
21
|
+
if field == "import_statements":
|
22
|
+
value = [v.strip() for v in value.split("^^")]
|
23
|
+
elif field == "functions" or field == "variables" or field == "classes":
|
24
|
+
value = [v.strip() for v in value.split(",")]
|
25
|
+
setattr(info, field, value)
|
26
|
+
|
27
|
+
return info
|
28
|
+
|
29
|
+
|
30
|
+
def symbols_info_to_str(info: SymbolsInfo, symbol_types: List[SymbolType]) -> str:
|
31
|
+
result = []
|
32
|
+
for symbol_type in symbol_types:
|
33
|
+
value = getattr(info, symbol_type.value)
|
34
|
+
if value:
|
35
|
+
if symbol_type == SymbolType.IMPORT_STATEMENTS:
|
36
|
+
value_str = "^^".join(value)
|
37
|
+
elif symbol_type in [SymbolType.FUNCTIONS, SymbolType.VARIABLES, SymbolType.CLASSES,]:
|
38
|
+
value_str = ",".join(value)
|
39
|
+
else:
|
40
|
+
value_str = value
|
41
|
+
result.append(f"{symbol_type.value}:{value_str}")
|
42
|
+
|
43
|
+
return "\n".join(result)
|
autocoder_nano/llm_types.py
CHANGED
@@ -30,9 +30,9 @@ class AutoCoderArgs(BaseModel):
|
|
30
30
|
context: Optional[str] = None #
|
31
31
|
human_as_model: Optional[bool] = False #
|
32
32
|
human_model_num: Optional[int] = 1 #
|
33
|
-
include_project_structure: Optional[bool] = False #
|
33
|
+
include_project_structure: Optional[bool] = False # 在生成代码的提示中是否包含项目目录结构
|
34
34
|
urls: Optional[Union[str, List[str]]] = "" # 一些文档的URL/路径,可以帮助模型了解你当前的工作
|
35
|
-
model: Optional[str] = "" # 您要驱动运行的模型
|
35
|
+
# model: Optional[str] = "" # 您要驱动运行的模型
|
36
36
|
model_max_input_length: Optional[int] = 6000 # 模型最大输入长度
|
37
37
|
skip_confirm: Optional[bool] = False
|
38
38
|
silence: Optional[bool] = False
|
@@ -66,12 +66,16 @@ class AutoCoderArgs(BaseModel):
|
|
66
66
|
# 模型相关参数
|
67
67
|
current_chat_model: Optional[str] = ""
|
68
68
|
current_code_model: Optional[str] = ""
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
69
|
+
model: Optional[str] = "" # 默认模型
|
70
|
+
chat_model: Optional[str] = "" # AI Chat交互模型
|
71
|
+
index_model: Optional[str] = "" # 代码索引生成模型
|
72
|
+
code_model: Optional[str] = "" # 编码模型
|
73
|
+
commit_model: Optional[str] = "" # Git Commit 模型
|
74
|
+
emb_model: Optional[str] = "" # RAG Emb 模型
|
75
|
+
recall_model: Optional[str] = "" # RAG 召回阶段模型
|
76
|
+
chunk_model: Optional[str] = "" # 段落重排序模型
|
77
|
+
qa_model: Optional[str] = "" # RAG 提问模型
|
78
|
+
vl_model: Optional[str] = "" # 多模态模型
|
75
79
|
|
76
80
|
class Config:
|
77
81
|
protected_namespaces = ()
|
autocoder_nano/version.py
CHANGED