auto-coder 0.1.263__py3-none-any.whl → 0.1.264__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.263.dist-info → auto_coder-0.1.264.dist-info}/METADATA +1 -1
- {auto_coder-0.1.263.dist-info → auto_coder-0.1.264.dist-info}/RECORD +24 -23
- autocoder/chat_auto_coder.py +53 -49
- autocoder/common/__init__.py +6 -0
- autocoder/common/auto_coder_lang.py +6 -2
- autocoder/common/code_auto_generate_diff.py +9 -9
- autocoder/common/code_auto_merge.py +23 -3
- autocoder/common/code_auto_merge_diff.py +28 -3
- autocoder/common/code_auto_merge_editblock.py +24 -4
- autocoder/common/code_auto_merge_strict_diff.py +23 -3
- autocoder/common/code_modification_ranker.py +65 -3
- autocoder/common/conf_validator.py +6 -0
- autocoder/common/context_pruner.py +305 -0
- autocoder/index/entry.py +8 -2
- autocoder/index/filter/normal_filter.py +13 -2
- autocoder/index/filter/quick_filter.py +127 -13
- autocoder/index/index.py +3 -2
- autocoder/utils/project_structure.py +258 -3
- autocoder/utils/thread_utils.py +6 -1
- autocoder/version.py +1 -1
- {auto_coder-0.1.263.dist-info → auto_coder-0.1.264.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.263.dist-info → auto_coder-0.1.264.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.263.dist-info → auto_coder-0.1.264.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.263.dist-info → auto_coder-0.1.264.dist-info}/top_level.txt +0 -0
|
@@ -164,15 +164,35 @@ class CodeAutoMergeEditBlock:
|
|
|
164
164
|
def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
|
|
165
165
|
if len(generate_result.contents) == 1:
|
|
166
166
|
return generate_result
|
|
167
|
-
|
|
167
|
+
|
|
168
|
+
merge_results = []
|
|
169
|
+
for content,conversations in zip(generate_result.contents,generate_result.conversations):
|
|
170
|
+
merge_result = self._merge_code_without_effect(content)
|
|
171
|
+
merge_results.append(merge_result)
|
|
172
|
+
|
|
173
|
+
# If all merge results are None, return first one
|
|
174
|
+
if all(len(result.failed_blocks) != 0 for result in merge_results):
|
|
175
|
+
self.printer.print_in_terminal("all_merge_results_failed")
|
|
176
|
+
return CodeGenerateResult(contents=[generate_result.contents[0]], conversations=[generate_result.conversations[0]])
|
|
177
|
+
|
|
178
|
+
# If only one merge result is not None, return that one
|
|
179
|
+
not_none_indices = [i for i, result in enumerate(merge_results) if len(result.failed_blocks) == 0]
|
|
180
|
+
if len(not_none_indices) == 1:
|
|
181
|
+
idx = not_none_indices[0]
|
|
182
|
+
self.printer.print_in_terminal("only_one_merge_result_success")
|
|
183
|
+
return CodeGenerateResult(contents=[generate_result.contents[idx]], conversations=[generate_result.conversations[idx]])
|
|
184
|
+
|
|
185
|
+
# 最后,如果有多个,那么根据质量排序再返回
|
|
168
186
|
ranker = CodeModificationRanker(self.llm, self.args)
|
|
169
|
-
ranked_result = ranker.rank_modifications(generate_result)
|
|
170
|
-
|
|
187
|
+
ranked_result = ranker.rank_modifications(generate_result,merge_results)
|
|
188
|
+
|
|
189
|
+
## 得到的结果,再做一次合并,第一个通过的返回 , 返回做合并有点重复低效,未来修改。
|
|
171
190
|
for content,conversations in zip(ranked_result.contents,ranked_result.conversations):
|
|
172
191
|
merge_result = self._merge_code_without_effect(content)
|
|
173
192
|
if not merge_result.failed_blocks:
|
|
174
193
|
return CodeGenerateResult(contents=[content], conversations=[conversations])
|
|
175
|
-
|
|
194
|
+
|
|
195
|
+
# 最后保底,但实际不会出现
|
|
176
196
|
return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
|
|
177
197
|
|
|
178
198
|
@byzerllm.prompt()
|
|
@@ -138,15 +138,35 @@ class CodeAutoMergeStrictDiff:
|
|
|
138
138
|
def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
|
|
139
139
|
if len(generate_result.contents) == 1:
|
|
140
140
|
return generate_result
|
|
141
|
+
|
|
142
|
+
merge_results = []
|
|
143
|
+
for content,conversations in zip(generate_result.contents,generate_result.conversations):
|
|
144
|
+
merge_result = self._merge_code_without_effect(content)
|
|
145
|
+
merge_results.append(merge_result)
|
|
146
|
+
|
|
147
|
+
# If all merge results are None, return first one
|
|
148
|
+
if all(len(result.failed_blocks) != 0 for result in merge_results):
|
|
149
|
+
self.printer.print_in_terminal("all_merge_results_failed")
|
|
150
|
+
return CodeGenerateResult(contents=[generate_result.contents[0]], conversations=[generate_result.conversations[0]])
|
|
151
|
+
|
|
152
|
+
# If only one merge result is not None, return that one
|
|
153
|
+
not_none_indices = [i for i, result in enumerate(merge_results) if len(result.failed_blocks) == 0]
|
|
154
|
+
if len(not_none_indices) == 1:
|
|
155
|
+
idx = not_none_indices[0]
|
|
156
|
+
self.printer.print_in_terminal("only_one_merge_result_success")
|
|
157
|
+
return CodeGenerateResult(contents=[generate_result.contents[idx]], conversations=[generate_result.conversations[idx]])
|
|
141
158
|
|
|
159
|
+
# 最后,如果有多个,那么根据质量排序再返回
|
|
142
160
|
ranker = CodeModificationRanker(self.llm, self.args)
|
|
143
|
-
ranked_result = ranker.rank_modifications(generate_result)
|
|
144
|
-
|
|
161
|
+
ranked_result = ranker.rank_modifications(generate_result,merge_results)
|
|
162
|
+
|
|
163
|
+
## 得到的结果,再做一次合并,第一个通过的返回 , 返回做合并有点重复低效,未来修改。
|
|
145
164
|
for content,conversations in zip(ranked_result.contents,ranked_result.conversations):
|
|
146
165
|
merge_result = self._merge_code_without_effect(content)
|
|
147
166
|
if not merge_result.failed_blocks:
|
|
148
167
|
return CodeGenerateResult(contents=[content], conversations=[conversations])
|
|
149
|
-
|
|
168
|
+
|
|
169
|
+
# 最后保底,但实际不会出现
|
|
150
170
|
return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
|
|
151
171
|
|
|
152
172
|
|
|
@@ -9,6 +9,8 @@ import traceback
|
|
|
9
9
|
from autocoder.common.utils_code_auto_generate import chat_with_continue
|
|
10
10
|
from byzerllm.utils.str2model import to_model
|
|
11
11
|
from autocoder.utils.llms import get_llm_names, get_model_info
|
|
12
|
+
from autocoder.common.types import CodeGenerateResult, MergeCodeWithoutEffect
|
|
13
|
+
import os
|
|
12
14
|
|
|
13
15
|
class RankResult(BaseModel):
|
|
14
16
|
rank_result: List[int]
|
|
@@ -51,12 +53,67 @@ class CodeModificationRanker:
|
|
|
51
53
|
}
|
|
52
54
|
```
|
|
53
55
|
|
|
54
|
-
注意:
|
|
56
|
+
注意:
|
|
55
57
|
1. id 为 edit_block 的 id,按质量从高到低排序,并且 id 必须是数字
|
|
56
58
|
2. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
|
|
57
59
|
'''
|
|
58
60
|
|
|
59
|
-
|
|
61
|
+
@byzerllm.prompt()
|
|
62
|
+
def _rank_modifications_with_merge_result(self, s: CodeGenerateResult,merge_results: List[MergeCodeWithoutEffect]) -> str:
|
|
63
|
+
'''
|
|
64
|
+
对一组代码修改进行质量评估并排序。
|
|
65
|
+
|
|
66
|
+
下面是修改需求:
|
|
67
|
+
|
|
68
|
+
<edit_requirement>
|
|
69
|
+
{{ s.conversations[0][-2]["content"] }}
|
|
70
|
+
</edit_requirement>
|
|
71
|
+
|
|
72
|
+
下面是相应的代码修改,如果Before 为空,那么表示是新增文件,如果After 为空,那么表示是删除文件,如果Before 和 After 都不为空,那么表示是修改文件:
|
|
73
|
+
{% for change in changes %}
|
|
74
|
+
<edit_file id="{{ loop.index0 }}">
|
|
75
|
+
{{change}}
|
|
76
|
+
</edit_file>
|
|
77
|
+
{% endfor %}
|
|
78
|
+
|
|
79
|
+
请输出如下格式的评估结果,只包含 JSON 数据:
|
|
80
|
+
|
|
81
|
+
```json
|
|
82
|
+
{
|
|
83
|
+
"rank_result": [id1, id2, id3]
|
|
84
|
+
}
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
注意:
|
|
88
|
+
1. 像python的缩进,前端诸如 reacjs,vue 的标签闭合匹配,这些很重要,需要在排序中作为重点考虑对象之一。
|
|
89
|
+
1. id 为 edit_file 的 id,按质量从高到低排序,并且 id 必须是数字
|
|
90
|
+
2. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
|
|
91
|
+
'''
|
|
92
|
+
changes = []
|
|
93
|
+
for merge_result in merge_results:
|
|
94
|
+
s = ""
|
|
95
|
+
for block in merge_result.success_blocks:
|
|
96
|
+
file_path,content = block
|
|
97
|
+
s += f"##File: {file_path}\n\n"
|
|
98
|
+
if not os.path.exists(file_path):
|
|
99
|
+
s += f"##Before: \n\n"
|
|
100
|
+
s += f"##File: {file_path}\n\n"
|
|
101
|
+
s += f"##After: \n\n"
|
|
102
|
+
s += content
|
|
103
|
+
else:
|
|
104
|
+
with open(file_path, "r",encoding="utf-8") as f:
|
|
105
|
+
original_content = f.read()
|
|
106
|
+
s += f"##Before: \n\n"
|
|
107
|
+
s += original_content
|
|
108
|
+
s += f"##File: {file_path}\n\n"
|
|
109
|
+
s += f"##After: \n\n"
|
|
110
|
+
s += content
|
|
111
|
+
changes.append(s)
|
|
112
|
+
return {
|
|
113
|
+
"changes": changes
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
def rank_modifications(self, generate_result: CodeGenerateResult, merge_result: List[MergeCodeWithoutEffect]) -> CodeGenerateResult:
|
|
60
117
|
import time
|
|
61
118
|
from collections import defaultdict
|
|
62
119
|
|
|
@@ -69,8 +126,13 @@ class CodeModificationRanker:
|
|
|
69
126
|
|
|
70
127
|
rank_times = self.args.rank_times_same_model
|
|
71
128
|
total_tasks = len(self.llms) * rank_times
|
|
129
|
+
if self.args.rank_strategy == "block":
|
|
130
|
+
query = self._rank_modifications.prompt(generate_result)
|
|
131
|
+
elif self.args.rank_strategy == "file":
|
|
132
|
+
query = self._rank_modifications_with_merge_result.prompt(generate_result, merge_result)
|
|
133
|
+
else:
|
|
134
|
+
raise Exception(f"Invalid rank strategy: {self.args.rank_strategy}")
|
|
72
135
|
|
|
73
|
-
query = self._rank_modifications.prompt(generate_result)
|
|
74
136
|
input_tokens_count = 0
|
|
75
137
|
generated_tokens_count = 0
|
|
76
138
|
try:
|
|
@@ -132,6 +132,12 @@ class ConfigValidator:
|
|
|
132
132
|
"type": str,
|
|
133
133
|
"default": "v3_chat",
|
|
134
134
|
"description": "提交信息生成模型名称"
|
|
135
|
+
},
|
|
136
|
+
"rank_strategy": {
|
|
137
|
+
"type": str,
|
|
138
|
+
"allowed": ["block", "file"],
|
|
139
|
+
"default": "block",
|
|
140
|
+
"description": "排序策略(block/file)"
|
|
135
141
|
}
|
|
136
142
|
}
|
|
137
143
|
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Union
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import json
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from autocoder.rag.token_counter import count_tokens
|
|
6
|
+
from autocoder.common import AutoCoderArgs,SourceCode
|
|
7
|
+
from byzerllm.utils.client.code_utils import extract_code
|
|
8
|
+
from autocoder.index.types import VerifyFileRelevance
|
|
9
|
+
import byzerllm
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
11
|
+
|
|
12
|
+
class PruneContext:
|
|
13
|
+
def __init__(self, max_tokens: int, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
|
|
14
|
+
self.max_tokens = max_tokens
|
|
15
|
+
self.args = args
|
|
16
|
+
self.llm = llm
|
|
17
|
+
|
|
18
|
+
def _delete_overflow_files(self, file_paths: List[str]) -> List[SourceCode]:
|
|
19
|
+
"""直接删除超出 token 限制的文件"""
|
|
20
|
+
total_tokens = 0
|
|
21
|
+
selected_files = []
|
|
22
|
+
token_count = 0
|
|
23
|
+
for file_path in file_paths:
|
|
24
|
+
try:
|
|
25
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
26
|
+
content = f.read()
|
|
27
|
+
token_count = count_tokens(content)
|
|
28
|
+
if total_tokens + token_count <= self.max_tokens:
|
|
29
|
+
total_tokens += token_count
|
|
30
|
+
print(f"{file_path} {token_count} {content}")
|
|
31
|
+
selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=token_count))
|
|
32
|
+
else:
|
|
33
|
+
break
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.error(f"Failed to read file {file_path}: {e}")
|
|
36
|
+
selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=token_count))
|
|
37
|
+
|
|
38
|
+
return selected_files
|
|
39
|
+
|
|
40
|
+
def _extract_code_snippets(self, file_paths: List[str], conversations: List[Dict[str, str]]) -> List[SourceCode]:
|
|
41
|
+
"""抽取关键代码片段策略"""
|
|
42
|
+
token_count = 0
|
|
43
|
+
selected_files = []
|
|
44
|
+
full_file_tokens = int(self.max_tokens * 0.8)
|
|
45
|
+
|
|
46
|
+
@byzerllm.prompt()
|
|
47
|
+
def extract_code_snippets(conversations: List[Dict[str, str]], content: str) -> str:
|
|
48
|
+
"""
|
|
49
|
+
根据提供的代码文件和对话历史提取相关代码片段。
|
|
50
|
+
|
|
51
|
+
处理示例:
|
|
52
|
+
<examples>
|
|
53
|
+
1. 代码文件:
|
|
54
|
+
<code_file>
|
|
55
|
+
1 def add(a, b):
|
|
56
|
+
2 return a + b
|
|
57
|
+
3 def sub(a, b):
|
|
58
|
+
4 return a - b
|
|
59
|
+
</code_file>
|
|
60
|
+
<conversation_history>
|
|
61
|
+
<user>: 如何实现加法?
|
|
62
|
+
</conversation_history>
|
|
63
|
+
|
|
64
|
+
输出:
|
|
65
|
+
```json
|
|
66
|
+
[
|
|
67
|
+
{"start_line": 1, "end_line": 2}
|
|
68
|
+
]
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
2. 代码文件:
|
|
72
|
+
1 class User:
|
|
73
|
+
2 def __init__(self, name):
|
|
74
|
+
3 self.name = name
|
|
75
|
+
4 def greet(self):
|
|
76
|
+
5 return f"Hello, {self.name}"
|
|
77
|
+
</code_file>
|
|
78
|
+
<conversation_history>
|
|
79
|
+
<user>: 如何创建一个User对象?
|
|
80
|
+
</conversation_history>
|
|
81
|
+
|
|
82
|
+
输出:
|
|
83
|
+
```json
|
|
84
|
+
[
|
|
85
|
+
{"start_line": 1, "end_line": 3}
|
|
86
|
+
]
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
3. 代码文件:
|
|
90
|
+
<code_file>
|
|
91
|
+
1 def foo():
|
|
92
|
+
2 pass
|
|
93
|
+
</code_file>
|
|
94
|
+
<conversation_history>
|
|
95
|
+
<user>: 如何实现减法?
|
|
96
|
+
</conversation_history>
|
|
97
|
+
|
|
98
|
+
输出:
|
|
99
|
+
```json
|
|
100
|
+
[]
|
|
101
|
+
```
|
|
102
|
+
</examples>
|
|
103
|
+
|
|
104
|
+
输入:
|
|
105
|
+
1. 代码文件内容:
|
|
106
|
+
<code_file>
|
|
107
|
+
{{ content }}
|
|
108
|
+
</code_file>
|
|
109
|
+
|
|
110
|
+
2. 对话历史:
|
|
111
|
+
<conversation_history>
|
|
112
|
+
{% for msg in conversations %}
|
|
113
|
+
<{{ msg.role }}>: {{ msg.content }}
|
|
114
|
+
{% endfor %}
|
|
115
|
+
</conversation_history>
|
|
116
|
+
|
|
117
|
+
任务:
|
|
118
|
+
1. 分析最后一个用户问题及其上下文。
|
|
119
|
+
2. 在代码文件中找出与问题相关的一个或多个重要代码段。
|
|
120
|
+
3. 对每个相关代码段,确定其起始行号(start_line)和结束行号(end_line)。
|
|
121
|
+
4. 代码段数量不超过4个。
|
|
122
|
+
|
|
123
|
+
输出要求:
|
|
124
|
+
1. 返回一个JSON数组,每个元素包含"start_line"和"end_line"。
|
|
125
|
+
2. start_line和end_line必须是整数,表示代码文件中的行号。
|
|
126
|
+
3. 行号从1开始计数。
|
|
127
|
+
4. 如果没有相关代码段,返回空数组[]。
|
|
128
|
+
|
|
129
|
+
输出格式:
|
|
130
|
+
严格的JSON数组,不包含其他文字或解释。
|
|
131
|
+
|
|
132
|
+
```json
|
|
133
|
+
[
|
|
134
|
+
{"start_line": 第一个代码段的起始行号, "end_line": 第一个代码段的结束行号},
|
|
135
|
+
{"start_line": 第二个代码段的起始行号, "end_line": 第二个代码段的结束行号}
|
|
136
|
+
]
|
|
137
|
+
```
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
for file_path in file_paths:
|
|
141
|
+
try:
|
|
142
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
143
|
+
content = f.read()
|
|
144
|
+
|
|
145
|
+
# 完整文件优先
|
|
146
|
+
tokens = count_tokens(content)
|
|
147
|
+
if token_count + tokens <= full_file_tokens:
|
|
148
|
+
selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=tokens))
|
|
149
|
+
token_count += tokens
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
# 抽取关键片段
|
|
153
|
+
extracted = extract_code_snippets.with_llm(self.llm).run(
|
|
154
|
+
conversations=conversations,
|
|
155
|
+
content=content
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if extracted:
|
|
159
|
+
json_str = extract_code(extracted)[0][1]
|
|
160
|
+
snippets = json.loads(json_str)
|
|
161
|
+
new_content = self._build_snippet_content(file_path, content, snippets)
|
|
162
|
+
|
|
163
|
+
snippet_tokens = count_tokens(new_content)
|
|
164
|
+
if token_count + snippet_tokens <= self.max_tokens:
|
|
165
|
+
selected_files.append(SourceCode(module_name=file_path,source_code=new_content,tokens=snippet_tokens))
|
|
166
|
+
token_count += snippet_tokens
|
|
167
|
+
else:
|
|
168
|
+
break
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.error(f"Failed to process {file_path}: {e}")
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
return selected_files
|
|
174
|
+
|
|
175
|
+
def _build_snippet_content(self, file_path: str, full_content: str, snippets: List[dict]) -> str:
|
|
176
|
+
"""构建包含代码片段的文件内容"""
|
|
177
|
+
lines = full_content.split("\n")
|
|
178
|
+
header = f"Snippets:\n"
|
|
179
|
+
|
|
180
|
+
content = []
|
|
181
|
+
for snippet in snippets:
|
|
182
|
+
start = max(0, snippet["start_line"] - 1)
|
|
183
|
+
end = min(len(lines), snippet["end_line"])
|
|
184
|
+
content.append(f"# Lines {start+1}-{end} ({snippet.get('reason','')})")
|
|
185
|
+
content.extend(lines[start:end])
|
|
186
|
+
|
|
187
|
+
return header + "\n".join(content)
|
|
188
|
+
|
|
189
|
+
def handle_overflow(
|
|
190
|
+
self,
|
|
191
|
+
file_paths: List[str],
|
|
192
|
+
conversations: List[Dict[str, str]],
|
|
193
|
+
strategy: str = "score"
|
|
194
|
+
) -> List[SourceCode]:
|
|
195
|
+
"""
|
|
196
|
+
处理超出 token 限制的文件
|
|
197
|
+
:param file_paths: 要处理的文件路径列表
|
|
198
|
+
:param conversations: 对话上下文(用于提取策略)
|
|
199
|
+
:param strategy: 处理策略 (delete/extract/score)
|
|
200
|
+
"""
|
|
201
|
+
total_tokens,sources = self._count_tokens(file_paths)
|
|
202
|
+
if total_tokens <= self.max_tokens:
|
|
203
|
+
return sources
|
|
204
|
+
|
|
205
|
+
if strategy == "score":
|
|
206
|
+
return self._score_and_filter_files(file_paths, conversations)
|
|
207
|
+
if strategy == "delete":
|
|
208
|
+
return self._delete_overflow_files(file_paths)
|
|
209
|
+
elif strategy == "extract":
|
|
210
|
+
return self._extract_code_snippets(file_paths, conversations)
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(f"无效策略: {strategy}. 可选值: delete/extract/score")
|
|
213
|
+
|
|
214
|
+
def _count_tokens(self, file_paths: List[str]) -> int:
|
|
215
|
+
"""计算文件总token数"""
|
|
216
|
+
total_tokens = 0
|
|
217
|
+
sources = []
|
|
218
|
+
for file_path in file_paths:
|
|
219
|
+
try:
|
|
220
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
221
|
+
content = f.read()
|
|
222
|
+
sources.append(SourceCode(module_name=file_path,source_code=content,tokens=count_tokens(content)))
|
|
223
|
+
total_tokens += count_tokens(content)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
logger.error(f"Failed to read file {file_path}: {e}")
|
|
226
|
+
total_tokens += 0
|
|
227
|
+
return total_tokens,sources
|
|
228
|
+
|
|
229
|
+
def _score_and_filter_files(self, file_paths: List[str], conversations: List[Dict[str, str]]) -> List[SourceCode]:
|
|
230
|
+
"""根据文件相关性评分过滤文件,直到token数大于max_tokens 停止追加"""
|
|
231
|
+
selected_files = []
|
|
232
|
+
total_tokens = 0
|
|
233
|
+
scored_files = []
|
|
234
|
+
|
|
235
|
+
@byzerllm.prompt()
|
|
236
|
+
def verify_file_relevance(file_content: str, conversations: List[Dict[str, str]]) -> str:
|
|
237
|
+
"""
|
|
238
|
+
请验证下面的文件内容是否与用户对话相关:
|
|
239
|
+
|
|
240
|
+
文件内容:
|
|
241
|
+
{{ file_content }}
|
|
242
|
+
|
|
243
|
+
历史对话:
|
|
244
|
+
<conversation_history>
|
|
245
|
+
{% for msg in conversations %}
|
|
246
|
+
<{{ msg.role }}>: {{ msg.content }}
|
|
247
|
+
{% endfor %}
|
|
248
|
+
</conversation_history>
|
|
249
|
+
|
|
250
|
+
相关是指,需要依赖这个文件提供上下文,或者需要修改这个文件才能解决用户的问题。
|
|
251
|
+
请给出相应的可能性分数:0-10,并结合用户问题,理由控制在50字以内。格式如下:
|
|
252
|
+
|
|
253
|
+
```json
|
|
254
|
+
{
|
|
255
|
+
"relevant_score": 0-10,
|
|
256
|
+
"reason": "这是相关的原因(不超过10个中文字符)..."
|
|
257
|
+
}
|
|
258
|
+
```
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def _score_file(file_path: str) -> dict:
|
|
262
|
+
try:
|
|
263
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
264
|
+
content = f.read()
|
|
265
|
+
tokens = count_tokens(content)
|
|
266
|
+
result = verify_file_relevance.with_llm(self.llm).with_return_type(VerifyFileRelevance).run(
|
|
267
|
+
file_content=content,
|
|
268
|
+
conversations=conversations
|
|
269
|
+
)
|
|
270
|
+
return {
|
|
271
|
+
"file_path": file_path,
|
|
272
|
+
"score": result.relevant_score,
|
|
273
|
+
"tokens": tokens,
|
|
274
|
+
"content": content
|
|
275
|
+
}
|
|
276
|
+
except Exception as e:
|
|
277
|
+
logger.error(f"Failed to score file {file_path}: {e}")
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
# 使用线程池并行打分
|
|
281
|
+
with ThreadPoolExecutor() as executor:
|
|
282
|
+
futures = [executor.submit(_score_file, file_path) for file_path in file_paths]
|
|
283
|
+
for future in as_completed(futures):
|
|
284
|
+
result = future.result()
|
|
285
|
+
print(f"score file {result['file_path']} {result['score']}")
|
|
286
|
+
if result:
|
|
287
|
+
scored_files.append(result)
|
|
288
|
+
|
|
289
|
+
# 第二步:按分数从高到低排序
|
|
290
|
+
scored_files.sort(key=lambda x: x["score"], reverse=True)
|
|
291
|
+
|
|
292
|
+
# 第三步:从高分开始过滤,直到token数大于max_tokens 停止追加
|
|
293
|
+
for file_info in scored_files:
|
|
294
|
+
if total_tokens + file_info["tokens"] <= self.max_tokens:
|
|
295
|
+
selected_files.append(SourceCode(
|
|
296
|
+
module_name=file_info["file_path"],
|
|
297
|
+
source_code=file_info["content"],
|
|
298
|
+
tokens=file_info["tokens"]
|
|
299
|
+
))
|
|
300
|
+
total_tokens += file_info["tokens"]
|
|
301
|
+
else:
|
|
302
|
+
break
|
|
303
|
+
|
|
304
|
+
return selected_files
|
|
305
|
+
|
autocoder/index/entry.py
CHANGED
|
@@ -24,6 +24,7 @@ from autocoder.index.filter.normal_filter import NormalFilter
|
|
|
24
24
|
from autocoder.index.index import IndexManager
|
|
25
25
|
from loguru import logger
|
|
26
26
|
from autocoder.common import SourceCodeList
|
|
27
|
+
from autocoder.common.context_pruner import PruneContext
|
|
27
28
|
|
|
28
29
|
def build_index_and_filter_files(
|
|
29
30
|
llm, args: AutoCoderArgs, sources: List[SourceCode]
|
|
@@ -113,8 +114,13 @@ def build_index_and_filter_files(
|
|
|
113
114
|
raise KeyboardInterrupt(printer.get_message_from_key_with_format("quick_filter_failed",error=quick_filter_result.error_message))
|
|
114
115
|
|
|
115
116
|
# Merge quick filter results into final_files
|
|
116
|
-
|
|
117
|
-
|
|
117
|
+
if args.context_prune:
|
|
118
|
+
context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
|
|
119
|
+
pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
120
|
+
for source_file in pruned_files:
|
|
121
|
+
final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
|
|
122
|
+
else:
|
|
123
|
+
final_files.update(quick_filter_result.files)
|
|
118
124
|
|
|
119
125
|
if not args.skip_filter_index and not args.index_filter_model:
|
|
120
126
|
model_name = getattr(index_manager.llm, 'default_model_name', None)
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
from typing import List, Union,Dict,Any
|
|
1
|
+
from typing import List, Union,Dict,Any,Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
2
4
|
from autocoder.index.types import IndexItem
|
|
3
5
|
from autocoder.common import SourceCode, AutoCoderArgs
|
|
4
6
|
import byzerllm
|
|
@@ -25,6 +27,11 @@ def get_file_path(file_path):
|
|
|
25
27
|
return file_path.strip()[2:]
|
|
26
28
|
return file_path
|
|
27
29
|
|
|
30
|
+
class NormalFilterResult(BaseModel):
|
|
31
|
+
files: Dict[str, TargetFile]
|
|
32
|
+
has_error: bool
|
|
33
|
+
error_message: Optional[str] = None
|
|
34
|
+
file_positions: Optional[Dict[str, int]]
|
|
28
35
|
|
|
29
36
|
class NormalFilter():
|
|
30
37
|
def __init__(self, index_manager: IndexManager,stats:Dict[str,Any],sources:List[SourceCode]):
|
|
@@ -167,4 +174,8 @@ class NormalFilter():
|
|
|
167
174
|
# Keep all files, not just verified ones
|
|
168
175
|
final_files = verified_files
|
|
169
176
|
|
|
170
|
-
return
|
|
177
|
+
return NormalFilterResult(
|
|
178
|
+
files=final_files,
|
|
179
|
+
has_error=False,
|
|
180
|
+
error_message=None
|
|
181
|
+
)
|