auto-coder 0.1.280__py3-none-any.whl → 0.1.282__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.280.dist-info → auto_coder-0.1.282.dist-info}/METADATA +2 -2
- {auto_coder-0.1.280.dist-info → auto_coder-0.1.282.dist-info}/RECORD +14 -13
- autocoder/auto_coder.py +2 -1
- autocoder/auto_coder_rag.py +93 -29
- autocoder/common/context_pruner.py +168 -206
- autocoder/index/entry.py +1 -1
- autocoder/rag/cache/local_byzer_storage_cache.py +457 -0
- autocoder/rag/document_retriever.py +22 -53
- autocoder/rag/long_context_rag.py +18 -1
- autocoder/version.py +1 -1
- {auto_coder-0.1.280.dist-info → auto_coder-0.1.282.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.280.dist-info → auto_coder-0.1.282.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.280.dist-info → auto_coder-0.1.282.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.280.dist-info → auto_coder-0.1.282.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
import json
|
|
5
5
|
from loguru import logger
|
|
6
6
|
from autocoder.rag.token_counter import count_tokens
|
|
7
|
-
from autocoder.common import AutoCoderArgs,SourceCode
|
|
7
|
+
from autocoder.common import AutoCoderArgs, SourceCode
|
|
8
8
|
from byzerllm.utils.client.code_utils import extract_code
|
|
9
9
|
from autocoder.index.types import VerifyFileRelevance
|
|
10
10
|
import byzerllm
|
|
@@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
13
13
|
from autocoder.common.printer import Printer
|
|
14
14
|
from autocoder.common.auto_coder_lang import get_message_with_format
|
|
15
15
|
|
|
16
|
+
|
|
16
17
|
class PruneContext:
|
|
17
18
|
def __init__(self, max_tokens: int, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
|
|
18
19
|
self.max_tokens = max_tokens
|
|
@@ -22,12 +23,12 @@ class PruneContext:
|
|
|
22
23
|
|
|
23
24
|
def _split_content_with_sliding_window(self, content: str, window_size=100, overlap=20) -> List[Tuple[int, int, str]]:
|
|
24
25
|
"""使用滑动窗口分割大文件内容,返回包含行号信息的文本块
|
|
25
|
-
|
|
26
|
+
|
|
26
27
|
Args:
|
|
27
28
|
content: 要分割的文件内容
|
|
28
29
|
window_size: 每个窗口包含的行数
|
|
29
30
|
overlap: 相邻窗口的重叠行数
|
|
30
|
-
|
|
31
|
+
|
|
31
32
|
Returns:
|
|
32
33
|
List[Tuple[int, int, str]]: 返回元组列表,每个元组包含:
|
|
33
34
|
- 起始行号(从1开始),在原始文件的绝对行号
|
|
@@ -38,107 +39,58 @@ class PruneContext:
|
|
|
38
39
|
lines = content.splitlines()
|
|
39
40
|
chunks = []
|
|
40
41
|
start = 0
|
|
41
|
-
|
|
42
|
+
|
|
42
43
|
while start < len(lines):
|
|
43
44
|
# 计算当前窗口的结束位置
|
|
44
45
|
end = min(start + window_size, len(lines))
|
|
45
|
-
|
|
46
|
+
|
|
46
47
|
# 计算实际的起始位置(考虑重叠)
|
|
47
48
|
actual_start = max(0, start - overlap)
|
|
48
|
-
|
|
49
|
+
|
|
49
50
|
# 提取当前窗口的行
|
|
50
51
|
chunk_lines = lines[actual_start:end]
|
|
51
|
-
|
|
52
|
+
|
|
52
53
|
# 为每一行添加行号
|
|
53
54
|
# 行号从actual_start+1开始,保持与原文件的绝对行号一致
|
|
54
55
|
chunk_content = "\n".join([
|
|
55
56
|
f"{i+1} {line}" for i, line in enumerate(chunk_lines, start=actual_start)
|
|
56
57
|
])
|
|
57
|
-
|
|
58
|
+
|
|
58
59
|
# 保存分块信息:(起始行号, 结束行号, 带行号的内容)
|
|
59
60
|
# 行号从1开始计数
|
|
60
61
|
chunks.append((actual_start + 1, end, chunk_content))
|
|
61
|
-
|
|
62
|
+
|
|
62
63
|
# 移动到下一个窗口的起始位置
|
|
63
64
|
# 减去overlap确保窗口重叠
|
|
64
65
|
start += (window_size - overlap)
|
|
65
|
-
|
|
66
|
-
return chunks
|
|
67
|
-
|
|
68
|
-
def _merge_overlapping_snippets(self, snippets: List[dict]) -> List[dict]:
|
|
69
|
-
"""合并重叠或相邻的代码片段
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
snippets: 代码片段列表,每个片段是包含start_line和end_line的字典
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
List[dict]: 合并后的代码片段列表
|
|
76
|
-
|
|
77
|
-
示例:
|
|
78
|
-
输入: [
|
|
79
|
-
{"start_line": 1, "end_line": 5},
|
|
80
|
-
{"start_line": 4, "end_line": 8},
|
|
81
|
-
{"start_line": 10, "end_line": 12}
|
|
82
|
-
]
|
|
83
|
-
输出: [
|
|
84
|
-
{"start_line": 1, "end_line": 8},
|
|
85
|
-
{"start_line": 10, "end_line": 12}
|
|
86
|
-
]
|
|
87
|
-
"""
|
|
88
|
-
if not snippets:
|
|
89
|
-
return []
|
|
90
|
-
|
|
91
|
-
# 按起始行排序
|
|
92
|
-
sorted_snippets = sorted(snippets, key=lambda x: x["start_line"])
|
|
93
|
-
|
|
94
|
-
merged = [sorted_snippets[0]]
|
|
95
|
-
|
|
96
|
-
for current in sorted_snippets[1:]:
|
|
97
|
-
last = merged[-1]
|
|
98
|
-
|
|
99
|
-
# 判断是否需要合并:
|
|
100
|
-
# 1. 如果当前片段的起始行小于等于上一个片段的结束行+1
|
|
101
|
-
# 2. +1是为了合并相邻的片段,比如1-5和6-8应该合并为1-8
|
|
102
|
-
if current["start_line"] <= last["end_line"] + 1:
|
|
103
|
-
# 合并区间:
|
|
104
|
-
# - 起始行取两者最小值
|
|
105
|
-
# - 结束行取两者最大值
|
|
106
|
-
merged[-1] = {
|
|
107
|
-
"start_line": min(last["start_line"], current["start_line"]),
|
|
108
|
-
"end_line": max(last["end_line"], current["end_line"])
|
|
109
|
-
}
|
|
110
|
-
else:
|
|
111
|
-
# 如果不重叠且不相邻,则作为新片段添加
|
|
112
|
-
merged.append(current)
|
|
113
66
|
|
|
114
|
-
return
|
|
115
|
-
|
|
67
|
+
return chunks
|
|
68
|
+
|
|
116
69
|
|
|
117
|
-
def _delete_overflow_files(self,
|
|
70
|
+
def _delete_overflow_files(self, file_sources: List[SourceCode]) -> List[SourceCode]:
|
|
118
71
|
"""直接删除超出 token 限制的文件"""
|
|
119
72
|
total_tokens = 0
|
|
120
73
|
selected_files = []
|
|
121
74
|
token_count = 0
|
|
122
|
-
for
|
|
123
|
-
try:
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
token_count = count_tokens(
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
75
|
+
for file_source in file_sources:
|
|
76
|
+
try:
|
|
77
|
+
token_count = file_source.tokens
|
|
78
|
+
if token_count <= 0:
|
|
79
|
+
token_count = count_tokens(file_source.source_code)
|
|
80
|
+
|
|
81
|
+
if total_tokens + token_count <= self.max_tokens:
|
|
82
|
+
total_tokens += token_count
|
|
83
|
+
print(f"{file_source.module_name} {token_count}")
|
|
84
|
+
selected_files.append(file_source)
|
|
85
|
+
else:
|
|
86
|
+
break
|
|
133
87
|
except Exception as e:
|
|
134
|
-
logger.error(f"Failed to read file {
|
|
135
|
-
selected_files.append(
|
|
88
|
+
logger.error(f"Failed to read file {file_source.module_name}: {e}")
|
|
89
|
+
selected_files.append(file_source)
|
|
136
90
|
|
|
137
91
|
return selected_files
|
|
138
|
-
|
|
139
|
-
|
|
140
92
|
|
|
141
|
-
def _extract_code_snippets(self,
|
|
93
|
+
def _extract_code_snippets(self, file_sources: List[SourceCode], conversations: List[Dict[str, str]]) -> List[SourceCode]:
|
|
142
94
|
"""抽取关键代码片段策略"""
|
|
143
95
|
token_count = 0
|
|
144
96
|
selected_files = []
|
|
@@ -236,108 +188,111 @@ class PruneContext:
|
|
|
236
188
|
|
|
237
189
|
输出格式:
|
|
238
190
|
严格的JSON数组,不包含其他文字或解释。
|
|
239
|
-
|
|
191
|
+
|
|
240
192
|
```json
|
|
241
193
|
[
|
|
242
194
|
{"start_line": 第一个代码段的起始行号, "end_line": 第一个代码段的结束行号},
|
|
243
195
|
{"start_line": 第二个代码段的起始行号, "end_line": 第二个代码段的结束行号}
|
|
244
196
|
]
|
|
245
197
|
```
|
|
246
|
-
|
|
198
|
+
|
|
247
199
|
"""
|
|
248
|
-
|
|
249
200
|
|
|
250
|
-
for
|
|
251
|
-
try:
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
201
|
+
for file_source in file_sources:
|
|
202
|
+
try:
|
|
203
|
+
# 完整文件优先
|
|
204
|
+
tokens = file_source.tokens
|
|
205
|
+
if token_count + tokens <= full_file_tokens:
|
|
206
|
+
selected_files.append(SourceCode(
|
|
207
|
+
module_name=file_source.module_name, source_code=file_source.source_code, tokens=tokens))
|
|
208
|
+
token_count += tokens
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
# 如果单个文件太大,那么先按滑动窗口分割,然后对窗口抽取代码片段
|
|
212
|
+
if tokens > self.max_tokens:
|
|
213
|
+
self.printer.print_in_terminal(
|
|
214
|
+
"file_sliding_window_processing", file_path=file_source.module_name, tokens=tokens)
|
|
215
|
+
|
|
216
|
+
chunks = self._split_content_with_sliding_window(file_source.source_code,
|
|
217
|
+
self.args.context_prune_sliding_window_size,
|
|
218
|
+
self.args.context_prune_sliding_window_overlap)
|
|
219
|
+
all_snippets = []
|
|
220
|
+
for chunk_start, chunk_end, chunk_content in chunks:
|
|
221
|
+
extracted = extract_code_snippets.with_llm(self.llm).run(
|
|
222
|
+
conversations=conversations,
|
|
223
|
+
content=chunk_content,
|
|
224
|
+
is_partial_content=True
|
|
225
|
+
)
|
|
226
|
+
if extracted:
|
|
227
|
+
json_str = extract_code(extracted)[0][1]
|
|
228
|
+
snippets = json.loads(json_str)
|
|
229
|
+
|
|
230
|
+
# 获取到的本来就是在原始文件里的绝对行号
|
|
231
|
+
# 后续在构建代码片段内容时,会为了适配数组操作修改行号,这里无需处理
|
|
232
|
+
adjusted_snippets = [{
|
|
233
|
+
"start_line": snippet["start_line"],
|
|
234
|
+
"end_line": snippet["end_line"]
|
|
235
|
+
} for snippet in snippets]
|
|
236
|
+
all_snippets.extend(adjusted_snippets)
|
|
237
|
+
merged_snippets = self._merge_overlapping_snippets(
|
|
238
|
+
all_snippets)
|
|
239
|
+
content_snippets = self._build_snippet_content(
|
|
240
|
+
file_source.module_name, file_source.source_code, merged_snippets)
|
|
241
|
+
snippet_tokens = count_tokens(content_snippets)
|
|
242
|
+
if token_count + snippet_tokens <= self.max_tokens:
|
|
243
|
+
selected_files.append(SourceCode(
|
|
244
|
+
module_name=file_source.module_name, source_code=content_snippets, tokens=snippet_tokens))
|
|
245
|
+
token_count += snippet_tokens
|
|
246
|
+
self.printer.print_in_terminal("file_snippet_procesed", file_path=file_source.module_name,
|
|
247
|
+
total_tokens=token_count,
|
|
248
|
+
tokens=tokens,
|
|
249
|
+
snippet_tokens=snippet_tokens)
|
|
260
250
|
continue
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
# 抽取关键片段
|
|
301
|
-
lines = content.splitlines()
|
|
302
|
-
new_content = ""
|
|
303
|
-
|
|
304
|
-
## 将文件内容按行编号
|
|
305
|
-
for index,line in enumerate(lines):
|
|
306
|
-
new_content += f"{index+1} {line}\n"
|
|
307
|
-
|
|
308
|
-
## 抽取代码片段
|
|
309
|
-
self.printer.print_in_terminal("file_snippet_processing", file_path=file_path)
|
|
310
|
-
extracted = extract_code_snippets.with_llm(self.llm).run(
|
|
311
|
-
conversations=conversations,
|
|
312
|
-
content=new_content
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
## 构建代码片段内容
|
|
316
|
-
if extracted:
|
|
317
|
-
json_str = extract_code(extracted)[0][1]
|
|
318
|
-
snippets = json.loads(json_str)
|
|
319
|
-
content_snippets = self._build_snippet_content(file_path, content, snippets)
|
|
320
|
-
|
|
321
|
-
snippet_tokens = count_tokens(content_snippets)
|
|
322
|
-
if token_count + snippet_tokens <= self.max_tokens:
|
|
323
|
-
selected_files.append(SourceCode(module_name=file_path,
|
|
324
|
-
source_code=content_snippets,
|
|
325
|
-
tokens=snippet_tokens))
|
|
326
|
-
token_count += snippet_tokens
|
|
327
|
-
self.printer.print_in_terminal("file_snippet_procesed", file_path=file_path,
|
|
328
|
-
total_tokens = token_count,
|
|
329
|
-
tokens=tokens,
|
|
330
|
-
snippet_tokens=snippet_tokens)
|
|
331
|
-
else:
|
|
332
|
-
break
|
|
251
|
+
else:
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
# 抽取关键片段
|
|
255
|
+
lines = file_source.source_code.splitlines()
|
|
256
|
+
new_content = ""
|
|
257
|
+
|
|
258
|
+
# 将文件内容按行编号
|
|
259
|
+
for index, line in enumerate(lines):
|
|
260
|
+
new_content += f"{index+1} {line}\n"
|
|
261
|
+
|
|
262
|
+
# 抽取代码片段
|
|
263
|
+
self.printer.print_in_terminal(
|
|
264
|
+
"file_snippet_processing", file_path=file_source.module_name)
|
|
265
|
+
extracted = extract_code_snippets.with_llm(self.llm).run(
|
|
266
|
+
conversations=conversations,
|
|
267
|
+
content=new_content
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# 构建代码片段内容
|
|
271
|
+
if extracted:
|
|
272
|
+
json_str = extract_code(extracted)[0][1]
|
|
273
|
+
snippets = json.loads(json_str)
|
|
274
|
+
content_snippets = self._build_snippet_content(
|
|
275
|
+
file_source.module_name, file_source.source_code, snippets)
|
|
276
|
+
|
|
277
|
+
snippet_tokens = count_tokens(content_snippets)
|
|
278
|
+
if token_count + snippet_tokens <= self.max_tokens:
|
|
279
|
+
selected_files.append(SourceCode(module_name=file_source.module_name,
|
|
280
|
+
source_code=content_snippets,
|
|
281
|
+
tokens=snippet_tokens))
|
|
282
|
+
token_count += snippet_tokens
|
|
283
|
+
self.printer.print_in_terminal("file_snippet_procesed", file_path=file_source.module_name,
|
|
284
|
+
total_tokens=token_count,
|
|
285
|
+
tokens=tokens,
|
|
286
|
+
snippet_tokens=snippet_tokens)
|
|
287
|
+
else:
|
|
288
|
+
break
|
|
333
289
|
except Exception as e:
|
|
334
|
-
logger.error(f"Failed to process {
|
|
290
|
+
logger.error(f"Failed to process {file_source.module_name}: {e}")
|
|
335
291
|
continue
|
|
336
292
|
|
|
337
293
|
return selected_files
|
|
338
|
-
|
|
339
294
|
|
|
340
|
-
def _merge_overlapping_snippets(self, snippets: List[dict]) -> List[dict]:
|
|
295
|
+
def _merge_overlapping_snippets(self, snippets: List[dict]) -> List[dict]:
|
|
341
296
|
if not snippets:
|
|
342
297
|
return []
|
|
343
298
|
|
|
@@ -367,27 +322,29 @@ class PruneContext:
|
|
|
367
322
|
for snippet in snippets:
|
|
368
323
|
start = max(0, snippet["start_line"] - 1)
|
|
369
324
|
end = min(len(lines), snippet["end_line"])
|
|
370
|
-
content.append(
|
|
325
|
+
content.append(
|
|
326
|
+
f"# Lines {start+1}-{end} ({snippet.get('reason','')})")
|
|
371
327
|
content.extend(lines[start:end])
|
|
372
328
|
|
|
373
329
|
return header + "\n".join(content)
|
|
374
330
|
|
|
375
331
|
def handle_overflow(
|
|
376
332
|
self,
|
|
377
|
-
|
|
333
|
+
file_sources: List[SourceCode],
|
|
378
334
|
conversations: List[Dict[str, str]],
|
|
379
|
-
strategy: str = "score"
|
|
335
|
+
strategy: str = "score"
|
|
380
336
|
) -> List[SourceCode]:
|
|
381
337
|
"""
|
|
382
338
|
处理超出 token 限制的文件
|
|
383
|
-
:param
|
|
339
|
+
:param file_sources: 要处理的文件
|
|
384
340
|
:param conversations: 对话上下文(用于提取策略)
|
|
385
341
|
:param strategy: 处理策略 (delete/extract/score)
|
|
386
342
|
"""
|
|
387
|
-
|
|
343
|
+
file_paths = [file_source.module_name for file_source in file_sources]
|
|
344
|
+
total_tokens, sources = self._count_tokens(file_sources=file_sources)
|
|
388
345
|
if total_tokens <= self.max_tokens:
|
|
389
346
|
return sources
|
|
390
|
-
|
|
347
|
+
|
|
391
348
|
self.printer.print_in_terminal(
|
|
392
349
|
"context_pruning_reason",
|
|
393
350
|
total_tokens=total_tokens,
|
|
@@ -396,9 +353,9 @@ class PruneContext:
|
|
|
396
353
|
)
|
|
397
354
|
|
|
398
355
|
self.printer.print_in_terminal(
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
356
|
+
"sorted_files_message",
|
|
357
|
+
files=file_paths
|
|
358
|
+
)
|
|
402
359
|
|
|
403
360
|
self.printer.print_in_terminal(
|
|
404
361
|
"context_pruning_start",
|
|
@@ -407,35 +364,42 @@ class PruneContext:
|
|
|
407
364
|
strategy=strategy
|
|
408
365
|
)
|
|
409
366
|
|
|
410
|
-
if strategy == "score":
|
|
411
|
-
return self._score_and_filter_files(
|
|
367
|
+
if strategy == "score":
|
|
368
|
+
return self._score_and_filter_files(sources, conversations)
|
|
412
369
|
if strategy == "delete":
|
|
413
|
-
return self._delete_overflow_files(
|
|
370
|
+
return self._delete_overflow_files(sources)
|
|
414
371
|
elif strategy == "extract":
|
|
415
|
-
return self._extract_code_snippets(
|
|
372
|
+
return self._extract_code_snippets(sources, conversations)
|
|
416
373
|
else:
|
|
417
374
|
raise ValueError(f"无效策略: {strategy}. 可选值: delete/extract/score")
|
|
418
375
|
|
|
419
|
-
def _count_tokens(self,
|
|
376
|
+
def _count_tokens(self, file_sources: List[SourceCode]) -> int:
|
|
420
377
|
"""计算文件总token数"""
|
|
421
378
|
total_tokens = 0
|
|
422
379
|
sources = []
|
|
423
|
-
for
|
|
380
|
+
for file_source in file_sources:
|
|
424
381
|
try:
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
382
|
+
if file_source.tokens > 0:
|
|
383
|
+
tokens = file_source.tokens
|
|
384
|
+
total_tokens += file_source.tokens
|
|
385
|
+
else:
|
|
386
|
+
tokens = count_tokens(file_source.source_code)
|
|
387
|
+
total_tokens += tokens
|
|
388
|
+
|
|
389
|
+
sources.append(SourceCode(module_name=file_source.module_name,
|
|
390
|
+
source_code=file_source.source_code, tokens=tokens))
|
|
391
|
+
|
|
429
392
|
except Exception as e:
|
|
430
|
-
logger.error(f"Failed to
|
|
431
|
-
|
|
432
|
-
|
|
393
|
+
logger.error(f"Failed to count tokens for {file_source.module_name}: {e}")
|
|
394
|
+
sources.append(SourceCode(module_name=file_source.module_name,
|
|
395
|
+
source_code=file_source.source_code, tokens=0))
|
|
396
|
+
return total_tokens, sources
|
|
433
397
|
|
|
434
|
-
def _score_and_filter_files(self,
|
|
398
|
+
def _score_and_filter_files(self, file_sources: List[SourceCode], conversations: List[Dict[str, str]]) -> List[SourceCode]:
|
|
435
399
|
"""根据文件相关性评分过滤文件,直到token数大于max_tokens 停止追加"""
|
|
436
400
|
selected_files = []
|
|
437
401
|
total_tokens = 0
|
|
438
|
-
scored_files = []
|
|
402
|
+
scored_files = []
|
|
439
403
|
|
|
440
404
|
@byzerllm.prompt()
|
|
441
405
|
def verify_file_relevance(file_content: str, conversations: List[Dict[str, str]]) -> str:
|
|
@@ -463,30 +427,28 @@ class PruneContext:
|
|
|
463
427
|
```
|
|
464
428
|
"""
|
|
465
429
|
|
|
466
|
-
def _score_file(
|
|
467
|
-
try:
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
"tokens": tokens,
|
|
479
|
-
"content": content
|
|
480
|
-
}
|
|
430
|
+
def _score_file(file_source: SourceCode) -> dict:
|
|
431
|
+
try:
|
|
432
|
+
result = verify_file_relevance.with_llm(self.llm).with_return_type(VerifyFileRelevance).run(
|
|
433
|
+
file_content=file_source.source_code,
|
|
434
|
+
conversations=conversations
|
|
435
|
+
)
|
|
436
|
+
return {
|
|
437
|
+
"file_path": file_source.module_name,
|
|
438
|
+
"score": result.relevant_score,
|
|
439
|
+
"tokens": file_source.tokens,
|
|
440
|
+
"content": file_source.source_code
|
|
441
|
+
}
|
|
481
442
|
except Exception as e:
|
|
482
|
-
logger.error(f"Failed to score file {
|
|
443
|
+
logger.error(f"Failed to score file {file_source.module_name}: {e}")
|
|
483
444
|
return None
|
|
484
445
|
|
|
485
446
|
# 使用线程池并行打分
|
|
486
447
|
with ThreadPoolExecutor() as executor:
|
|
487
|
-
futures = [executor.submit(_score_file,
|
|
448
|
+
futures = [executor.submit(_score_file, file_source)
|
|
449
|
+
for file_source in file_sources]
|
|
488
450
|
for future in as_completed(futures):
|
|
489
|
-
result = future.result()
|
|
451
|
+
result = future.result()
|
|
490
452
|
if result:
|
|
491
453
|
self.printer.print_str_in_terminal(
|
|
492
454
|
get_message_with_format(
|
autocoder/index/entry.py
CHANGED
|
@@ -293,7 +293,7 @@ def build_index_and_filter_files(
|
|
|
293
293
|
# 根据 sorted_file_paths 重新排序 temp_sources
|
|
294
294
|
temp_sources.sort(key=lambda x: sorted_file_paths.index(x.module_name) if x.module_name in sorted_file_paths else len(sorted_file_paths))
|
|
295
295
|
|
|
296
|
-
pruned_files = context_pruner.handle_overflow(
|
|
296
|
+
pruned_files = context_pruner.handle_overflow(temp_sources, [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
297
297
|
source_code_list.sources = pruned_files
|
|
298
298
|
|
|
299
299
|
|