auto-coder 0.1.268__py3-none-any.whl → 0.1.269__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

autocoder/index/entry.py CHANGED
@@ -58,8 +58,12 @@ def build_index_and_filter_files(
58
58
  return file_path.strip()[2:]
59
59
  return file_path
60
60
 
61
+ # 文件路径 -> TargetFile
61
62
  final_files: Dict[str, TargetFile] = {}
62
63
 
64
+ # 文件路径 -> 文件在文件列表中的位置(越前面表示越相关)
65
+ file_positions:Dict[str,int] = {}
66
+
63
67
  # Phase 1: Process REST/RAG/Search sources
64
68
  printer = Printer()
65
69
  printer.print_in_terminal("phase1_processing_sources")
@@ -102,25 +106,20 @@ def build_index_and_filter_files(
102
106
  })
103
107
  )
104
108
  )
105
-
109
+
110
+
106
111
  if not args.skip_filter_index and args.index_filter_model:
107
112
  model_name = getattr(index_manager.index_filter_llm, 'default_model_name', None)
108
113
  if not model_name:
109
114
  model_name = "unknown(without default model name)"
110
115
  printer.print_in_terminal("quick_filter_start", style="blue", model_name=model_name)
111
116
  quick_filter = QuickFilter(index_manager,stats,sources)
112
- quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
113
- # if quick_filter_result.has_error:
114
- # raise KeyboardInterrupt(printer.get_message_from_key_with_format("quick_filter_failed",error=quick_filter_result.error_message))
115
-
116
- # Merge quick filter results into final_files
117
- if args.context_prune:
118
- context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
119
- pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
120
- for source_file in pruned_files:
121
- final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
122
- else:
123
- final_files.update(quick_filter_result.files)
117
+ quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
118
+
119
+ final_files.update(quick_filter_result.files)
120
+
121
+ if quick_filter_result.file_positions:
122
+ file_positions.update(quick_filter_result.file_positions)
124
123
 
125
124
  if not args.skip_filter_index and not args.index_filter_model:
126
125
  model_name = getattr(index_manager.llm, 'default_model_name', None)
@@ -261,32 +260,53 @@ def build_index_and_filter_files(
261
260
  for file in final_filenames:
262
261
  print(f"{file} - {final_files[file].reason}")
263
262
 
264
- source_code = ""
263
+ # source_code = ""
265
264
  source_code_list = SourceCodeList(sources=[])
266
265
  depulicated_sources = set()
267
-
266
+
267
+ ## 先去重
268
+ temp_sources = []
268
269
  for file in sources:
269
270
  if file.module_name in final_filenames:
270
271
  if file.module_name in depulicated_sources:
271
272
  continue
272
273
  depulicated_sources.add(file.module_name)
273
- source_code += f"##File: {file.module_name}\n"
274
- source_code += f"{file.source_code}\n\n"
275
- source_code_list.sources.append(file)
274
+ # source_code += f"##File: {file.module_name}\n"
275
+ # source_code += f"{file.source_code}\n\n"
276
+ temp_sources.append(file)
277
+
278
+ ## 开启了裁剪,则需要做裁剪,不过目前只针对 quick filter 生效
279
+ if args.context_prune:
280
+ context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
281
+ # 如果 file_positions 不为空,则通过 file_positions 来获取文件
282
+ if file_positions:
283
+ ## 拿到位置列表,然后根据位置排序 得到 [(pos,file_path)]
284
+ ## 将 [(pos,file_path)] 转换为 [file_path]
285
+ ## 通过 [file_path] 顺序调整 temp_sources 的顺序
286
+ ## MARK
287
+ # 将 file_positions 转换为 [(pos, file_path)] 的列表
288
+ position_file_pairs = [(pos, file_path) for file_path, pos in file_positions.items()]
289
+ # 按位置排序
290
+ position_file_pairs.sort(key=lambda x: x[0])
291
+ # 提取排序后的文件路径列表
292
+ sorted_file_paths = [file_path for _, file_path in position_file_pairs]
293
+ # 根据 sorted_file_paths 重新排序 temp_sources
294
+ temp_sources.sort(key=lambda x: sorted_file_paths.index(x.module_name) if x.module_name in sorted_file_paths else len(sorted_file_paths))
295
+ pruned_files = context_pruner.handle_overflow([source.module_name for source in temp_sources], [{"role":"user","content":args.query}], args.context_prune_strategy)
296
+ source_code_list.sources = pruned_files
297
+
276
298
  if args.request_id and not args.skip_events:
277
299
  queue_communicate.send_event(
278
300
  request_id=args.request_id,
279
301
  event=CommunicateEvent(
280
302
  event_type=CommunicateEventType.CODE_INDEX_FILTER_FILE_SELECTED.value,
281
303
  data=json.dumps([
282
- (file["file_path"], file.reason)
283
- for file in final_files.values()
284
- if file.file_path in depulicated_sources
304
+ (file.module_name, "") for file in source_code_list.sources
285
305
  ])
286
306
  )
287
307
  )
288
308
 
289
- stats["final_files"] = len(depulicated_sources)
309
+ stats["final_files"] = len(source_code_list.sources)
290
310
  phase_end = time.monotonic()
291
311
  stats["timings"]["prepare_output"] = phase_end - phase_start
292
312
 
@@ -0,0 +1,120 @@
1
+ import os
2
+ import json
3
+ from collections import defaultdict
4
+ from typing import Dict, List, Set, Tuple
5
+ from pathlib import Path
6
+ from loguru import logger
7
+ import byzerllm
8
+ from autocoder.common import AutoCoderArgs
9
+ from autocoder.common.printer import Printer
10
+ from typing import Union
11
+ import pydantic
12
+ from autocoder.common.result_manager import ResultManager
13
+
14
+ class ExtensionClassifyResult(pydantic.BaseModel):
15
+ code: List[str] = []
16
+ config: List[str] = []
17
+ data: List[str] = []
18
+ document: List[str] = []
19
+ other: List[str] = []
20
+ framework: List[str] = []
21
+
22
+ class ProjectTypeAnalyzer:
23
+ def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
24
+ self.args = args
25
+ self.llm = llm
26
+ self.printer = Printer()
27
+ self.default_exclude_dirs = [
28
+ ".git", ".svn", ".hg", "build", "dist", "__pycache__",
29
+ "node_modules", ".auto-coder", ".vscode", ".idea", "venv",
30
+ ".next", ".nuxt", ".svelte-kit", "out", "cache", "logs",
31
+ "temp", "tmp", "coverage", ".DS_Store", "public", "static"
32
+ ]
33
+ self.extension_counts = defaultdict(int)
34
+ self.stats_file = Path(args.source_dir) / ".auto-coder" / "project_type_stats.json"
35
+ self.result_manager = ResultManager()
36
+
37
+ def traverse_project(self) -> None:
38
+ """遍历项目目录,统计文件后缀"""
39
+ for root, dirs, files in os.walk(self.args.source_dir):
40
+ # 过滤掉默认排除的目录
41
+ dirs[:] = [d for d in dirs if d not in self.default_exclude_dirs]
42
+
43
+ for file in files:
44
+ _, ext = os.path.splitext(file)
45
+ if ext: # 只统计有后缀的文件
46
+ self.extension_counts[ext.lower()] += 1
47
+
48
+ def count_extensions(self) -> Dict[str, int]:
49
+ """返回文件后缀统计结果"""
50
+ return dict(sorted(self.extension_counts.items(), key=lambda x: x[1], reverse=True))
51
+
52
+ @byzerllm.prompt()
53
+ def classify_extensions(self, extensions: str) -> str:
54
+ """
55
+ 根据文件后缀列表,将后缀分类为代码、配置、数据、文档等类型。
56
+
57
+ 文件后缀列表:
58
+ {{ extensions }}
59
+
60
+ 请返回如下JSON格式:
61
+ {
62
+ "code": ["后缀1", "后缀2"],
63
+ "config": ["后缀3", "后缀4"],
64
+ "data": ["后缀5", "后缀6"],
65
+ "document": ["后缀7", "后缀8"],
66
+ "other": ["后缀9", "后缀10"],
67
+ "framework": ["后缀11", "后缀12"]
68
+ }
69
+ """
70
+ return {
71
+ "extensions": extensions
72
+ }
73
+
74
+ def save_stats(self) -> None:
75
+ """保存统计结果到文件"""
76
+ stats = {
77
+ "extension_counts": self.extension_counts,
78
+ "project_type": self.detect_project_type()
79
+ }
80
+
81
+ # 确保目录存在
82
+ self.stats_file.parent.mkdir(parents=True, exist_ok=True)
83
+
84
+ with open(self.stats_file, "w", encoding="utf-8") as f:
85
+ json.dump(stats, f, indent=2)
86
+
87
+ self.printer.print_in_terminal("stats_saved", path=str(self.stats_file))
88
+
89
+ def load_stats(self) -> Dict[str, any]:
90
+ """从文件加载统计结果"""
91
+ if not self.stats_file.exists():
92
+ self.printer.print_in_terminal("stats_not_found", path=str(self.stats_file))
93
+ return {}
94
+
95
+ with open(self.stats_file, "r", encoding="utf-8") as f:
96
+ return json.load(f)
97
+
98
+ def detect_project_type(self) -> str:
99
+ """根据后缀统计结果推断项目类型"""
100
+ # 获取统计结果
101
+ ext_counts = self.count_extensions()
102
+ # 将后缀分类
103
+ classification = self.classify_extensions.with_llm(self.llm).with_return_type(ExtensionClassifyResult).run(json.dumps(ext_counts,ensure_ascii=False))
104
+ return ",".join(classification.code)
105
+
106
+ def analyze(self) -> Dict[str, any]:
107
+ """执行完整的项目类型分析流程"""
108
+ # 遍历项目目录
109
+ self.traverse_project()
110
+
111
+ # 检测项目类型
112
+ project_type = self.detect_project_type()
113
+
114
+ self.result_manager.add_result(content=project_type, meta={
115
+ "action": "get_project_type",
116
+ "input": {
117
+
118
+ }
119
+ })
120
+ return project_type
@@ -25,8 +25,8 @@ PROVIDER_INFO_LIST = [
25
25
  ProviderInfo(
26
26
  name="volcano",
27
27
  endpoint="https://ark.cn-beijing.volces.com/api/v3",
28
- r1_model="",
29
- v3_model="",
28
+ r1_model="deepseek-r1-250120",
29
+ v3_model="deepseek-v3-241226",
30
30
  api_key="",
31
31
  r1_input_price=2.0,
32
32
  r1_output_price=8.0,
@@ -162,32 +162,32 @@ class ModelProviderSelector:
162
162
  provider_info = provider
163
163
  break
164
164
 
165
- if result == "volcano":
166
- # Get R1 endpoint
167
- r1_endpoint = input_dialog(
168
- title=self.printer.get_message_from_key("model_provider_api_key_title"),
169
- text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
170
- validator=VolcanoEndpointValidator(),
171
- style=dialog_style
172
- ).run()
165
+ # if result == "volcano":
166
+ # # Get R1 endpoint
167
+ # r1_endpoint = input_dialog(
168
+ # title=self.printer.get_message_from_key("model_provider_api_key_title"),
169
+ # text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
170
+ # validator=VolcanoEndpointValidator(),
171
+ # style=dialog_style
172
+ # ).run()
173
173
 
174
- if r1_endpoint is None:
175
- return None
174
+ # if r1_endpoint is None:
175
+ # return None
176
176
 
177
- provider_info.r1_model = r1_endpoint
177
+ # provider_info.r1_model = r1_endpoint
178
178
 
179
- # Get V3 endpoint
180
- v3_endpoint = input_dialog(
181
- title=self.printer.get_message_from_key("model_provider_api_key_title"),
182
- text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
183
- validator=VolcanoEndpointValidator(),
184
- style=dialog_style
185
- ).run()
179
+ # # Get V3 endpoint
180
+ # v3_endpoint = input_dialog(
181
+ # title=self.printer.get_message_from_key("model_provider_api_key_title"),
182
+ # text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
183
+ # validator=VolcanoEndpointValidator(),
184
+ # style=dialog_style
185
+ # ).run()
186
186
 
187
- if v3_endpoint is None:
188
- return None
187
+ # if v3_endpoint is None:
188
+ # return None
189
189
 
190
- provider_info.v3_model = v3_endpoint
190
+ # provider_info.v3_model = v3_endpoint
191
191
 
192
192
  # Get API key for all providers
193
193
  api_key = input_dialog(
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.268"
1
+ __version__ = "0.1.269"