auto-coder 0.1.268__py3-none-any.whl → 0.1.269__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -21,6 +21,8 @@ from loguru import logger
21
21
  from autocoder.utils import llms as llms_utils
22
22
  from autocoder.rag.token_counter import count_tokens
23
23
  from autocoder.common.global_cancel import global_cancel
24
+ from autocoder.common.auto_configure import config_readme
25
+ from autocoder.utils.auto_project_type import ProjectTypeAnalyzer
24
26
 
25
27
  class CommandMessage(BaseModel):
26
28
  role: str
@@ -155,7 +157,8 @@ class CommandAutoTuner:
155
157
  self.printer = Printer()
156
158
  self.memory_config = memory_config
157
159
  self.command_config = command_config
158
- self.tools = AutoCommandTools(args=args, llm=self.llm)
160
+ self.tools = AutoCommandTools(args=args, llm=self.llm)
161
+ self.project_type_analyzer = ProjectTypeAnalyzer(args=args, llm=self.llm)
159
162
 
160
163
  def get_conversations(self) -> List[CommandMessage]:
161
164
  """Get conversation history from memory file"""
@@ -440,7 +443,7 @@ class CommandAutoTuner:
440
443
  safe_zone=self.args.conversation_prune_safe_zone_tokens
441
444
  )
442
445
  from autocoder.common.conversation_pruner import ConversationPruner
443
- pruner = ConversationPruner(self.llm)
446
+ pruner = ConversationPruner(self.args, self.llm)
444
447
  conversations = pruner.prune_conversations(conversations)
445
448
 
446
449
  title = printer.get_message_from_key("auto_command_analyzing")
@@ -646,34 +649,7 @@ class CommandAutoTuner:
646
649
 
647
650
  常见的一些配置选项示例:
648
651
 
649
- # 配置项说明
650
- ## auto_merge: 代码合并方式,可选值为editblock、diff、wholefile.
651
- - editblock: 生成 SEARCH/REPLACE 块,然后根据 SEARCH块到对应的源码查找,如果相似度阈值大于 editblock_similarity, 那么则将
652
- 找到的代码块替换为 REPLACE 块。大部分情况都推荐使用 editblock。
653
- - wholefile: 重新生成整个文件,然后替换原来的文件。对于重构场景,推荐使用 wholefile。
654
- - diff: 生成标准 git diff 格式,适用于简单的代码修改。
655
-
656
- ## editblock_similarity: editblock相似度阈值
657
- - editblock相似度阈值,取值范围为0-1,默认值为0.9。如果设置的太低,虽然能合并进去,但是会引入错误。推荐不要修改该值。
658
-
659
- ## generate_times_same_model: 相同模型生成次数,也叫采样数
660
- 当进行生成代码时,大模型会对同一个需求生成多份代码,然后会使用 generate_rerank_model 模型对多份代码进行重排序,
661
- 然后选择得分最高的代码。一般次数越多,最终得到正确的代码概率越高。默认值为1,推荐设置为3。但是设置值越多,可能速度就越慢,消耗的token也越多。
662
- 当用户提到,帮我采样数设置为3, 那么你就设置该参数即可。
663
-
664
- ## skip_filter_index: 是否跳过索引过滤
665
- 是否跳过根据用户的query 自动查找上下文。推荐设置为 false
666
-
667
- ## skip_build_index: 是否跳过索引构建
668
- 是否自动构建索引。推荐设置为 false。注意,如果该值设置为 true, 那么 skip_filter_index 设置不会生效。
669
-
670
- ## enable_global_memory: 是否开启全局记忆
671
- 是否开启全局记忆。
672
-
673
- ## rank_times_same_model: 相同模型重排序次数
674
- 默认值为1. 如果 generate_times_same_model 参数设置大于1,那么 coding 函数会自动对多份代码进行重排序。
675
- rank_times_same_model 表示重拍的次数,次数越多,选择到最好的代码的可能性越高,但是也会显著增加消耗的token和时间。
676
- 建议保持默认,要修改也建议不要超过3。
652
+ {{ config_readme }}
677
653
 
678
654
  比如你想开启索引,则可以执行:
679
655
 
@@ -1190,10 +1166,26 @@ class CommandAutoTuner:
1190
1166
  exclude_files(query="/drop regex://.*/package-lock\.json")
1191
1167
  </usage>
1192
1168
  </command>
1169
+
1170
+ <command>
1171
+ <name>get_project_type</name>
1172
+ <description>获取项目类型。</description>
1173
+ <usage>
1174
+ 该命令获取项目类型。
1175
+
1176
+ 使用例子:
1177
+ get_project_type()
1178
+
1179
+ 此时会返回诸如 "ts,py,java,go,js,ts" 这样的字符串,表示项目类型。
1180
+ </usage>
1181
+ </command>
1193
1182
  </commands>
1194
1183
 
1195
1184
 
1196
1185
  '''
1186
+ return {
1187
+ "config_readme": config_readme.prompt()
1188
+ }
1197
1189
 
1198
1190
  def execute_auto_command(self, command: str, parameters: Dict[str, Any]) -> None:
1199
1191
  """
@@ -1232,9 +1224,7 @@ class CommandAutoTuner:
1232
1224
  "get_project_related_files": self.tools.get_project_related_files,
1233
1225
  "ask_user":self.tools.ask_user,
1234
1226
  "read_file_with_keyword_ranges": self.tools.read_file_with_keyword_ranges,
1235
-
1236
-
1237
-
1227
+ "get_project_type": self.project_type_analyzer.analyze,
1238
1228
  }
1239
1229
 
1240
1230
  if command not in command_map:
@@ -376,12 +376,16 @@ class AutoCoderArgs(pydantic.BaseModel):
376
376
  conversation_prune_group_size: Optional[int] = 4
377
377
  conversation_prune_strategy: Optional[str] = "summarize"
378
378
 
379
- context_prune_strategy: Optional[str] = "score"
379
+ context_prune_strategy: Optional[str] = "extract"
380
380
  context_prune: Optional[bool] = True
381
+ context_prune_sliding_window_size: Optional[int] = 1000
382
+ context_prune_sliding_window_overlap: Optional[int] = 100
381
383
 
382
384
  auto_command_max_iterations: Optional[int] = 10
383
385
 
384
- skip_commit: Optional[bool] = False
386
+ skip_commit: Optional[bool] = False
387
+
388
+ enable_beta: Optional[bool] = False
385
389
 
386
390
  class Config:
387
391
  protected_namespaces = ()
@@ -161,9 +161,13 @@ MESSAGES = {
161
161
  "index_import_success": "Index imported successfully: {{path}}",
162
162
  "edits_title": "edits",
163
163
  "diff_blocks_title":"diff blocks",
164
- "index_exclude_files_error": "index filter exclude files fail: {{ error }}"
164
+ "index_exclude_files_error": "index filter exclude files fail: {{ error }}",
165
+ "file_sliding_window_processing": "File {{ file_path }} is too large ({{ tokens }} tokens), processing with sliding window...",
166
+ "file_snippet_processing": "Processing file {{ file_path }} with code snippet extraction..."
165
167
  },
166
168
  "zh": {
169
+ "file_sliding_window_processing": "文件 {{ file_path }} 过大 ({{ tokens }} tokens),正在使用滑动窗口处理...",
170
+ "file_snippet_processing": "正在对文件 {{ file_path }} 进行代码片段提取...",
167
171
  "file_scored_message": "文件评分: {{file_path}} - 分数: {{score}}",
168
172
  "invalid_file_pattern": "无效的文件模式: {{file_pattern}}. 例如: regex://.*/package-lock\\.json",
169
173
  "conf_not_found": "未找到配置文件: {{path}}",
@@ -119,7 +119,45 @@ class AutoConfigRequest(BaseModel):
119
119
 
120
120
  class AutoConfigResponse(BaseModel):
121
121
  configs: List[Dict[str, Any]] = Field(default_factory=list)
122
- reasoning: str = ""
122
+ reasoning: str = ""
123
+
124
+
125
+ @byzerllm.prompt()
126
+ def config_readme() -> str:
127
+ """
128
+ # 配置项说明
129
+ ## auto_merge: 代码合并方式,可选值为editblock、diff、wholefile.
130
+ - editblock: 生成 SEARCH/REPLACE 块,然后根据 SEARCH块到对应的源码查找,如果相似度阈值大于 editblock_similarity, 那么则将
131
+ 找到的代码块替换为 REPLACE 块。大部分情况都推荐使用 editblock。
132
+ - wholefile: 重新生成整个文件,然后替换原来的文件。对于重构场景,推荐使用 wholefile。
133
+ - diff: 生成标准 git diff 格式,适用于简单的代码修改。
134
+
135
+ ## editblock_similarity: editblock相似度阈值
136
+ - editblock相似度阈值,取值范围为0-1,默认值为0.9。如果设置的太低,虽然能合并进去,但是会引入错误。推荐不要修改该值。
137
+
138
+ ## generate_times_same_model: 相同模型生成次数
139
+ 当进行生成代码时,大模型会对同一个需求生成多份代码,然后会使用 generate_rerank_model 模型对多份代码进行重排序,
140
+ 然后选择得分最高的代码。一般次数越多,最终得到正确的代码概率越高。默认值为1,推荐设置为3。但是设置值越多,可能速度就越慢,消耗的token也越多。
141
+
142
+ ## skip_filter_index: 是否跳过索引过滤
143
+ 是否跳过根据用户的query 自动查找上下文。推荐设置为 false
144
+
145
+ ## skip_build_index: 是否跳过索引构建
146
+ 是否自动构建索引。推荐设置为 false。注意,如果该值设置为 true, 那么 skip_filter_index 设置不会生效。
147
+
148
+ ## rank_times_same_model: 相同模型重排序次数
149
+ 默认值为1. 如果 generate_times_same_model 参数设置大于1,那么 coding 函数会自动对多份代码进行重排序。
150
+ rank_times_same_model 表示重拍的次数,次数越多,选择到最好的代码的可能性越高,但是也会显著增加消耗的token和时间。
151
+ 建议保持默认,要修改也建议不要超过3。
152
+
153
+ ## project_type: 项目类型
154
+ 项目类型通常为如下三种选择:
155
+ 1. ts
156
+ 2. py
157
+ 3. 代码文件后缀名列表(比如.java,.py,.go,.js,.ts),多个按逗号分割
158
+
159
+ 推荐使用 3 选项,因为项目类型通常为多种后缀名混合。
160
+ """
123
161
 
124
162
  class ConfigAutoTuner:
125
163
  def __init__(self,args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM], memory_config: MemoryConfig):
@@ -135,34 +173,7 @@ class ConfigAutoTuner:
135
173
  self.memory_config.configure(conf, skip_print)
136
174
 
137
175
 
138
- @byzerllm.prompt()
139
- def config_readme(self) -> str:
140
- """
141
- # 配置项说明
142
- ## auto_merge: 代码合并方式,可选值为editblock、diff、wholefile.
143
- - editblock: 生成 SEARCH/REPLACE 块,然后根据 SEARCH块到对应的源码查找,如果相似度阈值大于 editblock_similarity, 那么则将
144
- 找到的代码块替换为 REPLACE 块。大部分情况都推荐使用 editblock。
145
- - wholefile: 重新生成整个文件,然后替换原来的文件。对于重构场景,推荐使用 wholefile。
146
- - diff: 生成标准 git diff 格式,适用于简单的代码修改。
147
-
148
- ## editblock_similarity: editblock相似度阈值
149
- - editblock相似度阈值,取值范围为0-1,默认值为0.9。如果设置的太低,虽然能合并进去,但是会引入错误。推荐不要修改该值。
150
-
151
- ## generate_times_same_model: 相同模型生成次数
152
- 当进行生成代码时,大模型会对同一个需求生成多份代码,然后会使用 generate_rerank_model 模型对多份代码进行重排序,
153
- 然后选择得分最高的代码。一般次数越多,最终得到正确的代码概率越高。默认值为1,推荐设置为3。但是设置值越多,可能速度就越慢,消耗的token也越多。
154
-
155
- ## skip_filter_index: 是否跳过索引过滤
156
- 是否跳过根据用户的query 自动查找上下文。推荐设置为 false
157
-
158
- ## skip_build_index: 是否跳过索引构建
159
- 是否自动构建索引。推荐设置为 false。注意,如果该值设置为 true, 那么 skip_filter_index 设置不会生效。
160
-
161
- ## rank_times_same_model: 相同模型重排序次数
162
- 默认值为1. 如果 generate_times_same_model 参数设置大于1,那么 coding 函数会自动对多份代码进行重排序。
163
- rank_times_same_model 表示重拍的次数,次数越多,选择到最好的代码的可能性越高,但是也会显著增加消耗的token和时间。
164
- 建议保持默认,要修改也建议不要超过3。
165
- """
176
+
166
177
 
167
178
  def command_readme(self) -> str:
168
179
  """
@@ -212,7 +223,7 @@ class ConfigAutoTuner:
212
223
  "query": request.query,
213
224
  "current_conf": json.dumps(self.memory_config.memory["conf"], indent=2),
214
225
  "last_execution_stat": "",
215
- "config_readme": self.config_readme.prompt()
226
+ "config_readme": config_readme.prompt()
216
227
  }
217
228
 
218
229
  def tune(self, request: AutoConfigRequest) -> 'AutoConfigResponse':
@@ -174,9 +174,8 @@ def base_base(source_dir:str,project_type:str)->str:
174
174
  source_dir: {{ source_dir }}
175
175
  target_file: {{ target_file }}
176
176
 
177
- model: v3_chat
178
- model_max_input_length: 100000
179
- model_max_input_length: 120000
177
+ model: v3_chat
178
+ model_max_input_length: 60000
180
179
  enable_multi_round_generate: false
181
180
  index_filter_workers: 100
182
181
  index_build_workers: 100
@@ -1,4 +1,5 @@
1
1
  from typing import List, Dict, Any, Union
2
+ from typing import Tuple
2
3
  from pathlib import Path
3
4
  import json
4
5
  from loguru import logger
@@ -19,6 +20,100 @@ class PruneContext:
19
20
  self.llm = llm
20
21
  self.printer = Printer()
21
22
 
23
+ def _split_content_with_sliding_window(self, content: str, window_size=100, overlap=20) -> List[Tuple[int, int, str]]:
24
+ """使用滑动窗口分割大文件内容,返回包含行号信息的文本块
25
+
26
+ Args:
27
+ content: 要分割的文件内容
28
+ window_size: 每个窗口包含的行数
29
+ overlap: 相邻窗口的重叠行数
30
+
31
+ Returns:
32
+ List[Tuple[int, int, str]]: 返回元组列表,每个元组包含:
33
+ - 起始行号(从1开始),在原始文件的绝对行号
34
+ - 结束行号,在原始文件的绝对行号
35
+ - 带行号的内容文本
36
+ """
37
+ # 按行分割内容
38
+ lines = content.splitlines()
39
+ chunks = []
40
+ start = 0
41
+
42
+ while start < len(lines):
43
+ # 计算当前窗口的结束位置
44
+ end = min(start + window_size, len(lines))
45
+
46
+ # 计算实际的起始位置(考虑重叠)
47
+ actual_start = max(0, start - overlap)
48
+
49
+ # 提取当前窗口的行
50
+ chunk_lines = lines[actual_start:end]
51
+
52
+ # 为每一行添加行号
53
+ # 行号从actual_start+1开始,保持与原文件的绝对行号一致
54
+ chunk_content = "\n".join([
55
+ f"{i+1} {line}" for i, line in enumerate(chunk_lines, start=actual_start)
56
+ ])
57
+
58
+ # 保存分块信息:(起始行号, 结束行号, 带行号的内容)
59
+ # 行号从1开始计数
60
+ chunks.append((actual_start + 1, end, chunk_content))
61
+
62
+ # 移动到下一个窗口的起始位置
63
+ # 减去overlap确保窗口重叠
64
+ start += (window_size - overlap)
65
+
66
+ return chunks
67
+
68
+ def _merge_overlapping_snippets(self, snippets: List[dict]) -> List[dict]:
69
+ """合并重叠或相邻的代码片段
70
+
71
+ Args:
72
+ snippets: 代码片段列表,每个片段是包含start_line和end_line的字典
73
+
74
+ Returns:
75
+ List[dict]: 合并后的代码片段列表
76
+
77
+ 示例:
78
+ 输入: [
79
+ {"start_line": 1, "end_line": 5},
80
+ {"start_line": 4, "end_line": 8},
81
+ {"start_line": 10, "end_line": 12}
82
+ ]
83
+ 输出: [
84
+ {"start_line": 1, "end_line": 8},
85
+ {"start_line": 10, "end_line": 12}
86
+ ]
87
+ """
88
+ if not snippets:
89
+ return []
90
+
91
+ # 按起始行排序
92
+ sorted_snippets = sorted(snippets, key=lambda x: x["start_line"])
93
+
94
+ merged = [sorted_snippets[0]]
95
+
96
+ for current in sorted_snippets[1:]:
97
+ last = merged[-1]
98
+
99
+ # 判断是否需要合并:
100
+ # 1. 如果当前片段的起始行小于等于上一个片段的结束行+1
101
+ # 2. +1是为了合并相邻的片段,比如1-5和6-8应该合并为1-8
102
+ if current["start_line"] <= last["end_line"] + 1:
103
+ # 合并区间:
104
+ # - 起始行取两者最小值
105
+ # - 结束行取两者最大值
106
+ merged[-1] = {
107
+ "start_line": min(last["start_line"], current["start_line"]),
108
+ "end_line": max(last["end_line"], current["end_line"])
109
+ }
110
+ else:
111
+ # 如果不重叠且不相邻,则作为新片段添加
112
+ merged.append(current)
113
+
114
+ return merged
115
+
116
+
22
117
  def _delete_overflow_files(self, file_paths: List[str]) -> List[SourceCode]:
23
118
  """直接删除超出 token 限制的文件"""
24
119
  total_tokens = 0
@@ -40,6 +135,8 @@ class PruneContext:
40
135
  selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=token_count))
41
136
 
42
137
  return selected_files
138
+
139
+
43
140
 
44
141
  def _extract_code_snippets(self, file_paths: List[str], conversations: List[Dict[str, str]]) -> List[SourceCode]:
45
142
  """抽取关键代码片段策略"""
@@ -48,7 +145,7 @@ class PruneContext:
48
145
  full_file_tokens = int(self.max_tokens * 0.8)
49
146
 
50
147
  @byzerllm.prompt()
51
- def extract_code_snippets(conversations: List[Dict[str, str]], content: str) -> str:
148
+ def extract_code_snippets(conversations: List[Dict[str, str]], content: str, is_partial_content: bool = False) -> str:
52
149
  """
53
150
  根据提供的代码文件和对话历史提取相关代码片段。
54
151
 
@@ -111,6 +208,13 @@ class PruneContext:
111
208
  {{ content }}
112
209
  </code_file>
113
210
 
211
+ <% if is_partial_content: %>
212
+ <partial_content_process_note>
213
+ 当前处理的是文件的局部内容(行号{start_line}-{end_line}),
214
+ 请仅基于当前可见内容判断相关性,返回标注的行号区间。
215
+ </partial_content_process_note>
216
+ <% endif %>
217
+
114
218
  2. 对话历史:
115
219
  <conversation_history>
116
220
  {% for msg in conversations %}
@@ -131,15 +235,17 @@ class PruneContext:
131
235
  4. 如果没有相关代码段,返回空数组[]。
132
236
 
133
237
  输出格式:
134
- 严格的JSON数组,不包含其他文字或解释。
135
-
238
+ 严格的JSON数组,不包含其他文字或解释。
239
+
136
240
  ```json
137
241
  [
138
242
  {"start_line": 第一个代码段的起始行号, "end_line": 第一个代码段的结束行号},
139
243
  {"start_line": 第二个代码段的起始行号, "end_line": 第二个代码段的结束行号}
140
244
  ]
141
- ```
245
+ ```
246
+
142
247
  """
248
+
143
249
 
144
250
  for file_path in file_paths:
145
251
  try:
@@ -152,21 +258,67 @@ class PruneContext:
152
258
  selected_files.append(SourceCode(module_name=file_path,source_code=content,tokens=tokens))
153
259
  token_count += tokens
154
260
  continue
155
-
261
+
262
+ ## 如果单个文件太大,那么先按滑动窗口分割,然后对窗口抽取代码片段
263
+ if tokens > self.max_tokens:
264
+ self.printer.print_in_terminal("file_sliding_window_processing", file_path=file_path, tokens=tokens)
265
+ chunks = self._split_content_with_sliding_window(content,
266
+ self.args.context_prune_sliding_window_size,
267
+ self.args.context_prune_sliding_window_overlap)
268
+ all_snippets = []
269
+ for chunk_start, chunk_end, chunk_content in chunks:
270
+ extracted = extract_code_snippets.with_llm(self.llm).run(
271
+ conversations=conversations,
272
+ content=chunk_content,
273
+ is_partial_content=True
274
+ )
275
+ if extracted:
276
+ json_str = extract_code(extracted)[0][1]
277
+ snippets = json.loads(json_str)
278
+
279
+ # 获取到的本来就是在原始文件里的绝对行号
280
+ # 后续在构建代码片段内容时,会为了适配数组操作修改行号,这里无需处理
281
+ adjusted_snippets = [{
282
+ "start_line": snippet["start_line"],
283
+ "end_line": snippet["end_line"]
284
+ } for snippet in snippets]
285
+ all_snippets.extend(adjusted_snippets)
286
+ merged_snippets = self._merge_overlapping_snippets(all_snippets)
287
+ content_snippets = self._build_snippet_content(file_path, content, merged_snippets)
288
+ snippet_tokens = count_tokens(content_snippets)
289
+ if token_count + snippet_tokens <= self.max_tokens:
290
+ selected_files.append(SourceCode(module_name=file_path,source_code=content_snippets,tokens=snippet_tokens))
291
+ token_count += snippet_tokens
292
+ continue
293
+ else:
294
+ break
295
+
156
296
  # 抽取关键片段
297
+ lines = content.splitlines()
298
+ new_content = ""
299
+
300
+ ## 将文件内容按行编号
301
+ for index,line in enumerate(lines):
302
+ new_content += f"{index+1} {line}\n"
303
+
304
+ ## 抽取代码片段
305
+ self.printer.print_in_terminal("file_snippet_processing", file_path=file_path)
157
306
  extracted = extract_code_snippets.with_llm(self.llm).run(
158
307
  conversations=conversations,
159
- content=content
308
+ content=new_content
160
309
  )
161
310
 
311
+ ## 构建代码片段内容
162
312
  if extracted:
163
313
  json_str = extract_code(extracted)[0][1]
164
314
  snippets = json.loads(json_str)
165
- new_content = self._build_snippet_content(file_path, content, snippets)
315
+ content_snippets = self._build_snippet_content(file_path, content, snippets)
166
316
 
167
- snippet_tokens = count_tokens(new_content)
317
+ snippet_tokens = count_tokens(content_snippets)
168
318
  if token_count + snippet_tokens <= self.max_tokens:
169
- selected_files.append(SourceCode(module_name=file_path,source_code=new_content,tokens=snippet_tokens))
319
+ selected_files.append(SourceCode(module_name=file_path,
320
+ source_code=content_snippets,
321
+ tokens=snippet_tokens))
170
322
  token_count += snippet_tokens
171
323
  else:
172
324
  break
@@ -175,10 +327,32 @@ class PruneContext:
175
327
  continue
176
328
 
177
329
  return selected_files
330
+
331
+
332
+ def _merge_overlapping_snippets(self, snippets: List[dict]) -> List[dict]:
333
+ if not snippets:
334
+ return []
335
+
336
+ # 按起始行排序
337
+ sorted_snippets = sorted(snippets, key=lambda x: x["start_line"])
338
+
339
+ merged = [sorted_snippets[0]]
340
+ for current in sorted_snippets[1:]:
341
+ last = merged[-1]
342
+ if current["start_line"] <= last["end_line"] + 1: # 允许1行间隔
343
+ # 合并区间
344
+ merged[-1] = {
345
+ "start_line": min(last["start_line"], current["start_line"]),
346
+ "end_line": max(last["end_line"], current["end_line"])
347
+ }
348
+ else:
349
+ merged.append(current)
350
+
351
+ return merged
178
352
 
179
353
  def _build_snippet_content(self, file_path: str, full_content: str, snippets: List[dict]) -> str:
180
354
  """构建包含代码片段的文件内容"""
181
- lines = full_content.split("\n")
355
+ lines = full_content.splitlines()
182
356
  header = f"Snippets:\n"
183
357
 
184
358
  content = []
@@ -214,7 +388,7 @@ class PruneContext:
214
388
  return self._extract_code_snippets(file_paths, conversations)
215
389
  else:
216
390
  raise ValueError(f"无效策略: {strategy}. 可选值: delete/extract/score")
217
-
391
+
218
392
  def _count_tokens(self, file_paths: List[str]) -> int:
219
393
  """计算文件总token数"""
220
394
  total_tokens = 0
@@ -312,4 +486,3 @@ class PruneContext:
312
486
  break
313
487
 
314
488
  return selected_files
315
-
@@ -3,8 +3,9 @@ import json
3
3
  from pydantic import BaseModel
4
4
  import byzerllm
5
5
  from autocoder.common.printer import Printer
6
- from autocoder.utils.llms import count_tokens
6
+ from autocoder.rag.token_counter import count_tokens
7
7
  from loguru import logger
8
+ from autocoder.common import AutoCoderArgs
8
9
 
9
10
  class PruneStrategy(BaseModel):
10
11
  name: str
@@ -12,25 +13,25 @@ class PruneStrategy(BaseModel):
12
13
  config: Dict[str, Any] = {"safe_zone_tokens": 0, "group_size": 4}
13
14
 
14
15
  class ConversationPruner:
15
- def __init__(self, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM],
16
- safe_zone_tokens: int = 500, group_size: int = 4):
16
+ def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
17
+ self.args = args
17
18
  self.llm = llm
18
19
  self.printer = Printer()
19
20
  self.strategies = {
20
21
  "summarize": PruneStrategy(
21
22
  name="summarize",
22
23
  description="对早期对话进行分组摘要,保留关键信息",
23
- config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
24
+ config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
24
25
  ),
25
26
  "truncate": PruneStrategy(
26
27
  name="truncate",
27
28
  description="分组截断最早的部分对话",
28
- config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
29
+ config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
29
30
  ),
30
31
  "hybrid": PruneStrategy(
31
32
  name="hybrid",
32
33
  description="先尝试分组摘要,如果仍超限则分组截断",
33
- config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
34
+ config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
34
35
  )
35
36
  }
36
37
 
@@ -57,7 +58,7 @@ class ConversationPruner:
57
58
  if strategy.name == "summarize":
58
59
  return self._summarize_prune(conversations, strategy.config)
59
60
  elif strategy.name == "truncate":
60
- return self._truncate_prune.with_llm(self.llm).run(conversations)
61
+ return self._truncate_prune(conversations)
61
62
  elif strategy.name == "hybrid":
62
63
  pruned = self._summarize_prune(conversations, strategy.config)
63
64
  if count_tokens(json.dumps(pruned, ensure_ascii=False)) > self.args.conversation_prune_safe_zone_tokens:
@@ -80,8 +81,8 @@ class ConversationPruner:
80
81
  break
81
82
 
82
83
  # 找到要处理的对话组
83
- early_conversations = processed_conversations[:-group_size]
84
- recent_conversations = processed_conversations[-group_size:]
84
+ early_conversations = processed_conversations[:group_size]
85
+ recent_conversations = processed_conversations[group_size:]
85
86
 
86
87
  if not early_conversations:
87
88
  break
@@ -90,7 +91,7 @@ class ConversationPruner:
90
91
  group_summary = self._generate_summary.with_llm(self.llm).run(early_conversations[-group_size:])
91
92
 
92
93
  # 更新对话历史
93
- processed_conversations = early_conversations[:-group_size] + [
94
+ processed_conversations = [
94
95
  {"role": "user", "content": f"历史对话摘要:\n{group_summary}"},
95
96
  {"role": "assistant", "content": f"收到"}
96
97
  ] + recent_conversations