auto-coder 0.1.256__py3-none-any.whl → 0.1.257__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -27,6 +27,8 @@ from loguru import logger
27
27
  import time
28
28
  from autocoder.common.printer import Printer
29
29
  from autocoder.utils.llms import get_llm_names
30
+ from autocoder.privacy.model_filter import ModelPathFilter
31
+ from autocoder.common import SourceCodeList
30
32
 
31
33
 
32
34
  class BaseAction:
@@ -55,14 +57,15 @@ class ActionTSProject(BaseAction):
55
57
  self.pp = pp
56
58
  pp.run()
57
59
 
58
- source_code = pp.output()
60
+ # source_code = pp.output()
61
+ source_code_list = SourceCodeList(pp.sources)
59
62
  if self.llm:
60
63
  if args.in_code_apply:
61
64
  old_query = args.query
62
65
  args.query = (args.context or "") + "\n\n" + args.query
63
- source_code = build_index_and_filter_files(
66
+ source_code_list = build_index_and_filter_files(
64
67
  llm=self.llm, args=args, sources=pp.sources
65
- )
68
+ )
66
69
  if args.in_code_apply:
67
70
  args.query = old_query
68
71
 
@@ -81,17 +84,21 @@ class ActionTSProject(BaseAction):
81
84
  html_path=html_path,
82
85
  max_iter=self.args.image_max_iter,
83
86
  )
84
-
87
+ html_code = ""
85
88
  with open(html_path, "r") as f:
86
89
  html_code = f.read()
87
- source_code = f"##File: {html_path}\n{html_code}\n\n" + source_code
90
+
91
+ source_code_list.sources.append(SourceCode(
92
+ module_name=html_path,
93
+ source_code=html_code,
94
+ tag="IMAGE"))
88
95
 
89
- self.process_content(source_code)
96
+ self.process_content(source_code_list)
90
97
  return True
91
98
 
92
- def process_content(self, content: str):
99
+ def process_content(self, source_code_list: SourceCodeList):
93
100
  args = self.args
94
-
101
+ content = source_code_list.to_str()
95
102
  if args.execute and self.llm and not args.human_as_model:
96
103
  content_length = self._get_content_length(content)
97
104
  if content_length > self.args.model_max_input_length:
@@ -116,13 +123,14 @@ class ActionTSProject(BaseAction):
116
123
  )
117
124
  else:
118
125
  generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
126
+
119
127
  if self.args.enable_multi_round_generate:
120
128
  generate_result = generate.multi_round_run(
121
- query=args.query, source_content=content
129
+ query=args.query, source_code_list=source_code_list
122
130
  )
123
131
  else:
124
132
  generate_result = generate.single_round_run(
125
- query=args.query, source_content=content
133
+ query=args.query, source_code_list=source_code_list
126
134
  )
127
135
  elapsed_time = time.time() - start_time
128
136
  speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
@@ -191,11 +199,12 @@ class ActionPyScriptProject(BaseAction):
191
199
  pp = Level1PyProject(
192
200
  script_path=args.script_path, package_name=args.package_name
193
201
  )
194
- content = pp.run()
195
- self.process_content(content)
202
+ pp.run()
203
+ source_code_list = SourceCodeList(pp.sources)
204
+ self.process_content(source_code_list)
196
205
  return True
197
206
 
198
- def process_content(self, content: str):
207
+ def process_content(self, source_code_list: SourceCodeList):
199
208
  args = self.args
200
209
  if args.execute:
201
210
  self.printer.print_in_terminal("code_generation_start")
@@ -216,11 +225,11 @@ class ActionPyScriptProject(BaseAction):
216
225
  generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
217
226
  if self.args.enable_multi_round_generate:
218
227
  generate_result = generate.multi_round_run(
219
- query=args.query, source_content=content
228
+ query=args.query, source_code_list=source_code_list
220
229
  )
221
230
  else:
222
231
  generate_result = generate.single_round_run(
223
- query=args.query, source_content=content
232
+ query=args.query, source_code_list=source_code_list
224
233
  )
225
234
 
226
235
  elapsed_time = time.time() - start_time
@@ -293,24 +302,24 @@ class ActionPyProject(BaseAction):
293
302
  pp = PyProject(args=self.args, llm=self.llm)
294
303
  self.pp = pp
295
304
  pp.run(packages=args.py_packages.split(",") if args.py_packages else [])
296
- source_code = pp.output()
305
+ source_code_list = SourceCodeList(pp.sources)
297
306
 
298
307
  if self.llm:
299
308
  old_query = args.query
300
309
  if args.in_code_apply:
301
310
  args.query = (args.context or "") + "\n\n" + args.query
302
- source_code = build_index_and_filter_files(
311
+ source_code_list = build_index_and_filter_files(
303
312
  llm=self.llm, args=args, sources=pp.sources
304
313
  )
305
314
  if args.in_code_apply:
306
315
  args.query = old_query
307
316
 
308
- self.process_content(source_code)
317
+ self.process_content(source_code_list)
309
318
  return True
310
319
 
311
- def process_content(self, content: str):
320
+ def process_content(self, source_code_list: SourceCodeList):
312
321
  args = self.args
313
-
322
+ content = source_code_list.to_str()
314
323
  if args.execute and self.llm and not args.human_as_model:
315
324
  content_length = self._get_content_length(content)
316
325
  if content_length > self.args.model_max_input_length:
@@ -342,11 +351,11 @@ class ActionPyProject(BaseAction):
342
351
 
343
352
  if self.args.enable_multi_round_generate:
344
353
  generate_result = generate.multi_round_run(
345
- query=args.query, source_content=content
354
+ query=args.query, source_code_list=source_code_list
346
355
  )
347
356
  else:
348
357
  generate_result = generate.single_round_run(
349
- query=args.query, source_content=content
358
+ query=args.query, source_code_list=source_code_list
350
359
  )
351
360
  elapsed_time = time.time() - start_time
352
361
  speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
@@ -414,20 +423,21 @@ class ActionSuffixProject(BaseAction):
414
423
  pp = SuffixProject(args=args, llm=self.llm)
415
424
  self.pp = pp
416
425
  pp.run()
417
- source_code = pp.output()
426
+ source_code_list = SourceCodeList(pp.sources)
418
427
  if self.llm:
419
428
  if args.in_code_apply:
420
429
  old_query = args.query
421
430
  args.query = (args.context or "") + "\n\n" + args.query
422
- source_code = build_index_and_filter_files(
431
+ source_code_list = build_index_and_filter_files(
423
432
  llm=self.llm, args=args, sources=pp.sources
424
433
  )
425
434
  if args.in_code_apply:
426
435
  args.query = old_query
427
- self.process_content(source_code)
436
+ self.process_content(source_code_list)
428
437
 
429
- def process_content(self, content: str):
438
+ def process_content(self, source_code_list: SourceCodeList):
430
439
  args = self.args
440
+ content = source_code_list.to_str()
431
441
 
432
442
  if args.execute and self.llm and not args.human_as_model:
433
443
  content_length = self._get_content_length(content)
@@ -455,11 +465,11 @@ class ActionSuffixProject(BaseAction):
455
465
  generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
456
466
  if self.args.enable_multi_round_generate:
457
467
  generate_result = generate.multi_round_run(
458
- query=args.query, source_content=content
468
+ query=args.query, source_code_list=source_code_list
459
469
  )
460
470
  else:
461
471
  generate_result = generate.single_round_run(
462
- query=args.query, source_content=content
472
+ query=args.query, source_code_list=source_code_list
463
473
  )
464
474
 
465
475
  elapsed_time = time.time() - start_time
@@ -15,6 +15,7 @@ from autocoder.utils.conversation_store import store_code_model_conversation
15
15
  from autocoder.common.printer import Printer
16
16
  import time
17
17
  from autocoder.utils.llms import get_llm_names
18
+ from autocoder.common import SourceCodeList
18
19
  from loguru import logger
19
20
  class ActionRegexProject:
20
21
  def __init__(
@@ -36,20 +37,21 @@ class ActionRegexProject:
36
37
  pp = RegexProject(args=args, llm=self.llm)
37
38
  self.pp = pp
38
39
  pp.run()
39
- source_code = pp.output()
40
+ source_code_list = SourceCodeList(pp.sources)
40
41
  if self.llm:
41
42
  if args.in_code_apply:
42
43
  old_query = args.query
43
44
  args.query = (args.context or "") + "\n\n" + args.query
44
- source_code = build_index_and_filter_files(
45
+ source_code_list = build_index_and_filter_files(
45
46
  llm=self.llm, args=args, sources=pp.sources
46
47
  )
47
48
  if args.in_code_apply:
48
49
  args.query = old_query
49
- self.process_content(source_code)
50
+ self.process_content(source_code_list)
50
51
 
51
- def process_content(self, content: str):
52
+ def process_content(self, source_code_list: SourceCodeList):
52
53
  args = self.args
54
+ content = source_code_list.to_str()
53
55
 
54
56
  if args.execute and self.llm and not args.human_as_model:
55
57
  if len(content) > self.args.model_max_input_length:
@@ -78,11 +80,11 @@ class ActionRegexProject:
78
80
  generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
79
81
  if self.args.enable_multi_round_generate:
80
82
  generate_result = generate.multi_round_run(
81
- query=args.query, source_content=content
83
+ query=args.query, source_code_list=source_code_list
82
84
  )
83
85
  else:
84
86
  generate_result = generate.single_round_run(
85
- query=args.query, source_content=content
87
+ query=args.query, source_code_list=source_code_list
86
88
  )
87
89
 
88
90
  elapsed_time = time.time() - start_time
autocoder/index/entry.py CHANGED
@@ -23,10 +23,11 @@ from autocoder.index.filter.quick_filter import QuickFilter
23
23
  from autocoder.index.filter.normal_filter import NormalFilter
24
24
  from autocoder.index.index import IndexManager
25
25
  from loguru import logger
26
+ from autocoder.common import SourceCodeList
26
27
 
27
28
  def build_index_and_filter_files(
28
29
  llm, args: AutoCoderArgs, sources: List[SourceCode]
29
- ) -> str:
30
+ ) -> SourceCodeList:
30
31
  # Initialize timing and statistics
31
32
  total_start_time = time.monotonic()
32
33
  stats = {
@@ -253,7 +254,8 @@ def build_index_and_filter_files(
253
254
  for file in final_filenames:
254
255
  print(f"{file} - {final_files[file].reason}")
255
256
 
256
- source_code = ""
257
+ source_code = ""
258
+ source_code_list = SourceCodeList(sources=[])
257
259
  depulicated_sources = set()
258
260
 
259
261
  for file in sources:
@@ -263,7 +265,7 @@ def build_index_and_filter_files(
263
265
  depulicated_sources.add(file.module_name)
264
266
  source_code += f"##File: {file.module_name}\n"
265
267
  source_code += f"{file.source_code}\n\n"
266
-
268
+ source_code_list.sources.append(file)
267
269
  if args.request_id and not args.skip_events:
268
270
  queue_communicate.send_event(
269
271
  request_id=args.request_id,
@@ -339,4 +341,4 @@ def build_index_and_filter_files(
339
341
  )
340
342
  )
341
343
 
342
- return source_code
344
+ return source_code_list
autocoder/index/index.py CHANGED
@@ -9,6 +9,7 @@ from autocoder.index.symbols_utils import (
9
9
  SymbolType,
10
10
  symbols_info_to_str,
11
11
  )
12
+ from autocoder.privacy.model_filter import ModelPathFilter
12
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
14
  import threading
14
15
 
@@ -17,6 +18,7 @@ import hashlib
17
18
 
18
19
  from autocoder.common.printer import Printer
19
20
  from autocoder.common.auto_coder_lang import get_message
21
+ from autocoder.utils.llms import get_llm_names, get_model_info
20
22
  from autocoder.index.types import (
21
23
  IndexItem,
22
24
  TargetFile,
@@ -30,6 +32,9 @@ class IndexManager:
30
32
  ):
31
33
  self.sources = sources
32
34
  self.source_dir = args.source_dir
35
+ # Initialize model filter for index_llm and index_filter_llm
36
+ self.index_model_filter = None
37
+ self.index_filter_model_filter = None
33
38
  self.anti_quota_limit = (
34
39
  args.index_model_anti_quota_limit or args.anti_quota_limit
35
40
  )
@@ -46,6 +51,12 @@ class IndexManager:
46
51
  self.index_filter_llm = llm
47
52
 
48
53
  self.llm = llm
54
+
55
+ # Initialize model filters
56
+ if self.index_llm:
57
+ self.index_model_filter = ModelPathFilter.from_model_object(self.index_llm, args)
58
+ if self.index_filter_llm:
59
+ self.index_filter_model_filter = ModelPathFilter.from_model_object(self.index_filter_llm, args)
49
60
  self.args = args
50
61
  self.max_input_length = (
51
62
  args.index_model_max_input_length or args.model_max_input_length
@@ -194,6 +205,17 @@ class IndexManager:
194
205
  ext = os.path.splitext(file_path)[1].lower()
195
206
  if ext in [".md", ".html", ".txt", ".doc", ".pdf"]:
196
207
  return True
208
+
209
+ # Check model filter restrictions
210
+ if self.index_model_filter and not self.index_model_filter.is_accessible(file_path):
211
+ self.printer.print_in_terminal(
212
+ "index_file_filtered",
213
+ style="yellow",
214
+ file_path=file_path,
215
+ model_name=",".join(get_llm_names(self.index_llm))
216
+ )
217
+ return True
218
+
197
219
  return False
198
220
 
199
221
  def build_index_for_single_source(self, source: SourceCode):
@@ -212,8 +234,29 @@ class IndexManager:
212
234
  model_name = ",".join(get_llm_names(self.index_llm))
213
235
 
214
236
  try:
237
+ # 获取模型名称列表
238
+ model_names = get_llm_names(self.index_llm)
239
+ model_name = ",".join(model_names)
240
+
241
+ # 获取模型价格信息
242
+ model_info_map = {}
243
+ for name in model_names:
244
+ info = get_model_info(name, self.args.product_mode)
245
+ if info:
246
+ model_info_map[name] = {
247
+ "input_price": info.get("input_price", 0.0),
248
+ "output_price": info.get("output_price", 0.0)
249
+ }
250
+
215
251
  start_time = time.monotonic()
216
252
  source_code = source.source_code
253
+
254
+ # 统计token和成本
255
+ total_input_tokens = 0
256
+ total_output_tokens = 0
257
+ total_input_cost = 0.0
258
+ total_output_cost = 0.0
259
+
217
260
  if len(source.source_code) > self.max_input_length:
218
261
  self.printer.print_in_terminal(
219
262
  "index_file_too_large",
@@ -227,15 +270,38 @@ class IndexManager:
227
270
  )
228
271
  symbols = []
229
272
  for chunk in chunks:
273
+ meta_holder = byzerllm.MetaHolder()
230
274
  chunk_symbols = self.get_all_file_symbols.with_llm(
231
- self.index_llm).run(source.module_name, chunk)
275
+ self.index_llm).with_meta(meta_holder).run(source.module_name, chunk)
232
276
  time.sleep(self.anti_quota_limit)
233
277
  symbols.append(chunk_symbols)
278
+
279
+ if meta_holder.get_meta():
280
+ meta_dict = meta_holder.get_meta()
281
+ total_input_tokens += meta_dict.get("input_tokens_count", 0)
282
+ total_output_tokens += meta_dict.get("generated_tokens_count", 0)
283
+
234
284
  symbols = "\n".join(symbols)
235
285
  else:
286
+ meta_holder = byzerllm.MetaHolder()
236
287
  symbols = self.get_all_file_symbols.with_llm(
237
- self.index_llm).run(source.module_name, source_code)
288
+ self.index_llm).with_meta(meta_holder).run(source.module_name, source_code)
238
289
  time.sleep(self.anti_quota_limit)
290
+
291
+ if meta_holder.get_meta():
292
+ meta_dict = meta_holder.get_meta()
293
+ total_input_tokens += meta_dict.get("input_tokens_count", 0)
294
+ total_output_tokens += meta_dict.get("generated_tokens_count", 0)
295
+
296
+ # 计算总成本
297
+ for name in model_names:
298
+ info = model_info_map.get(name, {})
299
+ total_input_cost += (total_input_tokens * info.get("input_price", 0.0)) / 1000000
300
+ total_output_cost += (total_output_tokens * info.get("output_price", 0.0)) / 1000000
301
+
302
+ # 四舍五入到4位小数
303
+ total_input_cost = round(total_input_cost, 4)
304
+ total_output_cost = round(total_output_cost, 4)
239
305
 
240
306
  self.printer.print_in_terminal(
241
307
  "index_update_success",
@@ -243,7 +309,11 @@ class IndexManager:
243
309
  file_path=file_path,
244
310
  md5=md5,
245
311
  duration=time.monotonic() - start_time,
246
- model_name=model_name
312
+ model_name=model_name,
313
+ input_tokens=total_input_tokens,
314
+ output_tokens=total_output_tokens,
315
+ input_cost=total_input_cost,
316
+ output_cost=total_output_cost
247
317
  )
248
318
 
249
319
  except Exception as e:
@@ -263,6 +333,10 @@ class IndexManager:
263
333
  "symbols": symbols,
264
334
  "last_modified": os.path.getmtime(file_path),
265
335
  "md5": md5,
336
+ "input_tokens_count": total_input_tokens,
337
+ "generated_tokens_count": total_output_tokens,
338
+ "input_tokens_cost": total_input_cost,
339
+ "generated_tokens_cost": total_output_cost
266
340
  }
267
341
 
268
342
  def build_index(self):
@@ -290,6 +364,11 @@ class IndexManager:
290
364
 
291
365
  updated_sources = []
292
366
 
367
+ total_input_tokens = 0
368
+ total_output_tokens = 0
369
+ total_input_cost = 0.0
370
+ total_output_cost = 0.0
371
+
293
372
  with ThreadPoolExecutor(max_workers=self.args.index_build_workers) as executor:
294
373
 
295
374
  wait_to_build_files = []
@@ -346,6 +425,10 @@ class IndexManager:
346
425
  num_files=num_files
347
426
  )
348
427
  module_name = result["module_name"]
428
+ total_input_tokens += result["input_tokens_count"]
429
+ total_output_tokens += result["generated_tokens_count"]
430
+ total_input_cost += result["input_tokens_cost"]
431
+ total_output_cost += result["generated_tokens_cost"]
349
432
  index_data[module_name] = result
350
433
  updated_sources.append(module_name)
351
434
  if len(updated_sources) > 5:
@@ -357,12 +440,19 @@ class IndexManager:
357
440
  if updated_sources or keys_to_remove:
358
441
  with open(self.index_file, "w") as file:
359
442
  json.dump(index_data, file, ensure_ascii=False, indent=2)
443
+
444
+ print("")
360
445
  self.printer.print_in_terminal(
361
446
  "index_file_saved",
362
447
  style="green",
363
448
  updated_files=len(updated_sources),
364
- removed_files=len(keys_to_remove)
449
+ removed_files=len(keys_to_remove),
450
+ input_tokens=total_input_tokens,
451
+ output_tokens=total_output_tokens,
452
+ input_cost=total_input_cost,
453
+ output_cost=total_output_cost
365
454
  )
455
+ print("")
366
456
 
367
457
  return index_data
368
458
 
autocoder/models.py CHANGED
@@ -110,6 +110,20 @@ def save_models(models: List[Dict]) -> None:
110
110
  json.dump(models, f, indent=2, ensure_ascii=False)
111
111
 
112
112
 
113
+ def add_and_activate_models(models: List[Dict]) -> None:
114
+ """
115
+ 添加模型
116
+ """
117
+ exits_models = load_models()
118
+ for model in models:
119
+ if model["name"] not in [m["name"] for m in exits_models]:
120
+ exits_models.append(model)
121
+ save_models(exits_models)
122
+
123
+ for model in models:
124
+ if "api_key" in model:
125
+ update_model_with_api_key(model["name"], model["api_key"])
126
+
113
127
  def get_model_by_name(name: str) -> Dict:
114
128
  """
115
129
  根据模型名称查找模型
@@ -0,0 +1,3 @@
1
+ from .model_filter import ModelPathFilter
2
+
3
+ __all__ = ["ModelPathFilter"]
@@ -0,0 +1,100 @@
1
+ import re
2
+ import yaml
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional
5
+ from autocoder.common import AutoCoderArgs
6
+ from autocoder.utils import llms as llm_utils
7
+
8
+
9
+ class ModelPathFilter:
10
+ def __init__(self,
11
+ model_name: str,
12
+ args: AutoCoderArgs,
13
+ default_forbidden: List[str] = None):
14
+ """
15
+ 模型路径过滤器
16
+ :param model_name: 当前使用的模型名称
17
+ :param args: 自动编码器参数
18
+ :param default_forbidden: 默认禁止路径规则
19
+ """
20
+ self.model_name = model_name
21
+ if args.model_filter_path:
22
+ self.config_path = Path(args.model_filter_path)
23
+ else:
24
+ self.config_path = Path(args.source_dir, ".model_filters.yml")
25
+ self.default_forbidden = default_forbidden or []
26
+ self._rules_cache: Dict[str, List[re.Pattern]] = {}
27
+ self._load_rules()
28
+
29
+ def _load_rules(self):
30
+ """加载并编译正则规则"""
31
+ if not self.config_path.exists():
32
+ return
33
+
34
+ with open(self.config_path, 'r', encoding="utf-8") as f:
35
+ config = yaml.safe_load(f)
36
+
37
+ model_rules = config.get('model_filters', {}).get(self.model_name, {})
38
+ all_rules = model_rules.get('forbidden_paths', []) + self.default_forbidden
39
+
40
+ # 预编译正则表达式
41
+ self._rules_cache[self.model_name] = [
42
+ re.compile(rule) for rule in all_rules
43
+ ]
44
+
45
+ def is_accessible(self, file_path: str) -> bool:
46
+ """
47
+ 检查文件路径是否符合访问规则
48
+ :return: True表示允许访问,False表示禁止
49
+ """
50
+ # 优先使用模型专属规则
51
+ patterns = self._rules_cache.get(self.model_name, [])
52
+
53
+ # 回退到默认规则
54
+ if not patterns and self.default_forbidden:
55
+ patterns = [re.compile(rule) for rule in self.default_forbidden]
56
+
57
+ # 如果路径为空或None,直接返回True
58
+ if not file_path:
59
+ return True
60
+
61
+ return not any(pattern.search(file_path) for pattern in patterns)
62
+
63
+ def add_temp_rule(self, rule: str):
64
+ """
65
+ 添加临时规则
66
+ :param rule: 正则表达式规则
67
+ """
68
+ patterns = self._rules_cache.get(self.model_name, [])
69
+ patterns.append(re.compile(rule))
70
+ self._rules_cache[self.model_name] = patterns
71
+
72
+ def reload_rules(self):
73
+ """重新加载规则配置"""
74
+ self._rules_cache.clear()
75
+ self._load_rules()
76
+
77
+ def has_rules(self):
78
+ """检查是否存在规则"""
79
+ return bool(self._rules_cache.get(self.model_name, []))
80
+
81
+ @classmethod
82
+ def from_model_object(cls,
83
+ llm_obj,
84
+ args: AutoCoderArgs,
85
+ default_forbidden: Optional[List[str]] = None):
86
+ """
87
+ 从LLM对象创建过滤器
88
+ :param llm_obj: ByzerLLM实例或类似对象
89
+ :param args: 自动编码器参数
90
+ :param default_forbidden: 默认禁止路径规则
91
+ """
92
+ model_name = ",".join(llm_utils.get_llm_names(llm_obj))
93
+ if not model_name:
94
+ raise ValueError(f"{model_name} is not found")
95
+
96
+ return cls(
97
+ model_name=model_name,
98
+ args=args,
99
+ default_forbidden=default_forbidden
100
+ )