auto-coder 0.1.263__py3-none-any.whl → 0.1.264__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -20,6 +20,8 @@ from byzerllm import MetaHolder
20
20
 
21
21
  from autocoder.utils.llms import get_llm_names, get_model_info
22
22
  from loguru import logger
23
+ from byzerllm.utils.client.code_utils import extract_code
24
+ import json
23
25
 
24
26
 
25
27
  def get_file_path(file_path):
@@ -32,6 +34,15 @@ class QuickFilterResult(BaseModel):
32
34
  files: Dict[str, TargetFile]
33
35
  has_error: bool
34
36
  error_message: Optional[str] = None
37
+ file_positions: Optional[Dict[str, int]] = {}
38
+
39
+ def get_sorted_file_positions(self) -> List[str]:
40
+ """
41
+ 返回按 value 排序的文件列表
42
+ """
43
+ if not self.file_positions:
44
+ return []
45
+ return [file_path for file_path, _ in sorted(self.file_positions.items(), key=lambda x: x[1])]
35
46
 
36
47
 
37
48
  class QuickFilter():
@@ -82,6 +93,7 @@ class QuickFilter():
82
93
  self.index_manager.index_filter_llm)
83
94
  model_name = ",".join(model_names)
84
95
  files: Dict[str, TargetFile] = {}
96
+ file_positions: Dict[str, int] = {}
85
97
 
86
98
  # 获取模型价格信息
87
99
  model_info_map = {}
@@ -166,7 +178,7 @@ class QuickFilter():
166
178
  )
167
179
 
168
180
  if file_number_list:
169
- for file_number in file_number_list.file_list:
181
+ for index,file_number in enumerate(file_number_list.file_list):
170
182
  if file_number < 0 or file_number >= len(chunk):
171
183
  self.printer.print_in_terminal(
172
184
  "invalid_file_number",
@@ -182,9 +194,11 @@ class QuickFilter():
182
194
  reason=self.printer.get_message_from_key(
183
195
  "quick_filter_reason")
184
196
  )
197
+ file_positions[file_path] = index
185
198
  return QuickFilterResult(
186
199
  files=files,
187
- has_error=False
200
+ has_error=False,
201
+ file_positions=file_positions
188
202
  )
189
203
 
190
204
  except Exception as e:
@@ -212,6 +226,7 @@ class QuickFilter():
212
226
 
213
227
  # 合并所有结果
214
228
  final_files: Dict[str, TargetFile] = {}
229
+ final_file_positions: Dict[str, int] = {}
215
230
  has_error = False
216
231
  error_messages: List[str] = []
217
232
 
@@ -222,16 +237,57 @@ class QuickFilter():
222
237
  error_messages.append(result.error_message)
223
238
  final_files.update(result.files)
224
239
 
240
+
241
+ for result in results:
242
+ if result.has_error:
243
+ has_error = True
244
+ if result.error_message:
245
+ error_messages.append(result.error_message)
246
+ ## 实现多个 result.file_positions 交织排序
247
+ # 比如第一个是 {file_path_1_0: 0, file_path_1_1: 1, file_path_1_2: 2}
248
+ # 第二个是 {file_path_2_0: 0, file_path_2_1: 1}
249
+ # 第三个是 {file_path_3_0: 0, file_path_3_1: 1, file_path_3_2: 2, file_path_3_3: 3}
250
+ # 收集逻辑为所以 0 的为一组,然后序号为 0,1,2, 所有1 的为一组,序号是 3,4,5,依次往下推
251
+ # {file_path_1_0: 0, file_path_2_0: 1, file_path_3_0: 2, file_path_1_1: 3, file_path_2_1: 4, file_path_3_1: 5}
252
+ #
253
+ # 获取所有结果的最大 position 值
254
+ max_position = max([max(pos.values()) for pos in [result.file_positions for result in results if result.file_positions]] + [0])
255
+
256
+ # 创建一个映射表,用于记录每个 position 对应的文件路径
257
+ position_map = {}
258
+ for result in results:
259
+ if result.file_positions:
260
+ for file_path, position in result.file_positions.items():
261
+ if position not in position_map:
262
+ position_map[position] = []
263
+ position_map[position].append(file_path)
264
+
265
+ # 重新排序文件路径
266
+ new_file_positions = {}
267
+ current_index = 0
268
+ for position in range(max_position + 1):
269
+ if position in position_map:
270
+ for file_path in position_map[position]:
271
+ new_file_positions[file_path] = current_index
272
+ current_index += 1
273
+
274
+ # 更新 final_file_positions
275
+ final_file_positions.update(new_file_positions)
276
+
225
277
  return QuickFilterResult(
226
278
  files=final_files,
227
279
  has_error=has_error,
228
280
  error_message="\n".join(error_messages) if error_messages else None
229
281
  )
282
+
230
283
 
231
284
  @byzerllm.prompt()
232
285
  def quick_filter_files(self, file_meta_list: List[IndexItem], query: str) -> str:
233
286
  '''
234
- 当用户提一个需求的时候,我们需要找到相关的文件,然后阅读这些文件,并且修改其中部分文件。
287
+ 当用户提一个需求的时候,我们要找到两种类型的源码文件:
288
+ 1. 根据需求需要被修改的文件,我们叫 edited_files
289
+ 2. 为了能够完成修改这些文件,还需要的一些额外参考文件, 我们叫 reference_files
290
+
235
291
  现在,给定下面的索引文件:
236
292
 
237
293
  <index>
@@ -258,12 +314,13 @@ class QuickFilter():
258
314
  }
259
315
  ```
260
316
 
261
- 特别注意
262
- 1. 如果用户的query @文件 或者 @@符号,那么被@的文件或者@@的符号必须要返回,并且尝试通过索引文件诸如导入语句等信息找到这些文件依赖的其他文件,再分析这些文件是否需要提供才能满足后续编码。
263
- 2. 如果 query 里是一段历史对话,那么对话里的内容提及的文件路径必须要返回。
264
- 3. 想想,如果是你需要修改代码,然后满足这个需求,根据索引文件,你希望查看哪些文件,修改哪些文件,然后返回这些文件。
265
- 4. 如果用户需求为空,则直接返回空列表即可。
266
- 5. 返回的 json格式数据不允许有注释
317
+ 特别注意:
318
+ 1. 如果用户的query里有 @文件 或者 @@符号,那么被@的文件或者@@的符号必须要返回。
319
+ 2. 根据需求以及根据 @文件 或者 @@符号 找到的文件,猜测需要被修改的edited_files文件,然后尝试通过索引文件诸如导入语句等信息找到这些文件依赖的其他文件得到 reference_files。
320
+ 3. file_list 里的文件序号,按被 @ 或者 @@ 文件,edited_files文件,reference_files文件的顺序排列。注意,reference_files 你要根据需求来猜测是否需要,过滤掉不相关的,避免返回文件数过多。
321
+ 4. 如果 query 里是一段历史对话,那么对话里的内容提及的文件路径必须要返回。
322
+ 5. 如果用户需求为空,则直接返回空列表即可。
323
+ 6. 返回的 json格式数据不允许有注释
267
324
  '''
268
325
 
269
326
  file_meta_str = "\n".join(
@@ -273,9 +330,58 @@ class QuickFilter():
273
330
  "query": query
274
331
  }
275
332
  return context
333
+
334
+
335
+ def _extract_code_snippets_from_overflow_files(self, validated_file_numbers: List[int],index_items: List[IndexItem], conversations: List[Dict[str, str]]):
336
+ token_count = 0
337
+ selected_files = []
338
+ selected_file_contents = []
339
+ full_file_tokens = int(self.max_tokens * 0.8)
340
+ for file_number in validated_file_numbers:
341
+ file_path = get_file_path(index_items[file_number].module_name)
342
+ with open(file_path, "r", encoding="utf-8") as f:
343
+ content = f.read()
344
+ tokens = count_tokens(content)
345
+ if token_count + tokens <= full_file_tokens:
346
+ selected_files.append(file_number)
347
+ selected_file_contents.append(content)
348
+ token_count += tokens
349
+ else:
350
+ # 对超出部分抽取代码片段
351
+ try:
352
+ extracted_info = (
353
+ self.extract_code_snippets_from_files.options(
354
+ {"llm_config": {"max_length": 100}}
355
+ )
356
+ .with_llm(self.index_manager.index_filter_llm)
357
+ .run(conversations, [content])
358
+ )
359
+ json_str = extract_code(extracted_info)[0][1]
360
+ json_objs = json.loads(json_str)
361
+
362
+ new_content = ""
363
+
364
+ if json_objs:
365
+ for json_obj in json_objs:
366
+ start_line = json_obj["start_line"] - 1
367
+ end_line = json_obj["end_line"]
368
+ chunk = "\n".join(content.split("\n")[start_line:end_line])
369
+ new_content += chunk + "\n"
370
+
371
+ token_count += count_tokens(new_content)
372
+ if token_count >= self.max_tokens:
373
+ break
374
+ else:
375
+ selected_files.append(file_number)
376
+ selected_file_contents.append(new_content)
377
+ except Exception as e:
378
+ logger.error(f"Failed to extract code snippets from {file_path}: {e}")
379
+ return selected_files
380
+
276
381
 
277
382
  def filter(self, index_items: List[IndexItem], query: str) -> QuickFilterResult:
278
383
  final_files: Dict[str, TargetFile] = {}
384
+ final_file_positions: Dict[str, int] = {}
279
385
  start_time = time.monotonic()
280
386
 
281
387
  prompt_str = self.quick_filter_files.prompt(index_items, query)
@@ -385,6 +491,7 @@ class QuickFilter():
385
491
  )
386
492
 
387
493
  if file_number_list:
494
+ validated_file_numbers = []
388
495
  for file_number in file_number_list.file_list:
389
496
  if file_number < 0 or file_number >= len(index_items):
390
497
  self.printer.print_in_terminal(
@@ -394,14 +501,21 @@ class QuickFilter():
394
501
  total_files=len(index_items)
395
502
  )
396
503
  continue
397
- final_files[get_file_path(index_items[file_number].module_name)] = TargetFile(
504
+ validated_file_numbers.append(file_number)
505
+
506
+ # 将最终选中的文件加入final_files
507
+ for index,file_number in enumerate(validated_file_numbers):
508
+ file_path = get_file_path(index_items[file_number].module_name)
509
+ final_files[file_path] = TargetFile(
398
510
  file_path=index_items[file_number].module_name,
399
- reason=self.printer.get_message_from_key(
400
- "quick_filter_reason")
511
+ reason=self.printer.get_message_from_key("quick_filter_reason")
401
512
  )
513
+ final_file_positions[file_path] = index
514
+
402
515
  end_time = time.monotonic()
403
516
  self.stats["timings"]["quick_filter"] = end_time - start_time
404
517
  return QuickFilterResult(
405
518
  files=final_files,
406
- has_error=False
519
+ has_error=False,
520
+ file_positions=final_file_positions
407
521
  )
autocoder/index/index.py CHANGED
@@ -26,6 +26,7 @@ from autocoder.index.types import (
26
26
  )
27
27
  from autocoder.common.global_cancel import global_cancel
28
28
  from autocoder.utils.llms import get_llm_names
29
+ from autocoder.rag.token_counter import count_tokens
29
30
  class IndexManager:
30
31
  def __init__(
31
32
  self, llm: byzerllm.ByzerLLM, sources: List[SourceCode], args: AutoCoderArgs
@@ -257,13 +258,13 @@ class IndexManager:
257
258
  total_input_cost = 0.0
258
259
  total_output_cost = 0.0
259
260
 
260
- if len(source.source_code) > self.max_input_length:
261
+ if count_tokens(source.source_code) > self.args.conversation_prune_safe_zone_tokens:
261
262
  self.printer.print_in_terminal(
262
263
  "index_file_too_large",
263
264
  style="yellow",
264
265
  file_path=source.module_name,
265
266
  file_size=len(source.source_code),
266
- max_length=self.max_input_length
267
+ max_length=self.args.conversation_prune_safe_zone_tokens
267
268
  )
268
269
  chunks = self.split_text_into_chunks(
269
270
  source_code, self.max_input_length - 1000
@@ -1,9 +1,264 @@
1
+ from collections import defaultdict
2
+ import os
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import List, Pattern, Dict, Any, Set, Union
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ import byzerllm
8
+ from pydantic import BaseModel
9
+ from rich.tree import Tree
10
+ from rich.console import Console
11
+ from loguru import logger
1
12
  from autocoder.pyproject import PyProject
2
13
  from autocoder.tsproject import TSProject
3
14
  from autocoder.suffixproject import SuffixProject
4
15
  from autocoder.common import AutoCoderArgs
5
- import byzerllm
6
- from typing import Union
16
+
17
+ @dataclass
18
+ class AnalysisConfig:
19
+ exclude_dirs: List[str] = None
20
+ exclude_file_patterns: List[Pattern] = None
21
+ exclude_extensions: List[str] = None
22
+ max_depth: int = -1
23
+ show_hidden: bool = False
24
+ parallel_processing: bool = True
25
+
26
+ class ExtentionResult(BaseModel):
27
+ code: List[str] = []
28
+ config: List[str] = []
29
+ data: List[str] = []
30
+ document: List[str] = []
31
+ other: List[str] = []
32
+
33
+ class EnhancedFileAnalyzer:
34
+ DEFAULT_EXCLUDE_DIRS = [".git", "node_modules", "__pycache__", "venv"]
35
+ DEFAULT_EXCLUDE_EXTS = [".log", ".tmp", ".bak", ".swp"]
36
+
37
+ def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM], config: AnalysisConfig = None,):
38
+ self.directory = os.path.abspath(args.source_dir)
39
+ self.config = config or self.default_config()
40
+ self.llm = llm
41
+ self.console = Console()
42
+ self.file_filter = EnhancedFileFilter(self.config)
43
+
44
+ @classmethod
45
+ def default_config(cls) -> AnalysisConfig:
46
+ return AnalysisConfig(
47
+ exclude_dirs=cls.DEFAULT_EXCLUDE_DIRS,
48
+ exclude_file_patterns=[re.compile(r'~$')], # 默认排除临时文件
49
+ exclude_extensions=cls.DEFAULT_EXCLUDE_EXTS
50
+ )
51
+
52
+ def analyze(self) -> Dict[str, Any]:
53
+ """执行完整分析流程"""
54
+ return {
55
+ "structure": self.get_tree_structure(),
56
+ "extensions": self.analyze_extensions(),
57
+ "stats": self.get_directory_stats()
58
+ }
59
+
60
+ def get_tree_structure(self) -> Dict:
61
+ """获取优化的树形结构"""
62
+ tree = {}
63
+ if self.config.parallel_processing:
64
+ return self._parallel_tree_build()
65
+ return self._sequential_tree_build()
66
+
67
+ def _sequential_tree_build(self) -> Dict:
68
+ """单线程构建目录树"""
69
+ tree = {}
70
+ for root, dirs, files in os.walk(self.directory):
71
+ dirs[:] = [d for d in dirs if not self.file_filter.should_ignore(d, True)]
72
+ relative_path = os.path.relpath(root, self.directory)
73
+ current = tree
74
+ for part in relative_path.split(os.sep):
75
+ current = current.setdefault(part, {})
76
+ current.update({f: None for f in files if not self.file_filter.should_ignore(f, False)})
77
+ return tree
78
+
79
+ def _parallel_tree_build(self) -> Dict:
80
+ """并行构建目录树"""
81
+ from concurrent.futures import ThreadPoolExecutor, as_completed
82
+ import threading
83
+
84
+ tree = {}
85
+ tree_lock = threading.Lock()
86
+
87
+ def process_directory(root: str, dirs: List[str], files: List[str]) -> Dict:
88
+ local_tree = {}
89
+ relative_path = os.path.relpath(root, self.directory)
90
+ current = local_tree
91
+ for part in relative_path.split(os.sep):
92
+ current = current.setdefault(part, {})
93
+ current.update({f: None for f in files if not self.file_filter.should_ignore(f, False)})
94
+ return local_tree
95
+
96
+ with ThreadPoolExecutor() as executor:
97
+ futures = []
98
+ for root, dirs, files in os.walk(self.directory):
99
+ dirs[:] = [d for d in dirs if not self.file_filter.should_ignore(d, True)]
100
+ futures.append(executor.submit(process_directory, root, dirs, files))
101
+
102
+ for future in as_completed(futures):
103
+ try:
104
+ local_tree = future.result()
105
+ with tree_lock:
106
+ self._merge_trees(tree, local_tree)
107
+ except Exception as e:
108
+ logger.error(f"Error processing directory: {e}")
109
+
110
+ return tree
111
+
112
+ def _merge_trees(self, base_tree: Dict, new_tree: Dict) -> None:
113
+ """递归合并两个目录树"""
114
+ for key, value in new_tree.items():
115
+ if key in base_tree:
116
+ if isinstance(value, dict) and isinstance(base_tree[key], dict):
117
+ self._merge_trees(base_tree[key], value)
118
+ else:
119
+ base_tree[key] = value
120
+
121
+ def analyze_extensions(self) -> Dict:
122
+ """增强版后缀分析"""
123
+ from collections import defaultdict
124
+ extensions = self._collect_extensions()
125
+ if self.llm:
126
+ return self._llm_enhanced_analysis.with_llm(self.llm).run(extensions)
127
+ return self._basic_analysis(extensions)
128
+
129
+ def _collect_extensions(self) -> Set[str]:
130
+ """带过滤的文件后缀收集"""
131
+ extensions = set()
132
+ for root, dirs, files in os.walk(self.directory):
133
+ dirs[:] = [d for d in dirs if not self.file_filter.should_ignore(d, True)]
134
+ for file in files:
135
+ if self.file_filter.should_ignore(file, False):
136
+ continue
137
+ ext = os.path.splitext(file)[1].lower()
138
+ if ext: # 排除无后缀文件
139
+ extensions.add(ext)
140
+ return extensions
141
+
142
+ @byzerllm.prompt()
143
+ def _llm_enhanced_analysis(self, extensions: List[str]) -> Dict:
144
+ """LLM增强分析"""
145
+ '''
146
+ 请根据以下文件后缀列表,按照以下规则进行分类:
147
+
148
+ 1. 代码文件:包含可编译代码、有语法结构的文件
149
+ 2. 配置文件:包含参数设置、环境配置的文件
150
+ 3. 数据文件:包含结构化或非结构化数据的文件
151
+ 4. 文档文件:包含文档、说明、笔记的文件
152
+ 5. 其他文件:无法明确分类的文件
153
+
154
+ 文件后缀列表:
155
+ {{ extensions | join(', ') }}
156
+
157
+ 请返回如下JSON格式:
158
+ {
159
+ "code": ["后缀1", "后缀2"],
160
+ "config": ["后缀3", "后缀4"],
161
+ "data": ["后缀5", "后缀6"],
162
+ "document": ["后缀7", "后缀8"],
163
+ "other": ["后缀9", "后缀10"]
164
+ }
165
+ '''
166
+ return {
167
+ "extensions": extensions
168
+ }
169
+
170
+ def _basic_analysis(self, extensions: Set[str]) -> Dict:
171
+ """基于规则的基础分析"""
172
+ CODE_EXTS = {'.py', '.js', '.ts', '.java', '.c', '.cpp'}
173
+ CONFIG_EXTS = {'.yml', '.yaml', '.json', '.toml', '.ini'}
174
+
175
+ return {
176
+ "code": [ext for ext in extensions if ext in CODE_EXTS],
177
+ "config": [ext for ext in extensions if ext in CONFIG_EXTS],
178
+ "unknown": [ext for ext in extensions if ext not in CODE_EXTS | CONFIG_EXTS]
179
+ }
180
+
181
+ def get_directory_stats(self) -> Dict:
182
+ """获取目录统计信息"""
183
+ stats = {
184
+ 'total_files': 0,
185
+ 'total_dirs': 0,
186
+ 'by_extension': defaultdict(int),
187
+ 'file_types': {
188
+ 'code': 0,
189
+ 'config': 0,
190
+ 'data': 0,
191
+ 'document': 0,
192
+ 'other': 0
193
+ }
194
+ }
195
+ for root, dirs, files in os.walk(self.directory):
196
+ dirs[:] = [d for d in dirs if not self.file_filter.should_ignore(d, True)]
197
+ stats['total_dirs'] += len(dirs)
198
+ for file in files:
199
+ if self.file_filter.should_ignore(file, False):
200
+ continue
201
+ stats['total_files'] += 1
202
+ ext = os.path.splitext(file)[1].lower()
203
+ stats['by_extension'][ext] += 1
204
+
205
+ # 根据扩展名分类
206
+ if ext in ['.py', '.js', '.ts', '.java', '.c', '.cpp']:
207
+ stats['file_types']['code'] += 1
208
+ elif ext in ['.yml', '.yaml', '.json', '.toml', '.ini']:
209
+ stats['file_types']['config'] += 1
210
+ else:
211
+ stats['file_types']['other'] += 1
212
+ return stats
213
+
214
+ def interactive_display(self):
215
+ """交互式可视化展示"""
216
+ tree = self.build_interactive_tree(self.directory, self.config)
217
+ self.console.print(tree)
218
+ self.console.print("\n[bold]Statistical Summary:[/]")
219
+ stats = self.get_directory_stats()
220
+
221
+ from rich.table import Table
222
+ table = Table(title="Directory Statistics", show_header=True, header_style="bold magenta")
223
+ table.add_column("Metric", style="cyan")
224
+ table.add_column("Value", style="green")
225
+
226
+ table.add_row("Total Files", str(stats['total_files']))
227
+ table.add_row("Total Directories", str(stats['total_dirs']))
228
+ table.add_row("Code Files", str(stats['file_types']['code']))
229
+ table.add_row("Config Files", str(stats['file_types']['config']))
230
+ self.console.print(table)
231
+
232
+ class EnhancedFileFilter:
233
+ """增强版文件过滤器"""
234
+ def __init__(self, config: AnalysisConfig):
235
+ self.config = config
236
+
237
+ def should_ignore(self, path: str, is_dir: bool) -> bool:
238
+ """综合判断是否应忽略路径"""
239
+ base_name = os.path.basename(path)
240
+
241
+ # 隐藏文件处理
242
+ if not self.config.show_hidden and base_name.startswith('.'):
243
+ return True
244
+
245
+ # 目录排除
246
+ if is_dir and base_name in self.config.exclude_dirs:
247
+ return True
248
+
249
+ # 文件扩展名排除
250
+ if not is_dir:
251
+ ext = os.path.splitext(path)[1].lower()
252
+ if ext in self.config.exclude_extensions:
253
+ return True
254
+
255
+ # 正则匹配排除
256
+ full_path = os.path.abspath(path)
257
+ for pattern in self.config.exclude_file_patterns:
258
+ if pattern.search(full_path):
259
+ return True
260
+
261
+ return False
7
262
 
8
263
  def get_project_structure(args:AutoCoderArgs, llm:Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
9
264
  if args.project_type == "ts":
@@ -12,4 +267,4 @@ def get_project_structure(args:AutoCoderArgs, llm:Union[byzerllm.ByzerLLM, byzer
12
267
  pp = PyProject(args=args, llm=llm)
13
268
  else:
14
269
  pp = SuffixProject(args=args, llm=llm, file_filter=None)
15
- return pp.get_tree_like_directory_structure()
270
+ return pp.get_tree_like_directory_structure()
@@ -176,7 +176,12 @@ def run_in_raw_thread():
176
176
  exception = []
177
177
  def worker():
178
178
  try:
179
- # global_cancel.reset()
179
+ # 如果刚开始就遇到了,可能是用户中断的还没有释放
180
+ # 等待五秒后强行释放
181
+ if global_cancel.requested:
182
+ time.sleep(5)
183
+ global_cancel.reset()
184
+
180
185
  ret = func(*args, **kwargs)
181
186
  result.append(ret)
182
187
  global_cancel.reset()
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.263"
1
+ __version__ = "0.1.264"