auto-coder 0.1.267__py3-none-any.whl → 0.1.269__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.267.dist-info → auto_coder-0.1.269.dist-info}/METADATA +2 -2
- {auto_coder-0.1.267.dist-info → auto_coder-0.1.269.dist-info}/RECORD +20 -18
- autocoder/auto_coder_runner.py +2635 -0
- autocoder/chat_auto_coder.py +54 -2630
- autocoder/commands/auto_command.py +34 -54
- autocoder/common/__init__.py +6 -2
- autocoder/common/auto_coder_lang.py +9 -1
- autocoder/common/auto_configure.py +41 -30
- autocoder/common/command_templates.py +2 -3
- autocoder/common/context_pruner.py +198 -15
- autocoder/common/conversation_pruner.py +11 -10
- autocoder/index/entry.py +42 -22
- autocoder/index/index.py +97 -38
- autocoder/utils/auto_project_type.py +120 -0
- autocoder/utils/model_provider_selector.py +23 -23
- autocoder/version.py +1 -1
- {auto_coder-0.1.267.dist-info → auto_coder-0.1.269.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.267.dist-info → auto_coder-0.1.269.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.267.dist-info → auto_coder-0.1.269.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.267.dist-info → auto_coder-0.1.269.dist-info}/top_level.txt +0 -0
|
@@ -3,8 +3,9 @@ import json
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
import byzerllm
|
|
5
5
|
from autocoder.common.printer import Printer
|
|
6
|
-
from autocoder.
|
|
6
|
+
from autocoder.rag.token_counter import count_tokens
|
|
7
7
|
from loguru import logger
|
|
8
|
+
from autocoder.common import AutoCoderArgs
|
|
8
9
|
|
|
9
10
|
class PruneStrategy(BaseModel):
|
|
10
11
|
name: str
|
|
@@ -12,25 +13,25 @@ class PruneStrategy(BaseModel):
|
|
|
12
13
|
config: Dict[str, Any] = {"safe_zone_tokens": 0, "group_size": 4}
|
|
13
14
|
|
|
14
15
|
class ConversationPruner:
|
|
15
|
-
def __init__(self, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]
|
|
16
|
-
|
|
16
|
+
def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
|
|
17
|
+
self.args = args
|
|
17
18
|
self.llm = llm
|
|
18
19
|
self.printer = Printer()
|
|
19
20
|
self.strategies = {
|
|
20
21
|
"summarize": PruneStrategy(
|
|
21
22
|
name="summarize",
|
|
22
23
|
description="对早期对话进行分组摘要,保留关键信息",
|
|
23
|
-
config={"safe_zone_tokens":
|
|
24
|
+
config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
|
|
24
25
|
),
|
|
25
26
|
"truncate": PruneStrategy(
|
|
26
27
|
name="truncate",
|
|
27
28
|
description="分组截断最早的部分对话",
|
|
28
|
-
config={"safe_zone_tokens":
|
|
29
|
+
config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
|
|
29
30
|
),
|
|
30
31
|
"hybrid": PruneStrategy(
|
|
31
32
|
name="hybrid",
|
|
32
33
|
description="先尝试分组摘要,如果仍超限则分组截断",
|
|
33
|
-
config={"safe_zone_tokens":
|
|
34
|
+
config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
|
|
34
35
|
)
|
|
35
36
|
}
|
|
36
37
|
|
|
@@ -57,7 +58,7 @@ class ConversationPruner:
|
|
|
57
58
|
if strategy.name == "summarize":
|
|
58
59
|
return self._summarize_prune(conversations, strategy.config)
|
|
59
60
|
elif strategy.name == "truncate":
|
|
60
|
-
return self._truncate_prune
|
|
61
|
+
return self._truncate_prune(conversations)
|
|
61
62
|
elif strategy.name == "hybrid":
|
|
62
63
|
pruned = self._summarize_prune(conversations, strategy.config)
|
|
63
64
|
if count_tokens(json.dumps(pruned, ensure_ascii=False)) > self.args.conversation_prune_safe_zone_tokens:
|
|
@@ -80,8 +81,8 @@ class ConversationPruner:
|
|
|
80
81
|
break
|
|
81
82
|
|
|
82
83
|
# 找到要处理的对话组
|
|
83
|
-
early_conversations = processed_conversations[
|
|
84
|
-
recent_conversations = processed_conversations[
|
|
84
|
+
early_conversations = processed_conversations[:group_size]
|
|
85
|
+
recent_conversations = processed_conversations[group_size:]
|
|
85
86
|
|
|
86
87
|
if not early_conversations:
|
|
87
88
|
break
|
|
@@ -90,7 +91,7 @@ class ConversationPruner:
|
|
|
90
91
|
group_summary = self._generate_summary.with_llm(self.llm).run(early_conversations[-group_size:])
|
|
91
92
|
|
|
92
93
|
# 更新对话历史
|
|
93
|
-
processed_conversations =
|
|
94
|
+
processed_conversations = [
|
|
94
95
|
{"role": "user", "content": f"历史对话摘要:\n{group_summary}"},
|
|
95
96
|
{"role": "assistant", "content": f"收到"}
|
|
96
97
|
] + recent_conversations
|
autocoder/index/entry.py
CHANGED
|
@@ -58,8 +58,12 @@ def build_index_and_filter_files(
|
|
|
58
58
|
return file_path.strip()[2:]
|
|
59
59
|
return file_path
|
|
60
60
|
|
|
61
|
+
# 文件路径 -> TargetFile
|
|
61
62
|
final_files: Dict[str, TargetFile] = {}
|
|
62
63
|
|
|
64
|
+
# 文件路径 -> 文件在文件列表中的位置(越前面表示越相关)
|
|
65
|
+
file_positions:Dict[str,int] = {}
|
|
66
|
+
|
|
63
67
|
# Phase 1: Process REST/RAG/Search sources
|
|
64
68
|
printer = Printer()
|
|
65
69
|
printer.print_in_terminal("phase1_processing_sources")
|
|
@@ -102,25 +106,20 @@ def build_index_and_filter_files(
|
|
|
102
106
|
})
|
|
103
107
|
)
|
|
104
108
|
)
|
|
105
|
-
|
|
109
|
+
|
|
110
|
+
|
|
106
111
|
if not args.skip_filter_index and args.index_filter_model:
|
|
107
112
|
model_name = getattr(index_manager.index_filter_llm, 'default_model_name', None)
|
|
108
113
|
if not model_name:
|
|
109
114
|
model_name = "unknown(without default model name)"
|
|
110
115
|
printer.print_in_terminal("quick_filter_start", style="blue", model_name=model_name)
|
|
111
116
|
quick_filter = QuickFilter(index_manager,stats,sources)
|
|
112
|
-
quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
|
|
119
|
-
pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
120
|
-
for source_file in pruned_files:
|
|
121
|
-
final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
|
|
122
|
-
else:
|
|
123
|
-
final_files.update(quick_filter_result.files)
|
|
117
|
+
quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
|
|
118
|
+
|
|
119
|
+
final_files.update(quick_filter_result.files)
|
|
120
|
+
|
|
121
|
+
if quick_filter_result.file_positions:
|
|
122
|
+
file_positions.update(quick_filter_result.file_positions)
|
|
124
123
|
|
|
125
124
|
if not args.skip_filter_index and not args.index_filter_model:
|
|
126
125
|
model_name = getattr(index_manager.llm, 'default_model_name', None)
|
|
@@ -261,32 +260,53 @@ def build_index_and_filter_files(
|
|
|
261
260
|
for file in final_filenames:
|
|
262
261
|
print(f"{file} - {final_files[file].reason}")
|
|
263
262
|
|
|
264
|
-
source_code = ""
|
|
263
|
+
# source_code = ""
|
|
265
264
|
source_code_list = SourceCodeList(sources=[])
|
|
266
265
|
depulicated_sources = set()
|
|
267
|
-
|
|
266
|
+
|
|
267
|
+
## 先去重
|
|
268
|
+
temp_sources = []
|
|
268
269
|
for file in sources:
|
|
269
270
|
if file.module_name in final_filenames:
|
|
270
271
|
if file.module_name in depulicated_sources:
|
|
271
272
|
continue
|
|
272
273
|
depulicated_sources.add(file.module_name)
|
|
273
|
-
source_code += f"##File: {file.module_name}\n"
|
|
274
|
-
source_code += f"{file.source_code}\n\n"
|
|
275
|
-
|
|
274
|
+
# source_code += f"##File: {file.module_name}\n"
|
|
275
|
+
# source_code += f"{file.source_code}\n\n"
|
|
276
|
+
temp_sources.append(file)
|
|
277
|
+
|
|
278
|
+
## 开启了裁剪,则需要做裁剪,不过目前只针对 quick filter 生效
|
|
279
|
+
if args.context_prune:
|
|
280
|
+
context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
|
|
281
|
+
# 如果 file_positions 不为空,则通过 file_positions 来获取文件
|
|
282
|
+
if file_positions:
|
|
283
|
+
## 拿到位置列表,然后根据位置排序 得到 [(pos,file_path)]
|
|
284
|
+
## 将 [(pos,file_path)] 转换为 [file_path]
|
|
285
|
+
## 通过 [file_path] 顺序调整 temp_sources 的顺序
|
|
286
|
+
## MARK
|
|
287
|
+
# 将 file_positions 转换为 [(pos, file_path)] 的列表
|
|
288
|
+
position_file_pairs = [(pos, file_path) for file_path, pos in file_positions.items()]
|
|
289
|
+
# 按位置排序
|
|
290
|
+
position_file_pairs.sort(key=lambda x: x[0])
|
|
291
|
+
# 提取排序后的文件路径列表
|
|
292
|
+
sorted_file_paths = [file_path for _, file_path in position_file_pairs]
|
|
293
|
+
# 根据 sorted_file_paths 重新排序 temp_sources
|
|
294
|
+
temp_sources.sort(key=lambda x: sorted_file_paths.index(x.module_name) if x.module_name in sorted_file_paths else len(sorted_file_paths))
|
|
295
|
+
pruned_files = context_pruner.handle_overflow([source.module_name for source in temp_sources], [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
296
|
+
source_code_list.sources = pruned_files
|
|
297
|
+
|
|
276
298
|
if args.request_id and not args.skip_events:
|
|
277
299
|
queue_communicate.send_event(
|
|
278
300
|
request_id=args.request_id,
|
|
279
301
|
event=CommunicateEvent(
|
|
280
302
|
event_type=CommunicateEventType.CODE_INDEX_FILTER_FILE_SELECTED.value,
|
|
281
303
|
data=json.dumps([
|
|
282
|
-
(file
|
|
283
|
-
for file in final_files.values()
|
|
284
|
-
if file.file_path in depulicated_sources
|
|
304
|
+
(file.module_name, "") for file in source_code_list.sources
|
|
285
305
|
])
|
|
286
306
|
)
|
|
287
307
|
)
|
|
288
308
|
|
|
289
|
-
stats["final_files"] = len(
|
|
309
|
+
stats["final_files"] = len(source_code_list.sources)
|
|
290
310
|
phase_end = time.monotonic()
|
|
291
311
|
stats["timings"]["prepare_output"] = phase_end - phase_start
|
|
292
312
|
|
autocoder/index/index.py
CHANGED
|
@@ -12,6 +12,7 @@ from autocoder.index.symbols_utils import (
|
|
|
12
12
|
from autocoder.privacy.model_filter import ModelPathFilter
|
|
13
13
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
14
14
|
import threading
|
|
15
|
+
import re
|
|
15
16
|
|
|
16
17
|
import byzerllm
|
|
17
18
|
import hashlib
|
|
@@ -27,6 +28,8 @@ from autocoder.index.types import (
|
|
|
27
28
|
from autocoder.common.global_cancel import global_cancel
|
|
28
29
|
from autocoder.utils.llms import get_llm_names
|
|
29
30
|
from autocoder.rag.token_counter import count_tokens
|
|
31
|
+
|
|
32
|
+
|
|
30
33
|
class IndexManager:
|
|
31
34
|
def __init__(
|
|
32
35
|
self, llm: byzerllm.ByzerLLM, sources: List[SourceCode], args: AutoCoderArgs
|
|
@@ -52,12 +55,14 @@ class IndexManager:
|
|
|
52
55
|
self.index_filter_llm = llm
|
|
53
56
|
|
|
54
57
|
self.llm = llm
|
|
55
|
-
|
|
58
|
+
|
|
56
59
|
# Initialize model filters
|
|
57
60
|
if self.index_llm:
|
|
58
|
-
self.index_model_filter = ModelPathFilter.from_model_object(
|
|
61
|
+
self.index_model_filter = ModelPathFilter.from_model_object(
|
|
62
|
+
self.index_llm, args)
|
|
59
63
|
if self.index_filter_llm:
|
|
60
|
-
self.index_filter_model_filter = ModelPathFilter.from_model_object(
|
|
64
|
+
self.index_filter_model_filter = ModelPathFilter.from_model_object(
|
|
65
|
+
self.index_filter_llm, args)
|
|
61
66
|
self.args = args
|
|
62
67
|
self.max_input_length = (
|
|
63
68
|
args.index_model_max_input_length or args.model_max_input_length
|
|
@@ -68,7 +73,6 @@ class IndexManager:
|
|
|
68
73
|
if not os.path.exists(self.index_dir):
|
|
69
74
|
os.makedirs(self.index_dir)
|
|
70
75
|
|
|
71
|
-
|
|
72
76
|
@byzerllm.prompt()
|
|
73
77
|
def verify_file_relevance(self, file_content: str, query: str) -> str:
|
|
74
78
|
"""
|
|
@@ -201,12 +205,12 @@ class IndexManager:
|
|
|
201
205
|
if current_chunk:
|
|
202
206
|
chunks.append("\n".join(current_chunk))
|
|
203
207
|
return chunks
|
|
204
|
-
|
|
208
|
+
|
|
205
209
|
def should_skip(self, file_path: str):
|
|
206
210
|
ext = os.path.splitext(file_path)[1].lower()
|
|
207
211
|
if ext in [".md", ".html", ".txt", ".doc", ".pdf"]:
|
|
208
212
|
return True
|
|
209
|
-
|
|
213
|
+
|
|
210
214
|
# Check model filter restrictions
|
|
211
215
|
if self.index_model_filter and not self.index_model_filter.is_accessible(file_path):
|
|
212
216
|
self.printer.print_in_terminal(
|
|
@@ -216,10 +220,10 @@ class IndexManager:
|
|
|
216
220
|
model_name=",".join(get_llm_names(self.index_llm))
|
|
217
221
|
)
|
|
218
222
|
return True
|
|
219
|
-
|
|
223
|
+
|
|
220
224
|
return False
|
|
221
225
|
|
|
222
|
-
def build_index_for_single_source(self, source: SourceCode):
|
|
226
|
+
def build_index_for_single_source(self, source: SourceCode):
|
|
223
227
|
if global_cancel.requested:
|
|
224
228
|
return None
|
|
225
229
|
|
|
@@ -251,13 +255,13 @@ class IndexManager:
|
|
|
251
255
|
|
|
252
256
|
start_time = time.monotonic()
|
|
253
257
|
source_code = source.source_code
|
|
254
|
-
|
|
258
|
+
|
|
255
259
|
# 统计token和成本
|
|
256
260
|
total_input_tokens = 0
|
|
257
261
|
total_output_tokens = 0
|
|
258
262
|
total_input_cost = 0.0
|
|
259
263
|
total_output_cost = 0.0
|
|
260
|
-
|
|
264
|
+
|
|
261
265
|
if count_tokens(source.source_code) > self.args.conversation_prune_safe_zone_tokens:
|
|
262
266
|
self.printer.print_in_terminal(
|
|
263
267
|
"index_file_too_large",
|
|
@@ -276,34 +280,40 @@ class IndexManager:
|
|
|
276
280
|
self.index_llm).with_meta(meta_holder).run(source.module_name, chunk)
|
|
277
281
|
time.sleep(self.anti_quota_limit)
|
|
278
282
|
symbols.append(chunk_symbols)
|
|
279
|
-
|
|
283
|
+
|
|
280
284
|
if meta_holder.get_meta():
|
|
281
285
|
meta_dict = meta_holder.get_meta()
|
|
282
|
-
total_input_tokens += meta_dict.get(
|
|
283
|
-
|
|
284
|
-
|
|
286
|
+
total_input_tokens += meta_dict.get(
|
|
287
|
+
"input_tokens_count", 0)
|
|
288
|
+
total_output_tokens += meta_dict.get(
|
|
289
|
+
"generated_tokens_count", 0)
|
|
290
|
+
|
|
285
291
|
symbols = "\n".join(symbols)
|
|
286
292
|
else:
|
|
287
293
|
meta_holder = byzerllm.MetaHolder()
|
|
288
294
|
symbols = self.get_all_file_symbols.with_llm(
|
|
289
295
|
self.index_llm).with_meta(meta_holder).run(source.module_name, source_code)
|
|
290
296
|
time.sleep(self.anti_quota_limit)
|
|
291
|
-
|
|
297
|
+
|
|
292
298
|
if meta_holder.get_meta():
|
|
293
299
|
meta_dict = meta_holder.get_meta()
|
|
294
|
-
total_input_tokens += meta_dict.get(
|
|
295
|
-
|
|
296
|
-
|
|
300
|
+
total_input_tokens += meta_dict.get(
|
|
301
|
+
"input_tokens_count", 0)
|
|
302
|
+
total_output_tokens += meta_dict.get(
|
|
303
|
+
"generated_tokens_count", 0)
|
|
304
|
+
|
|
297
305
|
# 计算总成本
|
|
298
306
|
for name in model_names:
|
|
299
307
|
info = model_info_map.get(name, {})
|
|
300
|
-
total_input_cost += (total_input_tokens *
|
|
301
|
-
|
|
302
|
-
|
|
308
|
+
total_input_cost += (total_input_tokens *
|
|
309
|
+
info.get("input_price", 0.0)) / 1000000
|
|
310
|
+
total_output_cost += (total_output_tokens *
|
|
311
|
+
info.get("output_price", 0.0)) / 1000000
|
|
312
|
+
|
|
303
313
|
# 四舍五入到4位小数
|
|
304
314
|
total_input_cost = round(total_input_cost, 4)
|
|
305
315
|
total_output_cost = round(total_output_cost, 4)
|
|
306
|
-
|
|
316
|
+
|
|
307
317
|
self.printer.print_in_terminal(
|
|
308
318
|
"index_update_success",
|
|
309
319
|
style="green",
|
|
@@ -340,9 +350,44 @@ class IndexManager:
|
|
|
340
350
|
"generated_tokens_cost": total_output_cost
|
|
341
351
|
}
|
|
342
352
|
|
|
353
|
+
def parse_exclude_files(self, exclude_files):
|
|
354
|
+
if not exclude_files:
|
|
355
|
+
return []
|
|
356
|
+
|
|
357
|
+
if isinstance(exclude_files, str):
|
|
358
|
+
exclude_files = [exclude_files]
|
|
359
|
+
|
|
360
|
+
exclude_patterns = []
|
|
361
|
+
for pattern in exclude_files:
|
|
362
|
+
if pattern.startswith("regex://"):
|
|
363
|
+
pattern = pattern[8:]
|
|
364
|
+
exclude_patterns.append(re.compile(pattern))
|
|
365
|
+
elif pattern.startswith("human://"):
|
|
366
|
+
pattern = pattern[8:]
|
|
367
|
+
v = (
|
|
368
|
+
self.generate_regex_pattern.with_llm(self.llm)
|
|
369
|
+
.with_extractor(self.extract_regex_pattern)
|
|
370
|
+
.run(desc=pattern)
|
|
371
|
+
)
|
|
372
|
+
if not v:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
"Fail to generate regex pattern, try again.")
|
|
375
|
+
exclude_patterns.append(re.compile(v))
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
"Invalid exclude_files format. Expected 'regex://<pattern>' or 'human://<description>' "
|
|
379
|
+
)
|
|
380
|
+
return exclude_patterns
|
|
381
|
+
|
|
382
|
+
def filter_exclude_files(self, file_path, exclude_patterns):
|
|
383
|
+
for pattern in exclude_patterns:
|
|
384
|
+
if pattern.search(file_path):
|
|
385
|
+
return True
|
|
386
|
+
return False
|
|
387
|
+
|
|
343
388
|
def build_index(self):
|
|
344
389
|
if os.path.exists(self.index_file):
|
|
345
|
-
with open(self.index_file, "r",encoding="utf-8") as file:
|
|
390
|
+
with open(self.index_file, "r", encoding="utf-8") as file:
|
|
346
391
|
index_data = json.load(file)
|
|
347
392
|
else:
|
|
348
393
|
index_data = {}
|
|
@@ -351,14 +396,27 @@ class IndexManager:
|
|
|
351
396
|
keys_to_remove = []
|
|
352
397
|
for file_path in index_data:
|
|
353
398
|
if not os.path.exists(file_path):
|
|
354
|
-
keys_to_remove.append(file_path)
|
|
355
|
-
|
|
399
|
+
keys_to_remove.append(file_path)
|
|
400
|
+
|
|
401
|
+
# 删除被排除的文件
|
|
402
|
+
try:
|
|
403
|
+
exclude_patterns = self.parse_exclude_files(self.args.exclude_files)
|
|
404
|
+
for file_path in index_data:
|
|
405
|
+
if self.filter_exclude_files(file_path, exclude_patterns):
|
|
406
|
+
keys_to_remove.append(file_path)
|
|
407
|
+
except Exception as e:
|
|
408
|
+
self.printer.print_in_terminal(
|
|
409
|
+
"index_exclude_files_error",
|
|
410
|
+
style="red",
|
|
411
|
+
error=str(e)
|
|
412
|
+
)
|
|
413
|
+
|
|
356
414
|
# 删除无效条目并记录日志
|
|
357
415
|
for key in set(keys_to_remove):
|
|
358
416
|
if key in index_data:
|
|
359
417
|
del index_data[key]
|
|
360
418
|
self.printer.print_in_terminal(
|
|
361
|
-
"index_file_removed",
|
|
419
|
+
"index_file_removed",
|
|
362
420
|
style="yellow",
|
|
363
421
|
file_path=key
|
|
364
422
|
)
|
|
@@ -388,7 +446,7 @@ class IndexManager:
|
|
|
388
446
|
for line in v:
|
|
389
447
|
new_v.append(line[line.find(":"):])
|
|
390
448
|
source_code = "\n".join(new_v)
|
|
391
|
-
|
|
449
|
+
|
|
392
450
|
md5 = hashlib.md5(source_code.encode("utf-8")).hexdigest()
|
|
393
451
|
if (
|
|
394
452
|
source.module_name not in index_data
|
|
@@ -397,7 +455,8 @@ class IndexManager:
|
|
|
397
455
|
wait_to_build_files.append(source)
|
|
398
456
|
|
|
399
457
|
# Remove duplicates based on module_name
|
|
400
|
-
wait_to_build_files = list(
|
|
458
|
+
wait_to_build_files = list(
|
|
459
|
+
{source.module_name: source for source in wait_to_build_files}.values())
|
|
401
460
|
|
|
402
461
|
counter = 0
|
|
403
462
|
num_files = len(wait_to_build_files)
|
|
@@ -433,16 +492,17 @@ class IndexManager:
|
|
|
433
492
|
index_data[module_name] = result
|
|
434
493
|
updated_sources.append(module_name)
|
|
435
494
|
if len(updated_sources) > 5:
|
|
436
|
-
with open(self.index_file, "w",encoding="utf-8") as file:
|
|
437
|
-
json.dump(index_data, file,
|
|
495
|
+
with open(self.index_file, "w", encoding="utf-8") as file:
|
|
496
|
+
json.dump(index_data, file,
|
|
497
|
+
ensure_ascii=False, indent=2)
|
|
438
498
|
updated_sources = []
|
|
439
|
-
|
|
499
|
+
|
|
440
500
|
# 如果 updated_sources 或 keys_to_remove 有值,则保存索引文件
|
|
441
501
|
if updated_sources or keys_to_remove:
|
|
442
|
-
with open(self.index_file, "w",encoding="utf-8") as file:
|
|
502
|
+
with open(self.index_file, "w", encoding="utf-8") as file:
|
|
443
503
|
json.dump(index_data, file, ensure_ascii=False, indent=2)
|
|
444
504
|
|
|
445
|
-
print("")
|
|
505
|
+
print("")
|
|
446
506
|
self.printer.print_in_terminal(
|
|
447
507
|
"index_file_saved",
|
|
448
508
|
style="green",
|
|
@@ -461,14 +521,14 @@ class IndexManager:
|
|
|
461
521
|
if not os.path.exists(self.index_file):
|
|
462
522
|
return []
|
|
463
523
|
|
|
464
|
-
with open(self.index_file, "r",encoding="utf-8") as file:
|
|
524
|
+
with open(self.index_file, "r", encoding="utf-8") as file:
|
|
465
525
|
return file.read()
|
|
466
526
|
|
|
467
527
|
def read_index(self) -> List[IndexItem]:
|
|
468
528
|
if not os.path.exists(self.index_file):
|
|
469
529
|
return []
|
|
470
530
|
|
|
471
|
-
with open(self.index_file, "r",encoding="utf-8") as file:
|
|
531
|
+
with open(self.index_file, "r", encoding="utf-8") as file:
|
|
472
532
|
index_data = json.load(file)
|
|
473
533
|
|
|
474
534
|
index_items = []
|
|
@@ -572,7 +632,7 @@ class IndexManager:
|
|
|
572
632
|
{file.file_path: file for file in all_results}.values())
|
|
573
633
|
return FileList(file_list=all_results)
|
|
574
634
|
|
|
575
|
-
def _query_index_with_thread(self, query, func):
|
|
635
|
+
def _query_index_with_thread(self, query, func):
|
|
576
636
|
all_results = []
|
|
577
637
|
lock = threading.Lock()
|
|
578
638
|
completed_threads = 0
|
|
@@ -582,7 +642,7 @@ class IndexManager:
|
|
|
582
642
|
nonlocal completed_threads
|
|
583
643
|
result = self._get_target_files_by_query.with_llm(
|
|
584
644
|
self.llm).with_return_type(FileList).run(chunk, query)
|
|
585
|
-
|
|
645
|
+
|
|
586
646
|
if result is not None:
|
|
587
647
|
with lock:
|
|
588
648
|
all_results.extend(result.file_list)
|
|
@@ -708,4 +768,3 @@ class IndexManager:
|
|
|
708
768
|
|
|
709
769
|
请确保结果的准确性和完整性,包括所有可能相关的文件。
|
|
710
770
|
"""
|
|
711
|
-
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Dict, List, Set, Tuple
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from loguru import logger
|
|
7
|
+
import byzerllm
|
|
8
|
+
from autocoder.common import AutoCoderArgs
|
|
9
|
+
from autocoder.common.printer import Printer
|
|
10
|
+
from typing import Union
|
|
11
|
+
import pydantic
|
|
12
|
+
from autocoder.common.result_manager import ResultManager
|
|
13
|
+
|
|
14
|
+
class ExtensionClassifyResult(pydantic.BaseModel):
|
|
15
|
+
code: List[str] = []
|
|
16
|
+
config: List[str] = []
|
|
17
|
+
data: List[str] = []
|
|
18
|
+
document: List[str] = []
|
|
19
|
+
other: List[str] = []
|
|
20
|
+
framework: List[str] = []
|
|
21
|
+
|
|
22
|
+
class ProjectTypeAnalyzer:
|
|
23
|
+
def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
|
|
24
|
+
self.args = args
|
|
25
|
+
self.llm = llm
|
|
26
|
+
self.printer = Printer()
|
|
27
|
+
self.default_exclude_dirs = [
|
|
28
|
+
".git", ".svn", ".hg", "build", "dist", "__pycache__",
|
|
29
|
+
"node_modules", ".auto-coder", ".vscode", ".idea", "venv",
|
|
30
|
+
".next", ".nuxt", ".svelte-kit", "out", "cache", "logs",
|
|
31
|
+
"temp", "tmp", "coverage", ".DS_Store", "public", "static"
|
|
32
|
+
]
|
|
33
|
+
self.extension_counts = defaultdict(int)
|
|
34
|
+
self.stats_file = Path(args.source_dir) / ".auto-coder" / "project_type_stats.json"
|
|
35
|
+
self.result_manager = ResultManager()
|
|
36
|
+
|
|
37
|
+
def traverse_project(self) -> None:
|
|
38
|
+
"""遍历项目目录,统计文件后缀"""
|
|
39
|
+
for root, dirs, files in os.walk(self.args.source_dir):
|
|
40
|
+
# 过滤掉默认排除的目录
|
|
41
|
+
dirs[:] = [d for d in dirs if d not in self.default_exclude_dirs]
|
|
42
|
+
|
|
43
|
+
for file in files:
|
|
44
|
+
_, ext = os.path.splitext(file)
|
|
45
|
+
if ext: # 只统计有后缀的文件
|
|
46
|
+
self.extension_counts[ext.lower()] += 1
|
|
47
|
+
|
|
48
|
+
def count_extensions(self) -> Dict[str, int]:
|
|
49
|
+
"""返回文件后缀统计结果"""
|
|
50
|
+
return dict(sorted(self.extension_counts.items(), key=lambda x: x[1], reverse=True))
|
|
51
|
+
|
|
52
|
+
@byzerllm.prompt()
|
|
53
|
+
def classify_extensions(self, extensions: str) -> str:
|
|
54
|
+
"""
|
|
55
|
+
根据文件后缀列表,将后缀分类为代码、配置、数据、文档等类型。
|
|
56
|
+
|
|
57
|
+
文件后缀列表:
|
|
58
|
+
{{ extensions }}
|
|
59
|
+
|
|
60
|
+
请返回如下JSON格式:
|
|
61
|
+
{
|
|
62
|
+
"code": ["后缀1", "后缀2"],
|
|
63
|
+
"config": ["后缀3", "后缀4"],
|
|
64
|
+
"data": ["后缀5", "后缀6"],
|
|
65
|
+
"document": ["后缀7", "后缀8"],
|
|
66
|
+
"other": ["后缀9", "后缀10"],
|
|
67
|
+
"framework": ["后缀11", "后缀12"]
|
|
68
|
+
}
|
|
69
|
+
"""
|
|
70
|
+
return {
|
|
71
|
+
"extensions": extensions
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
def save_stats(self) -> None:
|
|
75
|
+
"""保存统计结果到文件"""
|
|
76
|
+
stats = {
|
|
77
|
+
"extension_counts": self.extension_counts,
|
|
78
|
+
"project_type": self.detect_project_type()
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# 确保目录存在
|
|
82
|
+
self.stats_file.parent.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
with open(self.stats_file, "w", encoding="utf-8") as f:
|
|
85
|
+
json.dump(stats, f, indent=2)
|
|
86
|
+
|
|
87
|
+
self.printer.print_in_terminal("stats_saved", path=str(self.stats_file))
|
|
88
|
+
|
|
89
|
+
def load_stats(self) -> Dict[str, any]:
|
|
90
|
+
"""从文件加载统计结果"""
|
|
91
|
+
if not self.stats_file.exists():
|
|
92
|
+
self.printer.print_in_terminal("stats_not_found", path=str(self.stats_file))
|
|
93
|
+
return {}
|
|
94
|
+
|
|
95
|
+
with open(self.stats_file, "r", encoding="utf-8") as f:
|
|
96
|
+
return json.load(f)
|
|
97
|
+
|
|
98
|
+
def detect_project_type(self) -> str:
|
|
99
|
+
"""根据后缀统计结果推断项目类型"""
|
|
100
|
+
# 获取统计结果
|
|
101
|
+
ext_counts = self.count_extensions()
|
|
102
|
+
# 将后缀分类
|
|
103
|
+
classification = self.classify_extensions.with_llm(self.llm).with_return_type(ExtensionClassifyResult).run(json.dumps(ext_counts,ensure_ascii=False))
|
|
104
|
+
return ",".join(classification.code)
|
|
105
|
+
|
|
106
|
+
def analyze(self) -> Dict[str, any]:
|
|
107
|
+
"""执行完整的项目类型分析流程"""
|
|
108
|
+
# 遍历项目目录
|
|
109
|
+
self.traverse_project()
|
|
110
|
+
|
|
111
|
+
# 检测项目类型
|
|
112
|
+
project_type = self.detect_project_type()
|
|
113
|
+
|
|
114
|
+
self.result_manager.add_result(content=project_type, meta={
|
|
115
|
+
"action": "get_project_type",
|
|
116
|
+
"input": {
|
|
117
|
+
|
|
118
|
+
}
|
|
119
|
+
})
|
|
120
|
+
return project_type
|
|
@@ -25,8 +25,8 @@ PROVIDER_INFO_LIST = [
|
|
|
25
25
|
ProviderInfo(
|
|
26
26
|
name="volcano",
|
|
27
27
|
endpoint="https://ark.cn-beijing.volces.com/api/v3",
|
|
28
|
-
r1_model="",
|
|
29
|
-
v3_model="",
|
|
28
|
+
r1_model="deepseek-r1-250120",
|
|
29
|
+
v3_model="deepseek-v3-241226",
|
|
30
30
|
api_key="",
|
|
31
31
|
r1_input_price=2.0,
|
|
32
32
|
r1_output_price=8.0,
|
|
@@ -162,32 +162,32 @@ class ModelProviderSelector:
|
|
|
162
162
|
provider_info = provider
|
|
163
163
|
break
|
|
164
164
|
|
|
165
|
-
if result == "volcano":
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
165
|
+
# if result == "volcano":
|
|
166
|
+
# # Get R1 endpoint
|
|
167
|
+
# r1_endpoint = input_dialog(
|
|
168
|
+
# title=self.printer.get_message_from_key("model_provider_api_key_title"),
|
|
169
|
+
# text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
|
|
170
|
+
# validator=VolcanoEndpointValidator(),
|
|
171
|
+
# style=dialog_style
|
|
172
|
+
# ).run()
|
|
173
173
|
|
|
174
|
-
|
|
175
|
-
|
|
174
|
+
# if r1_endpoint is None:
|
|
175
|
+
# return None
|
|
176
176
|
|
|
177
|
-
|
|
177
|
+
# provider_info.r1_model = r1_endpoint
|
|
178
178
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
179
|
+
# # Get V3 endpoint
|
|
180
|
+
# v3_endpoint = input_dialog(
|
|
181
|
+
# title=self.printer.get_message_from_key("model_provider_api_key_title"),
|
|
182
|
+
# text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
|
|
183
|
+
# validator=VolcanoEndpointValidator(),
|
|
184
|
+
# style=dialog_style
|
|
185
|
+
# ).run()
|
|
186
186
|
|
|
187
|
-
|
|
188
|
-
|
|
187
|
+
# if v3_endpoint is None:
|
|
188
|
+
# return None
|
|
189
189
|
|
|
190
|
-
|
|
190
|
+
# provider_info.v3_model = v3_endpoint
|
|
191
191
|
|
|
192
192
|
# Get API key for all providers
|
|
193
193
|
api_key = input_dialog(
|
autocoder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.269"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|