auto-coder 0.1.262__py3-none-any.whl → 0.1.264__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -387,15 +387,35 @@ class CodeAutoMergeDiff:
387
387
  def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
388
388
  if len(generate_result.contents) == 1:
389
389
  return generate_result
390
+
391
+ merge_results = []
392
+ for content,conversations in zip(generate_result.contents,generate_result.conversations):
393
+ merge_result = self._merge_code_without_effect(content)
394
+ merge_results.append(merge_result)
390
395
 
396
+ # If all merge results are None, return first one
397
+ if all(len(result.failed_blocks) != 0 for result in merge_results):
398
+ self.printer.print_in_terminal("all_merge_results_failed")
399
+ return CodeGenerateResult(contents=[generate_result.contents[0]], conversations=[generate_result.conversations[0]])
400
+
401
+ # If only one merge result is not None, return that one
402
+ not_none_indices = [i for i, result in enumerate(merge_results) if len(result.failed_blocks) == 0]
403
+ if len(not_none_indices) == 1:
404
+ idx = not_none_indices[0]
405
+ self.printer.print_in_terminal("only_one_merge_result_success")
406
+ return CodeGenerateResult(contents=[generate_result.contents[idx]], conversations=[generate_result.conversations[idx]])
407
+
408
+ # 最后,如果有多个,那么根据质量排序再返回
391
409
  ranker = CodeModificationRanker(self.llm, self.args)
392
- ranked_result = ranker.rank_modifications(generate_result)
393
- # Filter out contents with failed blocks
410
+ ranked_result = ranker.rank_modifications(generate_result,merge_results)
411
+
412
+ ## 得到的结果,再做一次合并,第一个通过的返回 , 返回做合并有点重复低效,未来修改。
394
413
  for content,conversations in zip(ranked_result.contents,ranked_result.conversations):
395
414
  merge_result = self._merge_code_without_effect(content)
396
415
  if not merge_result.failed_blocks:
397
416
  return CodeGenerateResult(contents=[content], conversations=[conversations])
398
- # If all have failed blocks, return the first one
417
+
418
+ # 最后保底,但实际不会出现
399
419
  return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
400
420
 
401
421
  @byzerllm.prompt(render="jinja2")
@@ -440,6 +460,11 @@ class CodeAutoMergeDiff:
440
460
  errors = []
441
461
  for path, hunk in uniq:
442
462
  full_path = self.abs_root_path(path)
463
+
464
+ if not os.path.exists(full_path):
465
+ with open(full_path, "w",encoding="utf-8") as f:
466
+ f.write("")
467
+
443
468
  content = FileUtils.read_file(full_path)
444
469
 
445
470
  original, _ = hunk_to_before_after(hunk)
@@ -164,15 +164,35 @@ class CodeAutoMergeEditBlock:
164
164
  def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
165
165
  if len(generate_result.contents) == 1:
166
166
  return generate_result
167
-
167
+
168
+ merge_results = []
169
+ for content,conversations in zip(generate_result.contents,generate_result.conversations):
170
+ merge_result = self._merge_code_without_effect(content)
171
+ merge_results.append(merge_result)
172
+
173
+ # If all merge results are None, return first one
174
+ if all(len(result.failed_blocks) != 0 for result in merge_results):
175
+ self.printer.print_in_terminal("all_merge_results_failed")
176
+ return CodeGenerateResult(contents=[generate_result.contents[0]], conversations=[generate_result.conversations[0]])
177
+
178
+ # If only one merge result is not None, return that one
179
+ not_none_indices = [i for i, result in enumerate(merge_results) if len(result.failed_blocks) == 0]
180
+ if len(not_none_indices) == 1:
181
+ idx = not_none_indices[0]
182
+ self.printer.print_in_terminal("only_one_merge_result_success")
183
+ return CodeGenerateResult(contents=[generate_result.contents[idx]], conversations=[generate_result.conversations[idx]])
184
+
185
+ # 最后,如果有多个,那么根据质量排序再返回
168
186
  ranker = CodeModificationRanker(self.llm, self.args)
169
- ranked_result = ranker.rank_modifications(generate_result)
170
- # Filter out contents with failed blocks
187
+ ranked_result = ranker.rank_modifications(generate_result,merge_results)
188
+
189
+ ## 得到的结果,再做一次合并,第一个通过的返回 , 返回做合并有点重复低效,未来修改。
171
190
  for content,conversations in zip(ranked_result.contents,ranked_result.conversations):
172
191
  merge_result = self._merge_code_without_effect(content)
173
192
  if not merge_result.failed_blocks:
174
193
  return CodeGenerateResult(contents=[content], conversations=[conversations])
175
- # If all have failed blocks, return the first one
194
+
195
+ # 最后保底,但实际不会出现
176
196
  return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
177
197
 
178
198
  @byzerllm.prompt()
@@ -138,15 +138,35 @@ class CodeAutoMergeStrictDiff:
138
138
  def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
139
139
  if len(generate_result.contents) == 1:
140
140
  return generate_result
141
+
142
+ merge_results = []
143
+ for content,conversations in zip(generate_result.contents,generate_result.conversations):
144
+ merge_result = self._merge_code_without_effect(content)
145
+ merge_results.append(merge_result)
146
+
147
+ # If all merge results are None, return first one
148
+ if all(len(result.failed_blocks) != 0 for result in merge_results):
149
+ self.printer.print_in_terminal("all_merge_results_failed")
150
+ return CodeGenerateResult(contents=[generate_result.contents[0]], conversations=[generate_result.conversations[0]])
151
+
152
+ # If only one merge result is not None, return that one
153
+ not_none_indices = [i for i, result in enumerate(merge_results) if len(result.failed_blocks) == 0]
154
+ if len(not_none_indices) == 1:
155
+ idx = not_none_indices[0]
156
+ self.printer.print_in_terminal("only_one_merge_result_success")
157
+ return CodeGenerateResult(contents=[generate_result.contents[idx]], conversations=[generate_result.conversations[idx]])
141
158
 
159
+ # 最后,如果有多个,那么根据质量排序再返回
142
160
  ranker = CodeModificationRanker(self.llm, self.args)
143
- ranked_result = ranker.rank_modifications(generate_result)
144
- # Filter out contents with failed blocks
161
+ ranked_result = ranker.rank_modifications(generate_result,merge_results)
162
+
163
+ ## 得到的结果,再做一次合并,第一个通过的返回 , 返回做合并有点重复低效,未来修改。
145
164
  for content,conversations in zip(ranked_result.contents,ranked_result.conversations):
146
165
  merge_result = self._merge_code_without_effect(content)
147
166
  if not merge_result.failed_blocks:
148
167
  return CodeGenerateResult(contents=[content], conversations=[conversations])
149
- # If all have failed blocks, return the first one
168
+
169
+ # 最后保底,但实际不会出现
150
170
  return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
151
171
 
152
172
 
@@ -9,6 +9,8 @@ import traceback
9
9
  from autocoder.common.utils_code_auto_generate import chat_with_continue
10
10
  from byzerllm.utils.str2model import to_model
11
11
  from autocoder.utils.llms import get_llm_names, get_model_info
12
+ from autocoder.common.types import CodeGenerateResult, MergeCodeWithoutEffect
13
+ import os
12
14
 
13
15
  class RankResult(BaseModel):
14
16
  rank_result: List[int]
@@ -51,12 +53,67 @@ class CodeModificationRanker:
51
53
  }
52
54
  ```
53
55
 
54
- 注意:
56
+ 注意:
55
57
  1. id 为 edit_block 的 id,按质量从高到低排序,并且 id 必须是数字
56
58
  2. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
57
59
  '''
58
60
 
59
- def rank_modifications(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
61
+ @byzerllm.prompt()
62
+ def _rank_modifications_with_merge_result(self, s: CodeGenerateResult,merge_results: List[MergeCodeWithoutEffect]) -> str:
63
+ '''
64
+ 对一组代码修改进行质量评估并排序。
65
+
66
+ 下面是修改需求:
67
+
68
+ <edit_requirement>
69
+ {{ s.conversations[0][-2]["content"] }}
70
+ </edit_requirement>
71
+
72
+ 下面是相应的代码修改,如果Before 为空,那么表示是新增文件,如果After 为空,那么表示是删除文件,如果Before 和 After 都不为空,那么表示是修改文件:
73
+ {% for change in changes %}
74
+ <edit_file id="{{ loop.index0 }}">
75
+ {{change}}
76
+ </edit_file>
77
+ {% endfor %}
78
+
79
+ 请输出如下格式的评估结果,只包含 JSON 数据:
80
+
81
+ ```json
82
+ {
83
+ "rank_result": [id1, id2, id3]
84
+ }
85
+ ```
86
+
87
+ 注意:
88
+ 1. 像python的缩进,前端诸如 reacjs,vue 的标签闭合匹配,这些很重要,需要在排序中作为重点考虑对象之一。
89
+ 1. id 为 edit_file 的 id,按质量从高到低排序,并且 id 必须是数字
90
+ 2. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
91
+ '''
92
+ changes = []
93
+ for merge_result in merge_results:
94
+ s = ""
95
+ for block in merge_result.success_blocks:
96
+ file_path,content = block
97
+ s += f"##File: {file_path}\n\n"
98
+ if not os.path.exists(file_path):
99
+ s += f"##Before: \n\n"
100
+ s += f"##File: {file_path}\n\n"
101
+ s += f"##After: \n\n"
102
+ s += content
103
+ else:
104
+ with open(file_path, "r",encoding="utf-8") as f:
105
+ original_content = f.read()
106
+ s += f"##Before: \n\n"
107
+ s += original_content
108
+ s += f"##File: {file_path}\n\n"
109
+ s += f"##After: \n\n"
110
+ s += content
111
+ changes.append(s)
112
+ return {
113
+ "changes": changes
114
+ }
115
+
116
+ def rank_modifications(self, generate_result: CodeGenerateResult, merge_result: List[MergeCodeWithoutEffect]) -> CodeGenerateResult:
60
117
  import time
61
118
  from collections import defaultdict
62
119
 
@@ -69,8 +126,13 @@ class CodeModificationRanker:
69
126
 
70
127
  rank_times = self.args.rank_times_same_model
71
128
  total_tasks = len(self.llms) * rank_times
129
+ if self.args.rank_strategy == "block":
130
+ query = self._rank_modifications.prompt(generate_result)
131
+ elif self.args.rank_strategy == "file":
132
+ query = self._rank_modifications_with_merge_result.prompt(generate_result, merge_result)
133
+ else:
134
+ raise Exception(f"Invalid rank strategy: {self.args.rank_strategy}")
72
135
 
73
- query = self._rank_modifications.prompt(generate_result)
74
136
  input_tokens_count = 0
75
137
  generated_tokens_count = 0
76
138
  try:
@@ -132,6 +132,12 @@ class ConfigValidator:
132
132
  "type": str,
133
133
  "default": "v3_chat",
134
134
  "description": "提交信息生成模型名称"
135
+ },
136
+ "rank_strategy": {
137
+ "type": str,
138
+ "allowed": ["block", "file"],
139
+ "default": "block",
140
+ "description": "排序策略(block/file)"
135
141
  }
136
142
  }
137
143
 
@@ -0,0 +1,305 @@
1
+ from typing import List, Dict, Any, Union
2
+ from pathlib import Path
3
+ import json
4
+ from loguru import logger
5
+ from autocoder.rag.token_counter import count_tokens
6
+ from autocoder.common import AutoCoderArgs,SourceCode
7
+ from byzerllm.utils.client.code_utils import extract_code
8
+ from autocoder.index.types import VerifyFileRelevance
9
+ import byzerllm
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+
12
+ class PruneContext:
13
+ def __init__(self, max_tokens: int, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
14
+ self.max_tokens = max_tokens
15
+ self.args = args
16
+ self.llm = llm
17
+
18
+ def _delete_overflow_files(self, file_paths: List[str]) -> List[SourceCode]:
19
+ """直接删除超出 token 限制的文件"""
20
+ total_tokens = 0
21
+ selected_files = []
22
+ token_count = 0
23
+ for file_path in file_paths:
24
+ try:
25
+ with open(file_path, "r", encoding="utf-8") as f:
26
+ content = f.read()
27
+ token_count = count_tokens(content)
28
+ if total_tokens + token_count <= self.max_tokens:
29
+ total_tokens += token_count
30
+ print(f"{file_path} {token_count} {content}")
31
+ selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=token_count))
32
+ else:
33
+ break
34
+ except Exception as e:
35
+ logger.error(f"Failed to read file {file_path}: {e}")
36
+ selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=token_count))
37
+
38
+ return selected_files
39
+
40
+ def _extract_code_snippets(self, file_paths: List[str], conversations: List[Dict[str, str]]) -> List[SourceCode]:
41
+ """抽取关键代码片段策略"""
42
+ token_count = 0
43
+ selected_files = []
44
+ full_file_tokens = int(self.max_tokens * 0.8)
45
+
46
+ @byzerllm.prompt()
47
+ def extract_code_snippets(conversations: List[Dict[str, str]], content: str) -> str:
48
+ """
49
+ 根据提供的代码文件和对话历史提取相关代码片段。
50
+
51
+ 处理示例:
52
+ <examples>
53
+ 1. 代码文件:
54
+ <code_file>
55
+ 1 def add(a, b):
56
+ 2 return a + b
57
+ 3 def sub(a, b):
58
+ 4 return a - b
59
+ </code_file>
60
+ <conversation_history>
61
+ <user>: 如何实现加法?
62
+ </conversation_history>
63
+
64
+ 输出:
65
+ ```json
66
+ [
67
+ {"start_line": 1, "end_line": 2}
68
+ ]
69
+ ```
70
+
71
+ 2. 代码文件:
72
+ 1 class User:
73
+ 2 def __init__(self, name):
74
+ 3 self.name = name
75
+ 4 def greet(self):
76
+ 5 return f"Hello, {self.name}"
77
+ </code_file>
78
+ <conversation_history>
79
+ <user>: 如何创建一个User对象?
80
+ </conversation_history>
81
+
82
+ 输出:
83
+ ```json
84
+ [
85
+ {"start_line": 1, "end_line": 3}
86
+ ]
87
+ ```
88
+
89
+ 3. 代码文件:
90
+ <code_file>
91
+ 1 def foo():
92
+ 2 pass
93
+ </code_file>
94
+ <conversation_history>
95
+ <user>: 如何实现减法?
96
+ </conversation_history>
97
+
98
+ 输出:
99
+ ```json
100
+ []
101
+ ```
102
+ </examples>
103
+
104
+ 输入:
105
+ 1. 代码文件内容:
106
+ <code_file>
107
+ {{ content }}
108
+ </code_file>
109
+
110
+ 2. 对话历史:
111
+ <conversation_history>
112
+ {% for msg in conversations %}
113
+ <{{ msg.role }}>: {{ msg.content }}
114
+ {% endfor %}
115
+ </conversation_history>
116
+
117
+ 任务:
118
+ 1. 分析最后一个用户问题及其上下文。
119
+ 2. 在代码文件中找出与问题相关的一个或多个重要代码段。
120
+ 3. 对每个相关代码段,确定其起始行号(start_line)和结束行号(end_line)。
121
+ 4. 代码段数量不超过4个。
122
+
123
+ 输出要求:
124
+ 1. 返回一个JSON数组,每个元素包含"start_line"和"end_line"。
125
+ 2. start_line和end_line必须是整数,表示代码文件中的行号。
126
+ 3. 行号从1开始计数。
127
+ 4. 如果没有相关代码段,返回空数组[]。
128
+
129
+ 输出格式:
130
+ 严格的JSON数组,不包含其他文字或解释。
131
+
132
+ ```json
133
+ [
134
+ {"start_line": 第一个代码段的起始行号, "end_line": 第一个代码段的结束行号},
135
+ {"start_line": 第二个代码段的起始行号, "end_line": 第二个代码段的结束行号}
136
+ ]
137
+ ```
138
+ """
139
+
140
+ for file_path in file_paths:
141
+ try:
142
+ with open(file_path, "r", encoding="utf-8") as f:
143
+ content = f.read()
144
+
145
+ # 完整文件优先
146
+ tokens = count_tokens(content)
147
+ if token_count + tokens <= full_file_tokens:
148
+ selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=tokens))
149
+ token_count += tokens
150
+ continue
151
+
152
+ # 抽取关键片段
153
+ extracted = extract_code_snippets.with_llm(self.llm).run(
154
+ conversations=conversations,
155
+ content=content
156
+ )
157
+
158
+ if extracted:
159
+ json_str = extract_code(extracted)[0][1]
160
+ snippets = json.loads(json_str)
161
+ new_content = self._build_snippet_content(file_path, content, snippets)
162
+
163
+ snippet_tokens = count_tokens(new_content)
164
+ if token_count + snippet_tokens <= self.max_tokens:
165
+ selected_files.append(SourceCode(module_name=file_path,source_code=new_content,tokens=snippet_tokens))
166
+ token_count += snippet_tokens
167
+ else:
168
+ break
169
+ except Exception as e:
170
+ logger.error(f"Failed to process {file_path}: {e}")
171
+ continue
172
+
173
+ return selected_files
174
+
175
+ def _build_snippet_content(self, file_path: str, full_content: str, snippets: List[dict]) -> str:
176
+ """构建包含代码片段的文件内容"""
177
+ lines = full_content.split("\n")
178
+ header = f"Snippets:\n"
179
+
180
+ content = []
181
+ for snippet in snippets:
182
+ start = max(0, snippet["start_line"] - 1)
183
+ end = min(len(lines), snippet["end_line"])
184
+ content.append(f"# Lines {start+1}-{end} ({snippet.get('reason','')})")
185
+ content.extend(lines[start:end])
186
+
187
+ return header + "\n".join(content)
188
+
189
+ def handle_overflow(
190
+ self,
191
+ file_paths: List[str],
192
+ conversations: List[Dict[str, str]],
193
+ strategy: str = "score"
194
+ ) -> List[SourceCode]:
195
+ """
196
+ 处理超出 token 限制的文件
197
+ :param file_paths: 要处理的文件路径列表
198
+ :param conversations: 对话上下文(用于提取策略)
199
+ :param strategy: 处理策略 (delete/extract/score)
200
+ """
201
+ total_tokens,sources = self._count_tokens(file_paths)
202
+ if total_tokens <= self.max_tokens:
203
+ return sources
204
+
205
+ if strategy == "score":
206
+ return self._score_and_filter_files(file_paths, conversations)
207
+ if strategy == "delete":
208
+ return self._delete_overflow_files(file_paths)
209
+ elif strategy == "extract":
210
+ return self._extract_code_snippets(file_paths, conversations)
211
+ else:
212
+ raise ValueError(f"无效策略: {strategy}. 可选值: delete/extract/score")
213
+
214
+ def _count_tokens(self, file_paths: List[str]) -> int:
215
+ """计算文件总token数"""
216
+ total_tokens = 0
217
+ sources = []
218
+ for file_path in file_paths:
219
+ try:
220
+ with open(file_path, "r", encoding="utf-8") as f:
221
+ content = f.read()
222
+ sources.append(SourceCode(module_name=file_path,source_code=content,tokens=count_tokens(content)))
223
+ total_tokens += count_tokens(content)
224
+ except Exception as e:
225
+ logger.error(f"Failed to read file {file_path}: {e}")
226
+ total_tokens += 0
227
+ return total_tokens,sources
228
+
229
+ def _score_and_filter_files(self, file_paths: List[str], conversations: List[Dict[str, str]]) -> List[SourceCode]:
230
+ """根据文件相关性评分过滤文件,直到token数大于max_tokens 停止追加"""
231
+ selected_files = []
232
+ total_tokens = 0
233
+ scored_files = []
234
+
235
+ @byzerllm.prompt()
236
+ def verify_file_relevance(file_content: str, conversations: List[Dict[str, str]]) -> str:
237
+ """
238
+ 请验证下面的文件内容是否与用户对话相关:
239
+
240
+ 文件内容:
241
+ {{ file_content }}
242
+
243
+ 历史对话:
244
+ <conversation_history>
245
+ {% for msg in conversations %}
246
+ <{{ msg.role }}>: {{ msg.content }}
247
+ {% endfor %}
248
+ </conversation_history>
249
+
250
+ 相关是指,需要依赖这个文件提供上下文,或者需要修改这个文件才能解决用户的问题。
251
+ 请给出相应的可能性分数:0-10,并结合用户问题,理由控制在50字以内。格式如下:
252
+
253
+ ```json
254
+ {
255
+ "relevant_score": 0-10,
256
+ "reason": "这是相关的原因(不超过10个中文字符)..."
257
+ }
258
+ ```
259
+ """
260
+
261
+ def _score_file(file_path: str) -> dict:
262
+ try:
263
+ with open(file_path, "r", encoding="utf-8") as f:
264
+ content = f.read()
265
+ tokens = count_tokens(content)
266
+ result = verify_file_relevance.with_llm(self.llm).with_return_type(VerifyFileRelevance).run(
267
+ file_content=content,
268
+ conversations=conversations
269
+ )
270
+ return {
271
+ "file_path": file_path,
272
+ "score": result.relevant_score,
273
+ "tokens": tokens,
274
+ "content": content
275
+ }
276
+ except Exception as e:
277
+ logger.error(f"Failed to score file {file_path}: {e}")
278
+ return None
279
+
280
+ # 使用线程池并行打分
281
+ with ThreadPoolExecutor() as executor:
282
+ futures = [executor.submit(_score_file, file_path) for file_path in file_paths]
283
+ for future in as_completed(futures):
284
+ result = future.result()
285
+ print(f"score file {result['file_path']} {result['score']}")
286
+ if result:
287
+ scored_files.append(result)
288
+
289
+ # 第二步:按分数从高到低排序
290
+ scored_files.sort(key=lambda x: x["score"], reverse=True)
291
+
292
+ # 第三步:从高分开始过滤,直到token数大于max_tokens 停止追加
293
+ for file_info in scored_files:
294
+ if total_tokens + file_info["tokens"] <= self.max_tokens:
295
+ selected_files.append(SourceCode(
296
+ module_name=file_info["file_path"],
297
+ source_code=file_info["content"],
298
+ tokens=file_info["tokens"]
299
+ ))
300
+ total_tokens += file_info["tokens"]
301
+ else:
302
+ break
303
+
304
+ return selected_files
305
+
autocoder/index/entry.py CHANGED
@@ -24,6 +24,7 @@ from autocoder.index.filter.normal_filter import NormalFilter
24
24
  from autocoder.index.index import IndexManager
25
25
  from loguru import logger
26
26
  from autocoder.common import SourceCodeList
27
+ from autocoder.common.context_pruner import PruneContext
27
28
 
28
29
  def build_index_and_filter_files(
29
30
  llm, args: AutoCoderArgs, sources: List[SourceCode]
@@ -113,8 +114,13 @@ def build_index_and_filter_files(
113
114
  raise KeyboardInterrupt(printer.get_message_from_key_with_format("quick_filter_failed",error=quick_filter_result.error_message))
114
115
 
115
116
  # Merge quick filter results into final_files
116
- final_files.update(quick_filter_result.files)
117
-
117
+ if args.context_prune:
118
+ context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
119
+ pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
120
+ for source_file in pruned_files:
121
+ final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
122
+ else:
123
+ final_files.update(quick_filter_result.files)
118
124
 
119
125
  if not args.skip_filter_index and not args.index_filter_model:
120
126
  model_name = getattr(index_manager.llm, 'default_model_name', None)
@@ -1,4 +1,6 @@
1
- from typing import List, Union,Dict,Any
1
+ from typing import List, Union,Dict,Any,Optional
2
+
3
+ from pydantic import BaseModel
2
4
  from autocoder.index.types import IndexItem
3
5
  from autocoder.common import SourceCode, AutoCoderArgs
4
6
  import byzerllm
@@ -25,6 +27,11 @@ def get_file_path(file_path):
25
27
  return file_path.strip()[2:]
26
28
  return file_path
27
29
 
30
+ class NormalFilterResult(BaseModel):
31
+ files: Dict[str, TargetFile]
32
+ has_error: bool
33
+ error_message: Optional[str] = None
34
+ file_positions: Optional[Dict[str, int]]
28
35
 
29
36
  class NormalFilter():
30
37
  def __init__(self, index_manager: IndexManager,stats:Dict[str,Any],sources:List[SourceCode]):
@@ -167,4 +174,8 @@ class NormalFilter():
167
174
  # Keep all files, not just verified ones
168
175
  final_files = verified_files
169
176
 
170
- return final_files
177
+ return NormalFilterResult(
178
+ files=final_files,
179
+ has_error=False,
180
+ error_message=None
181
+ )