auto-coder 0.1.373__py3-none-any.whl → 0.1.375__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.373.dist-info → auto_coder-0.1.375.dist-info}/METADATA +2 -2
- {auto_coder-0.1.373.dist-info → auto_coder-0.1.375.dist-info}/RECORD +27 -23
- autocoder/agent/base_agentic/base_agent.py +193 -44
- autocoder/agent/base_agentic/default_tools.py +38 -6
- autocoder/agent/base_agentic/tools/list_files_tool_resolver.py +83 -43
- autocoder/agent/base_agentic/tools/read_file_tool_resolver.py +88 -25
- autocoder/agent/base_agentic/tools/replace_in_file_tool_resolver.py +171 -62
- autocoder/agent/base_agentic/tools/search_files_tool_resolver.py +101 -56
- autocoder/agent/base_agentic/tools/talk_to_group_tool_resolver.py +5 -0
- autocoder/agent/base_agentic/tools/talk_to_tool_resolver.py +5 -0
- autocoder/agent/base_agentic/tools/write_to_file_tool_resolver.py +145 -32
- autocoder/auto_coder_rag.py +68 -11
- autocoder/common/v2/agent/agentic_edit_tools/replace_in_file_tool_resolver.py +47 -141
- autocoder/common/v2/agent/agentic_edit_tools/write_to_file_tool_resolver.py +47 -102
- autocoder/index/index.py +1 -1
- autocoder/linters/linter_factory.py +4 -1
- autocoder/linters/normal_linter.py +2 -4
- autocoder/linters/python_linter.py +18 -115
- autocoder/rag/agentic_rag.py +217 -0
- autocoder/rag/tools/__init__.py +10 -0
- autocoder/rag/tools/recall_tool.py +162 -0
- autocoder/rag/tools/search_tool.py +125 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.373.dist-info → auto_coder-0.1.375.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.373.dist-info → auto_coder-0.1.375.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.373.dist-info → auto_coder-0.1.375.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.373.dist-info → auto_coder-0.1.375.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,3 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Module for linting Python code.
|
|
3
|
-
This module provides functionality to analyze Python code for quality and best practices.
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
1
|
import os
|
|
7
2
|
import sys
|
|
8
3
|
import json
|
|
@@ -25,8 +20,7 @@ class PythonLinter(BaseLinter):
|
|
|
25
20
|
Args:
|
|
26
21
|
verbose (bool): Whether to display verbose output.
|
|
27
22
|
"""
|
|
28
|
-
super().__init__(verbose)
|
|
29
|
-
BaseLinter.tt()
|
|
23
|
+
super().__init__(verbose)
|
|
30
24
|
|
|
31
25
|
def get_supported_extensions(self) -> List[str]:
|
|
32
26
|
"""
|
|
@@ -50,8 +44,7 @@ class PythonLinter(BaseLinter):
|
|
|
50
44
|
|
|
51
45
|
# Check for pylint or flake8
|
|
52
46
|
has_pylint = False
|
|
53
|
-
has_flake8 = False
|
|
54
|
-
has_black = False
|
|
47
|
+
has_flake8 = False
|
|
55
48
|
|
|
56
49
|
try:
|
|
57
50
|
subprocess.run([sys.executable, "-m", "pylint", "--version"],
|
|
@@ -67,18 +60,10 @@ class PythonLinter(BaseLinter):
|
|
|
67
60
|
has_flake8 = True
|
|
68
61
|
except (subprocess.SubprocessError, FileNotFoundError):
|
|
69
62
|
if self.verbose:
|
|
70
|
-
print("Flake8 not found.")
|
|
71
|
-
|
|
72
|
-
try:
|
|
73
|
-
subprocess.run([sys.executable, "-m", "black", "--version"],
|
|
74
|
-
check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
75
|
-
has_black = True
|
|
76
|
-
except (subprocess.SubprocessError, FileNotFoundError):
|
|
77
|
-
if self.verbose:
|
|
78
|
-
print("Black not found.")
|
|
63
|
+
print("Flake8 not found.")
|
|
79
64
|
|
|
80
65
|
# Need at least one linter
|
|
81
|
-
return has_pylint or has_flake8
|
|
66
|
+
return has_pylint or has_flake8
|
|
82
67
|
|
|
83
68
|
except (subprocess.SubprocessError, FileNotFoundError):
|
|
84
69
|
return False
|
|
@@ -173,12 +158,15 @@ class PythonLinter(BaseLinter):
|
|
|
173
158
|
|
|
174
159
|
# Process pylint issues
|
|
175
160
|
for message in pylint_output:
|
|
176
|
-
severity = "
|
|
161
|
+
severity = "info"
|
|
177
162
|
if message.get('type') in ['error', 'fatal']:
|
|
178
163
|
severity = "error"
|
|
179
164
|
result['error_count'] += 1
|
|
180
|
-
|
|
165
|
+
elif message.get('type') in ['warning']:
|
|
166
|
+
severity = "warning"
|
|
181
167
|
result['warning_count'] += 1
|
|
168
|
+
else:
|
|
169
|
+
result['info_count'] += 1
|
|
182
170
|
|
|
183
171
|
issue = {
|
|
184
172
|
'file': message.get('path', ''),
|
|
@@ -265,13 +253,16 @@ class PythonLinter(BaseLinter):
|
|
|
265
253
|
message = code_message
|
|
266
254
|
|
|
267
255
|
# Determine severity based on error code
|
|
268
|
-
severity = "
|
|
256
|
+
severity = "info"
|
|
269
257
|
# E errors are generally more serious than F warnings
|
|
270
|
-
if code.startswith('
|
|
258
|
+
if code.startswith('F'):
|
|
271
259
|
severity = "error"
|
|
272
260
|
result['error_count'] += 1
|
|
273
|
-
|
|
261
|
+
elif code.startswith('E'):
|
|
262
|
+
severity = "warning"
|
|
274
263
|
result['warning_count'] += 1
|
|
264
|
+
else:
|
|
265
|
+
result['info_count'] += 1
|
|
275
266
|
|
|
276
267
|
issue = {
|
|
277
268
|
'file': file_path,
|
|
@@ -294,76 +285,7 @@ class PythonLinter(BaseLinter):
|
|
|
294
285
|
|
|
295
286
|
return result
|
|
296
287
|
|
|
297
|
-
|
|
298
|
-
"""
|
|
299
|
-
Run black on the target file or directory to check formatting or fix it.
|
|
300
|
-
|
|
301
|
-
Args:
|
|
302
|
-
target (str): Path to the file or directory to format.
|
|
303
|
-
fix (bool): Whether to automatically fix formatting issues.
|
|
304
|
-
|
|
305
|
-
Returns:
|
|
306
|
-
Dict[str, Any]: The black results.
|
|
307
|
-
"""
|
|
308
|
-
result = {
|
|
309
|
-
'error_count': 0,
|
|
310
|
-
'warning_count': 0,
|
|
311
|
-
'issues': []
|
|
312
|
-
}
|
|
313
|
-
|
|
314
|
-
try:
|
|
315
|
-
# Build command
|
|
316
|
-
cmd = [
|
|
317
|
-
sys.executable,
|
|
318
|
-
"-m",
|
|
319
|
-
"black",
|
|
320
|
-
]
|
|
321
|
-
|
|
322
|
-
# Check-only mode if not fixing
|
|
323
|
-
if not fix:
|
|
324
|
-
cmd.append("--check")
|
|
325
|
-
|
|
326
|
-
# Add target
|
|
327
|
-
cmd.append(target)
|
|
328
|
-
|
|
329
|
-
process = subprocess.run(
|
|
330
|
-
cmd,
|
|
331
|
-
stdout=subprocess.PIPE,
|
|
332
|
-
stderr=subprocess.PIPE,
|
|
333
|
-
text=True
|
|
334
|
-
)
|
|
335
|
-
|
|
336
|
-
# Black exit code is 0 if no changes, 1 if changes were needed
|
|
337
|
-
if process.returncode == 1 and not fix:
|
|
338
|
-
# Parse output to find which files would be reformatted
|
|
339
|
-
for line in process.stdout.splitlines():
|
|
340
|
-
if line.startswith("would reformat"):
|
|
341
|
-
file_path = line.replace("would reformat ", "").strip()
|
|
342
|
-
|
|
343
|
-
result['warning_count'] += 1
|
|
344
|
-
|
|
345
|
-
issue = {
|
|
346
|
-
'file': file_path,
|
|
347
|
-
'line': 0, # Black doesn't provide line numbers
|
|
348
|
-
'column': 0,
|
|
349
|
-
'severity': "warning",
|
|
350
|
-
'message': "Code formatting doesn't match Black style",
|
|
351
|
-
'rule': "formatting",
|
|
352
|
-
'tool': 'black'
|
|
353
|
-
}
|
|
354
|
-
|
|
355
|
-
result['issues'].append(issue)
|
|
356
|
-
|
|
357
|
-
# If auto-fixing and Black reports changes
|
|
358
|
-
if fix and process.returncode == 0 and "reformatted" in process.stderr:
|
|
359
|
-
# This is good - it means Black fixed some issues
|
|
360
|
-
pass
|
|
361
|
-
|
|
362
|
-
except Exception as e:
|
|
363
|
-
if self.verbose:
|
|
364
|
-
print(f"Error running black: {str(e)}")
|
|
365
|
-
|
|
366
|
-
return result
|
|
288
|
+
|
|
367
289
|
|
|
368
290
|
def lint_file(self, file_path: str, fix: bool = False) -> Dict[str, Any]:
|
|
369
291
|
"""
|
|
@@ -400,17 +322,7 @@ class PythonLinter(BaseLinter):
|
|
|
400
322
|
# Try to install dependencies
|
|
401
323
|
if not self._install_dependencies_if_needed():
|
|
402
324
|
result['error'] = "Required dependencies are not installed and could not be installed automatically"
|
|
403
|
-
return result
|
|
404
|
-
|
|
405
|
-
# Run black first (to format if fix=True)
|
|
406
|
-
try:
|
|
407
|
-
black_result = self._run_black(file_path, fix)
|
|
408
|
-
result['issues'].extend(black_result['issues'])
|
|
409
|
-
result['error_count'] += black_result['error_count']
|
|
410
|
-
result['warning_count'] += black_result['warning_count']
|
|
411
|
-
except Exception as e:
|
|
412
|
-
if self.verbose:
|
|
413
|
-
print(f"Error running black: {str(e)}")
|
|
325
|
+
return result
|
|
414
326
|
|
|
415
327
|
# Run pylint
|
|
416
328
|
try:
|
|
@@ -477,16 +389,7 @@ class PythonLinter(BaseLinter):
|
|
|
477
389
|
python_files.append(os.path.join(root, file))
|
|
478
390
|
|
|
479
391
|
result['files_analyzed'] = len(python_files)
|
|
480
|
-
|
|
481
|
-
# Run black first (to format if fix=True)
|
|
482
|
-
try:
|
|
483
|
-
black_result = self._run_black(project_path, fix)
|
|
484
|
-
result['issues'].extend(black_result['issues'])
|
|
485
|
-
result['error_count'] += black_result['error_count']
|
|
486
|
-
result['warning_count'] += black_result['warning_count']
|
|
487
|
-
except Exception as e:
|
|
488
|
-
if self.verbose:
|
|
489
|
-
print(f"Error running black: {str(e)}")
|
|
392
|
+
|
|
490
393
|
|
|
491
394
|
# Run pylint
|
|
492
395
|
try:
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import pathspec
|
|
7
|
+
from byzerllm import ByzerLLM
|
|
8
|
+
import byzerllm
|
|
9
|
+
from loguru import logger
|
|
10
|
+
import traceback
|
|
11
|
+
|
|
12
|
+
from autocoder.common import AutoCoderArgs, SourceCode
|
|
13
|
+
from importlib.metadata import version
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
from autocoder.common import openai_content as OpenAIContentProcessor
|
|
16
|
+
from autocoder.rag.long_context_rag import LongContextRAG
|
|
17
|
+
import json, os
|
|
18
|
+
from autocoder.agent.base_agentic.base_agent import BaseAgent
|
|
19
|
+
from autocoder.agent.base_agentic.types import AgentRequest
|
|
20
|
+
from autocoder.common import SourceCodeList
|
|
21
|
+
from autocoder.rag.tools import register_search_tool, register_recall_tool
|
|
22
|
+
from byzerllm.utils.types import SingleOutputMeta
|
|
23
|
+
from autocoder.utils.llms import get_single_llm
|
|
24
|
+
try:
|
|
25
|
+
from autocoder_pro.rag.llm_compute import LLMComputeEngine
|
|
26
|
+
pro_version = version("auto-coder-pro")
|
|
27
|
+
autocoder_version = version("auto-coder")
|
|
28
|
+
logger.warning(
|
|
29
|
+
f"auto-coder-pro({pro_version}) plugin is enabled in auto-coder.rag({autocoder_version})")
|
|
30
|
+
except ImportError:
|
|
31
|
+
logger.warning(
|
|
32
|
+
"Please install auto-coder-pro to enhance llm compute ability")
|
|
33
|
+
LLMComputeEngine = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RAGAgent(BaseAgent):
|
|
37
|
+
def __init__(self, name: str,
|
|
38
|
+
llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM],
|
|
39
|
+
files: SourceCodeList,
|
|
40
|
+
args: AutoCoderArgs,
|
|
41
|
+
rag: LongContextRAG,
|
|
42
|
+
conversation_history: Optional[List[Dict[str, Any]]] = None):
|
|
43
|
+
|
|
44
|
+
self.default_llm = self.llm
|
|
45
|
+
self.context_prune_llm = self.default_llm
|
|
46
|
+
if self.default_llm.get_sub_client("context_prune_model"):
|
|
47
|
+
self.context_prune_llm = self.default_llm.get_sub_client("context_prune_model")
|
|
48
|
+
|
|
49
|
+
self.llm = self.default_llm
|
|
50
|
+
if self.default_llm.get_sub_client("agentic_model"):
|
|
51
|
+
self.llm = self.default_llm.get_sub_client("agentic_model")
|
|
52
|
+
|
|
53
|
+
self.rag = rag
|
|
54
|
+
super().__init__(name, self.llm, files, args, conversation_history, default_tools_list=["read_file"])
|
|
55
|
+
# 注册RAG工具
|
|
56
|
+
# register_search_tool()
|
|
57
|
+
register_recall_tool()
|
|
58
|
+
|
|
59
|
+
class AgenticRAG:
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
llm: ByzerLLM,
|
|
63
|
+
args: AutoCoderArgs,
|
|
64
|
+
path: str,
|
|
65
|
+
tokenizer_path: Optional[str] = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
self.llm = llm
|
|
68
|
+
self.args = args
|
|
69
|
+
self.path = path
|
|
70
|
+
self.tokenizer_path = tokenizer_path
|
|
71
|
+
self.rag = LongContextRAG(llm=self.llm, args=self.args, path=self.path, tokenizer_path=self.tokenizer_path)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def build(self):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
def search(self, query: str) -> List[SourceCode]:
|
|
78
|
+
return []
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def stream_chat_oai(
|
|
82
|
+
self,
|
|
83
|
+
conversations,
|
|
84
|
+
model: Optional[str] = None,
|
|
85
|
+
role_mapping=None,
|
|
86
|
+
llm_config: Dict[str, Any] = {},
|
|
87
|
+
extra_request_params: Dict[str, Any] = {}
|
|
88
|
+
):
|
|
89
|
+
try:
|
|
90
|
+
return self._stream_chat_oai(
|
|
91
|
+
conversations,
|
|
92
|
+
model=model,
|
|
93
|
+
role_mapping=role_mapping,
|
|
94
|
+
llm_config=llm_config,
|
|
95
|
+
extra_request_params=extra_request_params
|
|
96
|
+
)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.error(f"Error in stream_chat_oai: {str(e)}")
|
|
99
|
+
traceback.print_exc()
|
|
100
|
+
return ["出现错误,请稍后再试。"], []
|
|
101
|
+
|
|
102
|
+
@byzerllm.prompt()
|
|
103
|
+
def conversation_to_query(self,messages: List[Dict[str, Any]]):
|
|
104
|
+
'''
|
|
105
|
+
【历史对话】按时间顺序排列,从旧到新:
|
|
106
|
+
{% for message in messages %}
|
|
107
|
+
<message>
|
|
108
|
+
{% if message.role == "user" %}【用户】{% else %}【助手】{% endif %}
|
|
109
|
+
<content>
|
|
110
|
+
{{ message.content }}
|
|
111
|
+
</content>
|
|
112
|
+
</message>
|
|
113
|
+
{% endfor %}
|
|
114
|
+
|
|
115
|
+
【当前问题】用户的最新需求如下:
|
|
116
|
+
<current_query>
|
|
117
|
+
{{ query }}
|
|
118
|
+
</current_query>
|
|
119
|
+
'''
|
|
120
|
+
temp_messages = messages[0:-1]
|
|
121
|
+
message = messages[-1]
|
|
122
|
+
|
|
123
|
+
return {
|
|
124
|
+
"messages": temp_messages,
|
|
125
|
+
"query":message["content"]
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def system_prompt(self):
|
|
130
|
+
'''
|
|
131
|
+
你是一个基于知识库的智能助手,我的核心能力是通过检索增强生成(RAG)技术来回答用户问题。
|
|
132
|
+
|
|
133
|
+
你的工作流程如下:
|
|
134
|
+
1. 当用户提出问题时,我会首先理解问题的核心意图和关键信息需求
|
|
135
|
+
2. 你会从多个角度分析问题,确定最佳的检索策略和关键词,然后召回工具 recall 获取与问题最相关的详细内容,只有在特别有必要的情况下,你才回使用 read_file 来获得相关文件更详细的信息。
|
|
136
|
+
5. 如果获得的信息足够回答用户问题,你会直接生成回答。
|
|
137
|
+
6. 如果获得的信息不足以回答用户问题,你会继续使用 recall 工具,直到你确信已经获取了足够的信息来回答用户问题。
|
|
138
|
+
7. 有的问题可能需要拆解成多个问题,分别进行recall,然后最终得到的结果才是完整信息,最后才能进行回答。
|
|
139
|
+
|
|
140
|
+
此外,你回答会遵循以下要求:
|
|
141
|
+
|
|
142
|
+
1. 严格基于召回的文档内容回答
|
|
143
|
+
- 如果召回的文档提供的信息无法回答问题,请明确回复:"抱歉,文档中没有足够的信息来回答这个问题。"
|
|
144
|
+
- 不要添加、推测或扩展文档未提及的信息
|
|
145
|
+
|
|
146
|
+
2. 格式如  的 Markdown 图片处理
|
|
147
|
+
- 根据Markdown 图片前后文本内容推测改图片与问题的相关性,有相关性则在回答中输出该Markdown图片路径
|
|
148
|
+
- 根据相关图片在文档中的位置,自然融入答复内容,保持上下文连贯
|
|
149
|
+
- 完整保留原始图片路径,不省略任何部分
|
|
150
|
+
|
|
151
|
+
3. 回答格式要求
|
|
152
|
+
- 使用markdown格式提升可读性
|
|
153
|
+
{% if local_image_host %}
|
|
154
|
+
4. 图片路径处理
|
|
155
|
+
- 图片地址需返回绝对路径,
|
|
156
|
+
- 对于Windows风格的路径,需要转换为Linux风格, 例如:C:\\Users\\user\\Desktop\\image.png 转换为 C:/Users/user/Desktop/image.png
|
|
157
|
+
- 为请求图片资源 需增加 http://{{ local_image_host }}/static/ 作为前缀
|
|
158
|
+
例如:/path/to/images/image.png, 返回 http://{{ local_image_host }}/static/path/to/images/image.png
|
|
159
|
+
{% endif %}
|
|
160
|
+
'''
|
|
161
|
+
return {
|
|
162
|
+
"local_image_host": self.args.local_image_host
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _stream_chat_oai(
|
|
167
|
+
self,
|
|
168
|
+
conversations,
|
|
169
|
+
model: Optional[str] = None,
|
|
170
|
+
role_mapping=None,
|
|
171
|
+
llm_config: Dict[str, Any] = {},
|
|
172
|
+
extra_request_params: Dict[str, Any] = {}
|
|
173
|
+
):
|
|
174
|
+
if not llm_config:
|
|
175
|
+
llm_config = {}
|
|
176
|
+
|
|
177
|
+
if extra_request_params:
|
|
178
|
+
llm_config.update(extra_request_params)
|
|
179
|
+
|
|
180
|
+
conversations = OpenAIContentProcessor.process_conversations(conversations)
|
|
181
|
+
|
|
182
|
+
context = []
|
|
183
|
+
|
|
184
|
+
def _generate_sream():
|
|
185
|
+
|
|
186
|
+
recall_request = AgentRequest(user_input=self.conversation_to_query.prompt(conversations))
|
|
187
|
+
rag_agent = RAGAgent(
|
|
188
|
+
name="RAGAgent",
|
|
189
|
+
llm=self.llm,
|
|
190
|
+
files=SourceCodeList(sources=[]),
|
|
191
|
+
args=self.args,
|
|
192
|
+
rag=self.rag,
|
|
193
|
+
conversation_history=[]
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
rag_agent.who_am_i(self.system_prompt.prompt())
|
|
197
|
+
|
|
198
|
+
events =rag_agent.run_with_generator(recall_request)
|
|
199
|
+
for (t,content) in events:
|
|
200
|
+
if t == "thinking":
|
|
201
|
+
yield ("", SingleOutputMeta(
|
|
202
|
+
generated_tokens_count=0,
|
|
203
|
+
input_tokens_count=0,
|
|
204
|
+
reasoning_content=content,
|
|
205
|
+
))
|
|
206
|
+
else:
|
|
207
|
+
yield (content, SingleOutputMeta(
|
|
208
|
+
generated_tokens_count=0,
|
|
209
|
+
input_tokens_count=0,
|
|
210
|
+
reasoning_content="",
|
|
211
|
+
))
|
|
212
|
+
|
|
213
|
+
return _generate_sream(), context
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# 导出 SearchTool 相关类和函数
|
|
2
|
+
from .search_tool import SearchTool, SearchToolResolver, register_search_tool
|
|
3
|
+
|
|
4
|
+
# 导出 RecallTool 相关类和函数
|
|
5
|
+
from .recall_tool import RecallTool, RecallToolResolver, register_recall_tool
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'SearchTool', 'SearchToolResolver', 'register_search_tool',
|
|
9
|
+
'RecallTool', 'RecallToolResolver', 'register_recall_tool'
|
|
10
|
+
]
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RecallTool 模块
|
|
3
|
+
|
|
4
|
+
该模块实现了 RecallTool 和 RecallToolResolver 类,用于在 BaseAgent 框架中
|
|
5
|
+
提供基于 LongContextRAG 的文档内容召回功能。
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import traceback
|
|
10
|
+
from typing import Dict, Any, List, Optional
|
|
11
|
+
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from autocoder.agent.base_agentic.types import BaseTool, ToolResult
|
|
15
|
+
from autocoder.agent.base_agentic.tool_registry import ToolRegistry
|
|
16
|
+
from autocoder.agent.base_agentic.tools.base_tool_resolver import BaseToolResolver
|
|
17
|
+
from autocoder.agent.base_agentic.types import ToolDescription, ToolExample
|
|
18
|
+
from autocoder.common import AutoCoderArgs
|
|
19
|
+
from autocoder.rag.long_context_rag import LongContextRAG, RecallStat, ChunkStat, AnswerStat, RAGStat
|
|
20
|
+
from autocoder.rag.relevant_utils import FilterDoc, DocRelevance, DocFilterResult
|
|
21
|
+
from autocoder.common import SourceCode
|
|
22
|
+
from autocoder.rag.relevant_utils import TaskTiming
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class RecallTool(BaseTool):
|
|
26
|
+
"""召回工具,用于获取与查询相关的文档内容"""
|
|
27
|
+
query: str # 用户查询
|
|
28
|
+
file_paths: Optional[List[str]] = None # 指定要处理的文件路径列表,如果为空则自动搜索
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RecallToolResolver(BaseToolResolver):
|
|
32
|
+
"""召回工具解析器,实现召回逻辑"""
|
|
33
|
+
def __init__(self, agent, tool, args):
|
|
34
|
+
super().__init__(agent, tool, args)
|
|
35
|
+
self.tool: RecallTool = tool
|
|
36
|
+
|
|
37
|
+
def resolve(self) -> ToolResult:
|
|
38
|
+
"""实现召回工具的解析逻辑"""
|
|
39
|
+
try:
|
|
40
|
+
# 获取参数
|
|
41
|
+
query = self.tool.query
|
|
42
|
+
file_paths = self.tool.file_paths
|
|
43
|
+
rag:LongContextRAG = self.agent.rag
|
|
44
|
+
# 构建对话历史
|
|
45
|
+
conversations = [
|
|
46
|
+
{"role": "user", "content": query}
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
# 创建 RAGStat 对象
|
|
50
|
+
|
|
51
|
+
rag_stat = RAGStat(
|
|
52
|
+
recall_stat=RecallStat(total_input_tokens=0, total_generated_tokens=0),
|
|
53
|
+
chunk_stat=ChunkStat(total_input_tokens=0, total_generated_tokens=0),
|
|
54
|
+
answer_stat=AnswerStat(total_input_tokens=0, total_generated_tokens=0)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# 如果提供了文件路径,则直接使用;否则,执行搜索
|
|
58
|
+
if file_paths:
|
|
59
|
+
|
|
60
|
+
# 创建 FilterDoc 对象
|
|
61
|
+
relevant_docs = []
|
|
62
|
+
for file_path in file_paths:
|
|
63
|
+
try:
|
|
64
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
65
|
+
content = f.read()
|
|
66
|
+
|
|
67
|
+
source_code = SourceCode(
|
|
68
|
+
module_name=file_path,
|
|
69
|
+
source_code=content
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
doc = FilterDoc(
|
|
73
|
+
source_code=source_code,
|
|
74
|
+
relevance=DocRelevance(is_relevant=True, relevant_score=5), # 默认相关性
|
|
75
|
+
task_timing=TaskTiming()
|
|
76
|
+
)
|
|
77
|
+
relevant_docs.append(doc)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(f"读取文件 {file_path} 失败: {str(e)}")
|
|
80
|
+
else:
|
|
81
|
+
# 调用文档检索处理
|
|
82
|
+
generator = rag._process_document_retrieval(conversations, query, rag_stat)
|
|
83
|
+
|
|
84
|
+
# 获取检索结果
|
|
85
|
+
relevant_docs = None
|
|
86
|
+
for item in generator:
|
|
87
|
+
if isinstance(item, dict) and "result" in item:
|
|
88
|
+
relevant_docs = item["result"]
|
|
89
|
+
|
|
90
|
+
if not relevant_docs:
|
|
91
|
+
return ToolResult(
|
|
92
|
+
success=False,
|
|
93
|
+
message="未找到相关文档",
|
|
94
|
+
content=[]
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# 调用文档分块处理
|
|
98
|
+
relevant_docs = [doc.source_code for doc in relevant_docs]
|
|
99
|
+
doc_chunking_generator = rag._process_document_chunking(relevant_docs, conversations, rag_stat, 0)
|
|
100
|
+
|
|
101
|
+
# 获取分块结果
|
|
102
|
+
final_relevant_docs = None
|
|
103
|
+
for item in doc_chunking_generator:
|
|
104
|
+
if isinstance(item, dict) and "result" in item:
|
|
105
|
+
final_relevant_docs = item["result"]
|
|
106
|
+
|
|
107
|
+
if not final_relevant_docs:
|
|
108
|
+
return ToolResult(
|
|
109
|
+
success=False,
|
|
110
|
+
message="文档分块处理失败",
|
|
111
|
+
content=[]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# 格式化结果
|
|
115
|
+
doc_contents = []
|
|
116
|
+
for doc in final_relevant_docs:
|
|
117
|
+
doc_contents.append({
|
|
118
|
+
"path": doc.module_name,
|
|
119
|
+
"content": doc.source_code
|
|
120
|
+
})
|
|
121
|
+
|
|
122
|
+
return ToolResult(
|
|
123
|
+
success=True,
|
|
124
|
+
message=f"成功召回 {len(doc_contents)} 个相关文档片段",
|
|
125
|
+
content=doc_contents
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
import traceback
|
|
130
|
+
return ToolResult(
|
|
131
|
+
success=False,
|
|
132
|
+
message=f"召回工具执行失败: {str(e)}",
|
|
133
|
+
content=traceback.format_exc()
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def register_recall_tool():
|
|
138
|
+
"""注册召回工具"""
|
|
139
|
+
# 准备工具描述
|
|
140
|
+
description = ToolDescription(
|
|
141
|
+
description="召回与查询相关的文档内容",
|
|
142
|
+
parameters="query: 搜索查询\nfile_paths: 指定要处理的文件路径列表(可选)",
|
|
143
|
+
usage="用于根据查询获取相关文档的内容片段"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# 准备工具示例
|
|
147
|
+
example = ToolExample(
|
|
148
|
+
title="召回工具使用示例",
|
|
149
|
+
body="""<recall>
|
|
150
|
+
<query>如何实现文件监控功能</query>
|
|
151
|
+
</recall>"""
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# 注册工具
|
|
155
|
+
ToolRegistry.register_tool(
|
|
156
|
+
tool_tag="recall", # XML标签名
|
|
157
|
+
tool_cls=RecallTool, # 工具类
|
|
158
|
+
resolver_cls=RecallToolResolver, # 解析器类
|
|
159
|
+
description=description, # 工具描述
|
|
160
|
+
example=example, # 工具示例
|
|
161
|
+
use_guideline="此工具用于根据用户查询召回相关文档内容,返回经过分块和重排序的文档片段。适用于需要深入了解特定功能实现细节的场景。" # 使用指南
|
|
162
|
+
)
|