auto-coder 0.1.268__py3-none-any.whl → 0.1.270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.270.dist-info}/METADATA +2 -2
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.270.dist-info}/RECORD +23 -20
- autocoder/agent/auto_learn_from_commit.py +209 -0
- autocoder/auto_coder.py +4 -0
- autocoder/auto_coder_runner.py +2647 -0
- autocoder/chat_auto_coder.py +54 -2630
- autocoder/commands/auto_command.py +23 -33
- autocoder/common/__init__.py +6 -2
- autocoder/common/auto_coder_lang.py +21 -4
- autocoder/common/auto_configure.py +41 -30
- autocoder/common/code_modification_ranker.py +55 -11
- autocoder/common/command_templates.py +2 -3
- autocoder/common/context_pruner.py +214 -14
- autocoder/common/conversation_pruner.py +11 -10
- autocoder/index/entry.py +44 -22
- autocoder/index/index.py +1 -1
- autocoder/utils/auto_project_type.py +120 -0
- autocoder/utils/model_provider_selector.py +23 -23
- autocoder/version.py +1 -1
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.270.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.270.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.270.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.270.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import List, Dict, Any, Union
|
|
2
|
+
from typing import Tuple
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
import json
|
|
4
5
|
from loguru import logger
|
|
@@ -19,6 +20,100 @@ class PruneContext:
|
|
|
19
20
|
self.llm = llm
|
|
20
21
|
self.printer = Printer()
|
|
21
22
|
|
|
23
|
+
def _split_content_with_sliding_window(self, content: str, window_size=100, overlap=20) -> List[Tuple[int, int, str]]:
|
|
24
|
+
"""使用滑动窗口分割大文件内容,返回包含行号信息的文本块
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
content: 要分割的文件内容
|
|
28
|
+
window_size: 每个窗口包含的行数
|
|
29
|
+
overlap: 相邻窗口的重叠行数
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
List[Tuple[int, int, str]]: 返回元组列表,每个元组包含:
|
|
33
|
+
- 起始行号(从1开始),在原始文件的绝对行号
|
|
34
|
+
- 结束行号,在原始文件的绝对行号
|
|
35
|
+
- 带行号的内容文本
|
|
36
|
+
"""
|
|
37
|
+
# 按行分割内容
|
|
38
|
+
lines = content.splitlines()
|
|
39
|
+
chunks = []
|
|
40
|
+
start = 0
|
|
41
|
+
|
|
42
|
+
while start < len(lines):
|
|
43
|
+
# 计算当前窗口的结束位置
|
|
44
|
+
end = min(start + window_size, len(lines))
|
|
45
|
+
|
|
46
|
+
# 计算实际的起始位置(考虑重叠)
|
|
47
|
+
actual_start = max(0, start - overlap)
|
|
48
|
+
|
|
49
|
+
# 提取当前窗口的行
|
|
50
|
+
chunk_lines = lines[actual_start:end]
|
|
51
|
+
|
|
52
|
+
# 为每一行添加行号
|
|
53
|
+
# 行号从actual_start+1开始,保持与原文件的绝对行号一致
|
|
54
|
+
chunk_content = "\n".join([
|
|
55
|
+
f"{i+1} {line}" for i, line in enumerate(chunk_lines, start=actual_start)
|
|
56
|
+
])
|
|
57
|
+
|
|
58
|
+
# 保存分块信息:(起始行号, 结束行号, 带行号的内容)
|
|
59
|
+
# 行号从1开始计数
|
|
60
|
+
chunks.append((actual_start + 1, end, chunk_content))
|
|
61
|
+
|
|
62
|
+
# 移动到下一个窗口的起始位置
|
|
63
|
+
# 减去overlap确保窗口重叠
|
|
64
|
+
start += (window_size - overlap)
|
|
65
|
+
|
|
66
|
+
return chunks
|
|
67
|
+
|
|
68
|
+
def _merge_overlapping_snippets(self, snippets: List[dict]) -> List[dict]:
|
|
69
|
+
"""合并重叠或相邻的代码片段
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
snippets: 代码片段列表,每个片段是包含start_line和end_line的字典
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
List[dict]: 合并后的代码片段列表
|
|
76
|
+
|
|
77
|
+
示例:
|
|
78
|
+
输入: [
|
|
79
|
+
{"start_line": 1, "end_line": 5},
|
|
80
|
+
{"start_line": 4, "end_line": 8},
|
|
81
|
+
{"start_line": 10, "end_line": 12}
|
|
82
|
+
]
|
|
83
|
+
输出: [
|
|
84
|
+
{"start_line": 1, "end_line": 8},
|
|
85
|
+
{"start_line": 10, "end_line": 12}
|
|
86
|
+
]
|
|
87
|
+
"""
|
|
88
|
+
if not snippets:
|
|
89
|
+
return []
|
|
90
|
+
|
|
91
|
+
# 按起始行排序
|
|
92
|
+
sorted_snippets = sorted(snippets, key=lambda x: x["start_line"])
|
|
93
|
+
|
|
94
|
+
merged = [sorted_snippets[0]]
|
|
95
|
+
|
|
96
|
+
for current in sorted_snippets[1:]:
|
|
97
|
+
last = merged[-1]
|
|
98
|
+
|
|
99
|
+
# 判断是否需要合并:
|
|
100
|
+
# 1. 如果当前片段的起始行小于等于上一个片段的结束行+1
|
|
101
|
+
# 2. +1是为了合并相邻的片段,比如1-5和6-8应该合并为1-8
|
|
102
|
+
if current["start_line"] <= last["end_line"] + 1:
|
|
103
|
+
# 合并区间:
|
|
104
|
+
# - 起始行取两者最小值
|
|
105
|
+
# - 结束行取两者最大值
|
|
106
|
+
merged[-1] = {
|
|
107
|
+
"start_line": min(last["start_line"], current["start_line"]),
|
|
108
|
+
"end_line": max(last["end_line"], current["end_line"])
|
|
109
|
+
}
|
|
110
|
+
else:
|
|
111
|
+
# 如果不重叠且不相邻,则作为新片段添加
|
|
112
|
+
merged.append(current)
|
|
113
|
+
|
|
114
|
+
return merged
|
|
115
|
+
|
|
116
|
+
|
|
22
117
|
def _delete_overflow_files(self, file_paths: List[str]) -> List[SourceCode]:
|
|
23
118
|
"""直接删除超出 token 限制的文件"""
|
|
24
119
|
total_tokens = 0
|
|
@@ -40,6 +135,8 @@ class PruneContext:
|
|
|
40
135
|
selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=token_count))
|
|
41
136
|
|
|
42
137
|
return selected_files
|
|
138
|
+
|
|
139
|
+
|
|
43
140
|
|
|
44
141
|
def _extract_code_snippets(self, file_paths: List[str], conversations: List[Dict[str, str]]) -> List[SourceCode]:
|
|
45
142
|
"""抽取关键代码片段策略"""
|
|
@@ -48,7 +145,7 @@ class PruneContext:
|
|
|
48
145
|
full_file_tokens = int(self.max_tokens * 0.8)
|
|
49
146
|
|
|
50
147
|
@byzerllm.prompt()
|
|
51
|
-
def extract_code_snippets(conversations: List[Dict[str, str]], content: str) -> str:
|
|
148
|
+
def extract_code_snippets(conversations: List[Dict[str, str]], content: str, is_partial_content: bool = False) -> str:
|
|
52
149
|
"""
|
|
53
150
|
根据提供的代码文件和对话历史提取相关代码片段。
|
|
54
151
|
|
|
@@ -111,6 +208,13 @@ class PruneContext:
|
|
|
111
208
|
{{ content }}
|
|
112
209
|
</code_file>
|
|
113
210
|
|
|
211
|
+
<% if is_partial_content: %>
|
|
212
|
+
<partial_content_process_note>
|
|
213
|
+
当前处理的是文件的局部内容(行号{start_line}-{end_line}),
|
|
214
|
+
请仅基于当前可见内容判断相关性,返回标注的行号区间。
|
|
215
|
+
</partial_content_process_note>
|
|
216
|
+
<% endif %>
|
|
217
|
+
|
|
114
218
|
2. 对话历史:
|
|
115
219
|
<conversation_history>
|
|
116
220
|
{% for msg in conversations %}
|
|
@@ -131,15 +235,17 @@ class PruneContext:
|
|
|
131
235
|
4. 如果没有相关代码段,返回空数组[]。
|
|
132
236
|
|
|
133
237
|
输出格式:
|
|
134
|
-
严格的JSON数组,不包含其他文字或解释。
|
|
135
|
-
|
|
238
|
+
严格的JSON数组,不包含其他文字或解释。
|
|
239
|
+
|
|
136
240
|
```json
|
|
137
241
|
[
|
|
138
242
|
{"start_line": 第一个代码段的起始行号, "end_line": 第一个代码段的结束行号},
|
|
139
243
|
{"start_line": 第二个代码段的起始行号, "end_line": 第二个代码段的结束行号}
|
|
140
244
|
]
|
|
141
|
-
```
|
|
245
|
+
```
|
|
246
|
+
|
|
142
247
|
"""
|
|
248
|
+
|
|
143
249
|
|
|
144
250
|
for file_path in file_paths:
|
|
145
251
|
try:
|
|
@@ -152,22 +258,76 @@ class PruneContext:
|
|
|
152
258
|
selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=tokens))
|
|
153
259
|
token_count += tokens
|
|
154
260
|
continue
|
|
155
|
-
|
|
261
|
+
|
|
262
|
+
## 如果单个文件太大,那么先按滑动窗口分割,然后对窗口抽取代码片段
|
|
263
|
+
if tokens > self.max_tokens:
|
|
264
|
+
self.printer.print_in_terminal("file_sliding_window_processing", file_path=file_path, tokens=tokens)
|
|
265
|
+
chunks = self._split_content_with_sliding_window(content,
|
|
266
|
+
self.args.context_prune_sliding_window_size,
|
|
267
|
+
self.args.context_prune_sliding_window_overlap)
|
|
268
|
+
all_snippets = []
|
|
269
|
+
for chunk_start, chunk_end, chunk_content in chunks:
|
|
270
|
+
extracted = extract_code_snippets.with_llm(self.llm).run(
|
|
271
|
+
conversations=conversations,
|
|
272
|
+
content=chunk_content,
|
|
273
|
+
is_partial_content=True
|
|
274
|
+
)
|
|
275
|
+
if extracted:
|
|
276
|
+
json_str = extract_code(extracted)[0][1]
|
|
277
|
+
snippets = json.loads(json_str)
|
|
278
|
+
|
|
279
|
+
# 获取到的本来就是在原始文件里的绝对行号
|
|
280
|
+
# 后续在构建代码片段内容时,会为了适配数组操作修改行号,这里无需处理
|
|
281
|
+
adjusted_snippets = [{
|
|
282
|
+
"start_line": snippet["start_line"],
|
|
283
|
+
"end_line": snippet["end_line"]
|
|
284
|
+
} for snippet in snippets]
|
|
285
|
+
all_snippets.extend(adjusted_snippets)
|
|
286
|
+
merged_snippets = self._merge_overlapping_snippets(all_snippets)
|
|
287
|
+
content_snippets = self._build_snippet_content(file_path, content, merged_snippets)
|
|
288
|
+
snippet_tokens = count_tokens(content_snippets)
|
|
289
|
+
if token_count + snippet_tokens <= self.max_tokens:
|
|
290
|
+
selected_files.append(SourceCode(module_name=file_path,source_code=content_snippets,tokens=snippet_tokens))
|
|
291
|
+
token_count += snippet_tokens
|
|
292
|
+
self.printer.print_in_terminal("file_snippet_procesed", file_path=file_path,
|
|
293
|
+
total_tokens=token_count,
|
|
294
|
+
tokens=tokens,
|
|
295
|
+
snippet_tokens=snippet_tokens)
|
|
296
|
+
continue
|
|
297
|
+
else:
|
|
298
|
+
break
|
|
299
|
+
|
|
156
300
|
# 抽取关键片段
|
|
301
|
+
lines = content.splitlines()
|
|
302
|
+
new_content = ""
|
|
303
|
+
|
|
304
|
+
## 将文件内容按行编号
|
|
305
|
+
for index,line in enumerate(lines):
|
|
306
|
+
new_content += f"{index+1} {line}\n"
|
|
307
|
+
|
|
308
|
+
## 抽取代码片段
|
|
309
|
+
self.printer.print_in_terminal("file_snippet_processing", file_path=file_path)
|
|
157
310
|
extracted = extract_code_snippets.with_llm(self.llm).run(
|
|
158
311
|
conversations=conversations,
|
|
159
|
-
content=
|
|
312
|
+
content=new_content
|
|
160
313
|
)
|
|
161
314
|
|
|
315
|
+
## 构建代码片段内容
|
|
162
316
|
if extracted:
|
|
163
317
|
json_str = extract_code(extracted)[0][1]
|
|
164
318
|
snippets = json.loads(json_str)
|
|
165
|
-
|
|
319
|
+
content_snippets = self._build_snippet_content(file_path, content, snippets)
|
|
166
320
|
|
|
167
|
-
snippet_tokens = count_tokens(
|
|
168
|
-
if token_count + snippet_tokens <= self.max_tokens:
|
|
169
|
-
selected_files.append(SourceCode(module_name=file_path,
|
|
321
|
+
snippet_tokens = count_tokens(content_snippets)
|
|
322
|
+
if token_count + snippet_tokens <= self.max_tokens:
|
|
323
|
+
selected_files.append(SourceCode(module_name=file_path,
|
|
324
|
+
source_code=content_snippets,
|
|
325
|
+
tokens=snippet_tokens))
|
|
170
326
|
token_count += snippet_tokens
|
|
327
|
+
self.printer.print_in_terminal("file_snippet_procesed", file_path=file_path,
|
|
328
|
+
total_tokens = token_count,
|
|
329
|
+
tokens=tokens,
|
|
330
|
+
snippet_tokens=snippet_tokens)
|
|
171
331
|
else:
|
|
172
332
|
break
|
|
173
333
|
except Exception as e:
|
|
@@ -175,10 +335,32 @@ class PruneContext:
|
|
|
175
335
|
continue
|
|
176
336
|
|
|
177
337
|
return selected_files
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _merge_overlapping_snippets(self, snippets: List[dict]) -> List[dict]:
|
|
341
|
+
if not snippets:
|
|
342
|
+
return []
|
|
343
|
+
|
|
344
|
+
# 按起始行排序
|
|
345
|
+
sorted_snippets = sorted(snippets, key=lambda x: x["start_line"])
|
|
346
|
+
|
|
347
|
+
merged = [sorted_snippets[0]]
|
|
348
|
+
for current in sorted_snippets[1:]:
|
|
349
|
+
last = merged[-1]
|
|
350
|
+
if current["start_line"] <= last["end_line"] + 1: # 允许1行间隔
|
|
351
|
+
# 合并区间
|
|
352
|
+
merged[-1] = {
|
|
353
|
+
"start_line": min(last["start_line"], current["start_line"]),
|
|
354
|
+
"end_line": max(last["end_line"], current["end_line"])
|
|
355
|
+
}
|
|
356
|
+
else:
|
|
357
|
+
merged.append(current)
|
|
358
|
+
|
|
359
|
+
return merged
|
|
178
360
|
|
|
179
361
|
def _build_snippet_content(self, file_path: str, full_content: str, snippets: List[dict]) -> str:
|
|
180
362
|
"""构建包含代码片段的文件内容"""
|
|
181
|
-
lines = full_content.
|
|
363
|
+
lines = full_content.splitlines()
|
|
182
364
|
header = f"Snippets:\n"
|
|
183
365
|
|
|
184
366
|
content = []
|
|
@@ -205,7 +387,26 @@ class PruneContext:
|
|
|
205
387
|
total_tokens,sources = self._count_tokens(file_paths)
|
|
206
388
|
if total_tokens <= self.max_tokens:
|
|
207
389
|
return sources
|
|
208
|
-
|
|
390
|
+
|
|
391
|
+
self.printer.print_in_terminal(
|
|
392
|
+
"context_pruning_reason",
|
|
393
|
+
total_tokens=total_tokens,
|
|
394
|
+
max_tokens=self.max_tokens,
|
|
395
|
+
style="yellow"
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
self.printer.print_in_terminal(
|
|
399
|
+
"sorted_files_message",
|
|
400
|
+
files=file_paths
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
self.printer.print_in_terminal(
|
|
404
|
+
"context_pruning_start",
|
|
405
|
+
total_tokens=total_tokens,
|
|
406
|
+
max_tokens=self.max_tokens,
|
|
407
|
+
strategy=strategy
|
|
408
|
+
)
|
|
409
|
+
|
|
209
410
|
if strategy == "score":
|
|
210
411
|
return self._score_and_filter_files(file_paths, conversations)
|
|
211
412
|
if strategy == "delete":
|
|
@@ -214,7 +415,7 @@ class PruneContext:
|
|
|
214
415
|
return self._extract_code_snippets(file_paths, conversations)
|
|
215
416
|
else:
|
|
216
417
|
raise ValueError(f"无效策略: {strategy}. 可选值: delete/extract/score")
|
|
217
|
-
|
|
418
|
+
|
|
218
419
|
def _count_tokens(self, file_paths: List[str]) -> int:
|
|
219
420
|
"""计算文件总token数"""
|
|
220
421
|
total_tokens = 0
|
|
@@ -312,4 +513,3 @@ class PruneContext:
|
|
|
312
513
|
break
|
|
313
514
|
|
|
314
515
|
return selected_files
|
|
315
|
-
|
|
@@ -3,8 +3,9 @@ import json
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
import byzerllm
|
|
5
5
|
from autocoder.common.printer import Printer
|
|
6
|
-
from autocoder.
|
|
6
|
+
from autocoder.rag.token_counter import count_tokens
|
|
7
7
|
from loguru import logger
|
|
8
|
+
from autocoder.common import AutoCoderArgs
|
|
8
9
|
|
|
9
10
|
class PruneStrategy(BaseModel):
|
|
10
11
|
name: str
|
|
@@ -12,25 +13,25 @@ class PruneStrategy(BaseModel):
|
|
|
12
13
|
config: Dict[str, Any] = {"safe_zone_tokens": 0, "group_size": 4}
|
|
13
14
|
|
|
14
15
|
class ConversationPruner:
|
|
15
|
-
def __init__(self, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]
|
|
16
|
-
|
|
16
|
+
def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
|
|
17
|
+
self.args = args
|
|
17
18
|
self.llm = llm
|
|
18
19
|
self.printer = Printer()
|
|
19
20
|
self.strategies = {
|
|
20
21
|
"summarize": PruneStrategy(
|
|
21
22
|
name="summarize",
|
|
22
23
|
description="对早期对话进行分组摘要,保留关键信息",
|
|
23
|
-
config={"safe_zone_tokens":
|
|
24
|
+
config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
|
|
24
25
|
),
|
|
25
26
|
"truncate": PruneStrategy(
|
|
26
27
|
name="truncate",
|
|
27
28
|
description="分组截断最早的部分对话",
|
|
28
|
-
config={"safe_zone_tokens":
|
|
29
|
+
config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
|
|
29
30
|
),
|
|
30
31
|
"hybrid": PruneStrategy(
|
|
31
32
|
name="hybrid",
|
|
32
33
|
description="先尝试分组摘要,如果仍超限则分组截断",
|
|
33
|
-
config={"safe_zone_tokens":
|
|
34
|
+
config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
|
|
34
35
|
)
|
|
35
36
|
}
|
|
36
37
|
|
|
@@ -57,7 +58,7 @@ class ConversationPruner:
|
|
|
57
58
|
if strategy.name == "summarize":
|
|
58
59
|
return self._summarize_prune(conversations, strategy.config)
|
|
59
60
|
elif strategy.name == "truncate":
|
|
60
|
-
return self._truncate_prune
|
|
61
|
+
return self._truncate_prune(conversations)
|
|
61
62
|
elif strategy.name == "hybrid":
|
|
62
63
|
pruned = self._summarize_prune(conversations, strategy.config)
|
|
63
64
|
if count_tokens(json.dumps(pruned, ensure_ascii=False)) > self.args.conversation_prune_safe_zone_tokens:
|
|
@@ -80,8 +81,8 @@ class ConversationPruner:
|
|
|
80
81
|
break
|
|
81
82
|
|
|
82
83
|
# 找到要处理的对话组
|
|
83
|
-
early_conversations = processed_conversations[
|
|
84
|
-
recent_conversations = processed_conversations[
|
|
84
|
+
early_conversations = processed_conversations[:group_size]
|
|
85
|
+
recent_conversations = processed_conversations[group_size:]
|
|
85
86
|
|
|
86
87
|
if not early_conversations:
|
|
87
88
|
break
|
|
@@ -90,7 +91,7 @@ class ConversationPruner:
|
|
|
90
91
|
group_summary = self._generate_summary.with_llm(self.llm).run(early_conversations[-group_size:])
|
|
91
92
|
|
|
92
93
|
# 更新对话历史
|
|
93
|
-
processed_conversations =
|
|
94
|
+
processed_conversations = [
|
|
94
95
|
{"role": "user", "content": f"历史对话摘要:\n{group_summary}"},
|
|
95
96
|
{"role": "assistant", "content": f"收到"}
|
|
96
97
|
] + recent_conversations
|
autocoder/index/entry.py
CHANGED
|
@@ -58,8 +58,12 @@ def build_index_and_filter_files(
|
|
|
58
58
|
return file_path.strip()[2:]
|
|
59
59
|
return file_path
|
|
60
60
|
|
|
61
|
+
# 文件路径 -> TargetFile
|
|
61
62
|
final_files: Dict[str, TargetFile] = {}
|
|
62
63
|
|
|
64
|
+
# 文件路径 -> 文件在文件列表中的位置(越前面表示越相关)
|
|
65
|
+
file_positions:Dict[str,int] = {}
|
|
66
|
+
|
|
63
67
|
# Phase 1: Process REST/RAG/Search sources
|
|
64
68
|
printer = Printer()
|
|
65
69
|
printer.print_in_terminal("phase1_processing_sources")
|
|
@@ -102,25 +106,20 @@ def build_index_and_filter_files(
|
|
|
102
106
|
})
|
|
103
107
|
)
|
|
104
108
|
)
|
|
105
|
-
|
|
109
|
+
|
|
110
|
+
|
|
106
111
|
if not args.skip_filter_index and args.index_filter_model:
|
|
107
112
|
model_name = getattr(index_manager.index_filter_llm, 'default_model_name', None)
|
|
108
113
|
if not model_name:
|
|
109
114
|
model_name = "unknown(without default model name)"
|
|
110
115
|
printer.print_in_terminal("quick_filter_start", style="blue", model_name=model_name)
|
|
111
116
|
quick_filter = QuickFilter(index_manager,stats,sources)
|
|
112
|
-
quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
|
|
119
|
-
pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
120
|
-
for source_file in pruned_files:
|
|
121
|
-
final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
|
|
122
|
-
else:
|
|
123
|
-
final_files.update(quick_filter_result.files)
|
|
117
|
+
quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
|
|
118
|
+
|
|
119
|
+
final_files.update(quick_filter_result.files)
|
|
120
|
+
|
|
121
|
+
if quick_filter_result.file_positions:
|
|
122
|
+
file_positions.update(quick_filter_result.file_positions)
|
|
124
123
|
|
|
125
124
|
if not args.skip_filter_index and not args.index_filter_model:
|
|
126
125
|
model_name = getattr(index_manager.llm, 'default_model_name', None)
|
|
@@ -261,32 +260,55 @@ def build_index_and_filter_files(
|
|
|
261
260
|
for file in final_filenames:
|
|
262
261
|
print(f"{file} - {final_files[file].reason}")
|
|
263
262
|
|
|
264
|
-
source_code = ""
|
|
263
|
+
# source_code = ""
|
|
265
264
|
source_code_list = SourceCodeList(sources=[])
|
|
266
265
|
depulicated_sources = set()
|
|
267
|
-
|
|
266
|
+
|
|
267
|
+
## 先去重
|
|
268
|
+
temp_sources = []
|
|
268
269
|
for file in sources:
|
|
269
270
|
if file.module_name in final_filenames:
|
|
270
271
|
if file.module_name in depulicated_sources:
|
|
271
272
|
continue
|
|
272
273
|
depulicated_sources.add(file.module_name)
|
|
273
|
-
source_code += f"##File: {file.module_name}\n"
|
|
274
|
-
source_code += f"{file.source_code}\n\n"
|
|
275
|
-
|
|
274
|
+
# source_code += f"##File: {file.module_name}\n"
|
|
275
|
+
# source_code += f"{file.source_code}\n\n"
|
|
276
|
+
temp_sources.append(file)
|
|
277
|
+
|
|
278
|
+
## 开启了裁剪,则需要做裁剪,不过目前只针对 quick filter 生效
|
|
279
|
+
if args.context_prune:
|
|
280
|
+
context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
|
|
281
|
+
# 如果 file_positions 不为空,则通过 file_positions 来获取文件
|
|
282
|
+
if file_positions:
|
|
283
|
+
## 拿到位置列表,然后根据位置排序 得到 [(pos,file_path)]
|
|
284
|
+
## 将 [(pos,file_path)] 转换为 [file_path]
|
|
285
|
+
## 通过 [file_path] 顺序调整 temp_sources 的顺序
|
|
286
|
+
## MARK
|
|
287
|
+
# 将 file_positions 转换为 [(pos, file_path)] 的列表
|
|
288
|
+
position_file_pairs = [(pos, file_path) for file_path, pos in file_positions.items()]
|
|
289
|
+
# 按位置排序
|
|
290
|
+
position_file_pairs.sort(key=lambda x: x[0])
|
|
291
|
+
# 提取排序后的文件路径列表
|
|
292
|
+
sorted_file_paths = [file_path for _, file_path in position_file_pairs]
|
|
293
|
+
# 根据 sorted_file_paths 重新排序 temp_sources
|
|
294
|
+
temp_sources.sort(key=lambda x: sorted_file_paths.index(x.module_name) if x.module_name in sorted_file_paths else len(sorted_file_paths))
|
|
295
|
+
|
|
296
|
+
pruned_files = context_pruner.handle_overflow([source.module_name for source in temp_sources], [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
297
|
+
source_code_list.sources = pruned_files
|
|
298
|
+
|
|
299
|
+
|
|
276
300
|
if args.request_id and not args.skip_events:
|
|
277
301
|
queue_communicate.send_event(
|
|
278
302
|
request_id=args.request_id,
|
|
279
303
|
event=CommunicateEvent(
|
|
280
304
|
event_type=CommunicateEventType.CODE_INDEX_FILTER_FILE_SELECTED.value,
|
|
281
305
|
data=json.dumps([
|
|
282
|
-
(file
|
|
283
|
-
for file in final_files.values()
|
|
284
|
-
if file.file_path in depulicated_sources
|
|
306
|
+
(file.module_name, "") for file in source_code_list.sources
|
|
285
307
|
])
|
|
286
308
|
)
|
|
287
309
|
)
|
|
288
310
|
|
|
289
|
-
stats["final_files"] = len(
|
|
311
|
+
stats["final_files"] = len(source_code_list.sources)
|
|
290
312
|
phase_end = time.monotonic()
|
|
291
313
|
stats["timings"]["prepare_output"] = phase_end - phase_start
|
|
292
314
|
|
autocoder/index/index.py
CHANGED
|
@@ -400,7 +400,7 @@ class IndexManager:
|
|
|
400
400
|
|
|
401
401
|
# 删除被排除的文件
|
|
402
402
|
try:
|
|
403
|
-
exclude_patterns = self.parse_exclude_files(self.args.exclude_files)
|
|
403
|
+
exclude_patterns = self.parse_exclude_files(self.args.exclude_files)
|
|
404
404
|
for file_path in index_data:
|
|
405
405
|
if self.filter_exclude_files(file_path, exclude_patterns):
|
|
406
406
|
keys_to_remove.append(file_path)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Dict, List, Set, Tuple
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from loguru import logger
|
|
7
|
+
import byzerllm
|
|
8
|
+
from autocoder.common import AutoCoderArgs
|
|
9
|
+
from autocoder.common.printer import Printer
|
|
10
|
+
from typing import Union
|
|
11
|
+
import pydantic
|
|
12
|
+
from autocoder.common.result_manager import ResultManager
|
|
13
|
+
|
|
14
|
+
class ExtensionClassifyResult(pydantic.BaseModel):
|
|
15
|
+
code: List[str] = []
|
|
16
|
+
config: List[str] = []
|
|
17
|
+
data: List[str] = []
|
|
18
|
+
document: List[str] = []
|
|
19
|
+
other: List[str] = []
|
|
20
|
+
framework: List[str] = []
|
|
21
|
+
|
|
22
|
+
class ProjectTypeAnalyzer:
|
|
23
|
+
def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
|
|
24
|
+
self.args = args
|
|
25
|
+
self.llm = llm
|
|
26
|
+
self.printer = Printer()
|
|
27
|
+
self.default_exclude_dirs = [
|
|
28
|
+
".git", ".svn", ".hg", "build", "dist", "__pycache__",
|
|
29
|
+
"node_modules", ".auto-coder", ".vscode", ".idea", "venv",
|
|
30
|
+
".next", ".nuxt", ".svelte-kit", "out", "cache", "logs",
|
|
31
|
+
"temp", "tmp", "coverage", ".DS_Store", "public", "static"
|
|
32
|
+
]
|
|
33
|
+
self.extension_counts = defaultdict(int)
|
|
34
|
+
self.stats_file = Path(args.source_dir) / ".auto-coder" / "project_type_stats.json"
|
|
35
|
+
self.result_manager = ResultManager()
|
|
36
|
+
|
|
37
|
+
def traverse_project(self) -> None:
|
|
38
|
+
"""遍历项目目录,统计文件后缀"""
|
|
39
|
+
for root, dirs, files in os.walk(self.args.source_dir):
|
|
40
|
+
# 过滤掉默认排除的目录
|
|
41
|
+
dirs[:] = [d for d in dirs if d not in self.default_exclude_dirs]
|
|
42
|
+
|
|
43
|
+
for file in files:
|
|
44
|
+
_, ext = os.path.splitext(file)
|
|
45
|
+
if ext: # 只统计有后缀的文件
|
|
46
|
+
self.extension_counts[ext.lower()] += 1
|
|
47
|
+
|
|
48
|
+
def count_extensions(self) -> Dict[str, int]:
|
|
49
|
+
"""返回文件后缀统计结果"""
|
|
50
|
+
return dict(sorted(self.extension_counts.items(), key=lambda x: x[1], reverse=True))
|
|
51
|
+
|
|
52
|
+
@byzerllm.prompt()
|
|
53
|
+
def classify_extensions(self, extensions: str) -> str:
|
|
54
|
+
"""
|
|
55
|
+
根据文件后缀列表,将后缀分类为代码、配置、数据、文档等类型。
|
|
56
|
+
|
|
57
|
+
文件后缀列表:
|
|
58
|
+
{{ extensions }}
|
|
59
|
+
|
|
60
|
+
请返回如下JSON格式:
|
|
61
|
+
{
|
|
62
|
+
"code": ["后缀1", "后缀2"],
|
|
63
|
+
"config": ["后缀3", "后缀4"],
|
|
64
|
+
"data": ["后缀5", "后缀6"],
|
|
65
|
+
"document": ["后缀7", "后缀8"],
|
|
66
|
+
"other": ["后缀9", "后缀10"],
|
|
67
|
+
"framework": ["后缀11", "后缀12"]
|
|
68
|
+
}
|
|
69
|
+
"""
|
|
70
|
+
return {
|
|
71
|
+
"extensions": extensions
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
def save_stats(self) -> None:
|
|
75
|
+
"""保存统计结果到文件"""
|
|
76
|
+
stats = {
|
|
77
|
+
"extension_counts": self.extension_counts,
|
|
78
|
+
"project_type": self.detect_project_type()
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# 确保目录存在
|
|
82
|
+
self.stats_file.parent.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
with open(self.stats_file, "w", encoding="utf-8") as f:
|
|
85
|
+
json.dump(stats, f, indent=2)
|
|
86
|
+
|
|
87
|
+
self.printer.print_in_terminal("stats_saved", path=str(self.stats_file))
|
|
88
|
+
|
|
89
|
+
def load_stats(self) -> Dict[str, any]:
|
|
90
|
+
"""从文件加载统计结果"""
|
|
91
|
+
if not self.stats_file.exists():
|
|
92
|
+
self.printer.print_in_terminal("stats_not_found", path=str(self.stats_file))
|
|
93
|
+
return {}
|
|
94
|
+
|
|
95
|
+
with open(self.stats_file, "r", encoding="utf-8") as f:
|
|
96
|
+
return json.load(f)
|
|
97
|
+
|
|
98
|
+
def detect_project_type(self) -> str:
|
|
99
|
+
"""根据后缀统计结果推断项目类型"""
|
|
100
|
+
# 获取统计结果
|
|
101
|
+
ext_counts = self.count_extensions()
|
|
102
|
+
# 将后缀分类
|
|
103
|
+
classification = self.classify_extensions.with_llm(self.llm).with_return_type(ExtensionClassifyResult).run(json.dumps(ext_counts,ensure_ascii=False))
|
|
104
|
+
return ",".join(classification.code)
|
|
105
|
+
|
|
106
|
+
def analyze(self) -> Dict[str, any]:
|
|
107
|
+
"""执行完整的项目类型分析流程"""
|
|
108
|
+
# 遍历项目目录
|
|
109
|
+
self.traverse_project()
|
|
110
|
+
|
|
111
|
+
# 检测项目类型
|
|
112
|
+
project_type = self.detect_project_type()
|
|
113
|
+
|
|
114
|
+
self.result_manager.add_result(content=project_type, meta={
|
|
115
|
+
"action": "get_project_type",
|
|
116
|
+
"input": {
|
|
117
|
+
|
|
118
|
+
}
|
|
119
|
+
})
|
|
120
|
+
return project_type
|