auto-coder 0.1.267__py3-none-any.whl → 0.1.269__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -3,8 +3,9 @@ import json
3
3
  from pydantic import BaseModel
4
4
  import byzerllm
5
5
  from autocoder.common.printer import Printer
6
- from autocoder.utils.llms import count_tokens
6
+ from autocoder.rag.token_counter import count_tokens
7
7
  from loguru import logger
8
+ from autocoder.common import AutoCoderArgs
8
9
 
9
10
  class PruneStrategy(BaseModel):
10
11
  name: str
@@ -12,25 +13,25 @@ class PruneStrategy(BaseModel):
12
13
  config: Dict[str, Any] = {"safe_zone_tokens": 0, "group_size": 4}
13
14
 
14
15
  class ConversationPruner:
15
- def __init__(self, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM],
16
- safe_zone_tokens: int = 500, group_size: int = 4):
16
+ def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
17
+ self.args = args
17
18
  self.llm = llm
18
19
  self.printer = Printer()
19
20
  self.strategies = {
20
21
  "summarize": PruneStrategy(
21
22
  name="summarize",
22
23
  description="对早期对话进行分组摘要,保留关键信息",
23
- config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
24
+ config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
24
25
  ),
25
26
  "truncate": PruneStrategy(
26
27
  name="truncate",
27
28
  description="分组截断最早的部分对话",
28
- config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
29
+ config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
29
30
  ),
30
31
  "hybrid": PruneStrategy(
31
32
  name="hybrid",
32
33
  description="先尝试分组摘要,如果仍超限则分组截断",
33
- config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
34
+ config={"safe_zone_tokens": self.args.conversation_prune_safe_zone_tokens, "group_size": self.args.conversation_prune_group_size}
34
35
  )
35
36
  }
36
37
 
@@ -57,7 +58,7 @@ class ConversationPruner:
57
58
  if strategy.name == "summarize":
58
59
  return self._summarize_prune(conversations, strategy.config)
59
60
  elif strategy.name == "truncate":
60
- return self._truncate_prune.with_llm(self.llm).run(conversations)
61
+ return self._truncate_prune(conversations)
61
62
  elif strategy.name == "hybrid":
62
63
  pruned = self._summarize_prune(conversations, strategy.config)
63
64
  if count_tokens(json.dumps(pruned, ensure_ascii=False)) > self.args.conversation_prune_safe_zone_tokens:
@@ -80,8 +81,8 @@ class ConversationPruner:
80
81
  break
81
82
 
82
83
  # 找到要处理的对话组
83
- early_conversations = processed_conversations[:-group_size]
84
- recent_conversations = processed_conversations[-group_size:]
84
+ early_conversations = processed_conversations[:group_size]
85
+ recent_conversations = processed_conversations[group_size:]
85
86
 
86
87
  if not early_conversations:
87
88
  break
@@ -90,7 +91,7 @@ class ConversationPruner:
90
91
  group_summary = self._generate_summary.with_llm(self.llm).run(early_conversations[-group_size:])
91
92
 
92
93
  # 更新对话历史
93
- processed_conversations = early_conversations[:-group_size] + [
94
+ processed_conversations = [
94
95
  {"role": "user", "content": f"历史对话摘要:\n{group_summary}"},
95
96
  {"role": "assistant", "content": f"收到"}
96
97
  ] + recent_conversations
autocoder/index/entry.py CHANGED
@@ -58,8 +58,12 @@ def build_index_and_filter_files(
58
58
  return file_path.strip()[2:]
59
59
  return file_path
60
60
 
61
+ # 文件路径 -> TargetFile
61
62
  final_files: Dict[str, TargetFile] = {}
62
63
 
64
+ # 文件路径 -> 文件在文件列表中的位置(越前面表示越相关)
65
+ file_positions:Dict[str,int] = {}
66
+
63
67
  # Phase 1: Process REST/RAG/Search sources
64
68
  printer = Printer()
65
69
  printer.print_in_terminal("phase1_processing_sources")
@@ -102,25 +106,20 @@ def build_index_and_filter_files(
102
106
  })
103
107
  )
104
108
  )
105
-
109
+
110
+
106
111
  if not args.skip_filter_index and args.index_filter_model:
107
112
  model_name = getattr(index_manager.index_filter_llm, 'default_model_name', None)
108
113
  if not model_name:
109
114
  model_name = "unknown(without default model name)"
110
115
  printer.print_in_terminal("quick_filter_start", style="blue", model_name=model_name)
111
116
  quick_filter = QuickFilter(index_manager,stats,sources)
112
- quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
113
- # if quick_filter_result.has_error:
114
- # raise KeyboardInterrupt(printer.get_message_from_key_with_format("quick_filter_failed",error=quick_filter_result.error_message))
115
-
116
- # Merge quick filter results into final_files
117
- if args.context_prune:
118
- context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
119
- pruned_files = context_pruner.handle_overflow(quick_filter_result.files, [{"role":"user","content":args.query}], args.context_prune_strategy)
120
- for source_file in pruned_files:
121
- final_files[source_file.module_name] = quick_filter_result.files[source_file.module_name]
122
- else:
123
- final_files.update(quick_filter_result.files)
117
+ quick_filter_result = quick_filter.filter(index_manager.read_index(),args.query)
118
+
119
+ final_files.update(quick_filter_result.files)
120
+
121
+ if quick_filter_result.file_positions:
122
+ file_positions.update(quick_filter_result.file_positions)
124
123
 
125
124
  if not args.skip_filter_index and not args.index_filter_model:
126
125
  model_name = getattr(index_manager.llm, 'default_model_name', None)
@@ -261,32 +260,53 @@ def build_index_and_filter_files(
261
260
  for file in final_filenames:
262
261
  print(f"{file} - {final_files[file].reason}")
263
262
 
264
- source_code = ""
263
+ # source_code = ""
265
264
  source_code_list = SourceCodeList(sources=[])
266
265
  depulicated_sources = set()
267
-
266
+
267
+ ## 先去重
268
+ temp_sources = []
268
269
  for file in sources:
269
270
  if file.module_name in final_filenames:
270
271
  if file.module_name in depulicated_sources:
271
272
  continue
272
273
  depulicated_sources.add(file.module_name)
273
- source_code += f"##File: {file.module_name}\n"
274
- source_code += f"{file.source_code}\n\n"
275
- source_code_list.sources.append(file)
274
+ # source_code += f"##File: {file.module_name}\n"
275
+ # source_code += f"{file.source_code}\n\n"
276
+ temp_sources.append(file)
277
+
278
+ ## 开启了裁剪,则需要做裁剪,不过目前只针对 quick filter 生效
279
+ if args.context_prune:
280
+ context_pruner = PruneContext(max_tokens=args.conversation_prune_safe_zone_tokens, args=args, llm=llm)
281
+ # 如果 file_positions 不为空,则通过 file_positions 来获取文件
282
+ if file_positions:
283
+ ## 拿到位置列表,然后根据位置排序 得到 [(pos,file_path)]
284
+ ## 将 [(pos,file_path)] 转换为 [file_path]
285
+ ## 通过 [file_path] 顺序调整 temp_sources 的顺序
286
+ ## MARK
287
+ # 将 file_positions 转换为 [(pos, file_path)] 的列表
288
+ position_file_pairs = [(pos, file_path) for file_path, pos in file_positions.items()]
289
+ # 按位置排序
290
+ position_file_pairs.sort(key=lambda x: x[0])
291
+ # 提取排序后的文件路径列表
292
+ sorted_file_paths = [file_path for _, file_path in position_file_pairs]
293
+ # 根据 sorted_file_paths 重新排序 temp_sources
294
+ temp_sources.sort(key=lambda x: sorted_file_paths.index(x.module_name) if x.module_name in sorted_file_paths else len(sorted_file_paths))
295
+ pruned_files = context_pruner.handle_overflow([source.module_name for source in temp_sources], [{"role":"user","content":args.query}], args.context_prune_strategy)
296
+ source_code_list.sources = pruned_files
297
+
276
298
  if args.request_id and not args.skip_events:
277
299
  queue_communicate.send_event(
278
300
  request_id=args.request_id,
279
301
  event=CommunicateEvent(
280
302
  event_type=CommunicateEventType.CODE_INDEX_FILTER_FILE_SELECTED.value,
281
303
  data=json.dumps([
282
- (file["file_path"], file.reason)
283
- for file in final_files.values()
284
- if file.file_path in depulicated_sources
304
+ (file.module_name, "") for file in source_code_list.sources
285
305
  ])
286
306
  )
287
307
  )
288
308
 
289
- stats["final_files"] = len(depulicated_sources)
309
+ stats["final_files"] = len(source_code_list.sources)
290
310
  phase_end = time.monotonic()
291
311
  stats["timings"]["prepare_output"] = phase_end - phase_start
292
312
 
autocoder/index/index.py CHANGED
@@ -12,6 +12,7 @@ from autocoder.index.symbols_utils import (
12
12
  from autocoder.privacy.model_filter import ModelPathFilter
13
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
14
14
  import threading
15
+ import re
15
16
 
16
17
  import byzerllm
17
18
  import hashlib
@@ -27,6 +28,8 @@ from autocoder.index.types import (
27
28
  from autocoder.common.global_cancel import global_cancel
28
29
  from autocoder.utils.llms import get_llm_names
29
30
  from autocoder.rag.token_counter import count_tokens
31
+
32
+
30
33
  class IndexManager:
31
34
  def __init__(
32
35
  self, llm: byzerllm.ByzerLLM, sources: List[SourceCode], args: AutoCoderArgs
@@ -52,12 +55,14 @@ class IndexManager:
52
55
  self.index_filter_llm = llm
53
56
 
54
57
  self.llm = llm
55
-
58
+
56
59
  # Initialize model filters
57
60
  if self.index_llm:
58
- self.index_model_filter = ModelPathFilter.from_model_object(self.index_llm, args)
61
+ self.index_model_filter = ModelPathFilter.from_model_object(
62
+ self.index_llm, args)
59
63
  if self.index_filter_llm:
60
- self.index_filter_model_filter = ModelPathFilter.from_model_object(self.index_filter_llm, args)
64
+ self.index_filter_model_filter = ModelPathFilter.from_model_object(
65
+ self.index_filter_llm, args)
61
66
  self.args = args
62
67
  self.max_input_length = (
63
68
  args.index_model_max_input_length or args.model_max_input_length
@@ -68,7 +73,6 @@ class IndexManager:
68
73
  if not os.path.exists(self.index_dir):
69
74
  os.makedirs(self.index_dir)
70
75
 
71
-
72
76
  @byzerllm.prompt()
73
77
  def verify_file_relevance(self, file_content: str, query: str) -> str:
74
78
  """
@@ -201,12 +205,12 @@ class IndexManager:
201
205
  if current_chunk:
202
206
  chunks.append("\n".join(current_chunk))
203
207
  return chunks
204
-
208
+
205
209
  def should_skip(self, file_path: str):
206
210
  ext = os.path.splitext(file_path)[1].lower()
207
211
  if ext in [".md", ".html", ".txt", ".doc", ".pdf"]:
208
212
  return True
209
-
213
+
210
214
  # Check model filter restrictions
211
215
  if self.index_model_filter and not self.index_model_filter.is_accessible(file_path):
212
216
  self.printer.print_in_terminal(
@@ -216,10 +220,10 @@ class IndexManager:
216
220
  model_name=",".join(get_llm_names(self.index_llm))
217
221
  )
218
222
  return True
219
-
223
+
220
224
  return False
221
225
 
222
- def build_index_for_single_source(self, source: SourceCode):
226
+ def build_index_for_single_source(self, source: SourceCode):
223
227
  if global_cancel.requested:
224
228
  return None
225
229
 
@@ -251,13 +255,13 @@ class IndexManager:
251
255
 
252
256
  start_time = time.monotonic()
253
257
  source_code = source.source_code
254
-
258
+
255
259
  # 统计token和成本
256
260
  total_input_tokens = 0
257
261
  total_output_tokens = 0
258
262
  total_input_cost = 0.0
259
263
  total_output_cost = 0.0
260
-
264
+
261
265
  if count_tokens(source.source_code) > self.args.conversation_prune_safe_zone_tokens:
262
266
  self.printer.print_in_terminal(
263
267
  "index_file_too_large",
@@ -276,34 +280,40 @@ class IndexManager:
276
280
  self.index_llm).with_meta(meta_holder).run(source.module_name, chunk)
277
281
  time.sleep(self.anti_quota_limit)
278
282
  symbols.append(chunk_symbols)
279
-
283
+
280
284
  if meta_holder.get_meta():
281
285
  meta_dict = meta_holder.get_meta()
282
- total_input_tokens += meta_dict.get("input_tokens_count", 0)
283
- total_output_tokens += meta_dict.get("generated_tokens_count", 0)
284
-
286
+ total_input_tokens += meta_dict.get(
287
+ "input_tokens_count", 0)
288
+ total_output_tokens += meta_dict.get(
289
+ "generated_tokens_count", 0)
290
+
285
291
  symbols = "\n".join(symbols)
286
292
  else:
287
293
  meta_holder = byzerllm.MetaHolder()
288
294
  symbols = self.get_all_file_symbols.with_llm(
289
295
  self.index_llm).with_meta(meta_holder).run(source.module_name, source_code)
290
296
  time.sleep(self.anti_quota_limit)
291
-
297
+
292
298
  if meta_holder.get_meta():
293
299
  meta_dict = meta_holder.get_meta()
294
- total_input_tokens += meta_dict.get("input_tokens_count", 0)
295
- total_output_tokens += meta_dict.get("generated_tokens_count", 0)
296
-
300
+ total_input_tokens += meta_dict.get(
301
+ "input_tokens_count", 0)
302
+ total_output_tokens += meta_dict.get(
303
+ "generated_tokens_count", 0)
304
+
297
305
  # 计算总成本
298
306
  for name in model_names:
299
307
  info = model_info_map.get(name, {})
300
- total_input_cost += (total_input_tokens * info.get("input_price", 0.0)) / 1000000
301
- total_output_cost += (total_output_tokens * info.get("output_price", 0.0)) / 1000000
302
-
308
+ total_input_cost += (total_input_tokens *
309
+ info.get("input_price", 0.0)) / 1000000
310
+ total_output_cost += (total_output_tokens *
311
+ info.get("output_price", 0.0)) / 1000000
312
+
303
313
  # 四舍五入到4位小数
304
314
  total_input_cost = round(total_input_cost, 4)
305
315
  total_output_cost = round(total_output_cost, 4)
306
-
316
+
307
317
  self.printer.print_in_terminal(
308
318
  "index_update_success",
309
319
  style="green",
@@ -340,9 +350,44 @@ class IndexManager:
340
350
  "generated_tokens_cost": total_output_cost
341
351
  }
342
352
 
353
+ def parse_exclude_files(self, exclude_files):
354
+ if not exclude_files:
355
+ return []
356
+
357
+ if isinstance(exclude_files, str):
358
+ exclude_files = [exclude_files]
359
+
360
+ exclude_patterns = []
361
+ for pattern in exclude_files:
362
+ if pattern.startswith("regex://"):
363
+ pattern = pattern[8:]
364
+ exclude_patterns.append(re.compile(pattern))
365
+ elif pattern.startswith("human://"):
366
+ pattern = pattern[8:]
367
+ v = (
368
+ self.generate_regex_pattern.with_llm(self.llm)
369
+ .with_extractor(self.extract_regex_pattern)
370
+ .run(desc=pattern)
371
+ )
372
+ if not v:
373
+ raise ValueError(
374
+ "Fail to generate regex pattern, try again.")
375
+ exclude_patterns.append(re.compile(v))
376
+ else:
377
+ raise ValueError(
378
+ "Invalid exclude_files format. Expected 'regex://<pattern>' or 'human://<description>' "
379
+ )
380
+ return exclude_patterns
381
+
382
+ def filter_exclude_files(self, file_path, exclude_patterns):
383
+ for pattern in exclude_patterns:
384
+ if pattern.search(file_path):
385
+ return True
386
+ return False
387
+
343
388
  def build_index(self):
344
389
  if os.path.exists(self.index_file):
345
- with open(self.index_file, "r",encoding="utf-8") as file:
390
+ with open(self.index_file, "r", encoding="utf-8") as file:
346
391
  index_data = json.load(file)
347
392
  else:
348
393
  index_data = {}
@@ -351,14 +396,27 @@ class IndexManager:
351
396
  keys_to_remove = []
352
397
  for file_path in index_data:
353
398
  if not os.path.exists(file_path):
354
- keys_to_remove.append(file_path)
355
-
399
+ keys_to_remove.append(file_path)
400
+
401
+ # 删除被排除的文件
402
+ try:
403
+ exclude_patterns = self.parse_exclude_files(self.args.exclude_files)
404
+ for file_path in index_data:
405
+ if self.filter_exclude_files(file_path, exclude_patterns):
406
+ keys_to_remove.append(file_path)
407
+ except Exception as e:
408
+ self.printer.print_in_terminal(
409
+ "index_exclude_files_error",
410
+ style="red",
411
+ error=str(e)
412
+ )
413
+
356
414
  # 删除无效条目并记录日志
357
415
  for key in set(keys_to_remove):
358
416
  if key in index_data:
359
417
  del index_data[key]
360
418
  self.printer.print_in_terminal(
361
- "index_file_removed",
419
+ "index_file_removed",
362
420
  style="yellow",
363
421
  file_path=key
364
422
  )
@@ -388,7 +446,7 @@ class IndexManager:
388
446
  for line in v:
389
447
  new_v.append(line[line.find(":"):])
390
448
  source_code = "\n".join(new_v)
391
-
449
+
392
450
  md5 = hashlib.md5(source_code.encode("utf-8")).hexdigest()
393
451
  if (
394
452
  source.module_name not in index_data
@@ -397,7 +455,8 @@ class IndexManager:
397
455
  wait_to_build_files.append(source)
398
456
 
399
457
  # Remove duplicates based on module_name
400
- wait_to_build_files = list({source.module_name: source for source in wait_to_build_files}.values())
458
+ wait_to_build_files = list(
459
+ {source.module_name: source for source in wait_to_build_files}.values())
401
460
 
402
461
  counter = 0
403
462
  num_files = len(wait_to_build_files)
@@ -433,16 +492,17 @@ class IndexManager:
433
492
  index_data[module_name] = result
434
493
  updated_sources.append(module_name)
435
494
  if len(updated_sources) > 5:
436
- with open(self.index_file, "w",encoding="utf-8") as file:
437
- json.dump(index_data, file, ensure_ascii=False, indent=2)
495
+ with open(self.index_file, "w", encoding="utf-8") as file:
496
+ json.dump(index_data, file,
497
+ ensure_ascii=False, indent=2)
438
498
  updated_sources = []
439
-
499
+
440
500
  # 如果 updated_sources 或 keys_to_remove 有值,则保存索引文件
441
501
  if updated_sources or keys_to_remove:
442
- with open(self.index_file, "w",encoding="utf-8") as file:
502
+ with open(self.index_file, "w", encoding="utf-8") as file:
443
503
  json.dump(index_data, file, ensure_ascii=False, indent=2)
444
504
 
445
- print("")
505
+ print("")
446
506
  self.printer.print_in_terminal(
447
507
  "index_file_saved",
448
508
  style="green",
@@ -461,14 +521,14 @@ class IndexManager:
461
521
  if not os.path.exists(self.index_file):
462
522
  return []
463
523
 
464
- with open(self.index_file, "r",encoding="utf-8") as file:
524
+ with open(self.index_file, "r", encoding="utf-8") as file:
465
525
  return file.read()
466
526
 
467
527
  def read_index(self) -> List[IndexItem]:
468
528
  if not os.path.exists(self.index_file):
469
529
  return []
470
530
 
471
- with open(self.index_file, "r",encoding="utf-8") as file:
531
+ with open(self.index_file, "r", encoding="utf-8") as file:
472
532
  index_data = json.load(file)
473
533
 
474
534
  index_items = []
@@ -572,7 +632,7 @@ class IndexManager:
572
632
  {file.file_path: file for file in all_results}.values())
573
633
  return FileList(file_list=all_results)
574
634
 
575
- def _query_index_with_thread(self, query, func):
635
+ def _query_index_with_thread(self, query, func):
576
636
  all_results = []
577
637
  lock = threading.Lock()
578
638
  completed_threads = 0
@@ -582,7 +642,7 @@ class IndexManager:
582
642
  nonlocal completed_threads
583
643
  result = self._get_target_files_by_query.with_llm(
584
644
  self.llm).with_return_type(FileList).run(chunk, query)
585
-
645
+
586
646
  if result is not None:
587
647
  with lock:
588
648
  all_results.extend(result.file_list)
@@ -708,4 +768,3 @@ class IndexManager:
708
768
 
709
769
  请确保结果的准确性和完整性,包括所有可能相关的文件。
710
770
  """
711
-
@@ -0,0 +1,120 @@
1
+ import os
2
+ import json
3
+ from collections import defaultdict
4
+ from typing import Dict, List, Set, Tuple
5
+ from pathlib import Path
6
+ from loguru import logger
7
+ import byzerllm
8
+ from autocoder.common import AutoCoderArgs
9
+ from autocoder.common.printer import Printer
10
+ from typing import Union
11
+ import pydantic
12
+ from autocoder.common.result_manager import ResultManager
13
+
14
+ class ExtensionClassifyResult(pydantic.BaseModel):
15
+ code: List[str] = []
16
+ config: List[str] = []
17
+ data: List[str] = []
18
+ document: List[str] = []
19
+ other: List[str] = []
20
+ framework: List[str] = []
21
+
22
+ class ProjectTypeAnalyzer:
23
+ def __init__(self, args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
24
+ self.args = args
25
+ self.llm = llm
26
+ self.printer = Printer()
27
+ self.default_exclude_dirs = [
28
+ ".git", ".svn", ".hg", "build", "dist", "__pycache__",
29
+ "node_modules", ".auto-coder", ".vscode", ".idea", "venv",
30
+ ".next", ".nuxt", ".svelte-kit", "out", "cache", "logs",
31
+ "temp", "tmp", "coverage", ".DS_Store", "public", "static"
32
+ ]
33
+ self.extension_counts = defaultdict(int)
34
+ self.stats_file = Path(args.source_dir) / ".auto-coder" / "project_type_stats.json"
35
+ self.result_manager = ResultManager()
36
+
37
+ def traverse_project(self) -> None:
38
+ """遍历项目目录,统计文件后缀"""
39
+ for root, dirs, files in os.walk(self.args.source_dir):
40
+ # 过滤掉默认排除的目录
41
+ dirs[:] = [d for d in dirs if d not in self.default_exclude_dirs]
42
+
43
+ for file in files:
44
+ _, ext = os.path.splitext(file)
45
+ if ext: # 只统计有后缀的文件
46
+ self.extension_counts[ext.lower()] += 1
47
+
48
+ def count_extensions(self) -> Dict[str, int]:
49
+ """返回文件后缀统计结果"""
50
+ return dict(sorted(self.extension_counts.items(), key=lambda x: x[1], reverse=True))
51
+
52
+ @byzerllm.prompt()
53
+ def classify_extensions(self, extensions: str) -> str:
54
+ """
55
+ 根据文件后缀列表,将后缀分类为代码、配置、数据、文档等类型。
56
+
57
+ 文件后缀列表:
58
+ {{ extensions }}
59
+
60
+ 请返回如下JSON格式:
61
+ {
62
+ "code": ["后缀1", "后缀2"],
63
+ "config": ["后缀3", "后缀4"],
64
+ "data": ["后缀5", "后缀6"],
65
+ "document": ["后缀7", "后缀8"],
66
+ "other": ["后缀9", "后缀10"],
67
+ "framework": ["后缀11", "后缀12"]
68
+ }
69
+ """
70
+ return {
71
+ "extensions": extensions
72
+ }
73
+
74
+ def save_stats(self) -> None:
75
+ """保存统计结果到文件"""
76
+ stats = {
77
+ "extension_counts": self.extension_counts,
78
+ "project_type": self.detect_project_type()
79
+ }
80
+
81
+ # 确保目录存在
82
+ self.stats_file.parent.mkdir(parents=True, exist_ok=True)
83
+
84
+ with open(self.stats_file, "w", encoding="utf-8") as f:
85
+ json.dump(stats, f, indent=2)
86
+
87
+ self.printer.print_in_terminal("stats_saved", path=str(self.stats_file))
88
+
89
+ def load_stats(self) -> Dict[str, any]:
90
+ """从文件加载统计结果"""
91
+ if not self.stats_file.exists():
92
+ self.printer.print_in_terminal("stats_not_found", path=str(self.stats_file))
93
+ return {}
94
+
95
+ with open(self.stats_file, "r", encoding="utf-8") as f:
96
+ return json.load(f)
97
+
98
+ def detect_project_type(self) -> str:
99
+ """根据后缀统计结果推断项目类型"""
100
+ # 获取统计结果
101
+ ext_counts = self.count_extensions()
102
+ # 将后缀分类
103
+ classification = self.classify_extensions.with_llm(self.llm).with_return_type(ExtensionClassifyResult).run(json.dumps(ext_counts,ensure_ascii=False))
104
+ return ",".join(classification.code)
105
+
106
+ def analyze(self) -> Dict[str, any]:
107
+ """执行完整的项目类型分析流程"""
108
+ # 遍历项目目录
109
+ self.traverse_project()
110
+
111
+ # 检测项目类型
112
+ project_type = self.detect_project_type()
113
+
114
+ self.result_manager.add_result(content=project_type, meta={
115
+ "action": "get_project_type",
116
+ "input": {
117
+
118
+ }
119
+ })
120
+ return project_type
@@ -25,8 +25,8 @@ PROVIDER_INFO_LIST = [
25
25
  ProviderInfo(
26
26
  name="volcano",
27
27
  endpoint="https://ark.cn-beijing.volces.com/api/v3",
28
- r1_model="",
29
- v3_model="",
28
+ r1_model="deepseek-r1-250120",
29
+ v3_model="deepseek-v3-241226",
30
30
  api_key="",
31
31
  r1_input_price=2.0,
32
32
  r1_output_price=8.0,
@@ -162,32 +162,32 @@ class ModelProviderSelector:
162
162
  provider_info = provider
163
163
  break
164
164
 
165
- if result == "volcano":
166
- # Get R1 endpoint
167
- r1_endpoint = input_dialog(
168
- title=self.printer.get_message_from_key("model_provider_api_key_title"),
169
- text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
170
- validator=VolcanoEndpointValidator(),
171
- style=dialog_style
172
- ).run()
165
+ # if result == "volcano":
166
+ # # Get R1 endpoint
167
+ # r1_endpoint = input_dialog(
168
+ # title=self.printer.get_message_from_key("model_provider_api_key_title"),
169
+ # text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
170
+ # validator=VolcanoEndpointValidator(),
171
+ # style=dialog_style
172
+ # ).run()
173
173
 
174
- if r1_endpoint is None:
175
- return None
174
+ # if r1_endpoint is None:
175
+ # return None
176
176
 
177
- provider_info.r1_model = r1_endpoint
177
+ # provider_info.r1_model = r1_endpoint
178
178
 
179
- # Get V3 endpoint
180
- v3_endpoint = input_dialog(
181
- title=self.printer.get_message_from_key("model_provider_api_key_title"),
182
- text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
183
- validator=VolcanoEndpointValidator(),
184
- style=dialog_style
185
- ).run()
179
+ # # Get V3 endpoint
180
+ # v3_endpoint = input_dialog(
181
+ # title=self.printer.get_message_from_key("model_provider_api_key_title"),
182
+ # text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
183
+ # validator=VolcanoEndpointValidator(),
184
+ # style=dialog_style
185
+ # ).run()
186
186
 
187
- if v3_endpoint is None:
188
- return None
187
+ # if v3_endpoint is None:
188
+ # return None
189
189
 
190
- provider_info.v3_model = v3_endpoint
190
+ # provider_info.v3_model = v3_endpoint
191
191
 
192
192
  # Get API key for all providers
193
193
  api_key = input_dialog(
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.267"
1
+ __version__ = "0.1.269"