jarvis-ai-assistant 0.3.30__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +289 -87
- jarvis/jarvis_agent/agent_manager.py +17 -8
- jarvis/jarvis_agent/edit_file_handler.py +374 -86
- jarvis/jarvis_agent/event_bus.py +1 -1
- jarvis/jarvis_agent/file_context_handler.py +79 -0
- jarvis/jarvis_agent/jarvis.py +601 -43
- jarvis/jarvis_agent/main.py +32 -2
- jarvis/jarvis_agent/rewrite_file_handler.py +141 -0
- jarvis/jarvis_agent/run_loop.py +38 -5
- jarvis/jarvis_agent/share_manager.py +8 -1
- jarvis/jarvis_agent/stdio_redirect.py +295 -0
- jarvis/jarvis_agent/task_analyzer.py +5 -2
- jarvis/jarvis_agent/task_planner.py +496 -0
- jarvis/jarvis_agent/utils.py +5 -1
- jarvis/jarvis_agent/web_bridge.py +189 -0
- jarvis/jarvis_agent/web_output_sink.py +53 -0
- jarvis/jarvis_agent/web_server.py +751 -0
- jarvis/jarvis_c2rust/__init__.py +26 -0
- jarvis/jarvis_c2rust/cli.py +613 -0
- jarvis/jarvis_c2rust/collector.py +258 -0
- jarvis/jarvis_c2rust/library_replacer.py +1122 -0
- jarvis/jarvis_c2rust/llm_module_agent.py +1300 -0
- jarvis/jarvis_c2rust/optimizer.py +960 -0
- jarvis/jarvis_c2rust/scanner.py +1681 -0
- jarvis/jarvis_c2rust/transpiler.py +2325 -0
- jarvis/jarvis_code_agent/build_validation_config.py +133 -0
- jarvis/jarvis_code_agent/code_agent.py +1171 -94
- jarvis/jarvis_code_agent/code_analyzer/__init__.py +62 -0
- jarvis/jarvis_code_agent/code_analyzer/base_language.py +74 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/__init__.py +44 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/base.py +102 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/cmake.py +59 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/detector.py +125 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/fallback.py +69 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/go.py +38 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/java_gradle.py +44 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/java_maven.py +38 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/makefile.py +50 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/nodejs.py +93 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/python.py +129 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/rust.py +54 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator/validator.py +154 -0
- jarvis/jarvis_code_agent/code_analyzer/build_validator.py +43 -0
- jarvis/jarvis_code_agent/code_analyzer/context_manager.py +363 -0
- jarvis/jarvis_code_agent/code_analyzer/context_recommender.py +18 -0
- jarvis/jarvis_code_agent/code_analyzer/dependency_analyzer.py +132 -0
- jarvis/jarvis_code_agent/code_analyzer/file_ignore.py +330 -0
- jarvis/jarvis_code_agent/code_analyzer/impact_analyzer.py +781 -0
- jarvis/jarvis_code_agent/code_analyzer/language_registry.py +185 -0
- jarvis/jarvis_code_agent/code_analyzer/language_support.py +89 -0
- jarvis/jarvis_code_agent/code_analyzer/languages/__init__.py +31 -0
- jarvis/jarvis_code_agent/code_analyzer/languages/c_cpp_language.py +231 -0
- jarvis/jarvis_code_agent/code_analyzer/languages/go_language.py +183 -0
- jarvis/jarvis_code_agent/code_analyzer/languages/python_language.py +219 -0
- jarvis/jarvis_code_agent/code_analyzer/languages/rust_language.py +209 -0
- jarvis/jarvis_code_agent/code_analyzer/llm_context_recommender.py +451 -0
- jarvis/jarvis_code_agent/code_analyzer/symbol_extractor.py +77 -0
- jarvis/jarvis_code_agent/code_analyzer/tree_sitter_extractor.py +48 -0
- jarvis/jarvis_code_agent/lint.py +270 -8
- jarvis/jarvis_code_agent/utils.py +142 -0
- jarvis/jarvis_code_analysis/code_review.py +483 -569
- jarvis/jarvis_data/config_schema.json +97 -8
- jarvis/jarvis_git_utils/git_commiter.py +38 -26
- jarvis/jarvis_mcp/sse_mcp_client.py +2 -2
- jarvis/jarvis_mcp/stdio_mcp_client.py +1 -1
- jarvis/jarvis_memory_organizer/memory_organizer.py +1 -1
- jarvis/jarvis_multi_agent/__init__.py +239 -25
- jarvis/jarvis_multi_agent/main.py +37 -1
- jarvis/jarvis_platform/base.py +103 -51
- jarvis/jarvis_platform/openai.py +26 -1
- jarvis/jarvis_platform/yuanbao.py +1 -1
- jarvis/jarvis_platform_manager/service.py +2 -2
- jarvis/jarvis_rag/cli.py +4 -4
- jarvis/jarvis_sec/__init__.py +3605 -0
- jarvis/jarvis_sec/checkers/__init__.py +32 -0
- jarvis/jarvis_sec/checkers/c_checker.py +2680 -0
- jarvis/jarvis_sec/checkers/rust_checker.py +1108 -0
- jarvis/jarvis_sec/cli.py +116 -0
- jarvis/jarvis_sec/report.py +257 -0
- jarvis/jarvis_sec/status.py +264 -0
- jarvis/jarvis_sec/types.py +20 -0
- jarvis/jarvis_sec/workflow.py +219 -0
- jarvis/jarvis_stats/cli.py +1 -1
- jarvis/jarvis_stats/stats.py +1 -1
- jarvis/jarvis_stats/visualizer.py +1 -1
- jarvis/jarvis_tools/cli/main.py +1 -0
- jarvis/jarvis_tools/execute_script.py +46 -9
- jarvis/jarvis_tools/generate_new_tool.py +3 -1
- jarvis/jarvis_tools/read_code.py +275 -12
- jarvis/jarvis_tools/read_symbols.py +141 -0
- jarvis/jarvis_tools/read_webpage.py +5 -3
- jarvis/jarvis_tools/registry.py +73 -35
- jarvis/jarvis_tools/search_web.py +15 -11
- jarvis/jarvis_tools/sub_agent.py +24 -42
- jarvis/jarvis_tools/sub_code_agent.py +14 -13
- jarvis/jarvis_tools/virtual_tty.py +1 -1
- jarvis/jarvis_utils/config.py +187 -35
- jarvis/jarvis_utils/embedding.py +3 -0
- jarvis/jarvis_utils/git_utils.py +181 -6
- jarvis/jarvis_utils/globals.py +3 -3
- jarvis/jarvis_utils/http.py +1 -1
- jarvis/jarvis_utils/input.py +78 -2
- jarvis/jarvis_utils/methodology.py +25 -19
- jarvis/jarvis_utils/utils.py +644 -359
- {jarvis_ai_assistant-0.3.30.dist-info → jarvis_ai_assistant-0.7.0.dist-info}/METADATA +85 -1
- jarvis_ai_assistant-0.7.0.dist-info/RECORD +192 -0
- {jarvis_ai_assistant-0.3.30.dist-info → jarvis_ai_assistant-0.7.0.dist-info}/entry_points.txt +4 -0
- jarvis/jarvis_agent/config.py +0 -92
- jarvis/jarvis_tools/edit_file.py +0 -179
- jarvis/jarvis_tools/rewrite_file.py +0 -191
- jarvis_ai_assistant-0.3.30.dist-info/RECORD +0 -137
- {jarvis_ai_assistant-0.3.30.dist-info → jarvis_ai_assistant-0.7.0.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.3.30.dist-info → jarvis_ai_assistant-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.3.30.dist-info → jarvis_ai_assistant-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
"""智能上下文推荐器。
|
|
2
|
+
|
|
3
|
+
使用LLM进行语义理解,提供更准确的上下文推荐。
|
|
4
|
+
完全基于LLM实现,不依赖硬编码规则。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import yaml
|
|
11
|
+
from typing import List, Optional, Dict, Any, Set
|
|
12
|
+
|
|
13
|
+
from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
14
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
|
15
|
+
from jarvis.jarvis_utils.config import get_normal_platform_name, get_normal_model_name
|
|
16
|
+
from jarvis.jarvis_code_agent.utils import get_project_overview
|
|
17
|
+
|
|
18
|
+
from .context_recommender import ContextRecommendation
|
|
19
|
+
from .context_manager import ContextManager
|
|
20
|
+
from .symbol_extractor import Symbol
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ContextRecommender:
|
|
24
|
+
"""智能上下文推荐器。
|
|
25
|
+
|
|
26
|
+
使用LLM进行语义理解,根据编辑意图推荐相关的上下文信息。
|
|
27
|
+
完全基于LLM实现,提供语义级别的推荐,而非简单的关键词匹配。
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, context_manager: ContextManager, parent_model: Optional[Any] = None):
|
|
31
|
+
"""初始化上下文推荐器
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
context_manager: 上下文管理器
|
|
35
|
+
parent_model: 父Agent的模型实例,用于获取模型配置(平台名称、模型名称、模型组等)
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: 如果无法创建LLM模型
|
|
39
|
+
"""
|
|
40
|
+
self.context_manager = context_manager
|
|
41
|
+
|
|
42
|
+
# 自己创建LLM模型实例,使用父Agent的配置
|
|
43
|
+
try:
|
|
44
|
+
registry = PlatformRegistry.get_global_platform_registry()
|
|
45
|
+
|
|
46
|
+
# 从父Agent的model获取配置
|
|
47
|
+
platform_name = None
|
|
48
|
+
model_name = None
|
|
49
|
+
model_group = None
|
|
50
|
+
|
|
51
|
+
if parent_model:
|
|
52
|
+
try:
|
|
53
|
+
# 优先获取 model_group,因为它包含了完整的配置信息
|
|
54
|
+
model_group = getattr(parent_model, 'model_group', None)
|
|
55
|
+
platform_name = parent_model.platform_name()
|
|
56
|
+
model_name = parent_model.name()
|
|
57
|
+
except Exception:
|
|
58
|
+
# 如果获取失败,使用默认配置
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
# 优先根据 model_group 获取配置(确保配置一致性)
|
|
62
|
+
# 如果 model_group 存在,强制使用它来解析,避免使用 parent_model 中可能不一致的值
|
|
63
|
+
if model_group:
|
|
64
|
+
try:
|
|
65
|
+
platform_name = get_normal_platform_name(model_group)
|
|
66
|
+
model_name = get_normal_model_name(model_group)
|
|
67
|
+
except Exception:
|
|
68
|
+
# 如果从 model_group 解析失败,回退到从 parent_model 获取的值
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
# 创建平台实例
|
|
72
|
+
if platform_name:
|
|
73
|
+
self.llm_model = registry.create_platform(platform_name)
|
|
74
|
+
if self.llm_model is None:
|
|
75
|
+
# 如果创建失败,使用默认平台
|
|
76
|
+
self.llm_model = registry.get_normal_platform()
|
|
77
|
+
else:
|
|
78
|
+
self.llm_model = registry.get_normal_platform()
|
|
79
|
+
|
|
80
|
+
# 先设置模型组(如果从父Agent获取到),因为 model_group 可能会影响模型名称的解析
|
|
81
|
+
if model_group and self.llm_model:
|
|
82
|
+
try:
|
|
83
|
+
self.llm_model.set_model_group(model_group)
|
|
84
|
+
except Exception:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
# 然后设置模型名称(如果从父Agent或model_group获取到)
|
|
88
|
+
if model_name and self.llm_model:
|
|
89
|
+
try:
|
|
90
|
+
self.llm_model.set_model_name(model_name)
|
|
91
|
+
except Exception:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
# 设置抑制输出,因为这是后台任务
|
|
95
|
+
if self.llm_model:
|
|
96
|
+
self.llm_model.set_suppress_output(True)
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError("无法创建LLM模型实例")
|
|
99
|
+
except Exception as e:
|
|
100
|
+
raise ValueError(f"无法创建LLM模型: {e}")
|
|
101
|
+
|
|
102
|
+
def recommend_context(
|
|
103
|
+
self,
|
|
104
|
+
user_input: str,
|
|
105
|
+
) -> ContextRecommendation:
|
|
106
|
+
"""根据编辑意图推荐上下文
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
user_input: 用户输入/任务描述
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
ContextRecommendation: 推荐的上下文信息
|
|
113
|
+
"""
|
|
114
|
+
# 1. 使用LLM提取关键词(仅提取关键词)
|
|
115
|
+
keywords = self._extract_keywords_with_llm(user_input)
|
|
116
|
+
|
|
117
|
+
# 2. 初始化推荐结果
|
|
118
|
+
recommended_symbols: List[Symbol] = []
|
|
119
|
+
|
|
120
|
+
# 3. 基于关键词进行符号查找和文本查找,然后使用LLM挑选关联度高的条目(主要推荐方式)
|
|
121
|
+
if keywords:
|
|
122
|
+
# 3.1 使用关键词进行符号查找和文本查找,找到所有候选符号及其位置
|
|
123
|
+
candidate_symbols = self._search_symbols_by_keywords(keywords)
|
|
124
|
+
candidate_symbols_from_text = self._search_text_by_keywords(keywords)
|
|
125
|
+
|
|
126
|
+
# 合并候选符号(去重)
|
|
127
|
+
all_candidates = {}
|
|
128
|
+
for symbol in candidate_symbols + candidate_symbols_from_text:
|
|
129
|
+
# 使用 (file_path, name, line_start) 作为唯一键
|
|
130
|
+
key = (symbol.file_path, symbol.name, symbol.line_start)
|
|
131
|
+
if key not in all_candidates:
|
|
132
|
+
all_candidates[key] = symbol
|
|
133
|
+
|
|
134
|
+
candidate_symbols_list = list(all_candidates.values())
|
|
135
|
+
|
|
136
|
+
# 3.2 使用LLM从候选符号中挑选关联度高的条目
|
|
137
|
+
if candidate_symbols_list:
|
|
138
|
+
selected_symbols = self._select_relevant_symbols_with_llm(
|
|
139
|
+
user_input, keywords, candidate_symbols_list
|
|
140
|
+
)
|
|
141
|
+
recommended_symbols.extend(selected_symbols)
|
|
142
|
+
|
|
143
|
+
# 4. 限制符号数量
|
|
144
|
+
final_symbols = recommended_symbols[:10]
|
|
145
|
+
|
|
146
|
+
return ContextRecommendation(
|
|
147
|
+
recommended_symbols=final_symbols,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def _get_project_overview(self) -> str:
|
|
151
|
+
"""获取项目概况信息
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
项目概况字符串
|
|
155
|
+
"""
|
|
156
|
+
return get_project_overview(self.context_manager.project_root)
|
|
157
|
+
|
|
158
|
+
def _extract_keywords_with_llm(self, user_input: str) -> List[str]:
|
|
159
|
+
"""使用LLM提取关键词(仅提取关键词)
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
user_input: 用户输入
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
关键词列表
|
|
166
|
+
"""
|
|
167
|
+
# 获取项目概况
|
|
168
|
+
project_overview = self._get_project_overview()
|
|
169
|
+
|
|
170
|
+
prompt = f"""分析以下代码编辑任务,提取关键词。关键词应该是与任务相关的核心概念、技术术语、功能模块等。
|
|
171
|
+
|
|
172
|
+
{project_overview}
|
|
173
|
+
|
|
174
|
+
任务描述:
|
|
175
|
+
{user_input}
|
|
176
|
+
|
|
177
|
+
请提取5-10个关键词,以YAML数组格式返回,并用<KEYWORDS>标签包裹。
|
|
178
|
+
只返回关键词数组,不要包含其他文字。
|
|
179
|
+
|
|
180
|
+
示例格式:
|
|
181
|
+
<KEYWORDS>
|
|
182
|
+
- data processing
|
|
183
|
+
- validation
|
|
184
|
+
- error handling
|
|
185
|
+
- API endpoint
|
|
186
|
+
- authentication
|
|
187
|
+
</KEYWORDS>
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
response = self._call_llm(prompt)
|
|
192
|
+
# 从<KEYWORDS>标签中提取内容
|
|
193
|
+
response = response.strip()
|
|
194
|
+
yaml_match = re.search(r'<KEYWORDS>\s*(.*?)\s*</KEYWORDS>', response, re.DOTALL)
|
|
195
|
+
if yaml_match:
|
|
196
|
+
yaml_content = yaml_match.group(1).strip()
|
|
197
|
+
else:
|
|
198
|
+
# 如果没有找到标签,尝试清理markdown代码块
|
|
199
|
+
if response.startswith("```yaml"):
|
|
200
|
+
response = response[7:]
|
|
201
|
+
elif response.startswith("```"):
|
|
202
|
+
response = response[3:]
|
|
203
|
+
if response.endswith("```"):
|
|
204
|
+
response = response[:-3]
|
|
205
|
+
yaml_content = response.strip()
|
|
206
|
+
|
|
207
|
+
keywords = yaml.safe_load(yaml_content)
|
|
208
|
+
if not isinstance(keywords, list):
|
|
209
|
+
return []
|
|
210
|
+
|
|
211
|
+
# 过滤空字符串和过短的关键词
|
|
212
|
+
keywords = [k.strip() for k in keywords if k and isinstance(k, str) and len(k.strip()) > 1]
|
|
213
|
+
return keywords
|
|
214
|
+
except Exception as e:
|
|
215
|
+
# 解析失败,返回空列表
|
|
216
|
+
PrettyOutput.print(f"LLM关键词提取失败: {e}", OutputType.WARNING)
|
|
217
|
+
return []
|
|
218
|
+
|
|
219
|
+
def _search_symbols_by_keywords(self, keywords: List[str]) -> List[Symbol]:
|
|
220
|
+
"""基于关键词在符号表中查找相关符号
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
keywords: 关键词列表
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
候选符号列表
|
|
227
|
+
"""
|
|
228
|
+
if not keywords:
|
|
229
|
+
return []
|
|
230
|
+
|
|
231
|
+
found_symbols: List[Symbol] = []
|
|
232
|
+
keywords_lower = [k.lower() for k in keywords]
|
|
233
|
+
found_symbol_keys = set() # 用于去重,使用 (file_path, name, line_start) 作为键
|
|
234
|
+
|
|
235
|
+
# 遍历所有符号,查找名称或签名中包含关键词的符号
|
|
236
|
+
for symbol_name, symbols in self.context_manager.symbol_table.symbols_by_name.items():
|
|
237
|
+
symbol_name_lower = symbol_name.lower()
|
|
238
|
+
|
|
239
|
+
# 检查符号名称是否包含任何关键词
|
|
240
|
+
name_matched = False
|
|
241
|
+
for keyword in keywords_lower:
|
|
242
|
+
if keyword in symbol_name_lower:
|
|
243
|
+
# 找到匹配的符号,添加所有同名符号(可能有重载)
|
|
244
|
+
for symbol in symbols:
|
|
245
|
+
key = (symbol.file_path, symbol.name, symbol.line_start)
|
|
246
|
+
if key not in found_symbol_keys:
|
|
247
|
+
found_symbols.append(symbol)
|
|
248
|
+
found_symbol_keys.add(key)
|
|
249
|
+
name_matched = True
|
|
250
|
+
break
|
|
251
|
+
|
|
252
|
+
# 如果名称不匹配,检查符号签名是否包含关键词
|
|
253
|
+
if not name_matched:
|
|
254
|
+
for symbol in symbols:
|
|
255
|
+
if symbol.signature:
|
|
256
|
+
signature_lower = symbol.signature.lower()
|
|
257
|
+
for keyword in keywords_lower:
|
|
258
|
+
if keyword in signature_lower:
|
|
259
|
+
key = (symbol.file_path, symbol.name, symbol.line_start)
|
|
260
|
+
if key not in found_symbol_keys:
|
|
261
|
+
found_symbols.append(symbol)
|
|
262
|
+
found_symbol_keys.add(key)
|
|
263
|
+
break
|
|
264
|
+
|
|
265
|
+
return found_symbols
|
|
266
|
+
|
|
267
|
+
def _search_text_by_keywords(self, keywords: List[str]) -> List[Symbol]:
|
|
268
|
+
"""基于关键词在文件内容中进行文本查找,找到相关符号
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
keywords: 关键词列表
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
候选符号列表(在包含关键词的文件中找到的符号)
|
|
275
|
+
"""
|
|
276
|
+
if not keywords:
|
|
277
|
+
return []
|
|
278
|
+
|
|
279
|
+
found_symbols: List[Symbol] = []
|
|
280
|
+
keywords_lower = [k.lower() for k in keywords]
|
|
281
|
+
|
|
282
|
+
# 获取所有已分析的文件
|
|
283
|
+
all_files = set()
|
|
284
|
+
for symbol_name, symbols in self.context_manager.symbol_table.symbols_by_name.items():
|
|
285
|
+
for symbol in symbols:
|
|
286
|
+
all_files.add(symbol.file_path)
|
|
287
|
+
|
|
288
|
+
# 在文件内容中搜索关键词
|
|
289
|
+
for file_path in all_files:
|
|
290
|
+
content = self.context_manager._get_file_content(file_path)
|
|
291
|
+
if not content:
|
|
292
|
+
continue
|
|
293
|
+
|
|
294
|
+
content_lower = content.lower()
|
|
295
|
+
|
|
296
|
+
# 检查文件内容是否包含任何关键词
|
|
297
|
+
file_matches = False
|
|
298
|
+
for keyword in keywords_lower:
|
|
299
|
+
if keyword in content_lower:
|
|
300
|
+
file_matches = True
|
|
301
|
+
break
|
|
302
|
+
|
|
303
|
+
if file_matches:
|
|
304
|
+
# 获取该文件中的所有符号
|
|
305
|
+
file_symbols = self.context_manager.symbol_table.get_file_symbols(file_path)
|
|
306
|
+
found_symbols.extend(file_symbols)
|
|
307
|
+
|
|
308
|
+
return found_symbols
|
|
309
|
+
|
|
310
|
+
def _select_relevant_symbols_with_llm(
|
|
311
|
+
self, user_input: str, keywords: List[str], candidate_symbols: List[Symbol]
|
|
312
|
+
) -> List[Symbol]:
|
|
313
|
+
"""使用LLM从候选符号中挑选关联度高的条目
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
user_input: 用户输入/任务描述
|
|
317
|
+
keywords: 关键词列表
|
|
318
|
+
candidate_symbols: 候选符号列表(包含位置信息)
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
选中的符号列表
|
|
322
|
+
"""
|
|
323
|
+
if not candidate_symbols:
|
|
324
|
+
return []
|
|
325
|
+
|
|
326
|
+
# 限制候选符号数量,避免prompt过长
|
|
327
|
+
candidates_to_consider = candidate_symbols[:100] # 最多100个候选
|
|
328
|
+
|
|
329
|
+
# 构建带编号的符号信息列表(包含位置信息)
|
|
330
|
+
symbol_info_list = []
|
|
331
|
+
for idx, symbol in enumerate(candidates_to_consider, start=1):
|
|
332
|
+
symbol_info = {
|
|
333
|
+
"序号": idx,
|
|
334
|
+
"name": symbol.name,
|
|
335
|
+
"kind": symbol.kind,
|
|
336
|
+
"file": os.path.relpath(symbol.file_path, self.context_manager.project_root),
|
|
337
|
+
"line": symbol.line_start,
|
|
338
|
+
"signature": symbol.signature or "",
|
|
339
|
+
}
|
|
340
|
+
symbol_info_list.append(symbol_info)
|
|
341
|
+
|
|
342
|
+
# 获取项目概况
|
|
343
|
+
project_overview = self._get_project_overview()
|
|
344
|
+
|
|
345
|
+
prompt = f"""根据以下任务描述和关键词,从候选符号列表中选择最相关的符号。
|
|
346
|
+
|
|
347
|
+
{project_overview}
|
|
348
|
+
|
|
349
|
+
任务描述:{user_input}
|
|
350
|
+
关键词:{', '.join(keywords)}
|
|
351
|
+
|
|
352
|
+
候选符号列表(已编号,包含位置信息):
|
|
353
|
+
{yaml.dump(symbol_info_list, allow_unicode=True, default_flow_style=False)}
|
|
354
|
+
|
|
355
|
+
请返回最相关的10-20个符号的序号(YAML数组格式),按相关性排序,并用<SELECTED_INDICES>标签包裹。
|
|
356
|
+
|
|
357
|
+
只返回序号数组,例如:
|
|
358
|
+
<SELECTED_INDICES>
|
|
359
|
+
- 3
|
|
360
|
+
- 7
|
|
361
|
+
- 12
|
|
362
|
+
- 15
|
|
363
|
+
- 23
|
|
364
|
+
</SELECTED_INDICES>
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
response = self._call_llm(prompt)
|
|
369
|
+
# 从<SELECTED_INDICES>标签中提取内容
|
|
370
|
+
response = response.strip()
|
|
371
|
+
yaml_match = re.search(r'<SELECTED_INDICES>\s*(.*?)\s*</SELECTED_INDICES>', response, re.DOTALL)
|
|
372
|
+
if yaml_match:
|
|
373
|
+
yaml_content = yaml_match.group(1).strip()
|
|
374
|
+
else:
|
|
375
|
+
# 如果没有找到标签,尝试清理markdown代码块
|
|
376
|
+
if response.startswith("```yaml"):
|
|
377
|
+
response = response[7:]
|
|
378
|
+
elif response.startswith("```"):
|
|
379
|
+
response = response[3:]
|
|
380
|
+
if response.endswith("```"):
|
|
381
|
+
response = response[:-3]
|
|
382
|
+
yaml_content = response.strip()
|
|
383
|
+
|
|
384
|
+
selected_indices = yaml.safe_load(yaml_content)
|
|
385
|
+
if not isinstance(selected_indices, list):
|
|
386
|
+
return []
|
|
387
|
+
|
|
388
|
+
# 根据序号查找对应的符号对象
|
|
389
|
+
selected_symbols = []
|
|
390
|
+
for idx in selected_indices:
|
|
391
|
+
# 序号从1开始,转换为列表索引(从0开始)
|
|
392
|
+
if isinstance(idx, int) and 1 <= idx <= len(candidates_to_consider):
|
|
393
|
+
symbol = candidates_to_consider[idx - 1]
|
|
394
|
+
selected_symbols.append(symbol)
|
|
395
|
+
|
|
396
|
+
return selected_symbols
|
|
397
|
+
except Exception as e:
|
|
398
|
+
# 解析失败,返回空列表
|
|
399
|
+
PrettyOutput.print(f"LLM符号筛选失败: {e}", OutputType.WARNING)
|
|
400
|
+
return []
|
|
401
|
+
|
|
402
|
+
def _call_llm(self, prompt: str) -> str:
|
|
403
|
+
"""调用LLM生成响应
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
prompt: 提示词
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
LLM生成的响应文本
|
|
410
|
+
"""
|
|
411
|
+
if not self.llm_model:
|
|
412
|
+
raise ValueError("LLM model not available")
|
|
413
|
+
|
|
414
|
+
try:
|
|
415
|
+
# 使用chat_until_success方法(BasePlatform的标准接口)
|
|
416
|
+
if hasattr(self.llm_model, 'chat_until_success'):
|
|
417
|
+
response = self.llm_model.chat_until_success(prompt)
|
|
418
|
+
return str(response)
|
|
419
|
+
else:
|
|
420
|
+
# 如果不支持chat_until_success,抛出异常
|
|
421
|
+
raise ValueError("LLM model does not support chat_until_success interface")
|
|
422
|
+
except Exception as e:
|
|
423
|
+
PrettyOutput.print(f"LLM调用失败: {e}", OutputType.WARNING)
|
|
424
|
+
raise
|
|
425
|
+
|
|
426
|
+
def format_recommendation(self, recommendation: ContextRecommendation) -> str:
|
|
427
|
+
"""格式化推荐结果为可读文本
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
recommendation: 推荐结果
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
格式化的文本
|
|
434
|
+
"""
|
|
435
|
+
if not recommendation.recommended_symbols:
|
|
436
|
+
return ""
|
|
437
|
+
|
|
438
|
+
lines = ["\n💡 智能上下文推荐:"]
|
|
439
|
+
lines.append("─" * 60)
|
|
440
|
+
|
|
441
|
+
# 输出:符号在文件中的位置
|
|
442
|
+
symbols_str = "\n ".join(
|
|
443
|
+
f"• 符号 `{s.name}` ({s.kind}) 位于文件 {os.path.relpath(s.file_path, self.context_manager.project_root)} 第 {s.line_start} 行"
|
|
444
|
+
for s in recommendation.recommended_symbols
|
|
445
|
+
)
|
|
446
|
+
lines.append(f"🔗 推荐符号位置 ({len(recommendation.recommended_symbols)}个):\n {symbols_str}")
|
|
447
|
+
|
|
448
|
+
lines.append("─" * 60)
|
|
449
|
+
lines.append("") # 空行
|
|
450
|
+
|
|
451
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class Symbol:
|
|
8
|
+
"""Represents a single symbol in the code."""
|
|
9
|
+
name: str
|
|
10
|
+
kind: str # e.g., 'function', 'class', 'variable', 'import'
|
|
11
|
+
file_path: str
|
|
12
|
+
line_start: int
|
|
13
|
+
line_end: int
|
|
14
|
+
signature: Optional[str] = None
|
|
15
|
+
docstring: Optional[str] = None
|
|
16
|
+
# Add more fields as needed, e.g., parent scope
|
|
17
|
+
parent: Optional[str] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SymbolTable:
|
|
21
|
+
"""Stores and provides access to symbols across a project."""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
# A dictionary to store symbols by their name for quick lookups.
|
|
25
|
+
# A symbol name can appear in multiple files, so it's a list.
|
|
26
|
+
self.symbols_by_name: Dict[str, List[Symbol]] = {}
|
|
27
|
+
# A dictionary to store symbols on a per-file basis.
|
|
28
|
+
self.symbols_by_file: Dict[str, List[Symbol]] = {}
|
|
29
|
+
|
|
30
|
+
def add_symbol(self, symbol: Symbol):
|
|
31
|
+
"""Adds a symbol to the table."""
|
|
32
|
+
if symbol.name not in self.symbols_by_name:
|
|
33
|
+
self.symbols_by_name[symbol.name] = []
|
|
34
|
+
self.symbols_by_name[symbol.name].append(symbol)
|
|
35
|
+
|
|
36
|
+
if symbol.file_path not in self.symbols_by_file:
|
|
37
|
+
self.symbols_by_file[symbol.file_path] = []
|
|
38
|
+
self.symbols_by_file[symbol.file_path].append(symbol)
|
|
39
|
+
|
|
40
|
+
def find_symbol(self, name: str, file_path: Optional[str] = None) -> List[Symbol]:
|
|
41
|
+
"""
|
|
42
|
+
Finds a symbol by name.
|
|
43
|
+
If file_path is provided, the search is limited to that file.
|
|
44
|
+
"""
|
|
45
|
+
if file_path:
|
|
46
|
+
return [
|
|
47
|
+
s for s in self.get_file_symbols(file_path) if s.name == name
|
|
48
|
+
]
|
|
49
|
+
return self.symbols_by_name.get(name, [])
|
|
50
|
+
|
|
51
|
+
def get_file_symbols(self, file_path: str) -> List[Symbol]:
|
|
52
|
+
"""Gets all symbols within a specific file."""
|
|
53
|
+
return self.symbols_by_file.get(file_path, [])
|
|
54
|
+
|
|
55
|
+
def clear_file_symbols(self, file_path: str):
|
|
56
|
+
"""Removes all symbols associated with a specific file."""
|
|
57
|
+
if file_path in self.symbols_by_file:
|
|
58
|
+
symbols_to_remove = self.symbols_by_file.pop(file_path)
|
|
59
|
+
for symbol in symbols_to_remove:
|
|
60
|
+
if symbol.name in self.symbols_by_name:
|
|
61
|
+
self.symbols_by_name[symbol.name] = [
|
|
62
|
+
s for s in self.symbols_by_name[symbol.name]
|
|
63
|
+
if s.file_path != file_path
|
|
64
|
+
]
|
|
65
|
+
if not self.symbols_by_name[symbol.name]:
|
|
66
|
+
del self.symbols_by_name[symbol.name]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SymbolExtractor:
|
|
70
|
+
"""Extracts symbols from a source code file."""
|
|
71
|
+
|
|
72
|
+
def extract_symbols(self, file_path: str, content: str) -> List[Symbol]:
|
|
73
|
+
"""
|
|
74
|
+
Extracts symbols (functions, classes, variables, etc.) from the code.
|
|
75
|
+
This method should be implemented by language-specific subclasses.
|
|
76
|
+
"""
|
|
77
|
+
raise NotImplementedError("Subclasses must implement this method.")
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from tree_sitter import Language, Parser, Node
|
|
5
|
+
|
|
6
|
+
from .symbol_extractor import Symbol, SymbolExtractor
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TreeSitterExtractor(SymbolExtractor):
|
|
10
|
+
"""
|
|
11
|
+
A generic symbol extractor that uses tree-sitter for parsing.
|
|
12
|
+
Subclasses must provide the language-specific details, such as the
|
|
13
|
+
tree-sitter Language object and the symbol query.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, language: Language, symbol_query: str):
|
|
17
|
+
self.language = language
|
|
18
|
+
self.parser = Parser()
|
|
19
|
+
self.parser.set_language(self.language)
|
|
20
|
+
self.symbol_query = symbol_query
|
|
21
|
+
|
|
22
|
+
def extract_symbols(self, file_path: str, content: str) -> List[Symbol]:
|
|
23
|
+
"""
|
|
24
|
+
Parses the code with tree-sitter and extracts symbols based on the query.
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
tree = self.parser.parse(bytes(content, "utf8"))
|
|
28
|
+
query = self.language.query(self.symbol_query)
|
|
29
|
+
captures = query.captures(tree.root_node)
|
|
30
|
+
|
|
31
|
+
symbols = []
|
|
32
|
+
for node, name in captures:
|
|
33
|
+
symbol = self._create_symbol_from_capture(node, name, file_path)
|
|
34
|
+
if symbol:
|
|
35
|
+
symbols.append(symbol)
|
|
36
|
+
return symbols
|
|
37
|
+
except Exception as e:
|
|
38
|
+
print(f"Error extracting symbols from {file_path} with tree-sitter: {e}")
|
|
39
|
+
return []
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def _create_symbol_from_capture(self, node: Node, name: str, file_path: str) -> Optional[Symbol]:
|
|
43
|
+
"""
|
|
44
|
+
Creates a Symbol object from a tree-sitter query capture.
|
|
45
|
+
This method must be implemented by subclasses to map capture names
|
|
46
|
+
(e.g., "function.name") to Symbol attributes.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|