auto-coder 0.1.287__py3-none-any.whl → 0.1.289__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -22,6 +22,8 @@ from autocoder.utils.llms import get_llm_names, get_model_info
22
22
  from loguru import logger
23
23
  from byzerllm.utils.client.code_utils import extract_code
24
24
  import json
25
+ from autocoder.index.symbols_utils import extract_symbols
26
+ import os.path
25
27
 
26
28
 
27
29
  def get_file_path(file_path):
@@ -389,15 +391,45 @@ class QuickFilter():
389
391
 
390
392
  tokens_len = count_tokens(prompt_str)
391
393
 
392
- # Print current index size
394
+ # 打印当前索引大小
393
395
  self.printer.print_in_terminal(
394
396
  "quick_filter_tokens_len",
395
397
  style="blue",
396
398
  tokens_len=tokens_len
397
399
  )
398
-
399
- if tokens_len > self.max_tokens:
400
+
401
+ if tokens_len > self.max_tokens and tokens_len < 4*self.max_tokens:
402
+ # 打印 big_filter 模式的状态
403
+ self.printer.print_in_terminal(
404
+ "filter_mode_big",
405
+ style="yellow",
406
+ tokens_len=tokens_len
407
+ )
400
408
  return self.big_filter(index_items)
409
+ elif tokens_len > 4*self.max_tokens:
410
+ # 打印 super_big_filter 模式的状态
411
+ self.printer.print_in_terminal(
412
+ "filter_mode_super_big",
413
+ style="yellow",
414
+ tokens_len=tokens_len
415
+ )
416
+ round1 = self.super_big_filter(index_items)
417
+ round1_index_items = []
418
+ for file_path in round1.files.keys():
419
+ for index_item in index_items:
420
+ if index_item.module_name == file_path:
421
+ round1_index_items.append(index_item)
422
+
423
+ if round1_index_items:
424
+ round2 = self.big_filter(round1_index_items)
425
+ return round2
426
+ return round1
427
+ else:
428
+ # 打印普通过滤模式的状态
429
+ self.printer.print_in_terminal(
430
+ "filter_mode_normal",
431
+ style="blue"
432
+ )
401
433
 
402
434
  try:
403
435
  # 获取模型名称
@@ -520,3 +552,341 @@ class QuickFilter():
520
552
  has_error=False,
521
553
  file_positions=final_file_positions
522
554
  )
555
+
556
+ def super_big_filter(self, index_items: List[IndexItem]) -> QuickFilterResult:
557
+ """
558
+ 超大索引过滤方法,通过提取文件的核心信息(文件名和用途)来减少token数量
559
+ 可处理超大规模的索引文件,通过切分成多个chunks并行处理
560
+ """
561
+ compact_items = []
562
+
563
+ # 将每个索引项转换为更紧凑的格式:只保留文件名和用途
564
+ for index, item in enumerate(index_items):
565
+ # 从module_name中提取文件名
566
+ filename = os.path.basename(item.module_name)
567
+
568
+ # 从symbols中提取用途
569
+ symbols_info = extract_symbols(item.symbols)
570
+ usage = symbols_info.usage if symbols_info.usage else "无用途描述"
571
+
572
+ # 创建紧凑的表示
573
+ compact_item = {
574
+ "index": index,
575
+ "filename": filename,
576
+ "full_path": item.module_name,
577
+ "usage": usage
578
+ }
579
+ compact_items.append(compact_item)
580
+
581
+ # 切分compact_items成多个chunks
582
+ chunks = []
583
+ current_chunk = []
584
+ batch_size = 100 # 每100条记录检查一次token数量
585
+
586
+ # 计算总的tokens长度
587
+ full_prompt = self.super_big_quick_filter_files.prompt(compact_items, self.args.query)
588
+ tokens_len = count_tokens(full_prompt)
589
+
590
+ # 如果tokens长度不超过max_tokens,直接处理整个列表
591
+ if tokens_len <= self.max_tokens:
592
+ return self._process_compact_items(compact_items, index_items)
593
+
594
+ # 否则,将compact_items切分成多个chunks,每100条检查一次
595
+ for i, item in enumerate(compact_items):
596
+ current_chunk.append(item)
597
+
598
+ # 每处理batch_size条记录或者到达末尾时检查一次
599
+ if (i + 1) % batch_size == 0 or i == len(compact_items) - 1:
600
+ temp_prompt = self.super_big_quick_filter_files.prompt(current_chunk, self.args.query)
601
+ temp_size = count_tokens(temp_prompt)
602
+
603
+ # 如果当前chunk的token数超过限制,则从当前位置分割
604
+ if temp_size > self.max_tokens:
605
+ # 如果当前chunk为空,添加至少一项
606
+ if len(current_chunk) <= batch_size:
607
+ # 当前批次是第一批,但已经超过限制,则至少保留一半的记录
608
+ split_index = max(1, len(current_chunk) // 2)
609
+ chunks.append(current_chunk[:split_index])
610
+ current_chunk = current_chunk[split_index:]
611
+ else:
612
+ # 将前一批次的items作为一个chunk
613
+ prev_batch_end = len(current_chunk) - (i % batch_size + 1)
614
+ if prev_batch_end > 0:
615
+ chunks.append(current_chunk[:prev_batch_end])
616
+ current_chunk = current_chunk[prev_batch_end:]
617
+ else:
618
+ # 极端情况:即使一条记录也超过了限制,则尝试添加单条记录
619
+ chunks.append([current_chunk[0]])
620
+ current_chunk = current_chunk[1:]
621
+
622
+ # 确保最后的chunk也被添加
623
+ if current_chunk:
624
+ chunks.append(current_chunk)
625
+
626
+ # 打印切分信息
627
+ self.printer.print_in_terminal(
628
+ "super_big_filter_splitting",
629
+ style="yellow",
630
+ tokens_len=tokens_len,
631
+ max_tokens=self.max_tokens,
632
+ split_size=len(chunks)
633
+ )
634
+
635
+ # 定义处理单个chunk的函数
636
+ def process_chunk(chunk_index: int, chunk: List[dict]) -> QuickFilterResult:
637
+ # 为避免在所有chunk上都显示UI,只在第一个chunk上显示
638
+ if chunk_index == 0:
639
+ # 显示UI的处理方式
640
+ return self._process_compact_items(chunk, index_items, show_ui=True, chunk_index=chunk_index)
641
+ else:
642
+ # 非UI显示的处理方式
643
+ return self._process_compact_items(chunk, index_items, show_ui=False, chunk_index=chunk_index)
644
+
645
+ # 使用ThreadPoolExecutor并行处理所有chunks
646
+ results: List[QuickFilterResult] = []
647
+ if chunks:
648
+ with ThreadPoolExecutor() as executor:
649
+ futures = [executor.submit(process_chunk, i, chunk) for i, chunk in enumerate(chunks)]
650
+ for future in futures:
651
+ results.append(future.result())
652
+
653
+ # 合并所有结果
654
+ final_files: Dict[str, TargetFile] = {}
655
+ final_file_positions: Dict[str, int] = {}
656
+ has_error = False
657
+ error_messages: List[str] = []
658
+
659
+ # 收集所有文件和错误信息
660
+ for result in results:
661
+ if result.has_error:
662
+ has_error = True
663
+ if result.error_message:
664
+ error_messages.append(result.error_message)
665
+ final_files.update(result.files)
666
+
667
+ # 处理file_positions的交织排序
668
+ max_position = max([max(pos.values()) for pos in [result.file_positions for result in results if result.file_positions]] + [0])
669
+
670
+ # 创建position映射表
671
+ position_map = {}
672
+ for result in results:
673
+ if result.file_positions:
674
+ for file_path, position in result.file_positions.items():
675
+ if position not in position_map:
676
+ position_map[position] = []
677
+ position_map[position].append(file_path)
678
+
679
+ # 重新排序文件路径
680
+ current_index = 0
681
+ for position in range(max_position + 1):
682
+ if position in position_map:
683
+ for file_path in position_map[position]:
684
+ final_file_positions[file_path] = current_index
685
+ current_index += 1
686
+
687
+ return QuickFilterResult(
688
+ files=final_files,
689
+ has_error=has_error,
690
+ error_message="\n".join(error_messages) if error_messages else None,
691
+ file_positions=final_file_positions
692
+ )
693
+
694
+ def _process_compact_items(self, compact_items: List[dict], index_items: List[IndexItem], show_ui: bool = True, chunk_index: int = 0) -> QuickFilterResult:
695
+ """
696
+ 处理一组compact_items,返回QuickFilterResult
697
+ """
698
+ # 使用流式输出处理
699
+ model_names = get_llm_names(self.index_manager.index_filter_llm)
700
+ model_name = ",".join(model_names)
701
+
702
+ # 获取模型价格信息
703
+ model_info_map = {}
704
+ for name in model_names:
705
+ info = get_model_info(name, self.args.product_mode)
706
+ if info:
707
+ model_info_map[name] = {
708
+ "input_price": info.get("input_price", 0.0),
709
+ "output_price": info.get("output_price", 0.0)
710
+ }
711
+
712
+ try:
713
+ start_time = time.monotonic()
714
+ # 渲染 Prompt 模板
715
+ prompt = self.super_big_quick_filter_files.prompt(compact_items, self.args.query)
716
+
717
+ if show_ui:
718
+ # 使用流式输出处理
719
+ stream_generator = stream_chat_with_continue(
720
+ self.index_manager.index_filter_llm,
721
+ [{"role": "user", "content": prompt}],
722
+ {}
723
+ )
724
+
725
+ def extract_file_number_list(content: str) -> str:
726
+ try:
727
+ v = to_model(content, FileNumberList)
728
+ return "\n".join([index_items[compact_items[file_number]["index"]].module_name for file_number in v.file_list])
729
+ except Exception as e:
730
+ logger.error(f"Error extracting file number list: {e}")
731
+ return content
732
+
733
+ # 获取完整响应
734
+ full_response, last_meta = stream_out(
735
+ stream_generator,
736
+ model_name=model_name,
737
+ title=self.printer.get_message_from_key_with_format(
738
+ "super_big_filter_title", model_name=model_name),
739
+ args=self.args,
740
+ display_func=extract_file_number_list
741
+ )
742
+
743
+ # 解析结果
744
+ file_number_list = to_model(full_response, FileNumberList)
745
+ end_time = time.monotonic()
746
+
747
+ # 计算总成本
748
+ total_input_cost = 0.0
749
+ total_output_cost = 0.0
750
+
751
+ for name in model_names:
752
+ info = model_info_map.get(name, {})
753
+ total_input_cost += (last_meta.input_tokens_count *
754
+ info.get("input_price", 0.0)) / 1000000
755
+ total_output_cost += (last_meta.generated_tokens_count *
756
+ info.get("output_price", 0.0)) / 1000000
757
+
758
+ # 四舍五入到4位小数
759
+ total_input_cost = round(total_input_cost, 4)
760
+ total_output_cost = round(total_output_cost, 4)
761
+ speed = last_meta.generated_tokens_count / (end_time - start_time)
762
+
763
+ # 打印 token 统计信息和成本
764
+ self.printer.print_in_terminal(
765
+ "super_big_filter_stats",
766
+ style="blue",
767
+ elapsed_time=f"{end_time - start_time:.2f}",
768
+ input_tokens=last_meta.input_tokens_count,
769
+ output_tokens=last_meta.generated_tokens_count,
770
+ input_cost=total_input_cost,
771
+ output_cost=total_output_cost,
772
+ model_names=model_name,
773
+ speed=f"{speed:.2f}",
774
+ chunk_index=chunk_index
775
+ )
776
+ else:
777
+ # 非UI模式,直接使用LLM处理
778
+ meta_holder = MetaHolder()
779
+ file_number_list = self.super_big_quick_filter_files.with_llm(self.index_manager.index_filter_llm).with_meta(
780
+ meta_holder).with_return_type(FileNumberList).run(compact_items, self.args.query)
781
+ end_time = time.monotonic()
782
+
783
+ # 打印处理信息
784
+ if meta_holder.get_meta():
785
+ meta_dict = meta_holder.get_meta()
786
+ total_input_cost = meta_dict.get("input_tokens_count", 0) * model_info_map.get(model_name, {}).get("input_price", 0.0) / 1000000
787
+ total_output_cost = meta_dict.get("generated_tokens_count", 0) * model_info_map.get(model_name, {}).get("output_price", 0.0) / 1000000
788
+
789
+ self.printer.print_in_terminal(
790
+ "super_big_filter_stats",
791
+ style="blue",
792
+ input_tokens=meta_dict.get("input_tokens_count", 0),
793
+ output_tokens=meta_dict.get("generated_tokens_count", 0),
794
+ input_cost=total_input_cost,
795
+ output_cost=total_output_cost,
796
+ model_names=model_name,
797
+ elapsed_time=f"{end_time - start_time:.2f}",
798
+ chunk_index=chunk_index
799
+ )
800
+
801
+ # 构建返回结果
802
+ files = {}
803
+ file_positions = {}
804
+
805
+ if file_number_list:
806
+ validated_file_numbers = []
807
+ for file_number in file_number_list.file_list:
808
+ if file_number < 0 or file_number >= len(compact_items):
809
+ self.printer.print_in_terminal(
810
+ "invalid_file_number",
811
+ style="yellow",
812
+ file_number=file_number,
813
+ total_files=len(compact_items)
814
+ )
815
+ continue
816
+
817
+ # 获取实际的index_item索引
818
+ original_index = compact_items[file_number]["index"]
819
+ validated_file_numbers.append(original_index)
820
+
821
+ # 将最终选中的文件加入files
822
+ for index, file_number in enumerate(validated_file_numbers):
823
+ file_path = get_file_path(index_items[file_number].module_name)
824
+ files[file_path] = TargetFile(
825
+ file_path=index_items[file_number].module_name,
826
+ reason=self.printer.get_message_from_key("quick_filter_reason")
827
+ )
828
+ file_positions[file_path] = index
829
+
830
+ return QuickFilterResult(
831
+ files=files,
832
+ has_error=False,
833
+ file_positions=file_positions
834
+ )
835
+
836
+ except Exception as e:
837
+ self.printer.print_in_terminal(
838
+ "super_big_filter_failed",
839
+ style="red",
840
+ error=str(e)
841
+ )
842
+ return QuickFilterResult(
843
+ files={},
844
+ has_error=True,
845
+ error_message=str(e)
846
+ )
847
+
848
+ @byzerllm.prompt()
849
+ def super_big_quick_filter_files(self, compact_items: List[dict], query: str) -> str:
850
+ '''
851
+ 当用户提一个需求的时候,我们要找到相关的源码文件。
852
+
853
+ 下面是简化的索引文件列表,每项包含文件序号(index,##[]括起来的部分)、文件名(filename)和用途描述(usage):
854
+
855
+ <index>
856
+ {{ file_meta_str }}
857
+ </index>
858
+
859
+ 下面是用户的查询需求:
860
+
861
+ <query>
862
+ {{ query }}
863
+ </query>
864
+
865
+ 请根据用户的需求,找到相关的文件,并给出文件序号列表。请返回如下json格式:
866
+
867
+ ```json
868
+ {
869
+ "file_list": [
870
+ file_index1,
871
+ file_index2,
872
+ ...
873
+ ]
874
+ }
875
+ ```
876
+
877
+ 特别注意:
878
+ 1. 如果用户的query里有 @文件 或者 @@符号,请匹配对应的文件名,优先返回这些文件。
879
+ 2. 根据用户需求找出需要被修改的文件(edited_files),以及可能需要作为参考的文件(reference_files)。
880
+ 3. file_list 里的文件序号,按被 @ 的文件、edited_files文件和reference_files文件的顺序排列。
881
+ 4. 如果 query 里是一段历史对话,那么对话里提及的文件必须要返回。
882
+ 5. 如果用户需求为空,则直接返回空列表即可。
883
+ 6. 返回的 json格式数据不允许有注释
884
+ '''
885
+ file_meta_str = "\n".join(
886
+ [f"##[{index}]{item['filename']}\n{item['usage']}" for index, item in enumerate(compact_items)])
887
+
888
+ context = {
889
+ "file_meta_str": file_meta_str,
890
+ "query": query
891
+ }
892
+ return context