auto-coder 0.1.263__py3-none-any.whl → 0.1.265__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

Files changed (58) hide show
  1. {auto_coder-0.1.263.dist-info → auto_coder-0.1.265.dist-info}/METADATA +1 -1
  2. {auto_coder-0.1.263.dist-info → auto_coder-0.1.265.dist-info}/RECORD +58 -55
  3. autocoder/agent/planner.py +4 -4
  4. autocoder/auto_coder.py +26 -21
  5. autocoder/auto_coder_server.py +7 -7
  6. autocoder/chat_auto_coder.py +203 -98
  7. autocoder/commands/auto_command.py +81 -4
  8. autocoder/commands/tools.py +48 -50
  9. autocoder/common/__init__.py +6 -1
  10. autocoder/common/auto_coder_lang.py +41 -3
  11. autocoder/common/code_auto_generate.py +3 -3
  12. autocoder/common/code_auto_generate_diff.py +12 -15
  13. autocoder/common/code_auto_generate_editblock.py +3 -3
  14. autocoder/common/code_auto_generate_strict_diff.py +3 -3
  15. autocoder/common/code_auto_merge.py +23 -3
  16. autocoder/common/code_auto_merge_diff.py +29 -4
  17. autocoder/common/code_auto_merge_editblock.py +25 -5
  18. autocoder/common/code_auto_merge_strict_diff.py +26 -6
  19. autocoder/common/code_modification_ranker.py +65 -3
  20. autocoder/common/command_completer.py +3 -0
  21. autocoder/common/command_generator.py +24 -8
  22. autocoder/common/command_templates.py +2 -2
  23. autocoder/common/conf_import_export.py +105 -0
  24. autocoder/common/conf_validator.py +7 -1
  25. autocoder/common/context_pruner.py +305 -0
  26. autocoder/common/files.py +41 -2
  27. autocoder/common/image_to_page.py +11 -11
  28. autocoder/common/index_import_export.py +38 -18
  29. autocoder/common/mcp_hub.py +3 -3
  30. autocoder/common/mcp_server.py +2 -2
  31. autocoder/common/shells.py +254 -13
  32. autocoder/common/stats_panel.py +126 -0
  33. autocoder/dispacher/actions/action.py +6 -18
  34. autocoder/dispacher/actions/copilot.py +2 -2
  35. autocoder/dispacher/actions/plugins/action_regex_project.py +1 -3
  36. autocoder/dispacher/actions/plugins/action_translate.py +1 -1
  37. autocoder/index/entry.py +8 -2
  38. autocoder/index/filter/normal_filter.py +13 -2
  39. autocoder/index/filter/quick_filter.py +127 -13
  40. autocoder/index/index.py +8 -7
  41. autocoder/models.py +2 -2
  42. autocoder/pyproject/__init__.py +5 -5
  43. autocoder/rag/cache/byzer_storage_cache.py +4 -4
  44. autocoder/rag/cache/file_monitor_cache.py +2 -2
  45. autocoder/rag/cache/simple_cache.py +4 -4
  46. autocoder/rag/long_context_rag.py +2 -2
  47. autocoder/regexproject/__init__.py +3 -2
  48. autocoder/suffixproject/__init__.py +3 -2
  49. autocoder/tsproject/__init__.py +3 -2
  50. autocoder/utils/conversation_store.py +1 -1
  51. autocoder/utils/operate_config_api.py +3 -3
  52. autocoder/utils/project_structure.py +258 -3
  53. autocoder/utils/thread_utils.py +6 -1
  54. autocoder/version.py +1 -1
  55. {auto_coder-0.1.263.dist-info → auto_coder-0.1.265.dist-info}/LICENSE +0 -0
  56. {auto_coder-0.1.263.dist-info → auto_coder-0.1.265.dist-info}/WHEEL +0 -0
  57. {auto_coder-0.1.263.dist-info → auto_coder-0.1.265.dist-info}/entry_points.txt +0 -0
  58. {auto_coder-0.1.263.dist-info → auto_coder-0.1.265.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
-
2
1
  import sys
3
2
  import os
4
3
  import locale
5
4
  import subprocess
6
5
  import platform
6
+ import tempfile
7
+ import uuid
7
8
  from rich.console import Console
8
9
  from rich.panel import Panel
9
10
  from rich.text import Text
@@ -21,18 +22,170 @@ def get_terminal_name() -> str:
21
22
  else:
22
23
  return _get_unix_terminal_name()
23
24
 
25
+ def is_running_in_powershell() -> bool:
26
+ """
27
+ 检查当前 Python 进程是否在 PowerShell 环境中运行
28
+ Returns:
29
+ bool: True 表示在 PowerShell 环境中,False 表示不在
30
+ """
31
+ try:
32
+ # 方法1: 检查特定的 PowerShell 环境变量
33
+ if any(key for key in os.environ if 'POWERSHELL' in key.upper()):
34
+ return True
35
+
36
+ # 方法2: 尝试执行 PowerShell 特定命令
37
+ try:
38
+ result = subprocess.run(
39
+ ['powershell', '-NoProfile', '-Command', '$PSVersionTable'],
40
+ capture_output=True,
41
+ timeout=1
42
+ )
43
+ if result.returncode == 0:
44
+ return True
45
+ except Exception:
46
+ pass
47
+
48
+ # 方法3: 检查父进程
49
+ try:
50
+ import psutil
51
+ current_process = psutil.Process()
52
+ parent = current_process.parent()
53
+ if parent:
54
+ parent_name = parent.name().lower()
55
+ if 'powershell' in parent_name or 'pwsh' in parent_name:
56
+ return True
57
+
58
+ # 递归检查父进程链
59
+ while parent and parent.pid != 1: # 1 是系统初始进程
60
+ if 'powershell' in parent.name().lower() or 'pwsh' in parent.name().lower():
61
+ return True
62
+ parent = parent.parent()
63
+ except Exception:
64
+ pass
65
+
66
+ # 方法4: 检查命令行参数
67
+ try:
68
+ import sys
69
+ if any('powershell' in arg.lower() for arg in sys.argv):
70
+ return True
71
+ except Exception:
72
+ pass
73
+
74
+ return False
75
+ except Exception:
76
+ return False
77
+
78
+ def is_running_in_cmd() -> bool:
79
+ """
80
+ 检查当前 Python 进程是否在 CMD 环境中运行
81
+ Returns:
82
+ bool: True 表示在 CMD 环境中,False 表示不在
83
+ """
84
+ # 如果在 PowerShell 中,直接返回 False
85
+ if is_running_in_powershell():
86
+ return False
87
+
88
+ try:
89
+ # 方法1: 检查特定的 CMD 环境变量
90
+ env = os.environ
91
+ # CMD 特有的环境变量
92
+ if 'PROMPT' in env and not any(key for key in env if 'POWERSHELL' in key.upper()):
93
+ return True
94
+
95
+ # 方法2: 检查 ComSpec 环境变量
96
+ comspec = env.get('ComSpec', '').lower()
97
+ if 'cmd.exe' in comspec:
98
+ return True
99
+
100
+ # 方法3: 检查父进程
101
+ try:
102
+ import psutil
103
+ current_process = psutil.Process()
104
+ parent = current_process.parent()
105
+ if parent:
106
+ parent_name = parent.name().lower()
107
+ if 'cmd.exe' in parent_name:
108
+ return True
109
+
110
+ # 递归检查父进程链
111
+ while parent and parent.pid != 1: # 1 是系统初始进程
112
+ if 'cmd.exe' in parent.name().lower():
113
+ return True
114
+ parent = parent.parent()
115
+ except Exception:
116
+ pass
117
+
118
+ return False
119
+ except Exception:
120
+ return False
121
+
24
122
  def _get_windows_terminal_name() -> str:
25
123
  """Windows 系统终端检测"""
26
- # 检查是否在 PowerShell
27
- if 'POWERSHELL_DISTRIBUTION_CHANNEL' in os.environ:
124
+ # 检查环境变量
125
+ env = os.environ
126
+
127
+ # 首先使用新方法检查是否在 PowerShell 环境中
128
+ if is_running_in_powershell():
129
+ # 进一步区分是否在 VSCode 的 PowerShell 终端
130
+ if 'VSCODE_GIT_IPC_HANDLE' in env:
131
+ return 'vscode-powershell'
28
132
  return 'powershell'
29
133
 
134
+ # 检查是否在 CMD 环境中
135
+ if is_running_in_cmd():
136
+ # 区分是否在 VSCode 的 CMD 终端
137
+ if 'VSCODE_GIT_IPC_HANDLE' in env:
138
+ return 'vscode-cmd'
139
+ return 'cmd'
140
+
30
141
  # 检查是否在 Git Bash
31
- if 'MINGW' in platform.system():
142
+ if ('MINGW' in platform.system() or
143
+ 'MSYSTEM' in env or
144
+ any('bash.exe' in path.lower() for path in env.get('PATH', '').split(os.pathsep))):
145
+ # 区分是否在 VSCode 的 Git Bash 终端
146
+ if 'VSCODE_GIT_IPC_HANDLE' in env:
147
+ return 'vscode-git-bash'
32
148
  return 'git-bash'
33
149
 
150
+ # 检查是否在 VSCode 的集成终端
151
+ if 'VSCODE_GIT_IPC_HANDLE' in env:
152
+ if 'WT_SESSION' in env: # Windows Terminal
153
+ return 'vscode-windows-terminal'
154
+ return 'vscode-terminal'
155
+
156
+ # 检查是否在 Windows Terminal
157
+ if 'WT_SESSION' in env:
158
+ return 'windows-terminal'
159
+
160
+ # 检查是否在 Cygwin
161
+ if 'CYGWIN' in platform.system():
162
+ return 'cygwin'
163
+
164
+ # 检查 TERM 环境变量
165
+ term = env.get('TERM', '').lower()
166
+ if term:
167
+ if 'xterm' in term:
168
+ return 'xterm'
169
+ elif 'cygwin' in term:
170
+ return 'cygwin'
171
+
172
+ # 检查进程名
173
+ try:
174
+ import psutil
175
+ parent = psutil.Process().parent()
176
+ if parent:
177
+ parent_name = parent.name().lower()
178
+ if 'powershell' in parent_name:
179
+ return 'powershell'
180
+ elif 'windowsterminal' in parent_name:
181
+ return 'windows-terminal'
182
+ elif 'cmd.exe' in parent_name:
183
+ return 'cmd'
184
+ except (ImportError, Exception):
185
+ pass
186
+
34
187
  # 默认返回 cmd.exe
35
- return 'cmd.exe'
188
+ return 'cmd'
36
189
 
37
190
  def _get_unix_terminal_name() -> str:
38
191
  """Linux/Mac 系统终端检测"""
@@ -138,28 +291,109 @@ def execute_shell_command(command: str):
138
291
 
139
292
  Args:
140
293
  command (str): The shell command to execute
141
- encoding (str, optional): Override default encoding. Defaults to None.
142
294
  """
143
295
  console = Console()
144
296
  result_manager = ResultManager()
297
+ temp_file = None
145
298
  try:
146
- # Get terminal encoding
299
+ # Get terminal encoding and name
147
300
  encoding = get_terminal_encoding()
301
+ terminal_name = get_terminal_name()
302
+
303
+ # Windows系统特殊处理
304
+ if sys.platform == 'win32':
305
+ # 设置控制台代码页为 UTF-8
306
+ os.system('chcp 65001 > nul')
307
+ # 强制使用 UTF-8 编码
308
+ encoding = 'utf-8'
309
+ # 设置环境变量
310
+ os.environ['PYTHONIOENCODING'] = 'utf-8'
311
+
312
+ # Create temp script file
313
+ if sys.platform == 'win32':
314
+ if is_running_in_powershell():
315
+ # Create temp PowerShell script with UTF-8 BOM
316
+ temp_file = tempfile.NamedTemporaryFile(
317
+ mode='wb',
318
+ suffix='.ps1',
319
+ delete=False
320
+ )
321
+ # 添加 UTF-8 BOM
322
+ temp_file.write(b'\xef\xbb\xbf')
323
+ # 设置输出编码
324
+ ps_command = f'$OutputEncoding = [Console]::OutputEncoding = [Text.Encoding]::UTF8\n{command}'
325
+ temp_file.write(ps_command.encode('utf-8'))
326
+ temp_file.close()
327
+ # Execute the temp script with PowerShell
328
+ command = f'powershell.exe -NoProfile -NonInteractive -ExecutionPolicy Bypass -File "{temp_file.name}"'
329
+ elif is_running_in_cmd():
330
+ # Create temp batch script with UTF-8
331
+ temp_file = tempfile.NamedTemporaryFile(
332
+ mode='wb',
333
+ suffix='.cmd',
334
+ delete=False
335
+ )
336
+ # 添加 UTF-8 BOM
337
+ temp_file.write(b'\xef\xbb\xbf')
338
+ # 写入命令内容,确保UTF-8输出
339
+ content = f"""@echo off
340
+ chcp 65001 > nul
341
+ set PYTHONIOENCODING=utf-8
342
+ {command}
343
+ """
344
+ temp_file.write(content.encode('utf-8'))
345
+ temp_file.close()
346
+ # Execute the temp batch script
347
+ command = f'cmd.exe /c "{temp_file.name}"'
348
+ else:
349
+ # Create temp shell script for Unix-like systems
350
+ temp_file = tempfile.NamedTemporaryFile(
351
+ mode='w',
352
+ suffix='.sh',
353
+ encoding='utf-8',
354
+ delete=False
355
+ )
356
+ temp_file.write('#!/bin/bash\n' + command)
357
+ temp_file.close()
358
+ # Make the script executable
359
+ os.chmod(temp_file.name, 0o755)
360
+ command = temp_file.name
148
361
 
149
- # Start subprocess
362
+ # Start subprocess with UTF-8 encoding
363
+ startupinfo = None
364
+ if sys.platform == 'win32':
365
+ startupinfo = subprocess.STARTUPINFO()
366
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
367
+
368
+ # 创建子进程时设置环境变量
369
+ env = os.environ.copy()
370
+ env['PYTHONIOENCODING'] = 'utf-8'
371
+
150
372
  process = subprocess.Popen(
151
373
  command,
152
374
  stdout=subprocess.PIPE,
153
375
  stderr=subprocess.PIPE,
154
- shell=True
376
+ shell=True,
377
+ encoding='utf-8', # 直接指定 UTF-8 编码
378
+ errors='replace', # 处理无法解码的字符
379
+ env=env, # 传递修改后的环境变量
380
+ startupinfo=startupinfo
155
381
  )
156
382
 
157
- # Safe decoding helper
383
+ # Safe decoding helper (for binary output)
158
384
  def safe_decode(byte_stream, encoding):
385
+ if isinstance(byte_stream, str):
386
+ return byte_stream.strip()
159
387
  try:
160
- return byte_stream.decode(encoding).strip()
388
+ # 首先尝试 UTF-8
389
+ return byte_stream.decode('utf-8').strip()
161
390
  except UnicodeDecodeError:
162
- return byte_stream.decode(encoding, errors='replace').strip()
391
+ try:
392
+ # 如果失败,尝试 GBK
393
+ return byte_stream.decode('gbk').strip()
394
+ except UnicodeDecodeError:
395
+ # 最后使用替换模式
396
+ return byte_stream.decode(encoding, errors='replace').strip()
163
397
 
164
398
  output = []
165
399
  with Live(console=console, refresh_per_second=4) as live:
@@ -238,4 +472,11 @@ def execute_shell_command(command: str):
238
472
  })
239
473
  console.print(
240
474
  f"[bold red]Unexpected error:[/bold red] [yellow]{str(e)}[/yellow]"
241
- )
475
+ )
476
+ finally:
477
+ # Clean up temp file
478
+ if temp_file and os.path.exists(temp_file.name):
479
+ try:
480
+ os.unlink(temp_file.name)
481
+ except Exception:
482
+ pass
@@ -0,0 +1,126 @@
1
+
2
+
3
+
4
+
5
+
6
+
7
+
8
+
9
+ from rich.console import Console
10
+ from rich.panel import Panel
11
+ from rich.columns import Columns
12
+ from rich.text import Text
13
+ import math
14
+
15
+ class StatsPanel:
16
+ def __init__(self, console: Console = None):
17
+ self.console = console if console else Console()
18
+
19
+ def _format_speed_bar(self, speed: float) -> Text:
20
+ """生成速度可视化进度条(保持原30-60区间)"""
21
+ if speed < 30:
22
+ color = "red"
23
+ level = "低"
24
+ elif 30 <= speed < 60:
25
+ color = "yellow"
26
+ level = "中"
27
+ else:
28
+ color = "green"
29
+ level = "高"
30
+
31
+ bar_length = min(int(speed), 100)
32
+ bar = Text("▮" * bar_length, style=color)
33
+ bar.append(f" {speed:.1f} tokens/s ({level})", style="bold white")
34
+ return bar
35
+
36
+ def _format_progress_bar(self, value: int, max_value: int, label: str, color: str) -> Text:
37
+ """生成通用进度条"""
38
+ progress = min(value / max_value, 1.0)
39
+ bar_length = int(progress * 20)
40
+ bar = Text("▮" * bar_length, style=color)
41
+ bar.append(f" {value} ({label})", style="bold white")
42
+ return bar
43
+
44
+ def generate(
45
+ self,
46
+ model_names: str,
47
+ duration: float,
48
+ sampling_count: int,
49
+ input_tokens: int,
50
+ output_tokens: int,
51
+ input_cost: float,
52
+ output_cost: float,
53
+ speed: float,
54
+ ) -> None:
55
+ """新版紧凑布局"""
56
+ # 复合标题(带图标和关键数据)
57
+ title = Text.assemble(
58
+ "📊 ", ("代码生成统计", "bold cyan underline"),
59
+ " │ ⚡", (f"{speed:.1f}t/s ", "bold green"),
60
+ "│ 💰", (f"${input_cost + output_cost:.4f}", "bold yellow")
61
+ )
62
+
63
+ # 处理耗时颜色逻辑(新增15-30-60区间)
64
+ duration_color = "green"
65
+ if 15 <= duration < 30:
66
+ duration_color = "yellow"
67
+ elif duration >= 30:
68
+ duration_color = "red"
69
+
70
+ # 处理成本颜色逻辑(新增0.5-1区间)
71
+ def get_cost_color(cost: float) -> str:
72
+ if cost < 0.5: return "green"
73
+ elif 0.5 <= cost < 1: return "yellow"
74
+ else: return "red"
75
+
76
+ # 紧凑网格布局
77
+ grid = [
78
+ Panel(
79
+ Text.assemble(
80
+ ("🤖 模型: ", "bold"), model_names + "\n",
81
+ self._format_mini_progress(duration, 60.0, duration_color), # 耗时max=60
82
+ (" ⏱", duration_color), f" {duration:.1f}s │ ",
83
+ self._format_mini_progress(sampling_count, 100, "blue"),
84
+ (" 🔢", "blue"), f" {sampling_count}\n",
85
+ ("📥", "green"), " ",
86
+ self._format_mini_progress(input_tokens, 65536.0, "green"), # token分母改为65536
87
+ f" {input_tokens} ({input_tokens/65536*100:.2f}%) │ ", # 新增百分比显示
88
+ ("📤", "bright_green"), " ",
89
+ self._format_mini_progress(output_tokens, 65536.0, "bright_green"),
90
+ f" {output_tokens} ({output_tokens/65536*100:.2f}%)" # 新增百分比显示
91
+ ),
92
+ border_style="cyan",
93
+ padding=(0, 2)
94
+ ),
95
+ Panel(
96
+ Text.assemble(
97
+ ("💵 成本: ", "bold"),
98
+ self._format_mini_progress(input_cost, 1.0, get_cost_color(input_cost)), # 成本max=1
99
+ (" IN", get_cost_color(input_cost)), f" {input_cost:.3f}\n",
100
+ ("💸 ", "bold"),
101
+ self._format_mini_progress(output_cost, 1.0, get_cost_color(output_cost)),
102
+ (" OUT", get_cost_color(output_cost)), f" {output_cost:.3f}\n",
103
+ self._format_speed_bar(speed)
104
+ ),
105
+ border_style="yellow",
106
+ padding=(0, 1)
107
+ )
108
+ ]
109
+
110
+ # 组合布局
111
+ main_panel = Panel(
112
+ Columns(grid, equal=True, expand=True),
113
+ title=title,
114
+ border_style="bright_blue",
115
+ padding=(1, 2)
116
+ )
117
+
118
+ self.console.print(main_panel)
119
+
120
+
121
+ def _format_mini_progress(self, value: float, max_value: float, color: str) -> Text:
122
+ """紧凑型进度条(支持浮点数)"""
123
+ progress = min(value / max_value, 1.0)
124
+ filled = "▮" * int(progress * 10)
125
+ empty = "▯" * (10 - len(filled))
126
+ return Text(filled + empty, style=color)
@@ -86,7 +86,7 @@ class ActionTSProject(BaseAction):
86
86
  max_iter=self.args.image_max_iter,
87
87
  )
88
88
  html_code = ""
89
- with open(html_path, "r") as f:
89
+ with open(html_path, "r",encoding="utf-8") as f:
90
90
  html_code = f.read()
91
91
 
92
92
  source_code_list.sources.append(SourceCode(
@@ -190,9 +190,7 @@ class ActionTSProject(BaseAction):
190
190
  conversations=generate_result.conversations[0],
191
191
  model=self.llm.default_model_name,
192
192
  )
193
-
194
- with open(args.target_file, "w") as file:
195
- file.write(content)
193
+
196
194
 
197
195
 
198
196
  class ActionPyScriptProject(BaseAction):
@@ -300,11 +298,7 @@ class ActionPyScriptProject(BaseAction):
300
298
  instruction=self.args.query,
301
299
  conversations=generate_result.conversations[0],
302
300
  model=self.llm.default_model_name,
303
- )
304
-
305
- end_time = time.time()
306
- with open(self.args.target_file, "w") as file:
307
- file.write(content)
301
+ )
308
302
 
309
303
 
310
304
  class ActionPyProject(BaseAction):
@@ -435,9 +429,7 @@ class ActionPyProject(BaseAction):
435
429
  instruction=self.args.query,
436
430
  conversations=generate_result.conversations[0],
437
431
  model=self.llm.default_model_name,
438
- )
439
- with open(args.target_file, "w") as file:
440
- file.write(content)
432
+ )
441
433
 
442
434
 
443
435
  class ActionSuffixProject(BaseAction):
@@ -551,9 +543,7 @@ class ActionSuffixProject(BaseAction):
551
543
  instruction=self.args.query,
552
544
  conversations=merge_result.conversations[0],
553
545
  model=self.llm.default_model_name,
554
- )
555
- with open(args.target_file, "w") as file:
556
- file.write(content)
546
+ )
557
547
  else:
558
548
  content = generate_result.contents[0]
559
549
 
@@ -563,7 +553,5 @@ class ActionSuffixProject(BaseAction):
563
553
  conversations=generate_result.conversations[0],
564
554
  model=self.llm.default_model_name,
565
555
  )
566
-
567
- with open(args.target_file, "w") as file:
568
- file.write(content)
556
+
569
557
 
@@ -343,7 +343,7 @@ class ActionCopilot:
343
343
  logger.info(
344
344
  "model is not specified and we will generate prompt to the target file"
345
345
  )
346
- with open(args.target_file, "w") as f:
346
+ with open(args.target_file, "w",encoding="utf-8") as f:
347
347
  f.write(q)
348
348
  return True
349
349
 
@@ -379,7 +379,7 @@ class ActionCopilot:
379
379
  logger.info(result)
380
380
 
381
381
  # 将结果写入文件
382
- with open(args.target_file, "w") as f:
382
+ with open(args.target_file, "w",encoding="utf-8") as f:
383
383
  f.write("=================CONVERSATION==================\n\n")
384
384
  for conversation in conversations:
385
385
  f.write(f"{conversation['role']}: {conversation['content']}\n")
@@ -146,6 +146,4 @@ class ActionRegexProject:
146
146
  conversations=generate_result.conversations[0],
147
147
  model=self.llm.default_model_name,
148
148
  )
149
-
150
- with open(args.target_file, "w") as file:
151
- file.write(content)
149
+
@@ -209,6 +209,6 @@ class ActionTranslate:
209
209
  new_filename = f"{filename}{new_file_mark}{extension}"
210
210
 
211
211
  logger.info(f"Writing to {new_filename}...")
212
- with open(new_filename, "w") as file:
212
+ with open(new_filename, "w",encoding="utf-8") as file:
213
213
  file.write(readme.content)
214
214
  return True
autocoder/index/entry.py CHANGED
@@ -24,6 +24,7 @@ from autocoder.index.filter.normal_filter import NormalFilter
24
24
  from autocoder.index.index import IndexManager
25
25
  from loguru import logger
26
26
  from autocoder.common import SourceCodeList
27
+ from autocoder.common.context_pruner import PruneContext
27
28
 
28
29
  def build_index_and_filter_files(
29
30
  llm, args: AutoCoderArgs, sources: List[SourceCode]
@@ -113,8 +114,13 @@ def build_index_and_filter_files(
113
114
  raise KeyboardInterrupt(printer.get_message_from_key_with_format("quick_filter_failed",error=quick_filter_result.error_message))
114
115
 
115
116
  # Merge quick filter results into final_files
116
- final_files.update(quick_filter_result.files)
117
-
117
+ if args.context_prune:
118
+ context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
119
+ pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
120
+ for source_file in pruned_files:
121
+ final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
122
+ else:
123
+ final_files.update(quick_filter_result.files)
118
124
 
119
125
  if not args.skip_filter_index and not args.index_filter_model:
120
126
  model_name = getattr(index_manager.llm, 'default_model_name', None)
@@ -1,4 +1,6 @@
1
- from typing import List, Union,Dict,Any
1
+ from typing import List, Union,Dict,Any,Optional
2
+
3
+ from pydantic import BaseModel
2
4
  from autocoder.index.types import IndexItem
3
5
  from autocoder.common import SourceCode, AutoCoderArgs
4
6
  import byzerllm
@@ -25,6 +27,11 @@ def get_file_path(file_path):
25
27
  return file_path.strip()[2:]
26
28
  return file_path
27
29
 
30
+ class NormalFilterResult(BaseModel):
31
+ files: Dict[str, TargetFile]
32
+ has_error: bool
33
+ error_message: Optional[str] = None
34
+ file_positions: Optional[Dict[str, int]]
28
35
 
29
36
  class NormalFilter():
30
37
  def __init__(self, index_manager: IndexManager,stats:Dict[str,Any],sources:List[SourceCode]):
@@ -167,4 +174,8 @@ class NormalFilter():
167
174
  # Keep all files, not just verified ones
168
175
  final_files = verified_files
169
176
 
170
- return final_files
177
+ return NormalFilterResult(
178
+ files=final_files,
179
+ has_error=False,
180
+ error_message=None
181
+ )