auto-coder 0.1.268__py3-none-any.whl → 0.1.269__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.269.dist-info}/METADATA +2 -2
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.269.dist-info}/RECORD +19 -17
- autocoder/auto_coder_runner.py +2635 -0
- autocoder/chat_auto_coder.py +54 -2630
- autocoder/commands/auto_command.py +23 -33
- autocoder/common/__init__.py +6 -2
- autocoder/common/auto_coder_lang.py +5 -1
- autocoder/common/auto_configure.py +41 -30
- autocoder/common/command_templates.py +2 -3
- autocoder/common/context_pruner.py +185 -12
- autocoder/common/conversation_pruner.py +11 -10
- autocoder/index/entry.py +42 -22
- autocoder/utils/auto_project_type.py +120 -0
- autocoder/utils/model_provider_selector.py +23 -23
- autocoder/version.py +1 -1
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.269.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.269.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.269.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.268.dist-info → auto_coder-0.1.269.dist-info}/top_level.txt +0 -0
autocoder/index/entry.py
CHANGED
|
@@ -58,8 +58,12 @@ def build_index_and_filter_files(
|
|
|
58
58
|
return file_path.strip()[2:]
|
|
59
59
|
return file_path
|
|
60
60
|
|
|
61
|
+
# 文件路径 -> TargetFile
|
|
61
62
|
final_files: Dict[str, TargetFile] = {}
|
|
62
63
|
|
|
64
|
+
# 文件路径 -> 文件在文件列表中的位置(越前面表示越相关)
|
|
65
|
+
file_positions:Dict[str,int] = {}
|
|
66
|
+
|
|
63
67
|
# Phase 1: Process REST/RAG/Search sources
|
|
64
68
|
printer = Printer()
|
|
65
69
|
printer.print_in_terminal("phase1_processing_sources")
|
|
@@ -102,25 +106,20 @@ def build_index_and_filter_files(
|
|
|
102
106
|
})
|
|
103
107
|
)
|
|
104
108
|
)
|
|
105
|
-
|
|
109
|
+
|
|
110
|
+
|
|
106
111
|
if not args.skip_filter_index and args.index_filter_model:
|
|
107
112
|
model_name = getattr(index_manager.index_filter_llm, 'default_model_name', None)
|
|
108
113
|
if not model_name:
|
|
109
114
|
model_name = "unknown(without default model name)"
|
|
110
115
|
printer.print_in_terminal("quick_filter_start", style="blue", model_name=model_name)
|
|
111
116
|
quick_filter = QuickFilter(index_manager,stats,sources)
|
|
112
|
-
quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
|
|
119
|
-
pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
120
|
-
for source_file in pruned_files:
|
|
121
|
-
final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
|
|
122
|
-
else:
|
|
123
|
-
final_files.update(quick_filter_result.files)
|
|
117
|
+
quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
|
|
118
|
+
|
|
119
|
+
final_files.update(quick_filter_result.files)
|
|
120
|
+
|
|
121
|
+
if quick_filter_result.file_positions:
|
|
122
|
+
file_positions.update(quick_filter_result.file_positions)
|
|
124
123
|
|
|
125
124
|
if not args.skip_filter_index and not args.index_filter_model:
|
|
126
125
|
model_name = getattr(index_manager.llm, 'default_model_name', None)
|
|
@@ -261,32 +260,53 @@ def build_index_and_filter_files(
|
|
|
261
260
|
for file in final_filenames:
|
|
262
261
|
print(f"{file} - {final_files[file].reason}")
|
|
263
262
|
|
|
264
|
-
source_code = ""
|
|
263
|
+
# source_code = ""
|
|
265
264
|
source_code_list = SourceCodeList(sources=[])
|
|
266
265
|
depulicated_sources = set()
|
|
267
|
-
|
|
266
|
+
|
|
267
|
+
## 先去重
|
|
268
|
+
temp_sources = []
|
|
268
269
|
for file in sources:
|
|
269
270
|
if file.module_name in final_filenames:
|
|
270
271
|
if file.module_name in depulicated_sources:
|
|
271
272
|
continue
|
|
272
273
|
depulicated_sources.add(file.module_name)
|
|
273
|
-
source_code += f"##File: {file.module_name}\n"
|
|
274
|
-
source_code += f"{file.source_code}\n\n"
|
|
275
|
-
|
|
274
|
+
# source_code += f"##File: {file.module_name}\n"
|
|
275
|
+
# source_code += f"{file.source_code}\n\n"
|
|
276
|
+
temp_sources.append(file)
|
|
277
|
+
|
|
278
|
+
## 开启了裁剪,则需要做裁剪,不过目前只针对 quick filter 生效
|
|
279
|
+
if args.context_prune:
|
|
280
|
+
context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
|
|
281
|
+
# 如果 file_positions 不为空,则通过 file_positions 来获取文件
|
|
282
|
+
if file_positions:
|
|
283
|
+
## 拿到位置列表,然后根据位置排序 得到 [(pos,file_path)]
|
|
284
|
+
## 将 [(pos,file_path)] 转换为 [file_path]
|
|
285
|
+
## 通过 [file_path] 顺序调整 temp_sources 的顺序
|
|
286
|
+
## MARK
|
|
287
|
+
# 将 file_positions 转换为 [(pos, file_path)] 的列表
|
|
288
|
+
position_file_pairs = [(pos, file_path) for file_path, pos in file_positions.items()]
|
|
289
|
+
# 按位置排序
|
|
290
|
+
position_file_pairs.sort(key=lambda x: x[0])
|
|
291
|
+
# 提取排序后的文件路径列表
|
|
292
|
+
sorted_file_paths = [file_path for _, file_path in position_file_pairs]
|
|
293
|
+
# 根据 sorted_file_paths 重新排序 temp_sources
|
|
294
|
+
temp_sources.sort(key=lambda x: sorted_file_paths.index(x.module_name) if x.module_name in sorted_file_paths else len(sorted_file_paths))
|
|
295
|
+
pruned_files = context_pruner.handle_overflow([source.module_name for source in temp_sources], [{"role":"user","content":args.query}], args.context_prune_strategy)
|
|
296
|
+
source_code_list.sources = pruned_files
|
|
297
|
+
|
|
276
298
|
if args.request_id and not args.skip_events:
|
|
277
299
|
queue_communicate.send_event(
|
|
278
300
|
request_id=args.request_id,
|
|
279
301
|
event=CommunicateEvent(
|
|
280
302
|
event_type=CommunicateEventType.CODE_INDEX_FILTER_FILE_SELECTED.value,
|
|
281
303
|
data=json.dumps([
|
|
282
|
-
(file
|
|
283
|
-
for file in final_files.values()
|
|
284
|
-
if file.file_path in depulicated_sources
|
|
304
|
+
(file.module_name, "") for file in source_code_list.sources
|
|
285
305
|
])
|
|
286
306
|
)
|
|
287
307
|
)
|
|
288
308
|
|
|
289
|
-
stats["final_files"] = len(
|
|
309
|
+
stats["final_files"] = len(source_code_list.sources)
|
|
290
310
|
phase_end = time.monotonic()
|
|
291
311
|
stats["timings"]["prepare_output"] = phase_end - phase_start
|
|
292
312
|
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Dict, List, Set, Tuple
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from loguru import logger
|
|
7
|
+
import byzerllm
|
|
8
|
+
from autocoder.common import AutoCoderArgs
|
|
9
|
+
from autocoder.common.printer import Printer
|
|
10
|
+
from typing import Union
|
|
11
|
+
import pydantic
|
|
12
|
+
from autocoder.common.result_manager import ResultManager
|
|
13
|
+
|
|
14
|
+
class ExtensionClassifyResult(pydantic.BaseModel):
|
|
15
|
+
code: List[str] = []
|
|
16
|
+
config: List[str] = []
|
|
17
|
+
data: List[str] = []
|
|
18
|
+
document: List[str] = []
|
|
19
|
+
other: List[str] = []
|
|
20
|
+
framework: List[str] = []
|
|
21
|
+
|
|
22
|
+
class ProjectTypeAnalyzer:
|
|
23
|
+
def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
|
|
24
|
+
self.args = args
|
|
25
|
+
self.llm = llm
|
|
26
|
+
self.printer = Printer()
|
|
27
|
+
self.default_exclude_dirs = [
|
|
28
|
+
".git", ".svn", ".hg", "build", "dist", "__pycache__",
|
|
29
|
+
"node_modules", ".auto-coder", ".vscode", ".idea", "venv",
|
|
30
|
+
".next", ".nuxt", ".svelte-kit", "out", "cache", "logs",
|
|
31
|
+
"temp", "tmp", "coverage", ".DS_Store", "public", "static"
|
|
32
|
+
]
|
|
33
|
+
self.extension_counts = defaultdict(int)
|
|
34
|
+
self.stats_file = Path(args.source_dir) / ".auto-coder" / "project_type_stats.json"
|
|
35
|
+
self.result_manager = ResultManager()
|
|
36
|
+
|
|
37
|
+
def traverse_project(self) -> None:
|
|
38
|
+
"""遍历项目目录,统计文件后缀"""
|
|
39
|
+
for root, dirs, files in os.walk(self.args.source_dir):
|
|
40
|
+
# 过滤掉默认排除的目录
|
|
41
|
+
dirs[:] = [d for d in dirs if d not in self.default_exclude_dirs]
|
|
42
|
+
|
|
43
|
+
for file in files:
|
|
44
|
+
_, ext = os.path.splitext(file)
|
|
45
|
+
if ext: # 只统计有后缀的文件
|
|
46
|
+
self.extension_counts[ext.lower()] += 1
|
|
47
|
+
|
|
48
|
+
def count_extensions(self) -> Dict[str, int]:
|
|
49
|
+
"""返回文件后缀统计结果"""
|
|
50
|
+
return dict(sorted(self.extension_counts.items(), key=lambda x: x[1], reverse=True))
|
|
51
|
+
|
|
52
|
+
@byzerllm.prompt()
|
|
53
|
+
def classify_extensions(self, extensions: str) -> str:
|
|
54
|
+
"""
|
|
55
|
+
根据文件后缀列表,将后缀分类为代码、配置、数据、文档等类型。
|
|
56
|
+
|
|
57
|
+
文件后缀列表:
|
|
58
|
+
{{ extensions }}
|
|
59
|
+
|
|
60
|
+
请返回如下JSON格式:
|
|
61
|
+
{
|
|
62
|
+
"code": ["后缀1", "后缀2"],
|
|
63
|
+
"config": ["后缀3", "后缀4"],
|
|
64
|
+
"data": ["后缀5", "后缀6"],
|
|
65
|
+
"document": ["后缀7", "后缀8"],
|
|
66
|
+
"other": ["后缀9", "后缀10"],
|
|
67
|
+
"framework": ["后缀11", "后缀12"]
|
|
68
|
+
}
|
|
69
|
+
"""
|
|
70
|
+
return {
|
|
71
|
+
"extensions": extensions
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
def save_stats(self) -> None:
|
|
75
|
+
"""保存统计结果到文件"""
|
|
76
|
+
stats = {
|
|
77
|
+
"extension_counts": self.extension_counts,
|
|
78
|
+
"project_type": self.detect_project_type()
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# 确保目录存在
|
|
82
|
+
self.stats_file.parent.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
with open(self.stats_file, "w", encoding="utf-8") as f:
|
|
85
|
+
json.dump(stats, f, indent=2)
|
|
86
|
+
|
|
87
|
+
self.printer.print_in_terminal("stats_saved", path=str(self.stats_file))
|
|
88
|
+
|
|
89
|
+
def load_stats(self) -> Dict[str, any]:
|
|
90
|
+
"""从文件加载统计结果"""
|
|
91
|
+
if not self.stats_file.exists():
|
|
92
|
+
self.printer.print_in_terminal("stats_not_found", path=str(self.stats_file))
|
|
93
|
+
return {}
|
|
94
|
+
|
|
95
|
+
with open(self.stats_file, "r", encoding="utf-8") as f:
|
|
96
|
+
return json.load(f)
|
|
97
|
+
|
|
98
|
+
def detect_project_type(self) -> str:
|
|
99
|
+
"""根据后缀统计结果推断项目类型"""
|
|
100
|
+
# 获取统计结果
|
|
101
|
+
ext_counts = self.count_extensions()
|
|
102
|
+
# 将后缀分类
|
|
103
|
+
classification = self.classify_extensions.with_llm(self.llm).with_return_type(ExtensionClassifyResult).run(json.dumps(ext_counts,ensure_ascii=False))
|
|
104
|
+
return ",".join(classification.code)
|
|
105
|
+
|
|
106
|
+
def analyze(self) -> Dict[str, any]:
|
|
107
|
+
"""执行完整的项目类型分析流程"""
|
|
108
|
+
# 遍历项目目录
|
|
109
|
+
self.traverse_project()
|
|
110
|
+
|
|
111
|
+
# 检测项目类型
|
|
112
|
+
project_type = self.detect_project_type()
|
|
113
|
+
|
|
114
|
+
self.result_manager.add_result(content=project_type, meta={
|
|
115
|
+
"action": "get_project_type",
|
|
116
|
+
"input": {
|
|
117
|
+
|
|
118
|
+
}
|
|
119
|
+
})
|
|
120
|
+
return project_type
|
|
@@ -25,8 +25,8 @@ PROVIDER_INFO_LIST = [
|
|
|
25
25
|
ProviderInfo(
|
|
26
26
|
name="volcano",
|
|
27
27
|
endpoint="https://ark.cn-beijing.volces.com/api/v3",
|
|
28
|
-
r1_model="",
|
|
29
|
-
v3_model="",
|
|
28
|
+
r1_model="deepseek-r1-250120",
|
|
29
|
+
v3_model="deepseek-v3-241226",
|
|
30
30
|
api_key="",
|
|
31
31
|
r1_input_price=2.0,
|
|
32
32
|
r1_output_price=8.0,
|
|
@@ -162,32 +162,32 @@ class ModelProviderSelector:
|
|
|
162
162
|
provider_info = provider
|
|
163
163
|
break
|
|
164
164
|
|
|
165
|
-
if result == "volcano":
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
165
|
+
# if result == "volcano":
|
|
166
|
+
# # Get R1 endpoint
|
|
167
|
+
# r1_endpoint = input_dialog(
|
|
168
|
+
# title=self.printer.get_message_from_key("model_provider_api_key_title"),
|
|
169
|
+
# text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
|
|
170
|
+
# validator=VolcanoEndpointValidator(),
|
|
171
|
+
# style=dialog_style
|
|
172
|
+
# ).run()
|
|
173
173
|
|
|
174
|
-
|
|
175
|
-
|
|
174
|
+
# if r1_endpoint is None:
|
|
175
|
+
# return None
|
|
176
176
|
|
|
177
|
-
|
|
177
|
+
# provider_info.r1_model = r1_endpoint
|
|
178
178
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
179
|
+
# # Get V3 endpoint
|
|
180
|
+
# v3_endpoint = input_dialog(
|
|
181
|
+
# title=self.printer.get_message_from_key("model_provider_api_key_title"),
|
|
182
|
+
# text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
|
|
183
|
+
# validator=VolcanoEndpointValidator(),
|
|
184
|
+
# style=dialog_style
|
|
185
|
+
# ).run()
|
|
186
186
|
|
|
187
|
-
|
|
188
|
-
|
|
187
|
+
# if v3_endpoint is None:
|
|
188
|
+
# return None
|
|
189
189
|
|
|
190
|
-
|
|
190
|
+
# provider_info.v3_model = v3_endpoint
|
|
191
191
|
|
|
192
192
|
# Get API key for all providers
|
|
193
193
|
api_key = input_dialog(
|
autocoder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.269"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|