auto-coder 0.1.255__py3-none-any.whl → 0.1.257__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

Files changed (30) hide show
  1. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/METADATA +2 -2
  2. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/RECORD +30 -27
  3. autocoder/auto_coder.py +44 -50
  4. autocoder/chat_auto_coder.py +16 -17
  5. autocoder/chat_auto_coder_lang.py +1 -1
  6. autocoder/common/__init__.py +7 -0
  7. autocoder/common/auto_coder_lang.py +46 -16
  8. autocoder/common/code_auto_generate.py +45 -5
  9. autocoder/common/code_auto_generate_diff.py +45 -7
  10. autocoder/common/code_auto_generate_editblock.py +48 -4
  11. autocoder/common/code_auto_generate_strict_diff.py +46 -7
  12. autocoder/common/code_modification_ranker.py +39 -3
  13. autocoder/dispacher/actions/action.py +60 -40
  14. autocoder/dispacher/actions/plugins/action_regex_project.py +12 -6
  15. autocoder/index/entry.py +6 -4
  16. autocoder/index/filter/quick_filter.py +175 -65
  17. autocoder/index/index.py +94 -4
  18. autocoder/models.py +44 -6
  19. autocoder/privacy/__init__.py +3 -0
  20. autocoder/privacy/model_filter.py +100 -0
  21. autocoder/pyproject/__init__.py +1 -0
  22. autocoder/suffixproject/__init__.py +1 -0
  23. autocoder/tsproject/__init__.py +1 -0
  24. autocoder/utils/llms.py +27 -0
  25. autocoder/utils/model_provider_selector.py +192 -0
  26. autocoder/version.py +1 -1
  27. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/LICENSE +0 -0
  28. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/WHEEL +0 -0
  29. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/entry_points.txt +0 -0
  30. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ from autocoder.utils.conversation_store import store_code_model_conversation
15
15
  from autocoder.common.printer import Printer
16
16
  import time
17
17
  from autocoder.utils.llms import get_llm_names
18
+ from autocoder.common import SourceCodeList
18
19
  from loguru import logger
19
20
  class ActionRegexProject:
20
21
  def __init__(
@@ -36,20 +37,21 @@ class ActionRegexProject:
36
37
  pp = RegexProject(args=args, llm=self.llm)
37
38
  self.pp = pp
38
39
  pp.run()
39
- source_code = pp.output()
40
+ source_code_list = SourceCodeList(pp.sources)
40
41
  if self.llm:
41
42
  if args.in_code_apply:
42
43
  old_query = args.query
43
44
  args.query = (args.context or "") + "\n\n" + args.query
44
- source_code = build_index_and_filter_files(
45
+ source_code_list = build_index_and_filter_files(
45
46
  llm=self.llm, args=args, sources=pp.sources
46
47
  )
47
48
  if args.in_code_apply:
48
49
  args.query = old_query
49
- self.process_content(source_code)
50
+ self.process_content(source_code_list)
50
51
 
51
- def process_content(self, content: str):
52
+ def process_content(self, source_code_list: SourceCodeList):
52
53
  args = self.args
54
+ content = source_code_list.to_str()
53
55
 
54
56
  if args.execute and self.llm and not args.human_as_model:
55
57
  if len(content) > self.args.model_max_input_length:
@@ -78,21 +80,25 @@ class ActionRegexProject:
78
80
  generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
79
81
  if self.args.enable_multi_round_generate:
80
82
  generate_result = generate.multi_round_run(
81
- query=args.query, source_content=content
83
+ query=args.query, source_code_list=source_code_list
82
84
  )
83
85
  else:
84
86
  generate_result = generate.single_round_run(
85
- query=args.query, source_content=content
87
+ query=args.query, source_code_list=source_code_list
86
88
  )
87
89
 
88
90
  elapsed_time = time.time() - start_time
89
91
  speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
90
92
  model_names = ",".join(get_llm_names(self.llm))
93
+ input_tokens_cost = generate_result.metadata.get('input_tokens_cost', 0)
94
+ generated_tokens_cost = generate_result.metadata.get('generated_tokens_cost', 0)
91
95
  self.printer.print_in_terminal(
92
96
  "code_generation_complete",
93
97
  duration=elapsed_time,
94
98
  input_tokens=generate_result.metadata.get('input_tokens_count', 0),
95
99
  output_tokens=generate_result.metadata.get('generated_tokens_count', 0),
100
+ input_cost=input_tokens_cost,
101
+ output_cost=generated_tokens_cost,
96
102
  speed=round(speed, 2),
97
103
  model_names=model_names
98
104
  )
autocoder/index/entry.py CHANGED
@@ -23,10 +23,11 @@ from autocoder.index.filter.quick_filter import QuickFilter
23
23
  from autocoder.index.filter.normal_filter import NormalFilter
24
24
  from autocoder.index.index import IndexManager
25
25
  from loguru import logger
26
+ from autocoder.common import SourceCodeList
26
27
 
27
28
  def build_index_and_filter_files(
28
29
  llm, args: AutoCoderArgs, sources: List[SourceCode]
29
- ) -> str:
30
+ ) -> SourceCodeList:
30
31
  # Initialize timing and statistics
31
32
  total_start_time = time.monotonic()
32
33
  stats = {
@@ -253,7 +254,8 @@ def build_index_and_filter_files(
253
254
  for file in final_filenames:
254
255
  print(f"{file} - {final_files[file].reason}")
255
256
 
256
- source_code = ""
257
+ source_code = ""
258
+ source_code_list = SourceCodeList(sources=[])
257
259
  depulicated_sources = set()
258
260
 
259
261
  for file in sources:
@@ -263,7 +265,7 @@ def build_index_and_filter_files(
263
265
  depulicated_sources.add(file.module_name)
264
266
  source_code += f"##File: {file.module_name}\n"
265
267
  source_code += f"{file.source_code}\n\n"
266
-
268
+ source_code_list.sources.append(file)
267
269
  if args.request_id and not args.skip_events:
268
270
  queue_communicate.send_event(
269
271
  request_id=args.request_id,
@@ -339,4 +341,4 @@ def build_index_and_filter_files(
339
341
  )
340
342
  )
341
343
 
342
- return source_code
344
+ return source_code_list
@@ -4,21 +4,21 @@ from autocoder.utils.auto_coder_utils.chat_stream_out import stream_out
4
4
  from autocoder.common.utils_code_auto_generate import stream_chat_with_continue
5
5
  from byzerllm.utils.str2model import to_model
6
6
  from autocoder.index.types import IndexItem
7
- from autocoder.common import AutoCoderArgs,SourceCode
7
+ from autocoder.common import AutoCoderArgs, SourceCode
8
8
  import byzerllm
9
9
  import time
10
10
  from autocoder.index.index import IndexManager
11
11
  from autocoder.index.types import (
12
12
  IndexItem,
13
- TargetFile,
13
+ TargetFile,
14
14
  FileNumberList
15
15
  )
16
16
  from autocoder.rag.token_counter import count_tokens
17
17
  from autocoder.common.printer import Printer
18
18
  from concurrent.futures import ThreadPoolExecutor
19
- import threading
19
+ from byzerllm import MetaHolder
20
20
 
21
- from autocoder.utils.llms import get_llm_names
21
+ from autocoder.utils.llms import get_llm_names, get_model_info
22
22
 
23
23
 
24
24
  def get_file_path(file_path):
@@ -32,8 +32,9 @@ class QuickFilterResult(BaseModel):
32
32
  has_error: bool
33
33
  error_message: Optional[str] = None
34
34
 
35
+
35
36
  class QuickFilter():
36
- def __init__(self, index_manager: IndexManager,stats:Dict[str,Any],sources:List[SourceCode]):
37
+ def __init__(self, index_manager: IndexManager, stats: Dict[str, Any], sources: List[SourceCode]):
37
38
  self.index_manager = index_manager
38
39
  self.args = index_manager.args
39
40
  self.stats = stats
@@ -41,72 +42,142 @@ class QuickFilter():
41
42
  self.printer = Printer()
42
43
  self.max_tokens = self.args.index_filter_model_max_input_length
43
44
 
44
-
45
45
  def big_filter(self, index_items: List[IndexItem],) -> QuickFilterResult:
46
46
  chunks = []
47
47
  current_chunk = []
48
-
48
+
49
49
  # 将 index_items 切分成多个 chunks,第一个chunk尽可能接近max_tokens
50
50
  for item in index_items:
51
51
  # 使用 quick_filter_files.prompt 生成文本再统计
52
52
  temp_chunk = current_chunk + [item]
53
- prompt_text = self.quick_filter_files.prompt(temp_chunk, self.args.query)
54
- temp_size = count_tokens(prompt_text)
53
+ prompt_text = self.quick_filter_files.prompt(
54
+ temp_chunk, self.args.query)
55
+ temp_size = count_tokens(prompt_text)
55
56
  # 如果当前chunk为空,或者添加item后不超过max_tokens,就添加到当前chunk
56
57
  if not current_chunk or temp_size <= self.max_tokens:
57
- current_chunk.append(item)
58
+ current_chunk.append(item)
58
59
  else:
59
60
  # 当前chunk已满,创建新chunk
60
61
  chunks.append(current_chunk)
61
- current_chunk = [item]
62
-
62
+ current_chunk = [item]
63
+
63
64
  if current_chunk:
64
65
  chunks.append(current_chunk)
65
-
66
- tokens_len = count_tokens(self.quick_filter_files.prompt(index_items, self.args.query))
66
+
67
+ tokens_len = count_tokens(
68
+ self.quick_filter_files.prompt(index_items, self.args.query))
67
69
  self.printer.print_in_terminal(
68
- "quick_filter_too_long",
69
- style="yellow",
70
- tokens_len=tokens_len,
71
- max_tokens=self.max_tokens,
72
- split_size=len(chunks)
73
- )
70
+ "quick_filter_too_long",
71
+ style="yellow",
72
+ tokens_len=tokens_len,
73
+ max_tokens=self.max_tokens,
74
+ split_size=len(chunks)
75
+ )
74
76
 
75
77
  def process_chunk(chunk_index: int, chunk: List[IndexItem]) -> QuickFilterResult:
76
78
  try:
77
- model_name = ",".join(get_llm_names(self.index_manager.index_filter_llm))
79
+ # 获取模型名称列表
80
+ model_names = get_llm_names(
81
+ self.index_manager.index_filter_llm)
82
+ model_name = ",".join(model_names)
78
83
  files: Dict[str, TargetFile] = {}
79
-
84
+
85
+ # 获取模型价格信息
86
+ model_info_map = {}
87
+ for name in model_names:
88
+ # 第二个参数是产品模式,从args中获取
89
+ info = get_model_info(name, self.args.product_mode)
90
+ if info:
91
+ model_info_map[name] = {
92
+ # 每百万tokens成本
93
+ "input_price": info.get("input_price", 0.0),
94
+ # 每百万tokens成本
95
+ "output_price": info.get("output_price", 0.0)
96
+ }
97
+
80
98
  if chunk_index == 0:
81
99
  # 第一个chunk使用流式输出
82
100
  stream_generator = stream_chat_with_continue(
83
101
  self.index_manager.index_filter_llm,
84
- [{"role": "user", "content": self.quick_filter_files.prompt(chunk, self.args.query)}],
102
+ [{"role": "user", "content": self.quick_filter_files.prompt(
103
+ chunk, self.args.query)}],
85
104
  {}
86
105
  )
87
- full_response, _ = stream_out(
106
+ full_response, last_meta = stream_out(
88
107
  stream_generator,
89
108
  model_name=model_name,
90
- title=self.printer.get_message_from_key_with_format("quick_filter_title", model_name=model_name),
109
+ title=self.printer.get_message_from_key_with_format(
110
+ "quick_filter_title", model_name=model_name),
91
111
  args=self.args
92
112
  )
93
113
  file_number_list = to_model(full_response, FileNumberList)
114
+
115
+ # 计算总成本
116
+ total_input_cost = 0.0
117
+ total_output_cost = 0.0
118
+
119
+ for name in model_names:
120
+ info = model_info_map.get(name, {})
121
+ # 计算公式:token数 * 单价 / 1000000
122
+ total_input_cost += (last_meta.input_tokens_count *
123
+ info.get("input_price", 0.0)) / 1000000
124
+ total_output_cost += (last_meta.generated_tokens_count *
125
+ info.get("output_price", 0.0)) / 1000000
126
+
127
+ # 四舍五入到4位小数
128
+ total_input_cost = round(total_input_cost, 4)
129
+ total_output_cost = round(total_output_cost, 4)
130
+
131
+ # 打印 token 统计信息和成本
132
+ self.printer.print_in_terminal(
133
+ "quick_filter_stats",
134
+ style="blue",
135
+ input_tokens=last_meta.input_tokens_count,
136
+ output_tokens=last_meta.generated_tokens_count,
137
+ input_cost=total_input_cost,
138
+ output_cost=total_output_cost,
139
+ model_names=model_name
140
+ )
94
141
  else:
95
142
  # 其他chunks直接使用with_llm
96
- file_number_list = self.quick_filter_files.with_llm(self.index_manager.index_filter_llm).with_return_type(FileNumberList).run(chunk, self.args.query)
97
-
143
+ meta_holder = MetaHolder()
144
+ start_time = time.monotonic()
145
+ file_number_list = self.quick_filter_files.with_llm(self.index_manager.index_filter_llm).with_meta(
146
+ meta_holder).with_return_type(FileNumberList).run(chunk, self.args.query)
147
+ end_time = time.monotonic()
148
+
149
+ total_input_cost = 0.0
150
+ total_output_cost = 0.0
151
+ if meta_holder.get_meta():
152
+ meta_dict = meta_holder.get_meta()
153
+ total_input_cost = meta_dict.get("input_tokens_count", 0) * model_info_map.get(model_name, {}).get("input_price", 0.0) / 1000000
154
+ total_output_cost = meta_dict.get("generated_tokens_count", 0) * model_info_map.get(model_name, {}).get("output_price", 0.0) / 1000000
155
+
156
+ self.printer.print_in_terminal(
157
+ "quick_filter_stats",
158
+ style="blue",
159
+ input_tokens=meta_dict.get("input_tokens_count", 0),
160
+ output_tokens=meta_dict.get("generated_tokens_count", 0),
161
+ input_cost=total_input_cost,
162
+ output_cost=total_output_cost,
163
+ model_names=model_name,
164
+ elapsed_time=f"{end_time - start_time:.2f}"
165
+ )
166
+
98
167
  if file_number_list:
99
168
  for file_number in file_number_list.file_list:
100
- file_path = get_file_path(chunk[file_number].module_name)
169
+ file_path = get_file_path(
170
+ chunk[file_number].module_name)
101
171
  files[file_path] = TargetFile(
102
172
  file_path=chunk[file_number].module_name,
103
- reason=self.printer.get_message_from_key("quick_filter_reason")
173
+ reason=self.printer.get_message_from_key(
174
+ "quick_filter_reason")
104
175
  )
105
176
  return QuickFilterResult(
106
177
  files=files,
107
178
  has_error=False
108
179
  )
109
-
180
+
110
181
  except Exception as e:
111
182
  self.printer.print_in_terminal(
112
183
  "quick_filter_failed",
@@ -123,25 +194,25 @@ class QuickFilter():
123
194
  if chunks:
124
195
  with ThreadPoolExecutor() as executor:
125
196
  # 提交所有chunks到线程池并收集结果
126
- futures = [executor.submit(process_chunk, i, chunk)
127
- for i, chunk in enumerate(chunks)]
128
-
197
+ futures = [executor.submit(process_chunk, i, chunk)
198
+ for i, chunk in enumerate(chunks)]
199
+
129
200
  # 等待所有任务完成并收集结果
130
201
  for future in futures:
131
202
  results.append(future.result())
132
-
203
+
133
204
  # 合并所有结果
134
205
  final_files: Dict[str, TargetFile] = {}
135
206
  has_error = False
136
207
  error_messages: List[str] = []
137
-
208
+
138
209
  for result in results:
139
210
  if result.has_error:
140
211
  has_error = True
141
212
  if result.error_message:
142
213
  error_messages.append(result.error_message)
143
214
  final_files.update(result.files)
144
-
215
+
145
216
  return QuickFilterResult(
146
217
  files=final_files,
147
218
  has_error=has_error,
@@ -149,7 +220,7 @@ class QuickFilter():
149
220
  )
150
221
 
151
222
  @byzerllm.prompt()
152
- def quick_filter_files(self,file_meta_list:List[IndexItem],query:str) -> str:
223
+ def quick_filter_files(self, file_meta_list: List[IndexItem], query: str) -> str:
153
224
  '''
154
225
  当用户提一个需求的时候,我们需要找到相关的文件,然后阅读这些文件,并且修改其中部分文件。
155
226
  现在,给定下面的索引文件:
@@ -160,7 +231,7 @@ class QuickFilter():
160
231
 
161
232
  索引文件包含文件序号(##[]括起来的部分),文件路径,文件符号信息等。
162
233
  下面是用户的查询需求:
163
-
234
+
164
235
  <query>
165
236
  {{ query }}
166
237
  </query>
@@ -182,63 +253,101 @@ class QuickFilter():
182
253
  2. 如果 query 里是一段历史对话,那么对话里的内容提及的文件路径必须要返回。
183
254
  3. json格式数据不允许有注释
184
255
  '''
185
- file_meta_str = "\n".join([f"##[{index}]{item.module_name}\n{item.symbols}" for index,item in enumerate(file_meta_list)])
256
+ file_meta_str = "\n".join(
257
+ [f"##[{index}]{item.module_name}\n{item.symbols}" for index, item in enumerate(file_meta_list)])
186
258
  context = {
187
259
  "content": file_meta_str,
188
260
  "query": query
189
261
  }
190
- return context
262
+ return context
191
263
 
192
264
  def filter(self, index_items: List[IndexItem], query: str) -> QuickFilterResult:
193
265
  final_files: Dict[str, TargetFile] = {}
194
- start_time = time.monotonic()
266
+ start_time = time.monotonic()
267
+
268
+ prompt_str = self.quick_filter_files.prompt(index_items, query)
269
+
270
+ tokens_len = count_tokens(prompt_str)
195
271
 
196
- prompt_str = self.quick_filter_files.prompt(index_items,query)
197
-
198
- tokens_len = count_tokens(prompt_str)
199
-
200
272
  # Print current index size
201
273
  self.printer.print_in_terminal(
202
274
  "quick_filter_tokens_len",
203
275
  style="blue",
204
276
  tokens_len=tokens_len
205
277
  )
206
-
207
- if tokens_len > self.max_tokens:
278
+
279
+ if tokens_len > self.max_tokens:
208
280
  return self.big_filter(index_items)
209
-
281
+
210
282
  try:
211
- model_name = ",".join(get_llm_names(self.index_manager.index_filter_llm))
212
-
283
+ # 获取模型名称
284
+ model_names = get_llm_names(self.index_manager.index_filter_llm)
285
+ model_name = ",".join(model_names)
286
+
287
+ # 获取模型价格信息
288
+ model_info_map = {}
289
+ for name in model_names:
290
+ # 第二个参数是产品模式,从args中获取
291
+ info = get_model_info(name, self.args.product_mode)
292
+ if info:
293
+ model_info_map[name] = {
294
+ # 每百万tokens成本
295
+ "input_price": info.get("input_price", 0.0),
296
+ # 每百万tokens成本
297
+ "output_price": info.get("output_price", 0.0)
298
+ }
299
+
213
300
  # 渲染 Prompt 模板
214
- query = self.quick_filter_files.prompt(index_items, self.args.query)
215
-
301
+ query = self.quick_filter_files.prompt(
302
+ index_items, self.args.query)
303
+
216
304
  # 使用流式输出处理
217
305
  stream_generator = stream_chat_with_continue(
218
306
  self.index_manager.index_filter_llm,
219
307
  [{"role": "user", "content": query}],
220
308
  {}
221
309
  )
222
-
310
+
223
311
  # 获取完整响应
224
312
  full_response, last_meta = stream_out(
225
313
  stream_generator,
226
314
  model_name=model_name,
227
- title=self.printer.get_message_from_key_with_format("quick_filter_title", model_name=model_name),
315
+ title=self.printer.get_message_from_key_with_format(
316
+ "quick_filter_title", model_name=model_name),
228
317
  args=self.args
229
- )
318
+ )
230
319
  # 解析结果
231
320
  file_number_list = to_model(full_response, FileNumberList)
232
- end_time = time.monotonic()
233
- # 打印 token 统计信息
321
+ end_time = time.monotonic()
322
+
323
+ # 计算总成本
324
+ total_input_cost = 0.0
325
+ total_output_cost = 0.0
326
+
327
+ for name in model_names:
328
+ info = model_info_map.get(name, {})
329
+ # 计算公式:token数 * 单价 / 1000000
330
+ total_input_cost += (last_meta.input_tokens_count *
331
+ info.get("input_price", 0.0)) / 1000000
332
+ total_output_cost += (last_meta.generated_tokens_count *
333
+ info.get("output_price", 0.0)) / 1000000
334
+
335
+ # 四舍五入到4位小数
336
+ total_input_cost = round(total_input_cost, 4)
337
+ total_output_cost = round(total_output_cost, 4)
338
+
339
+ # 打印 token 统计信息和成本
234
340
  self.printer.print_in_terminal(
235
- "quick_filter_stats",
341
+ "quick_filter_stats",
236
342
  style="blue",
237
343
  elapsed_time=f"{end_time - start_time:.2f}",
238
344
  input_tokens=last_meta.input_tokens_count,
239
- output_tokens=last_meta.generated_tokens_count
345
+ output_tokens=last_meta.generated_tokens_count,
346
+ input_cost=total_input_cost,
347
+ output_cost=total_output_cost,
348
+ model_names=model_name
240
349
  )
241
-
350
+
242
351
  except Exception as e:
243
352
  self.printer.print_in_terminal(
244
353
  "quick_filter_failed",
@@ -250,16 +359,17 @@ class QuickFilter():
250
359
  has_error=True,
251
360
  error_message=str(e)
252
361
  )
253
-
362
+
254
363
  if file_number_list:
255
364
  for file_number in file_number_list.file_list:
256
365
  final_files[get_file_path(index_items[file_number].module_name)] = TargetFile(
257
366
  file_path=index_items[file_number].module_name,
258
- reason=self.printer.get_message_from_key("quick_filter_reason")
367
+ reason=self.printer.get_message_from_key(
368
+ "quick_filter_reason")
259
369
  )
260
- end_time = time.monotonic()
261
- self.stats["timings"]["quick_filter"] = end_time - start_time
370
+ end_time = time.monotonic()
371
+ self.stats["timings"]["quick_filter"] = end_time - start_time
262
372
  return QuickFilterResult(
263
373
  files=final_files,
264
374
  has_error=False
265
- )
375
+ )