auto-coder 0.1.288__py3-none-any.whl → 0.1.289__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.288.dist-info → auto_coder-0.1.289.dist-info}/METADATA +1 -1
- {auto_coder-0.1.288.dist-info → auto_coder-0.1.289.dist-info}/RECORD +14 -13
- autocoder/chat_auto_coder_lang.py +16 -16
- autocoder/common/auto_coder_lang.py +16 -4
- autocoder/common/mcp_hub.py +99 -77
- autocoder/common/mcp_server.py +162 -61
- autocoder/index/filter/quick_filter.py +373 -3
- autocoder/rag/long_context_rag.py +22 -9
- autocoder/rag/searchable.py +58 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.288.dist-info → auto_coder-0.1.289.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.288.dist-info → auto_coder-0.1.289.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.288.dist-info → auto_coder-0.1.289.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.288.dist-info → auto_coder-0.1.289.dist-info}/top_level.txt +0 -0
|
@@ -22,6 +22,8 @@ from autocoder.utils.llms import get_llm_names, get_model_info
|
|
|
22
22
|
from loguru import logger
|
|
23
23
|
from byzerllm.utils.client.code_utils import extract_code
|
|
24
24
|
import json
|
|
25
|
+
from autocoder.index.symbols_utils import extract_symbols
|
|
26
|
+
import os.path
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def get_file_path(file_path):
|
|
@@ -389,15 +391,45 @@ class QuickFilter():
|
|
|
389
391
|
|
|
390
392
|
tokens_len = count_tokens(prompt_str)
|
|
391
393
|
|
|
392
|
-
#
|
|
394
|
+
# 打印当前索引大小
|
|
393
395
|
self.printer.print_in_terminal(
|
|
394
396
|
"quick_filter_tokens_len",
|
|
395
397
|
style="blue",
|
|
396
398
|
tokens_len=tokens_len
|
|
397
399
|
)
|
|
398
|
-
|
|
399
|
-
if tokens_len > self.max_tokens:
|
|
400
|
+
|
|
401
|
+
if tokens_len > self.max_tokens and tokens_len < 4*self.max_tokens:
|
|
402
|
+
# 打印 big_filter 模式的状态
|
|
403
|
+
self.printer.print_in_terminal(
|
|
404
|
+
"filter_mode_big",
|
|
405
|
+
style="yellow",
|
|
406
|
+
tokens_len=tokens_len
|
|
407
|
+
)
|
|
400
408
|
return self.big_filter(index_items)
|
|
409
|
+
elif tokens_len > 4*self.max_tokens:
|
|
410
|
+
# 打印 super_big_filter 模式的状态
|
|
411
|
+
self.printer.print_in_terminal(
|
|
412
|
+
"filter_mode_super_big",
|
|
413
|
+
style="yellow",
|
|
414
|
+
tokens_len=tokens_len
|
|
415
|
+
)
|
|
416
|
+
round1 = self.super_big_filter(index_items)
|
|
417
|
+
round1_index_items = []
|
|
418
|
+
for file_path in round1.files.keys():
|
|
419
|
+
for index_item in index_items:
|
|
420
|
+
if index_item.module_name == file_path:
|
|
421
|
+
round1_index_items.append(index_item)
|
|
422
|
+
|
|
423
|
+
if round1_index_items:
|
|
424
|
+
round2 = self.big_filter(round1_index_items)
|
|
425
|
+
return round2
|
|
426
|
+
return round1
|
|
427
|
+
else:
|
|
428
|
+
# 打印普通过滤模式的状态
|
|
429
|
+
self.printer.print_in_terminal(
|
|
430
|
+
"filter_mode_normal",
|
|
431
|
+
style="blue"
|
|
432
|
+
)
|
|
401
433
|
|
|
402
434
|
try:
|
|
403
435
|
# 获取模型名称
|
|
@@ -520,3 +552,341 @@ class QuickFilter():
|
|
|
520
552
|
has_error=False,
|
|
521
553
|
file_positions=final_file_positions
|
|
522
554
|
)
|
|
555
|
+
|
|
556
|
+
def super_big_filter(self, index_items: List[IndexItem]) -> QuickFilterResult:
|
|
557
|
+
"""
|
|
558
|
+
超大索引过滤方法,通过提取文件的核心信息(文件名和用途)来减少token数量
|
|
559
|
+
可处理超大规模的索引文件,通过切分成多个chunks并行处理
|
|
560
|
+
"""
|
|
561
|
+
compact_items = []
|
|
562
|
+
|
|
563
|
+
# 将每个索引项转换为更紧凑的格式:只保留文件名和用途
|
|
564
|
+
for index, item in enumerate(index_items):
|
|
565
|
+
# 从module_name中提取文件名
|
|
566
|
+
filename = os.path.basename(item.module_name)
|
|
567
|
+
|
|
568
|
+
# 从symbols中提取用途
|
|
569
|
+
symbols_info = extract_symbols(item.symbols)
|
|
570
|
+
usage = symbols_info.usage if symbols_info.usage else "无用途描述"
|
|
571
|
+
|
|
572
|
+
# 创建紧凑的表示
|
|
573
|
+
compact_item = {
|
|
574
|
+
"index": index,
|
|
575
|
+
"filename": filename,
|
|
576
|
+
"full_path": item.module_name,
|
|
577
|
+
"usage": usage
|
|
578
|
+
}
|
|
579
|
+
compact_items.append(compact_item)
|
|
580
|
+
|
|
581
|
+
# 切分compact_items成多个chunks
|
|
582
|
+
chunks = []
|
|
583
|
+
current_chunk = []
|
|
584
|
+
batch_size = 100 # 每100条记录检查一次token数量
|
|
585
|
+
|
|
586
|
+
# 计算总的tokens长度
|
|
587
|
+
full_prompt = self.super_big_quick_filter_files.prompt(compact_items, self.args.query)
|
|
588
|
+
tokens_len = count_tokens(full_prompt)
|
|
589
|
+
|
|
590
|
+
# 如果tokens长度不超过max_tokens,直接处理整个列表
|
|
591
|
+
if tokens_len <= self.max_tokens:
|
|
592
|
+
return self._process_compact_items(compact_items, index_items)
|
|
593
|
+
|
|
594
|
+
# 否则,将compact_items切分成多个chunks,每100条检查一次
|
|
595
|
+
for i, item in enumerate(compact_items):
|
|
596
|
+
current_chunk.append(item)
|
|
597
|
+
|
|
598
|
+
# 每处理batch_size条记录或者到达末尾时检查一次
|
|
599
|
+
if (i + 1) % batch_size == 0 or i == len(compact_items) - 1:
|
|
600
|
+
temp_prompt = self.super_big_quick_filter_files.prompt(current_chunk, self.args.query)
|
|
601
|
+
temp_size = count_tokens(temp_prompt)
|
|
602
|
+
|
|
603
|
+
# 如果当前chunk的token数超过限制,则从当前位置分割
|
|
604
|
+
if temp_size > self.max_tokens:
|
|
605
|
+
# 如果当前chunk为空,添加至少一项
|
|
606
|
+
if len(current_chunk) <= batch_size:
|
|
607
|
+
# 当前批次是第一批,但已经超过限制,则至少保留一半的记录
|
|
608
|
+
split_index = max(1, len(current_chunk) // 2)
|
|
609
|
+
chunks.append(current_chunk[:split_index])
|
|
610
|
+
current_chunk = current_chunk[split_index:]
|
|
611
|
+
else:
|
|
612
|
+
# 将前一批次的items作为一个chunk
|
|
613
|
+
prev_batch_end = len(current_chunk) - (i % batch_size + 1)
|
|
614
|
+
if prev_batch_end > 0:
|
|
615
|
+
chunks.append(current_chunk[:prev_batch_end])
|
|
616
|
+
current_chunk = current_chunk[prev_batch_end:]
|
|
617
|
+
else:
|
|
618
|
+
# 极端情况:即使一条记录也超过了限制,则尝试添加单条记录
|
|
619
|
+
chunks.append([current_chunk[0]])
|
|
620
|
+
current_chunk = current_chunk[1:]
|
|
621
|
+
|
|
622
|
+
# 确保最后的chunk也被添加
|
|
623
|
+
if current_chunk:
|
|
624
|
+
chunks.append(current_chunk)
|
|
625
|
+
|
|
626
|
+
# 打印切分信息
|
|
627
|
+
self.printer.print_in_terminal(
|
|
628
|
+
"super_big_filter_splitting",
|
|
629
|
+
style="yellow",
|
|
630
|
+
tokens_len=tokens_len,
|
|
631
|
+
max_tokens=self.max_tokens,
|
|
632
|
+
split_size=len(chunks)
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
# 定义处理单个chunk的函数
|
|
636
|
+
def process_chunk(chunk_index: int, chunk: List[dict]) -> QuickFilterResult:
|
|
637
|
+
# 为避免在所有chunk上都显示UI,只在第一个chunk上显示
|
|
638
|
+
if chunk_index == 0:
|
|
639
|
+
# 显示UI的处理方式
|
|
640
|
+
return self._process_compact_items(chunk, index_items, show_ui=True, chunk_index=chunk_index)
|
|
641
|
+
else:
|
|
642
|
+
# 非UI显示的处理方式
|
|
643
|
+
return self._process_compact_items(chunk, index_items, show_ui=False, chunk_index=chunk_index)
|
|
644
|
+
|
|
645
|
+
# 使用ThreadPoolExecutor并行处理所有chunks
|
|
646
|
+
results: List[QuickFilterResult] = []
|
|
647
|
+
if chunks:
|
|
648
|
+
with ThreadPoolExecutor() as executor:
|
|
649
|
+
futures = [executor.submit(process_chunk, i, chunk) for i, chunk in enumerate(chunks)]
|
|
650
|
+
for future in futures:
|
|
651
|
+
results.append(future.result())
|
|
652
|
+
|
|
653
|
+
# 合并所有结果
|
|
654
|
+
final_files: Dict[str, TargetFile] = {}
|
|
655
|
+
final_file_positions: Dict[str, int] = {}
|
|
656
|
+
has_error = False
|
|
657
|
+
error_messages: List[str] = []
|
|
658
|
+
|
|
659
|
+
# 收集所有文件和错误信息
|
|
660
|
+
for result in results:
|
|
661
|
+
if result.has_error:
|
|
662
|
+
has_error = True
|
|
663
|
+
if result.error_message:
|
|
664
|
+
error_messages.append(result.error_message)
|
|
665
|
+
final_files.update(result.files)
|
|
666
|
+
|
|
667
|
+
# 处理file_positions的交织排序
|
|
668
|
+
max_position = max([max(pos.values()) for pos in [result.file_positions for result in results if result.file_positions]] + [0])
|
|
669
|
+
|
|
670
|
+
# 创建position映射表
|
|
671
|
+
position_map = {}
|
|
672
|
+
for result in results:
|
|
673
|
+
if result.file_positions:
|
|
674
|
+
for file_path, position in result.file_positions.items():
|
|
675
|
+
if position not in position_map:
|
|
676
|
+
position_map[position] = []
|
|
677
|
+
position_map[position].append(file_path)
|
|
678
|
+
|
|
679
|
+
# 重新排序文件路径
|
|
680
|
+
current_index = 0
|
|
681
|
+
for position in range(max_position + 1):
|
|
682
|
+
if position in position_map:
|
|
683
|
+
for file_path in position_map[position]:
|
|
684
|
+
final_file_positions[file_path] = current_index
|
|
685
|
+
current_index += 1
|
|
686
|
+
|
|
687
|
+
return QuickFilterResult(
|
|
688
|
+
files=final_files,
|
|
689
|
+
has_error=has_error,
|
|
690
|
+
error_message="\n".join(error_messages) if error_messages else None,
|
|
691
|
+
file_positions=final_file_positions
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
def _process_compact_items(self, compact_items: List[dict], index_items: List[IndexItem], show_ui: bool = True, chunk_index: int = 0) -> QuickFilterResult:
|
|
695
|
+
"""
|
|
696
|
+
处理一组compact_items,返回QuickFilterResult
|
|
697
|
+
"""
|
|
698
|
+
# 使用流式输出处理
|
|
699
|
+
model_names = get_llm_names(self.index_manager.index_filter_llm)
|
|
700
|
+
model_name = ",".join(model_names)
|
|
701
|
+
|
|
702
|
+
# 获取模型价格信息
|
|
703
|
+
model_info_map = {}
|
|
704
|
+
for name in model_names:
|
|
705
|
+
info = get_model_info(name, self.args.product_mode)
|
|
706
|
+
if info:
|
|
707
|
+
model_info_map[name] = {
|
|
708
|
+
"input_price": info.get("input_price", 0.0),
|
|
709
|
+
"output_price": info.get("output_price", 0.0)
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
try:
|
|
713
|
+
start_time = time.monotonic()
|
|
714
|
+
# 渲染 Prompt 模板
|
|
715
|
+
prompt = self.super_big_quick_filter_files.prompt(compact_items, self.args.query)
|
|
716
|
+
|
|
717
|
+
if show_ui:
|
|
718
|
+
# 使用流式输出处理
|
|
719
|
+
stream_generator = stream_chat_with_continue(
|
|
720
|
+
self.index_manager.index_filter_llm,
|
|
721
|
+
[{"role": "user", "content": prompt}],
|
|
722
|
+
{}
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
def extract_file_number_list(content: str) -> str:
|
|
726
|
+
try:
|
|
727
|
+
v = to_model(content, FileNumberList)
|
|
728
|
+
return "\n".join([index_items[compact_items[file_number]["index"]].module_name for file_number in v.file_list])
|
|
729
|
+
except Exception as e:
|
|
730
|
+
logger.error(f"Error extracting file number list: {e}")
|
|
731
|
+
return content
|
|
732
|
+
|
|
733
|
+
# 获取完整响应
|
|
734
|
+
full_response, last_meta = stream_out(
|
|
735
|
+
stream_generator,
|
|
736
|
+
model_name=model_name,
|
|
737
|
+
title=self.printer.get_message_from_key_with_format(
|
|
738
|
+
"super_big_filter_title", model_name=model_name),
|
|
739
|
+
args=self.args,
|
|
740
|
+
display_func=extract_file_number_list
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
# 解析结果
|
|
744
|
+
file_number_list = to_model(full_response, FileNumberList)
|
|
745
|
+
end_time = time.monotonic()
|
|
746
|
+
|
|
747
|
+
# 计算总成本
|
|
748
|
+
total_input_cost = 0.0
|
|
749
|
+
total_output_cost = 0.0
|
|
750
|
+
|
|
751
|
+
for name in model_names:
|
|
752
|
+
info = model_info_map.get(name, {})
|
|
753
|
+
total_input_cost += (last_meta.input_tokens_count *
|
|
754
|
+
info.get("input_price", 0.0)) / 1000000
|
|
755
|
+
total_output_cost += (last_meta.generated_tokens_count *
|
|
756
|
+
info.get("output_price", 0.0)) / 1000000
|
|
757
|
+
|
|
758
|
+
# 四舍五入到4位小数
|
|
759
|
+
total_input_cost = round(total_input_cost, 4)
|
|
760
|
+
total_output_cost = round(total_output_cost, 4)
|
|
761
|
+
speed = last_meta.generated_tokens_count / (end_time - start_time)
|
|
762
|
+
|
|
763
|
+
# 打印 token 统计信息和成本
|
|
764
|
+
self.printer.print_in_terminal(
|
|
765
|
+
"super_big_filter_stats",
|
|
766
|
+
style="blue",
|
|
767
|
+
elapsed_time=f"{end_time - start_time:.2f}",
|
|
768
|
+
input_tokens=last_meta.input_tokens_count,
|
|
769
|
+
output_tokens=last_meta.generated_tokens_count,
|
|
770
|
+
input_cost=total_input_cost,
|
|
771
|
+
output_cost=total_output_cost,
|
|
772
|
+
model_names=model_name,
|
|
773
|
+
speed=f"{speed:.2f}",
|
|
774
|
+
chunk_index=chunk_index
|
|
775
|
+
)
|
|
776
|
+
else:
|
|
777
|
+
# 非UI模式,直接使用LLM处理
|
|
778
|
+
meta_holder = MetaHolder()
|
|
779
|
+
file_number_list = self.super_big_quick_filter_files.with_llm(self.index_manager.index_filter_llm).with_meta(
|
|
780
|
+
meta_holder).with_return_type(FileNumberList).run(compact_items, self.args.query)
|
|
781
|
+
end_time = time.monotonic()
|
|
782
|
+
|
|
783
|
+
# 打印处理信息
|
|
784
|
+
if meta_holder.get_meta():
|
|
785
|
+
meta_dict = meta_holder.get_meta()
|
|
786
|
+
total_input_cost = meta_dict.get("input_tokens_count", 0) * model_info_map.get(model_name, {}).get("input_price", 0.0) / 1000000
|
|
787
|
+
total_output_cost = meta_dict.get("generated_tokens_count", 0) * model_info_map.get(model_name, {}).get("output_price", 0.0) / 1000000
|
|
788
|
+
|
|
789
|
+
self.printer.print_in_terminal(
|
|
790
|
+
"super_big_filter_stats",
|
|
791
|
+
style="blue",
|
|
792
|
+
input_tokens=meta_dict.get("input_tokens_count", 0),
|
|
793
|
+
output_tokens=meta_dict.get("generated_tokens_count", 0),
|
|
794
|
+
input_cost=total_input_cost,
|
|
795
|
+
output_cost=total_output_cost,
|
|
796
|
+
model_names=model_name,
|
|
797
|
+
elapsed_time=f"{end_time - start_time:.2f}",
|
|
798
|
+
chunk_index=chunk_index
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
# 构建返回结果
|
|
802
|
+
files = {}
|
|
803
|
+
file_positions = {}
|
|
804
|
+
|
|
805
|
+
if file_number_list:
|
|
806
|
+
validated_file_numbers = []
|
|
807
|
+
for file_number in file_number_list.file_list:
|
|
808
|
+
if file_number < 0 or file_number >= len(compact_items):
|
|
809
|
+
self.printer.print_in_terminal(
|
|
810
|
+
"invalid_file_number",
|
|
811
|
+
style="yellow",
|
|
812
|
+
file_number=file_number,
|
|
813
|
+
total_files=len(compact_items)
|
|
814
|
+
)
|
|
815
|
+
continue
|
|
816
|
+
|
|
817
|
+
# 获取实际的index_item索引
|
|
818
|
+
original_index = compact_items[file_number]["index"]
|
|
819
|
+
validated_file_numbers.append(original_index)
|
|
820
|
+
|
|
821
|
+
# 将最终选中的文件加入files
|
|
822
|
+
for index, file_number in enumerate(validated_file_numbers):
|
|
823
|
+
file_path = get_file_path(index_items[file_number].module_name)
|
|
824
|
+
files[file_path] = TargetFile(
|
|
825
|
+
file_path=index_items[file_number].module_name,
|
|
826
|
+
reason=self.printer.get_message_from_key("quick_filter_reason")
|
|
827
|
+
)
|
|
828
|
+
file_positions[file_path] = index
|
|
829
|
+
|
|
830
|
+
return QuickFilterResult(
|
|
831
|
+
files=files,
|
|
832
|
+
has_error=False,
|
|
833
|
+
file_positions=file_positions
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
except Exception as e:
|
|
837
|
+
self.printer.print_in_terminal(
|
|
838
|
+
"super_big_filter_failed",
|
|
839
|
+
style="red",
|
|
840
|
+
error=str(e)
|
|
841
|
+
)
|
|
842
|
+
return QuickFilterResult(
|
|
843
|
+
files={},
|
|
844
|
+
has_error=True,
|
|
845
|
+
error_message=str(e)
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
@byzerllm.prompt()
|
|
849
|
+
def super_big_quick_filter_files(self, compact_items: List[dict], query: str) -> str:
|
|
850
|
+
'''
|
|
851
|
+
当用户提一个需求的时候,我们要找到相关的源码文件。
|
|
852
|
+
|
|
853
|
+
下面是简化的索引文件列表,每项包含文件序号(index,##[]括起来的部分)、文件名(filename)和用途描述(usage):
|
|
854
|
+
|
|
855
|
+
<index>
|
|
856
|
+
{{ file_meta_str }}
|
|
857
|
+
</index>
|
|
858
|
+
|
|
859
|
+
下面是用户的查询需求:
|
|
860
|
+
|
|
861
|
+
<query>
|
|
862
|
+
{{ query }}
|
|
863
|
+
</query>
|
|
864
|
+
|
|
865
|
+
请根据用户的需求,找到相关的文件,并给出文件序号列表。请返回如下json格式:
|
|
866
|
+
|
|
867
|
+
```json
|
|
868
|
+
{
|
|
869
|
+
"file_list": [
|
|
870
|
+
file_index1,
|
|
871
|
+
file_index2,
|
|
872
|
+
...
|
|
873
|
+
]
|
|
874
|
+
}
|
|
875
|
+
```
|
|
876
|
+
|
|
877
|
+
特别注意:
|
|
878
|
+
1. 如果用户的query里有 @文件 或者 @@符号,请匹配对应的文件名,优先返回这些文件。
|
|
879
|
+
2. 根据用户需求找出需要被修改的文件(edited_files),以及可能需要作为参考的文件(reference_files)。
|
|
880
|
+
3. file_list 里的文件序号,按被 @ 的文件、edited_files文件和reference_files文件的顺序排列。
|
|
881
|
+
4. 如果 query 里是一段历史对话,那么对话里提及的文件必须要返回。
|
|
882
|
+
5. 如果用户需求为空,则直接返回空列表即可。
|
|
883
|
+
6. 返回的 json格式数据不允许有注释
|
|
884
|
+
'''
|
|
885
|
+
file_meta_str = "\n".join(
|
|
886
|
+
[f"##[{index}]{item['filename']}\n{item['usage']}" for index, item in enumerate(compact_items)])
|
|
887
|
+
|
|
888
|
+
context = {
|
|
889
|
+
"file_meta_str": file_meta_str,
|
|
890
|
+
"query": query
|
|
891
|
+
}
|
|
892
|
+
return context
|
|
@@ -38,7 +38,7 @@ from pydantic import BaseModel
|
|
|
38
38
|
from byzerllm.utils.types import SingleOutputMeta
|
|
39
39
|
from autocoder.rag.lang import get_message_with_format_and_newline
|
|
40
40
|
from autocoder.rag.qa_conversation_strategy import get_qa_strategy
|
|
41
|
-
|
|
41
|
+
from autocoder.rag.searchable import SearchableResults
|
|
42
42
|
try:
|
|
43
43
|
from autocoder_pro.rag.llm_compute import LLMComputeEngine
|
|
44
44
|
pro_version = version("auto-coder-pro")
|
|
@@ -257,7 +257,7 @@ class LongContextRAG:
|
|
|
257
257
|
请根据提供的文档内容、用户对话历史以及最后一个问题,提取并总结文档中与问题相关的重要信息。
|
|
258
258
|
如果文档中没有相关信息,请回复"该文档中没有与问题相关的信息"。
|
|
259
259
|
提取的信息尽量保持和原文中的一样,并且只输出这些信息。
|
|
260
|
-
"""
|
|
260
|
+
"""
|
|
261
261
|
|
|
262
262
|
def _get_document_retriever_class(self):
|
|
263
263
|
"""Get the document retriever class based on configuration."""
|
|
@@ -500,6 +500,9 @@ class LongContextRAG:
|
|
|
500
500
|
except json.JSONDecodeError:
|
|
501
501
|
pass
|
|
502
502
|
|
|
503
|
+
if not only_contexts and extra_request_params.get("only_contexts", False):
|
|
504
|
+
only_contexts = True
|
|
505
|
+
|
|
503
506
|
logger.info(f"Query: {query} only_contexts: {only_contexts}")
|
|
504
507
|
start_time = time.time()
|
|
505
508
|
|
|
@@ -593,10 +596,19 @@ class LongContextRAG:
|
|
|
593
596
|
)
|
|
594
597
|
|
|
595
598
|
if only_contexts:
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
599
|
+
try:
|
|
600
|
+
searcher = SearchableResults()
|
|
601
|
+
result = searcher.reorder(docs=relevant_docs)
|
|
602
|
+
yield (json.dumps(result.model_dump(), ensure_ascii=False), SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
|
|
603
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
|
|
604
|
+
rag_stat.chunk_stat.total_generated_tokens,
|
|
605
|
+
))
|
|
606
|
+
except Exception as e:
|
|
607
|
+
yield (str(e), SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
|
|
608
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
|
|
609
|
+
rag_stat.chunk_stat.total_generated_tokens,
|
|
610
|
+
))
|
|
611
|
+
return
|
|
600
612
|
|
|
601
613
|
if not relevant_docs:
|
|
602
614
|
yield ("没有找到可以回答你问题的相关文档", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
|
|
@@ -816,12 +828,13 @@ class LongContextRAG:
|
|
|
816
828
|
|
|
817
829
|
self._print_rag_stats(rag_stat)
|
|
818
830
|
else:
|
|
819
|
-
|
|
820
|
-
qa_strategy = get_qa_strategy(
|
|
831
|
+
|
|
832
|
+
qa_strategy = get_qa_strategy(
|
|
833
|
+
self.args.rag_qa_conversation_strategy)
|
|
821
834
|
new_conversations = qa_strategy.create_conversation(
|
|
822
835
|
documents=[doc.source_code for doc in relevant_docs],
|
|
823
836
|
conversations=conversations
|
|
824
|
-
)
|
|
837
|
+
)
|
|
825
838
|
|
|
826
839
|
chunks = target_llm.stream_chat_oai(
|
|
827
840
|
conversations=new_conversations,
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from typing import Dict, List, Any, Optional, Tuple, Set
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
from autocoder.rag.relevant_utils import FilterDoc
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FileOccurrence(BaseModel):
|
|
9
|
+
"""Represents a file and its occurrence count in search results"""
|
|
10
|
+
file_path: str
|
|
11
|
+
count: int
|
|
12
|
+
score: float = 0.0 # Optional relevance score
|
|
13
|
+
|
|
14
|
+
class FileResult(BaseModel):
|
|
15
|
+
files: List[FileOccurrence]
|
|
16
|
+
|
|
17
|
+
class SearchableResults:
|
|
18
|
+
"""Class to process and organize search results by file frequency"""
|
|
19
|
+
|
|
20
|
+
def __init__(self):
|
|
21
|
+
"""Initialize the SearchableResults instance"""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def extract_original_docs(self, docs: List[FilterDoc]) -> List[str]:
|
|
25
|
+
"""Extract all original_docs from a list of document metadata"""
|
|
26
|
+
all_files = []
|
|
27
|
+
|
|
28
|
+
for doc in docs:
|
|
29
|
+
# Extract from metadata if available
|
|
30
|
+
metadata = doc.source_code.metadata
|
|
31
|
+
if "original_docs" in metadata:
|
|
32
|
+
all_files.extend(metadata["original_docs"])
|
|
33
|
+
# Also include the module_name from source_code as a fallback
|
|
34
|
+
else:
|
|
35
|
+
all_files.append(doc.source_code.module_name)
|
|
36
|
+
|
|
37
|
+
return all_files
|
|
38
|
+
|
|
39
|
+
def count_file_occurrences(self, files: List[str]) -> List[FileOccurrence]:
|
|
40
|
+
"""Count occurrences of each file and return sorted list"""
|
|
41
|
+
# Count occurrences
|
|
42
|
+
counter = Counter(files)
|
|
43
|
+
|
|
44
|
+
# Convert to FileOccurrence objects
|
|
45
|
+
occurrences = [
|
|
46
|
+
FileOccurrence(file_path=file_path, count=count)
|
|
47
|
+
for file_path, count in counter.items()
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
# Sort by count (descending)
|
|
51
|
+
return sorted(occurrences, key=lambda x: x.count, reverse=True)
|
|
52
|
+
|
|
53
|
+
def reorder(self, docs: List[FilterDoc]) -> List[FileOccurrence]:
|
|
54
|
+
"""Process search results to extract and rank files by occurrence (main entry point)"""
|
|
55
|
+
all_files = self.extract_original_docs(docs)
|
|
56
|
+
return FileResult(files=self.count_file_occurrences(all_files))
|
|
57
|
+
|
|
58
|
+
|
autocoder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.289"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|