auto-coder 0.1.374__py3-none-any.whl → 0.1.376__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.374.dist-info → auto_coder-0.1.376.dist-info}/METADATA +2 -2
- {auto_coder-0.1.374.dist-info → auto_coder-0.1.376.dist-info}/RECORD +27 -57
- autocoder/agent/base_agentic/base_agent.py +202 -52
- autocoder/agent/base_agentic/default_tools.py +38 -6
- autocoder/agent/base_agentic/tools/list_files_tool_resolver.py +83 -43
- autocoder/agent/base_agentic/tools/read_file_tool_resolver.py +88 -25
- autocoder/agent/base_agentic/tools/replace_in_file_tool_resolver.py +171 -62
- autocoder/agent/base_agentic/tools/search_files_tool_resolver.py +101 -56
- autocoder/agent/base_agentic/tools/talk_to_group_tool_resolver.py +5 -0
- autocoder/agent/base_agentic/tools/talk_to_tool_resolver.py +5 -0
- autocoder/agent/base_agentic/tools/write_to_file_tool_resolver.py +145 -32
- autocoder/auto_coder_rag.py +80 -11
- autocoder/models.py +2 -2
- autocoder/rag/agentic_rag.py +217 -0
- autocoder/rag/cache/local_duckdb_storage_cache.py +63 -33
- autocoder/rag/conversation_to_queries.py +37 -5
- autocoder/rag/long_context_rag.py +161 -41
- autocoder/rag/tools/__init__.py +10 -0
- autocoder/rag/tools/recall_tool.py +163 -0
- autocoder/rag/tools/search_tool.py +126 -0
- autocoder/rag/types.py +36 -0
- autocoder/utils/_markitdown.py +59 -13
- autocoder/version.py +1 -1
- autocoder/agent/agentic_edit.py +0 -833
- autocoder/agent/agentic_edit_tools/__init__.py +0 -28
- autocoder/agent/agentic_edit_tools/ask_followup_question_tool_resolver.py +0 -32
- autocoder/agent/agentic_edit_tools/attempt_completion_tool_resolver.py +0 -29
- autocoder/agent/agentic_edit_tools/base_tool_resolver.py +0 -29
- autocoder/agent/agentic_edit_tools/execute_command_tool_resolver.py +0 -84
- autocoder/agent/agentic_edit_tools/list_code_definition_names_tool_resolver.py +0 -75
- autocoder/agent/agentic_edit_tools/list_files_tool_resolver.py +0 -62
- autocoder/agent/agentic_edit_tools/plan_mode_respond_tool_resolver.py +0 -30
- autocoder/agent/agentic_edit_tools/read_file_tool_resolver.py +0 -36
- autocoder/agent/agentic_edit_tools/replace_in_file_tool_resolver.py +0 -95
- autocoder/agent/agentic_edit_tools/search_files_tool_resolver.py +0 -70
- autocoder/agent/agentic_edit_tools/use_mcp_tool_resolver.py +0 -55
- autocoder/agent/agentic_edit_tools/write_to_file_tool_resolver.py +0 -98
- autocoder/agent/agentic_edit_types.py +0 -124
- autocoder/auto_coder_lang.py +0 -60
- autocoder/auto_coder_rag_client_mcp.py +0 -170
- autocoder/auto_coder_rag_mcp.py +0 -193
- autocoder/common/llm_rerank.py +0 -84
- autocoder/common/model_speed_test.py +0 -392
- autocoder/common/v2/agent/agentic_edit_conversation.py +0 -188
- autocoder/common/v2/agent/ignore_utils.py +0 -50
- autocoder/dispacher/actions/plugins/action_translate.py +0 -214
- autocoder/ignorefiles/__init__.py +0 -4
- autocoder/ignorefiles/ignore_file_utils.py +0 -63
- autocoder/ignorefiles/test_ignore_file_utils.py +0 -91
- autocoder/linters/code_linter.py +0 -588
- autocoder/rag/loaders/test_image_loader.py +0 -209
- autocoder/rag/raw_rag.py +0 -96
- autocoder/rag/simple_directory_reader.py +0 -646
- autocoder/rag/simple_rag.py +0 -404
- autocoder/regex_project/__init__.py +0 -162
- autocoder/utils/coder.py +0 -125
- autocoder/utils/tests.py +0 -37
- {auto_coder-0.1.374.dist-info → auto_coder-0.1.376.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.374.dist-info → auto_coder-0.1.376.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.374.dist-info → auto_coder-0.1.376.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.374.dist-info → auto_coder-0.1.376.dist-info}/top_level.txt +0 -0
|
@@ -29,6 +29,9 @@ from autocoder.rag.searchable import SearchableResults
|
|
|
29
29
|
from autocoder.rag.conversation_to_queries import extract_search_queries
|
|
30
30
|
from autocoder.common import openai_content as OpenAIContentProcessor
|
|
31
31
|
from autocoder.common.save_formatted_log import save_formatted_log
|
|
32
|
+
from autocoder.rag.types import (
|
|
33
|
+
RecallStat,ChunkStat,AnswerStat,OtherStat,RAGStat
|
|
34
|
+
)
|
|
32
35
|
import json, os
|
|
33
36
|
try:
|
|
34
37
|
from autocoder_pro.rag.llm_compute import LLMComputeEngine
|
|
@@ -42,29 +45,6 @@ except ImportError:
|
|
|
42
45
|
LLMComputeEngine = None
|
|
43
46
|
|
|
44
47
|
|
|
45
|
-
class RecallStat(BaseModel):
|
|
46
|
-
total_input_tokens: int
|
|
47
|
-
total_generated_tokens: int
|
|
48
|
-
model_name: str = "unknown"
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class ChunkStat(BaseModel):
|
|
52
|
-
total_input_tokens: int
|
|
53
|
-
total_generated_tokens: int
|
|
54
|
-
model_name: str = "unknown"
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
class AnswerStat(BaseModel):
|
|
58
|
-
total_input_tokens: int
|
|
59
|
-
total_generated_tokens: int
|
|
60
|
-
model_name: str = "unknown"
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class RAGStat(BaseModel):
|
|
64
|
-
recall_stat: RecallStat
|
|
65
|
-
chunk_stat: ChunkStat
|
|
66
|
-
answer_stat: AnswerStat
|
|
67
|
-
|
|
68
48
|
|
|
69
49
|
class LongContextRAG:
|
|
70
50
|
def __init__(
|
|
@@ -690,7 +670,7 @@ class LongContextRAG:
|
|
|
690
670
|
yield gen_item
|
|
691
671
|
|
|
692
672
|
# 打印最终的统计信息
|
|
693
|
-
self._print_rag_stats(rag_stat)
|
|
673
|
+
self._print_rag_stats(rag_stat, conversations)
|
|
694
674
|
return
|
|
695
675
|
|
|
696
676
|
def _process_document_retrieval(self, conversations,
|
|
@@ -716,7 +696,7 @@ class LongContextRAG:
|
|
|
716
696
|
|
|
717
697
|
# 提取查询并检索候选文档
|
|
718
698
|
queries = extract_search_queries(
|
|
719
|
-
conversations=conversations, args=self.args, llm=self.llm, max_queries=self.args.rag_recall_max_queries)
|
|
699
|
+
conversations=conversations, args=self.args, llm=self.llm, max_queries=self.args.rag_recall_max_queries,rag_stat=rag_stat)
|
|
720
700
|
documents = self._retrieve_documents(
|
|
721
701
|
options={"queries": [query] + [query.query for query in queries]})
|
|
722
702
|
|
|
@@ -913,7 +893,7 @@ class LongContextRAG:
|
|
|
913
893
|
rag_stat.answer_stat.total_generated_tokens
|
|
914
894
|
yield chunk
|
|
915
895
|
|
|
916
|
-
def _print_rag_stats(self, rag_stat: RAGStat) -> None:
|
|
896
|
+
def _print_rag_stats(self, rag_stat: RAGStat, conversations: Optional[List[Dict[str, str]]] = None) -> None:
|
|
917
897
|
"""打印RAG执行的详细统计信息"""
|
|
918
898
|
total_input_tokens = (
|
|
919
899
|
rag_stat.recall_stat.total_input_tokens +
|
|
@@ -937,12 +917,46 @@ class LongContextRAG:
|
|
|
937
917
|
rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
|
|
938
918
|
answer_percent = (rag_stat.answer_stat.total_input_tokens +
|
|
939
919
|
rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
920
|
+
|
|
921
|
+
# 计算其他阶段的令牌占比
|
|
922
|
+
other_percents = []
|
|
923
|
+
if total_tokens > 0 and rag_stat.other_stats:
|
|
924
|
+
for other_stat in rag_stat.other_stats:
|
|
925
|
+
other_percent = (other_stat.total_input_tokens +
|
|
926
|
+
other_stat.total_generated_tokens) / total_tokens * 100
|
|
927
|
+
other_percents.append(other_percent)
|
|
928
|
+
|
|
929
|
+
# 计算成本分布百分比
|
|
930
|
+
if rag_stat.cost == 0:
|
|
931
|
+
recall_cost_percent = chunk_cost_percent = answer_cost_percent = 0
|
|
932
|
+
else:
|
|
933
|
+
recall_cost_percent = rag_stat.recall_stat.cost / rag_stat.cost * 100
|
|
934
|
+
chunk_cost_percent = rag_stat.chunk_stat.cost / rag_stat.cost * 100
|
|
935
|
+
answer_cost_percent = rag_stat.answer_stat.cost / rag_stat.cost * 100
|
|
936
|
+
|
|
937
|
+
# 计算其他阶段的成本占比
|
|
938
|
+
other_costs_percent = []
|
|
939
|
+
if rag_stat.cost > 0 and rag_stat.other_stats:
|
|
940
|
+
for other_stat in rag_stat.other_stats:
|
|
941
|
+
other_costs_percent.append(other_stat.cost / rag_stat.cost * 100)
|
|
942
|
+
|
|
943
|
+
## 这里会计算每个阶段的成本
|
|
944
|
+
estimated_cost = self._estimate_token_cost(rag_stat)
|
|
945
|
+
# 构建统计信息字符串
|
|
946
|
+
query_content = ""
|
|
947
|
+
if conversations and len(conversations) > 0:
|
|
948
|
+
query_content = conversations[-1].get("content", "")
|
|
949
|
+
if len(query_content) > 100:
|
|
950
|
+
query_content = query_content[:100] + "..."
|
|
951
|
+
query_content = f"查询内容: {query_content}\n"
|
|
952
|
+
|
|
953
|
+
stats_str = (
|
|
954
|
+
f"=== (RAG 执行统计信息) ===\n"
|
|
955
|
+
f"{query_content}"
|
|
943
956
|
f"总令牌使用: {total_tokens} 令牌\n"
|
|
944
957
|
f" * 输入令牌总数: {total_input_tokens}\n"
|
|
945
958
|
f" * 生成令牌总数: {total_generated_tokens}\n"
|
|
959
|
+
f" * 总成本: {rag_stat.cost:.6f}\n"
|
|
946
960
|
f"\n"
|
|
947
961
|
f"阶段统计:\n"
|
|
948
962
|
f" 1. 文档检索阶段:\n"
|
|
@@ -950,40 +964,146 @@ class LongContextRAG:
|
|
|
950
964
|
f" - 输入令牌: {rag_stat.recall_stat.total_input_tokens}\n"
|
|
951
965
|
f" - 生成令牌: {rag_stat.recall_stat.total_generated_tokens}\n"
|
|
952
966
|
f" - 阶段总计: {rag_stat.recall_stat.total_input_tokens + rag_stat.recall_stat.total_generated_tokens}\n"
|
|
967
|
+
f" - 阶段成本: {rag_stat.recall_stat.cost:.6f}\n"
|
|
953
968
|
f"\n"
|
|
954
969
|
f" 2. 文档分块阶段:\n"
|
|
955
970
|
f" - 模型: {rag_stat.chunk_stat.model_name}\n"
|
|
956
971
|
f" - 输入令牌: {rag_stat.chunk_stat.total_input_tokens}\n"
|
|
957
972
|
f" - 生成令牌: {rag_stat.chunk_stat.total_generated_tokens}\n"
|
|
958
973
|
f" - 阶段总计: {rag_stat.chunk_stat.total_input_tokens + rag_stat.chunk_stat.total_generated_tokens}\n"
|
|
974
|
+
f" - 阶段成本: {rag_stat.chunk_stat.cost:.6f}\n"
|
|
959
975
|
f"\n"
|
|
960
976
|
f" 3. 答案生成阶段:\n"
|
|
961
977
|
f" - 模型: {rag_stat.answer_stat.model_name}\n"
|
|
962
978
|
f" - 输入令牌: {rag_stat.answer_stat.total_input_tokens}\n"
|
|
963
979
|
f" - 生成令牌: {rag_stat.answer_stat.total_generated_tokens}\n"
|
|
964
980
|
f" - 阶段总计: {rag_stat.answer_stat.total_input_tokens + rag_stat.answer_stat.total_generated_tokens}\n"
|
|
981
|
+
f" - 阶段成本: {rag_stat.answer_stat.cost:.6f}\n"
|
|
965
982
|
f"\n"
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
# 如果存在 other_stats,添加其统计信息
|
|
986
|
+
if rag_stat.other_stats:
|
|
987
|
+
for i, other_stat in enumerate(rag_stat.other_stats):
|
|
988
|
+
stats_str += (
|
|
989
|
+
f" {i+4}. 其他阶段 {i+1}:\n"
|
|
990
|
+
f" - 模型: {other_stat.model_name}\n"
|
|
991
|
+
f" - 输入令牌: {other_stat.total_input_tokens}\n"
|
|
992
|
+
f" - 生成令牌: {other_stat.total_generated_tokens}\n"
|
|
993
|
+
f" - 阶段总计: {other_stat.total_input_tokens + other_stat.total_generated_tokens}\n"
|
|
994
|
+
f" - 阶段成本: {other_stat.cost:.6f}\n"
|
|
995
|
+
f"\n"
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
# 添加令牌分布百分比
|
|
999
|
+
stats_str += (
|
|
966
1000
|
f"令牌分布百分比:\n"
|
|
967
1001
|
f" - 文档检索: {recall_percent:.1f}%\n"
|
|
968
1002
|
f" - 文档分块: {chunk_percent:.1f}%\n"
|
|
969
1003
|
f" - 答案生成: {answer_percent:.1f}%\n"
|
|
970
1004
|
)
|
|
1005
|
+
|
|
1006
|
+
# 如果存在 other_stats,添加其令牌占比
|
|
1007
|
+
if rag_stat.other_stats:
|
|
1008
|
+
for i, other_percent in enumerate(other_percents):
|
|
1009
|
+
if other_percent > 0:
|
|
1010
|
+
stats_str += f" - 其他阶段 {i+1}: {other_percent:.1f}%\n"
|
|
1011
|
+
|
|
1012
|
+
# 添加成本分布百分比
|
|
1013
|
+
stats_str += (
|
|
1014
|
+
f"\n"
|
|
1015
|
+
f"成本分布百分比:\n"
|
|
1016
|
+
f" - 文档检索: {recall_cost_percent:.1f}%\n"
|
|
1017
|
+
f" - 文档分块: {chunk_cost_percent:.1f}%\n"
|
|
1018
|
+
f" - 答案生成: {answer_cost_percent:.1f}%\n"
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
# 如果存在 other_stats,添加其成本占比
|
|
1022
|
+
if rag_stat.other_stats:
|
|
1023
|
+
for i, other_cost_percent in enumerate(other_costs_percent):
|
|
1024
|
+
if other_cost_percent > 0:
|
|
1025
|
+
stats_str += f" - 其他阶段 {i+1}: {other_cost_percent:.1f}%\n"
|
|
1026
|
+
|
|
1027
|
+
# 输出统计信息
|
|
1028
|
+
logger.info(stats_str)
|
|
971
1029
|
|
|
972
1030
|
# 记录原始统计数据,以便调试
|
|
973
1031
|
logger.debug(f"RAG Stat 原始数据: {rag_stat}")
|
|
974
1032
|
|
|
975
|
-
|
|
976
|
-
estimated_cost = self._estimate_token_cost(
|
|
977
|
-
total_input_tokens, total_generated_tokens)
|
|
1033
|
+
|
|
978
1034
|
if estimated_cost > 0:
|
|
979
|
-
logger.info(f"估计成本: 约
|
|
1035
|
+
logger.info(f"估计成本: 约 {estimated_cost:.4f} ")
|
|
980
1036
|
|
|
981
|
-
def _estimate_token_cost(self,
|
|
1037
|
+
def _estimate_token_cost(self, rag_stat: RAGStat) -> float:
|
|
982
1038
|
"""估算当前请求的令牌成本(人民币)"""
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
1039
|
+
from autocoder.models import get_model_by_name
|
|
1040
|
+
|
|
1041
|
+
total_cost = 0.0
|
|
1042
|
+
|
|
1043
|
+
# 计算召回阶段成本
|
|
1044
|
+
if rag_stat.recall_stat.model_name != "unknown":
|
|
1045
|
+
try:
|
|
1046
|
+
recall_model = get_model_by_name(rag_stat.recall_stat.model_name)
|
|
1047
|
+
input_cost = recall_model.get("input_price", 0.0) / 1000000
|
|
1048
|
+
output_cost = recall_model.get("output_price", 0.0) / 1000000
|
|
1049
|
+
recall_cost = (rag_stat.recall_stat.total_input_tokens * input_cost) + \
|
|
1050
|
+
(rag_stat.recall_stat.total_generated_tokens * output_cost)
|
|
1051
|
+
total_cost += recall_cost
|
|
1052
|
+
except Exception as e:
|
|
1053
|
+
logger.warning(f"计算召回阶段成本时出错: {str(e)}")
|
|
1054
|
+
recall_cost = 0.0
|
|
1055
|
+
total_cost += recall_cost
|
|
1056
|
+
rag_stat.recall_stat.cost = recall_cost
|
|
1057
|
+
|
|
1058
|
+
# 计算分块阶段成本
|
|
1059
|
+
if rag_stat.chunk_stat.model_name != "unknown":
|
|
1060
|
+
try:
|
|
1061
|
+
chunk_model = get_model_by_name(rag_stat.chunk_stat.model_name)
|
|
1062
|
+
input_cost = chunk_model.get("input_price", 0.0) / 1000000
|
|
1063
|
+
output_cost = chunk_model.get("output_price", 0.0) / 1000000
|
|
1064
|
+
chunk_cost = (rag_stat.chunk_stat.total_input_tokens * input_cost) + \
|
|
1065
|
+
(rag_stat.chunk_stat.total_generated_tokens * output_cost)
|
|
1066
|
+
total_cost += chunk_cost
|
|
1067
|
+
except Exception as e:
|
|
1068
|
+
logger.warning(f"计算分块阶段成本时出错: {str(e)}")
|
|
1069
|
+
# 使用默认值
|
|
1070
|
+
chunk_cost = 0.0
|
|
1071
|
+
total_cost += chunk_cost
|
|
1072
|
+
rag_stat.chunk_stat.cost = chunk_cost
|
|
1073
|
+
|
|
1074
|
+
# 计算答案生成阶段成本
|
|
1075
|
+
if rag_stat.answer_stat.model_name != "unknown":
|
|
1076
|
+
try:
|
|
1077
|
+
answer_model = get_model_by_name(rag_stat.answer_stat.model_name)
|
|
1078
|
+
input_cost = answer_model.get("input_price", 0.0) / 1000000
|
|
1079
|
+
output_cost = answer_model.get("output_price", 0.0) / 1000000
|
|
1080
|
+
answer_cost = (rag_stat.answer_stat.total_input_tokens * input_cost) + \
|
|
1081
|
+
(rag_stat.answer_stat.total_generated_tokens * output_cost)
|
|
1082
|
+
total_cost += answer_cost
|
|
1083
|
+
except Exception as e:
|
|
1084
|
+
logger.warning(f"计算答案生成阶段成本时出错: {str(e)}")
|
|
1085
|
+
# 使用默认值
|
|
1086
|
+
answer_cost = 0.0
|
|
1087
|
+
total_cost += answer_cost
|
|
1088
|
+
rag_stat.answer_stat.cost = answer_cost
|
|
1089
|
+
|
|
1090
|
+
# 计算其他阶段成本(如果存在)
|
|
1091
|
+
for i, other_stat in enumerate(rag_stat.other_stats):
|
|
1092
|
+
if other_stat.model_name != "unknown":
|
|
1093
|
+
try:
|
|
1094
|
+
other_model = get_model_by_name(other_stat.model_name)
|
|
1095
|
+
input_cost = other_model.get("input_price", 0.0) / 1000000
|
|
1096
|
+
output_cost = other_model.get("output_price", 0.0) / 1000000
|
|
1097
|
+
other_cost = (other_stat.total_input_tokens * input_cost) + \
|
|
1098
|
+
(other_stat.total_generated_tokens * output_cost)
|
|
1099
|
+
total_cost += other_cost
|
|
1100
|
+
except Exception as e:
|
|
1101
|
+
logger.warning(f"计算其他阶段 {i+1} 成本时出错: {str(e)}")
|
|
1102
|
+
# 使用默认值
|
|
1103
|
+
other_cost = 0.0
|
|
1104
|
+
total_cost += other_cost
|
|
1105
|
+
rag_stat.other_stats[i].cost = other_cost
|
|
1106
|
+
|
|
1107
|
+
# 将总成本保存到 rag_stat
|
|
1108
|
+
rag_stat.cost = total_cost
|
|
1109
|
+
return total_cost
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# 导出 SearchTool 相关类和函数
|
|
2
|
+
from .search_tool import SearchTool, SearchToolResolver, register_search_tool
|
|
3
|
+
|
|
4
|
+
# 导出 RecallTool 相关类和函数
|
|
5
|
+
from .recall_tool import RecallTool, RecallToolResolver, register_recall_tool
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'SearchTool', 'SearchToolResolver', 'register_search_tool',
|
|
9
|
+
'RecallTool', 'RecallToolResolver', 'register_recall_tool'
|
|
10
|
+
]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RecallTool 模块
|
|
3
|
+
|
|
4
|
+
该模块实现了 RecallTool 和 RecallToolResolver 类,用于在 BaseAgent 框架中
|
|
5
|
+
提供基于 LongContextRAG 的文档内容召回功能。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import traceback
|
|
10
|
+
from typing import Dict, Any, List, Optional
|
|
11
|
+
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from autocoder.agent.base_agentic.types import BaseTool, ToolResult
|
|
15
|
+
from autocoder.agent.base_agentic.tool_registry import ToolRegistry
|
|
16
|
+
from autocoder.agent.base_agentic.tools.base_tool_resolver import BaseToolResolver
|
|
17
|
+
from autocoder.agent.base_agentic.types import ToolDescription, ToolExample
|
|
18
|
+
from autocoder.common import AutoCoderArgs
|
|
19
|
+
from autocoder.rag.long_context_rag import LongContextRAG
|
|
20
|
+
from autocoder.rag.types import RecallStat, ChunkStat, AnswerStat, RAGStat
|
|
21
|
+
from autocoder.rag.relevant_utils import FilterDoc, DocRelevance, DocFilterResult
|
|
22
|
+
from autocoder.common import SourceCode
|
|
23
|
+
from autocoder.rag.relevant_utils import TaskTiming
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RecallTool(BaseTool):
|
|
27
|
+
"""召回工具,用于获取与查询相关的文档内容"""
|
|
28
|
+
query: str # 用户查询
|
|
29
|
+
file_paths: Optional[List[str]] = None # 指定要处理的文件路径列表,如果为空则自动搜索
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RecallToolResolver(BaseToolResolver):
|
|
33
|
+
"""召回工具解析器,实现召回逻辑"""
|
|
34
|
+
def __init__(self, agent, tool, args):
|
|
35
|
+
super().__init__(agent, tool, args)
|
|
36
|
+
self.tool: RecallTool = tool
|
|
37
|
+
|
|
38
|
+
def resolve(self) -> ToolResult:
|
|
39
|
+
"""实现召回工具的解析逻辑"""
|
|
40
|
+
try:
|
|
41
|
+
# 获取参数
|
|
42
|
+
query = self.tool.query
|
|
43
|
+
file_paths = self.tool.file_paths
|
|
44
|
+
rag:LongContextRAG = self.agent.rag
|
|
45
|
+
# 构建对话历史
|
|
46
|
+
conversations = [
|
|
47
|
+
{"role": "user", "content": query}
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
# 创建 RAGStat 对象
|
|
51
|
+
|
|
52
|
+
rag_stat = RAGStat(
|
|
53
|
+
recall_stat=RecallStat(total_input_tokens=0, total_generated_tokens=0),
|
|
54
|
+
chunk_stat=ChunkStat(total_input_tokens=0, total_generated_tokens=0),
|
|
55
|
+
answer_stat=AnswerStat(total_input_tokens=0, total_generated_tokens=0)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# 如果提供了文件路径,则直接使用;否则,执行搜索
|
|
59
|
+
if file_paths:
|
|
60
|
+
|
|
61
|
+
# 创建 FilterDoc 对象
|
|
62
|
+
relevant_docs = []
|
|
63
|
+
for file_path in file_paths:
|
|
64
|
+
try:
|
|
65
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
66
|
+
content = f.read()
|
|
67
|
+
|
|
68
|
+
source_code = SourceCode(
|
|
69
|
+
module_name=file_path,
|
|
70
|
+
source_code=content
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
doc = FilterDoc(
|
|
74
|
+
source_code=source_code,
|
|
75
|
+
relevance=DocRelevance(is_relevant=True, relevant_score=5), # 默认相关性
|
|
76
|
+
task_timing=TaskTiming()
|
|
77
|
+
)
|
|
78
|
+
relevant_docs.append(doc)
|
|
79
|
+
except Exception as e:
|
|
80
|
+
logger.error(f"读取文件 {file_path} 失败: {str(e)}")
|
|
81
|
+
else:
|
|
82
|
+
# 调用文档检索处理
|
|
83
|
+
generator = rag._process_document_retrieval(conversations, query, rag_stat)
|
|
84
|
+
|
|
85
|
+
# 获取检索结果
|
|
86
|
+
relevant_docs = None
|
|
87
|
+
for item in generator:
|
|
88
|
+
if isinstance(item, dict) and "result" in item:
|
|
89
|
+
relevant_docs = item["result"]
|
|
90
|
+
|
|
91
|
+
if not relevant_docs:
|
|
92
|
+
return ToolResult(
|
|
93
|
+
success=False,
|
|
94
|
+
message="未找到相关文档",
|
|
95
|
+
content=[]
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# 调用文档分块处理
|
|
99
|
+
relevant_docs = [doc.source_code for doc in relevant_docs]
|
|
100
|
+
doc_chunking_generator = rag._process_document_chunking(relevant_docs, conversations, rag_stat, 0)
|
|
101
|
+
|
|
102
|
+
# 获取分块结果
|
|
103
|
+
final_relevant_docs = None
|
|
104
|
+
for item in doc_chunking_generator:
|
|
105
|
+
if isinstance(item, dict) and "result" in item:
|
|
106
|
+
final_relevant_docs = item["result"]
|
|
107
|
+
|
|
108
|
+
if not final_relevant_docs:
|
|
109
|
+
return ToolResult(
|
|
110
|
+
success=False,
|
|
111
|
+
message="文档分块处理失败",
|
|
112
|
+
content=[]
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# 格式化结果
|
|
116
|
+
doc_contents = []
|
|
117
|
+
for doc in final_relevant_docs:
|
|
118
|
+
doc_contents.append({
|
|
119
|
+
"path": doc.module_name,
|
|
120
|
+
"content": doc.source_code
|
|
121
|
+
})
|
|
122
|
+
|
|
123
|
+
return ToolResult(
|
|
124
|
+
success=True,
|
|
125
|
+
message=f"成功召回 {len(doc_contents)} 个相关文档片段",
|
|
126
|
+
content=doc_contents
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
import traceback
|
|
131
|
+
return ToolResult(
|
|
132
|
+
success=False,
|
|
133
|
+
message=f"召回工具执行失败: {str(e)}",
|
|
134
|
+
content=traceback.format_exc()
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def register_recall_tool():
|
|
139
|
+
"""注册召回工具"""
|
|
140
|
+
# 准备工具描述
|
|
141
|
+
description = ToolDescription(
|
|
142
|
+
description="召回与查询相关的文档内容",
|
|
143
|
+
parameters="query: 搜索查询\nfile_paths: 指定要处理的文件路径列表(可选)",
|
|
144
|
+
usage="用于根据查询获取相关文档的内容片段"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# 准备工具示例
|
|
148
|
+
example = ToolExample(
|
|
149
|
+
title="召回工具使用示例",
|
|
150
|
+
body="""<recall>
|
|
151
|
+
<query>如何实现文件监控功能</query>
|
|
152
|
+
</recall>"""
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# 注册工具
|
|
156
|
+
ToolRegistry.register_tool(
|
|
157
|
+
tool_tag="recall", # XML标签名
|
|
158
|
+
tool_cls=RecallTool, # 工具类
|
|
159
|
+
resolver_cls=RecallToolResolver, # 解析器类
|
|
160
|
+
description=description, # 工具描述
|
|
161
|
+
example=example, # 工具示例
|
|
162
|
+
use_guideline="此工具用于根据用户查询召回相关文档内容,返回经过分块和重排序的文档片段。适用于需要深入了解特定功能实现细节的场景。" # 使用指南
|
|
163
|
+
)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SearchTool 模块
|
|
3
|
+
|
|
4
|
+
该模块实现了 SearchTool 和 SearchToolResolver 类,用于在 BaseAgent 框架中
|
|
5
|
+
提供基于 LongContextRAG 的文档搜索功能。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import Dict, Any, List, Optional
|
|
10
|
+
|
|
11
|
+
from loguru import logger
|
|
12
|
+
|
|
13
|
+
from autocoder.agent.base_agentic.types import BaseTool, ToolResult
|
|
14
|
+
from autocoder.agent.base_agentic.tool_registry import ToolRegistry
|
|
15
|
+
from autocoder.agent.base_agentic.tools.base_tool_resolver import BaseToolResolver
|
|
16
|
+
from autocoder.agent.base_agentic.types import ToolDescription, ToolExample
|
|
17
|
+
from autocoder.common import AutoCoderArgs
|
|
18
|
+
from autocoder.rag.long_context_rag import LongContextRAG
|
|
19
|
+
from autocoder.rag.types import RecallStat, ChunkStat, AnswerStat, RAGStat
|
|
20
|
+
from autocoder.rag.relevant_utils import FilterDoc, DocRelevance, DocFilterResult
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SearchTool(BaseTool):
|
|
24
|
+
"""搜索工具,用于获取与查询相关的文件列表"""
|
|
25
|
+
query: str # 用户查询
|
|
26
|
+
max_files: Optional[int] = 10 # 最大返回文件数量
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SearchToolResolver(BaseToolResolver):
|
|
30
|
+
"""搜索工具解析器,实现搜索逻辑"""
|
|
31
|
+
def __init__(self, agent, tool, args):
|
|
32
|
+
super().__init__(agent, tool, args)
|
|
33
|
+
self.tool: SearchTool = tool
|
|
34
|
+
|
|
35
|
+
def resolve(self) -> ToolResult:
|
|
36
|
+
"""实现搜索工具的解析逻辑"""
|
|
37
|
+
try:
|
|
38
|
+
# 获取参数
|
|
39
|
+
query = self.tool.query
|
|
40
|
+
max_files = self.tool.max_files
|
|
41
|
+
rag = self.agent.rag
|
|
42
|
+
# 构建对话历史
|
|
43
|
+
conversations = [
|
|
44
|
+
{"role": "user", "content": query}
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
# 创建 RAGStat 对象
|
|
48
|
+
rag_stat = RAGStat(
|
|
49
|
+
recall_stat=RecallStat(total_input_tokens=0, total_generated_tokens=0),
|
|
50
|
+
chunk_stat=ChunkStat(total_input_tokens=0, total_generated_tokens=0),
|
|
51
|
+
answer_stat=AnswerStat(total_input_tokens=0, total_generated_tokens=0)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# 调用文档检索处理
|
|
55
|
+
generator = rag._process_document_retrieval(conversations, query, rag_stat)
|
|
56
|
+
|
|
57
|
+
# 获取最终结果
|
|
58
|
+
result = None
|
|
59
|
+
for item in generator:
|
|
60
|
+
if isinstance(item, dict) and "result" in item:
|
|
61
|
+
result = item["result"]
|
|
62
|
+
|
|
63
|
+
if not result:
|
|
64
|
+
return ToolResult(
|
|
65
|
+
success=False,
|
|
66
|
+
message="未找到相关文档",
|
|
67
|
+
content=[]
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# 格式化结果
|
|
71
|
+
file_list = []
|
|
72
|
+
for doc in result:
|
|
73
|
+
file_list.append({
|
|
74
|
+
"path": doc.source_code.module_name,
|
|
75
|
+
"relevance": doc.relevance.relevant_score if doc.relevance else 0,
|
|
76
|
+
"is_relevant": doc.relevance.is_relevant if doc.relevance else False
|
|
77
|
+
})
|
|
78
|
+
|
|
79
|
+
# 按相关性排序
|
|
80
|
+
file_list.sort(key=lambda x: x["relevance"], reverse=True)
|
|
81
|
+
|
|
82
|
+
# 限制返回数量
|
|
83
|
+
file_list = file_list[:max_files]
|
|
84
|
+
|
|
85
|
+
return ToolResult(
|
|
86
|
+
success=True,
|
|
87
|
+
message=f"成功检索到 {len(file_list)} 个相关文件",
|
|
88
|
+
content=file_list
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
import traceback
|
|
93
|
+
return ToolResult(
|
|
94
|
+
success=False,
|
|
95
|
+
message=f"搜索工具执行失败: {str(e)}",
|
|
96
|
+
content=traceback.format_exc()
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def register_search_tool():
|
|
101
|
+
"""注册搜索工具"""
|
|
102
|
+
# 准备工具描述
|
|
103
|
+
description = ToolDescription(
|
|
104
|
+
description="搜索与查询相关的文件",
|
|
105
|
+
parameters="query: 搜索查询\nmax_files: 最大返回文件数量(可选,默认为10)",
|
|
106
|
+
usage="用于根据查询找到相关的代码文件"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# 准备工具示例
|
|
110
|
+
example = ToolExample(
|
|
111
|
+
title="搜索工具使用示例",
|
|
112
|
+
body="""<search>
|
|
113
|
+
<query>如何实现文件监控功能</query>
|
|
114
|
+
<max_files>5</max_files>
|
|
115
|
+
</search>"""
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# 注册工具
|
|
119
|
+
ToolRegistry.register_tool(
|
|
120
|
+
tool_tag="search", # XML标签名
|
|
121
|
+
tool_cls=SearchTool, # 工具类
|
|
122
|
+
resolver_cls=SearchToolResolver, # 解析器类
|
|
123
|
+
description=description, # 工具描述
|
|
124
|
+
example=example, # 工具示例
|
|
125
|
+
use_guideline="此工具用于根据用户查询搜索相关代码文件,返回文件路径及其相关性分数。适用于需要快速找到与特定功能或概念相关的代码文件的场景。" # 使用指南
|
|
126
|
+
)
|
autocoder/rag/types.py
CHANGED
|
@@ -3,10 +3,46 @@ import os
|
|
|
3
3
|
import json
|
|
4
4
|
import time
|
|
5
5
|
import pydantic
|
|
6
|
+
from pydantic import BaseModel
|
|
6
7
|
from typing import Dict, Any, Optional, List
|
|
7
8
|
import psutil
|
|
8
9
|
import glob
|
|
9
10
|
|
|
11
|
+
class RecallStat(BaseModel):
|
|
12
|
+
total_input_tokens: int
|
|
13
|
+
total_generated_tokens: int
|
|
14
|
+
model_name: str = "unknown"
|
|
15
|
+
cost:float = 0.0
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ChunkStat(BaseModel):
|
|
19
|
+
total_input_tokens: int
|
|
20
|
+
total_generated_tokens: int
|
|
21
|
+
model_name: str = "unknown"
|
|
22
|
+
cost:float = 0.0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnswerStat(BaseModel):
|
|
26
|
+
total_input_tokens: int
|
|
27
|
+
total_generated_tokens: int
|
|
28
|
+
model_name: str = "unknown"
|
|
29
|
+
cost:float = 0.0
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OtherStat(BaseModel):
|
|
33
|
+
total_input_tokens: int = 0
|
|
34
|
+
total_generated_tokens: int = 0
|
|
35
|
+
model_name: str = "unknown"
|
|
36
|
+
cost:float = 0.0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RAGStat(BaseModel):
|
|
40
|
+
recall_stat: RecallStat
|
|
41
|
+
chunk_stat: ChunkStat
|
|
42
|
+
answer_stat: AnswerStat
|
|
43
|
+
other_stats: List[OtherStat] = []
|
|
44
|
+
cost:float = 0.0
|
|
45
|
+
|
|
10
46
|
class RAGServiceInfo(pydantic.BaseModel):
|
|
11
47
|
host: str
|
|
12
48
|
port: int
|