jarvis-ai-assistant 0.1.58__py3-none-any.whl → 0.1.74__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +132 -63
- jarvis/jarvis_codebase/__init__.py +0 -0
- jarvis/jarvis_codebase/main.py +636 -0
- jarvis/jarvis_coder/__init__.py +0 -0
- jarvis/jarvis_coder/main.py +249 -384
- jarvis/main.py +0 -2
- jarvis/models/ai8.py +2 -3
- jarvis/models/base.py +1 -5
- jarvis/models/kimi.py +2 -0
- jarvis/models/openai.py +1 -2
- jarvis/models/oyi.py +2 -5
- jarvis/models/registry.py +8 -7
- jarvis/tools/__init__.py +1 -0
- jarvis/tools/codebase_qa.py +74 -0
- jarvis/tools/coder.py +69 -0
- jarvis/tools/methodology.py +16 -16
- jarvis/tools/registry.py +1 -1
- jarvis/tools/search.py +33 -1
- jarvis/utils.py +8 -1
- {jarvis_ai_assistant-0.1.58.dist-info → jarvis_ai_assistant-0.1.74.dist-info}/METADATA +94 -39
- jarvis_ai_assistant-0.1.74.dist-info/RECORD +33 -0
- {jarvis_ai_assistant-0.1.58.dist-info → jarvis_ai_assistant-0.1.74.dist-info}/entry_points.txt +1 -0
- jarvis/tools/bing_search.py +0 -38
- jarvis_ai_assistant-0.1.58.dist-info/RECORD +0 -29
- {jarvis_ai_assistant-0.1.58.dist-info → jarvis_ai_assistant-0.1.74.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.58.dist-info → jarvis_ai_assistant-0.1.74.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.58.dist-info → jarvis_ai_assistant-0.1.74.dist-info}/top_level.txt +0 -0
jarvis/jarvis_coder/main.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
|
1
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
2
|
-
import hashlib
|
|
3
1
|
import os
|
|
4
2
|
import re
|
|
5
|
-
import sqlite3
|
|
6
3
|
import threading
|
|
7
4
|
import time
|
|
8
|
-
from typing import Dict, Any, List,
|
|
5
|
+
from typing import Dict, Any, List, Tuple
|
|
9
6
|
|
|
10
7
|
import yaml
|
|
11
8
|
from jarvis.models.base import BasePlatform
|
|
12
|
-
from jarvis.utils import OutputType, PrettyOutput, get_multiline_input, load_env_from_file
|
|
9
|
+
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_multiline_input, load_env_from_file
|
|
13
10
|
from jarvis.models.registry import PlatformRegistry
|
|
11
|
+
from jarvis.jarvis_codebase.main import CodeBase
|
|
12
|
+
from prompt_toolkit import PromptSession
|
|
13
|
+
from prompt_toolkit.completion import WordCompleter, Completer, Completion
|
|
14
|
+
from prompt_toolkit.formatted_text import FormattedText
|
|
15
|
+
from prompt_toolkit.styles import Style
|
|
16
|
+
import fnmatch
|
|
14
17
|
|
|
15
18
|
# 全局锁对象
|
|
16
19
|
index_lock = threading.Lock()
|
|
@@ -18,12 +21,15 @@ index_lock = threading.Lock()
|
|
|
18
21
|
class JarvisCoder:
|
|
19
22
|
def __init__(self, root_dir: str, language: str):
|
|
20
23
|
"""初始化代码修改工具"""
|
|
24
|
+
|
|
25
|
+
self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM") or os.environ.get("JARVIS_PLATFORM")
|
|
26
|
+
self.model = os.environ.get("JARVIS_CODEGEN_MODEL") or os.environ.get("JARVIS_MODEL")
|
|
21
27
|
|
|
22
|
-
self.root_dir = root_dir
|
|
23
|
-
self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM")
|
|
24
|
-
self.model = os.environ.get("JARVIS_CODEGEN_MODEL")
|
|
25
28
|
|
|
26
|
-
self.
|
|
29
|
+
if not self.platform or not self.model:
|
|
30
|
+
raise ValueError("JARVIS_CODEGEN_PLATFORM or JARVIS_CODEGEN_MODEL is not set")
|
|
31
|
+
|
|
32
|
+
self.root_dir = find_git_root(root_dir)
|
|
27
33
|
if not self.root_dir:
|
|
28
34
|
self.root_dir = root_dir
|
|
29
35
|
|
|
@@ -41,10 +47,6 @@ class JarvisCoder:
|
|
|
41
47
|
if not os.path.exists(self.jarvis_dir):
|
|
42
48
|
os.makedirs(self.jarvis_dir)
|
|
43
49
|
|
|
44
|
-
self.index_db_path = os.path.join(self.jarvis_dir, "index.db")
|
|
45
|
-
if not os.path.exists(self.index_db_path):
|
|
46
|
-
self._create_index_db()
|
|
47
|
-
|
|
48
50
|
self.record_dir = os.path.join(self.jarvis_dir, "record")
|
|
49
51
|
if not os.path.exists(self.record_dir):
|
|
50
52
|
os.makedirs(self.record_dir)
|
|
@@ -65,6 +67,9 @@ class JarvisCoder:
|
|
|
65
67
|
os.system(f"git add .")
|
|
66
68
|
os.system(f"git commit -m 'commit before code edit'")
|
|
67
69
|
|
|
70
|
+
# 4. 初始化代码库
|
|
71
|
+
self._codebase = CodeBase(self.root_dir)
|
|
72
|
+
|
|
68
73
|
def _new_model(self):
|
|
69
74
|
"""获取大模型"""
|
|
70
75
|
model = PlatformRegistry().get_global_platform_registry().create_platform(self.platform)
|
|
@@ -109,341 +114,6 @@ class JarvisCoder:
|
|
|
109
114
|
time.sleep(delay)
|
|
110
115
|
delay *= 2 # 指数退避
|
|
111
116
|
|
|
112
|
-
def _get_key_info(self, file_path: str, content: str) -> Optional[Dict[str, Any]]:
|
|
113
|
-
"""获取文件的关键信息
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
file_path: 文件路径
|
|
117
|
-
content: 文件内容
|
|
118
|
-
|
|
119
|
-
Returns:
|
|
120
|
-
Optional[Dict[str, Any]]: 文件信息,包含文件描述
|
|
121
|
-
"""
|
|
122
|
-
model = self._new_model() # 创建新的模型实例
|
|
123
|
-
model.set_suppress_output(True)
|
|
124
|
-
|
|
125
|
-
prompt = f"""你是一个资深程序员,请根据文件内容,生成文件的关键信息,要求如下,除了代码,不要输出任何内容:
|
|
126
|
-
|
|
127
|
-
1. 文件路径: {file_path}
|
|
128
|
-
2. 文件内容:(<CONTENT_START>和<CONTENT_END>之间的部分)
|
|
129
|
-
<CONTENT_START>
|
|
130
|
-
{content}
|
|
131
|
-
<CONTENT_END>
|
|
132
|
-
3. 关键信息: 请生成这个文件的主要功能和作用描述,包含的特征符号(函数和类、变量等),不超过100字
|
|
133
|
-
"""
|
|
134
|
-
try:
|
|
135
|
-
return model.chat(prompt)
|
|
136
|
-
except Exception as e:
|
|
137
|
-
PrettyOutput.print(f"解析文件信息失败: {str(e)}", OutputType.ERROR)
|
|
138
|
-
return None
|
|
139
|
-
finally:
|
|
140
|
-
# 确保清理模型资源
|
|
141
|
-
try:
|
|
142
|
-
model.delete_chat()
|
|
143
|
-
except:
|
|
144
|
-
pass
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
def _get_file_md5(self, file_path: str) -> str:
|
|
149
|
-
"""获取文件MD5"""
|
|
150
|
-
return hashlib.md5(open(file_path, "rb").read()).hexdigest()
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
def _create_index_db(self):
|
|
154
|
-
"""创建索引数据库"""
|
|
155
|
-
with index_lock:
|
|
156
|
-
if not os.path.exists(self.index_db_path):
|
|
157
|
-
PrettyOutput.print("Index database does not exist, creating...", OutputType.INFO)
|
|
158
|
-
index_db = sqlite3.connect(self.index_db_path)
|
|
159
|
-
index_db.execute(
|
|
160
|
-
"CREATE TABLE files (file_path TEXT PRIMARY KEY, file_md5 TEXT, file_description TEXT)")
|
|
161
|
-
index_db.commit()
|
|
162
|
-
index_db.close()
|
|
163
|
-
PrettyOutput.print("Index database created", OutputType.SUCCESS)
|
|
164
|
-
# commit
|
|
165
|
-
os.chdir(self.root_dir)
|
|
166
|
-
os.system(f"git add .gitignore -f")
|
|
167
|
-
os.system(f"git commit -m 'add index database'")
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
def _find_file_by_md5(self, file_md5: str) -> Optional[str]:
|
|
171
|
-
"""根据文件MD5查找文件路径"""
|
|
172
|
-
with index_lock:
|
|
173
|
-
index_db = sqlite3.connect(self.index_db_path)
|
|
174
|
-
cursor = index_db.cursor()
|
|
175
|
-
cursor.execute(
|
|
176
|
-
"SELECT file_path FROM files WHERE file_md5 = ?", (file_md5,))
|
|
177
|
-
result = cursor.fetchone()
|
|
178
|
-
index_db.close()
|
|
179
|
-
return result[0] if result else None
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
def _update_file_path(self, file_path: str, file_md5: str):
|
|
183
|
-
"""更新文件路径"""
|
|
184
|
-
with index_lock:
|
|
185
|
-
index_db = sqlite3.connect(self.index_db_path)
|
|
186
|
-
cursor = index_db.cursor()
|
|
187
|
-
cursor.execute(
|
|
188
|
-
"UPDATE files SET file_path = ? WHERE file_md5 = ?", (file_path, file_md5))
|
|
189
|
-
index_db.commit()
|
|
190
|
-
index_db.close()
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def _insert_info(self, file_path: str, file_md5: str, file_description: str):
|
|
194
|
-
"""插入文件信息"""
|
|
195
|
-
with index_lock:
|
|
196
|
-
index_db = sqlite3.connect(self.index_db_path)
|
|
197
|
-
cursor = index_db.cursor()
|
|
198
|
-
cursor.execute("DELETE FROM files WHERE file_path = ?", (file_path,))
|
|
199
|
-
cursor.execute("INSERT INTO files (file_path, file_md5, file_description) VALUES (?, ?, ?)",
|
|
200
|
-
(file_path, file_md5, file_description))
|
|
201
|
-
index_db.commit()
|
|
202
|
-
index_db.close()
|
|
203
|
-
|
|
204
|
-
def _is_text_file(self, file_path: str) -> bool:
|
|
205
|
-
"""判断文件是否是文本文件"""
|
|
206
|
-
try:
|
|
207
|
-
with open(file_path, 'rb') as f:
|
|
208
|
-
# 读取文件前1024个字节
|
|
209
|
-
chunk = f.read(1024)
|
|
210
|
-
# 检查是否包含空字节
|
|
211
|
-
if b'\x00' in chunk:
|
|
212
|
-
return False
|
|
213
|
-
# 尝试解码为文本
|
|
214
|
-
chunk.decode('utf-8')
|
|
215
|
-
return True
|
|
216
|
-
except:
|
|
217
|
-
return False
|
|
218
|
-
|
|
219
|
-
def _index_project(self):
|
|
220
|
-
"""建立代码库索引"""
|
|
221
|
-
import threading
|
|
222
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
223
|
-
|
|
224
|
-
git_files = os.popen("git ls-files").read().splitlines()
|
|
225
|
-
|
|
226
|
-
index_db = sqlite3.connect(self.index_db_path)
|
|
227
|
-
cursor = index_db.cursor()
|
|
228
|
-
cursor.execute("SELECT file_path FROM files")
|
|
229
|
-
db_files = [row[0] for row in cursor.fetchall()]
|
|
230
|
-
for db_file in db_files:
|
|
231
|
-
if not os.path.exists(db_file):
|
|
232
|
-
cursor.execute("DELETE FROM files WHERE file_path = ?", (db_file,))
|
|
233
|
-
PrettyOutput.print(f"删除不存在的文件记录: {db_file}", OutputType.INFO)
|
|
234
|
-
index_db.commit()
|
|
235
|
-
index_db.close()
|
|
236
|
-
|
|
237
|
-
def process_file(file_path: str):
|
|
238
|
-
"""处理单个文件的索引任务"""
|
|
239
|
-
if not self._is_text_file(file_path):
|
|
240
|
-
return
|
|
241
|
-
|
|
242
|
-
# 计算文件MD5
|
|
243
|
-
file_md5 = self._get_file_md5(file_path)
|
|
244
|
-
|
|
245
|
-
# 查找文件
|
|
246
|
-
file_path_in_db = self._find_file_by_md5(file_md5)
|
|
247
|
-
if file_path_in_db:
|
|
248
|
-
PrettyOutput.print(
|
|
249
|
-
f"文件 {file_path} 重复,跳过", OutputType.INFO)
|
|
250
|
-
if file_path_in_db != file_path:
|
|
251
|
-
self._update_file_path(file_path, file_md5)
|
|
252
|
-
PrettyOutput.print(
|
|
253
|
-
f"文件 {file_path} 重复,更新路径为 {file_path}", OutputType.INFO)
|
|
254
|
-
return
|
|
255
|
-
|
|
256
|
-
with open(file_path, "r", encoding="utf-8") as f:
|
|
257
|
-
file_content = f.read()
|
|
258
|
-
key_info = self._get_key_info(file_path, file_content)
|
|
259
|
-
if not key_info:
|
|
260
|
-
PrettyOutput.print(
|
|
261
|
-
f"文件 {file_path} 索引失败", OutputType.INFO)
|
|
262
|
-
return
|
|
263
|
-
|
|
264
|
-
self._insert_info(file_path, file_md5, key_info)
|
|
265
|
-
PrettyOutput.print(
|
|
266
|
-
f"文件 {file_path} 已建立索引", OutputType.INFO)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
# 使用线程池处理文件索引
|
|
270
|
-
with ThreadPoolExecutor(max_workers=10) as executor:
|
|
271
|
-
futures = [executor.submit(process_file, file_path) for file_path in git_files]
|
|
272
|
-
for future in as_completed(futures):
|
|
273
|
-
try:
|
|
274
|
-
future.result()
|
|
275
|
-
except Exception as e:
|
|
276
|
-
PrettyOutput.print(f"处理文件时发生错误: {str(e)}", OutputType.ERROR)
|
|
277
|
-
|
|
278
|
-
PrettyOutput.print("项目索引完成", OutputType.INFO)
|
|
279
|
-
|
|
280
|
-
def _get_files_from_db(self) -> List[Tuple[str, str]]:
|
|
281
|
-
"""从数据库获取所有文件信息
|
|
282
|
-
|
|
283
|
-
Returns:
|
|
284
|
-
List[Tuple[str, str]]: [(file_path, file_description), ...]
|
|
285
|
-
"""
|
|
286
|
-
try:
|
|
287
|
-
index_db = sqlite3.connect(self.index_db_path)
|
|
288
|
-
cursor = index_db.cursor()
|
|
289
|
-
cursor.execute("SELECT file_path, file_description FROM files")
|
|
290
|
-
all_files = cursor.fetchall()
|
|
291
|
-
index_db.close()
|
|
292
|
-
return all_files
|
|
293
|
-
except sqlite3.Error as e:
|
|
294
|
-
PrettyOutput.print(f"数据库操作失败: {str(e)}", OutputType.ERROR)
|
|
295
|
-
return []
|
|
296
|
-
|
|
297
|
-
def _analyze_files_in_batches(self, all_files: List[Tuple[str, str]], feature: str, batch_size: int = 100) -> List[Dict]:
|
|
298
|
-
"""批量分析文件相关性
|
|
299
|
-
|
|
300
|
-
Args:
|
|
301
|
-
all_files: 所有文件列表
|
|
302
|
-
feature: 需求描述
|
|
303
|
-
batch_size: 批处理大小
|
|
304
|
-
|
|
305
|
-
Returns:
|
|
306
|
-
List[Dict]: 带评分的文件列表
|
|
307
|
-
"""
|
|
308
|
-
batch_results = []
|
|
309
|
-
|
|
310
|
-
with ThreadPoolExecutor(max_workers=10) as executor:
|
|
311
|
-
futures = []
|
|
312
|
-
for i in range(0, len(all_files), batch_size):
|
|
313
|
-
batch_files = all_files[i:i + batch_size]
|
|
314
|
-
prompt = self._create_batch_analysis_prompt(batch_files, feature)
|
|
315
|
-
model = self._new_model()
|
|
316
|
-
model.set_suppress_output(True)
|
|
317
|
-
futures.append(executor.submit(self._call_model_with_retry, model, prompt))
|
|
318
|
-
|
|
319
|
-
for future in as_completed(futures):
|
|
320
|
-
success, response = future.result()
|
|
321
|
-
if not success:
|
|
322
|
-
continue
|
|
323
|
-
|
|
324
|
-
batch_start = futures.index(future) * batch_size
|
|
325
|
-
batch_end = min(batch_start + batch_size, len(all_files))
|
|
326
|
-
current_batch = all_files[batch_start:batch_end]
|
|
327
|
-
|
|
328
|
-
results = self._process_batch_response(response, current_batch)
|
|
329
|
-
batch_results.extend(results)
|
|
330
|
-
|
|
331
|
-
return batch_results
|
|
332
|
-
|
|
333
|
-
def _create_batch_analysis_prompt(self, batch_files: List[Tuple[str, str]], feature: str) -> str:
|
|
334
|
-
"""创建批量分析的提示词
|
|
335
|
-
|
|
336
|
-
Args:
|
|
337
|
-
batch_files: 批次文件列表
|
|
338
|
-
feature: 需求描述
|
|
339
|
-
|
|
340
|
-
Returns:
|
|
341
|
-
str: 提示词
|
|
342
|
-
"""
|
|
343
|
-
prompt = """你是资深程序员,请根据需求描述,从以下文件路径中选出最相关的文件,按相关度从高到低排序。
|
|
344
|
-
|
|
345
|
-
相关度打分标准(0-9分):
|
|
346
|
-
- 9分:文件名直接包含需求中的关键词,且文件功能与需求完全匹配
|
|
347
|
-
- 7-8分:文件名包含需求相关词,或文件功能与需求高度相关
|
|
348
|
-
- 5-6分:文件名暗示与需求有关,或文件功能与需求部分相关
|
|
349
|
-
- 3-4分:文件可能需要小幅修改以配合需求
|
|
350
|
-
- 1-2分:文件与需求关系较远,但可能需要少量改动
|
|
351
|
-
- 0分:文件与需求完全无关
|
|
352
|
-
|
|
353
|
-
请输出yaml格式,仅输出以下格式内容:
|
|
354
|
-
<RELEVANT_FILES_START>
|
|
355
|
-
file1.py: 9
|
|
356
|
-
file2.py: 7
|
|
357
|
-
<RELEVANT_FILES_END>
|
|
358
|
-
|
|
359
|
-
文件列表:
|
|
360
|
-
"""
|
|
361
|
-
for file_path, _ in batch_files:
|
|
362
|
-
prompt += f"- {file_path}\n"
|
|
363
|
-
prompt += f"\n需求描述: {feature}\n"
|
|
364
|
-
prompt += "\n注意:\n1. 只输出最相关的文件,不超过5个\n2. 根据上述打分标准判断相关性\n3. 相关度必须是0-9的整数"
|
|
365
|
-
|
|
366
|
-
return prompt
|
|
367
|
-
|
|
368
|
-
def _process_batch_response(self, response: str, batch_files: List[Tuple[str, str]]) -> List[Dict]:
|
|
369
|
-
"""处理批量分析的响应
|
|
370
|
-
|
|
371
|
-
Args:
|
|
372
|
-
response: 模型响应
|
|
373
|
-
batch_files: 批次文件列表
|
|
374
|
-
|
|
375
|
-
Returns:
|
|
376
|
-
List[Dict]: 处理后的文件列表
|
|
377
|
-
"""
|
|
378
|
-
try:
|
|
379
|
-
response = response.replace("<RELEVANT_FILES_START>", "").replace("<RELEVANT_FILES_END>", "")
|
|
380
|
-
result = yaml.safe_load(response)
|
|
381
|
-
|
|
382
|
-
batch_files_dict = {f[0]: f[1] for f in batch_files}
|
|
383
|
-
results = []
|
|
384
|
-
for file_path, score in result.items():
|
|
385
|
-
if isinstance(file_path, str) and isinstance(score, int):
|
|
386
|
-
score = max(0, min(9, score)) # Ensure score is between 0-9
|
|
387
|
-
if file_path in batch_files_dict:
|
|
388
|
-
results.append({
|
|
389
|
-
"file_path": file_path,
|
|
390
|
-
"file_description": batch_files_dict[file_path],
|
|
391
|
-
"score": score
|
|
392
|
-
})
|
|
393
|
-
return results
|
|
394
|
-
except Exception as e:
|
|
395
|
-
PrettyOutput.print(f"处理批次文件失败: {str(e)}", OutputType.ERROR)
|
|
396
|
-
return []
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
def _process_content_response(self, response: str, top_files: List[Dict]) -> List[Dict]:
|
|
400
|
-
"""处理内容分析的响应"""
|
|
401
|
-
try:
|
|
402
|
-
response = response.replace("<FILE_RELATION_START>", "").replace("<FILE_RELATION_END>", "")
|
|
403
|
-
file_relation = yaml.safe_load(response)
|
|
404
|
-
if not file_relation:
|
|
405
|
-
return top_files[:5]
|
|
406
|
-
|
|
407
|
-
score = [[] for _ in range(10)] # 创建10个空列表,对应0-9分
|
|
408
|
-
for file_id, relation in file_relation.items():
|
|
409
|
-
id = int(file_id)
|
|
410
|
-
relation = max(0, min(9, relation)) # 确保范围在0-9之间
|
|
411
|
-
score[relation].append(top_files[id])
|
|
412
|
-
|
|
413
|
-
files = []
|
|
414
|
-
for scores in reversed(score): # 从高分到低分遍历
|
|
415
|
-
files.extend(scores)
|
|
416
|
-
if len(files) >= 5: # 直接取相关性最高的5个文件
|
|
417
|
-
break
|
|
418
|
-
|
|
419
|
-
return files[:5]
|
|
420
|
-
except Exception as e:
|
|
421
|
-
PrettyOutput.print(f"处理文件关系失败: {str(e)}", OutputType.ERROR)
|
|
422
|
-
return top_files[:5]
|
|
423
|
-
|
|
424
|
-
def _find_related_files(self, feature: str) -> List[Dict]:
|
|
425
|
-
"""根据需求描述,查找相关文件
|
|
426
|
-
|
|
427
|
-
Args:
|
|
428
|
-
feature: 需求描述
|
|
429
|
-
|
|
430
|
-
Returns:
|
|
431
|
-
List[Dict]: 相关文件列表
|
|
432
|
-
"""
|
|
433
|
-
# 1. 从数据库获取所有文件
|
|
434
|
-
all_files = self._get_files_from_db()
|
|
435
|
-
if not all_files:
|
|
436
|
-
return []
|
|
437
|
-
|
|
438
|
-
# 2. 批量分析文件相关性
|
|
439
|
-
batch_results = self._analyze_files_in_batches(all_files, feature)
|
|
440
|
-
|
|
441
|
-
# 3. 排序并获取前5个文件
|
|
442
|
-
batch_results.sort(key=lambda x: x["score"], reverse=True)
|
|
443
|
-
return batch_results[:5]
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
117
|
def _remake_patch(self, prompt: str) -> List[str]:
|
|
448
118
|
success, response = self._call_model_with_retry(self.main_model, prompt, max_retries=5) # 增加重试次数
|
|
449
119
|
if not success:
|
|
@@ -454,7 +124,7 @@ file2.py: 7
|
|
|
454
124
|
return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
|
|
455
125
|
for patch in patches if patch.strip()]
|
|
456
126
|
except Exception as e:
|
|
457
|
-
PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.
|
|
127
|
+
PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.WARNING)
|
|
458
128
|
return []
|
|
459
129
|
|
|
460
130
|
def _make_patch(self, related_files: List[Dict], feature: str) -> List[str]:
|
|
@@ -468,21 +138,32 @@ file2.py: 7
|
|
|
468
138
|
要替换的内容
|
|
469
139
|
=======
|
|
470
140
|
新的内容
|
|
471
|
-
|
|
141
|
+
>>>>>>
|
|
472
142
|
<PATCH_END>
|
|
473
143
|
|
|
474
|
-
2.
|
|
144
|
+
2. 如果是新文件或者替换整个文件内容,格式如下:
|
|
475
145
|
<PATCH_START>
|
|
476
146
|
>>>>>> path/to/new/file
|
|
477
147
|
=======
|
|
478
148
|
新文件的完整内容
|
|
479
|
-
|
|
149
|
+
>>>>>>
|
|
150
|
+
<PATCH_END>
|
|
151
|
+
|
|
152
|
+
3. 如果要删除文件中的某一段,格式如下:
|
|
153
|
+
<PATCH_START>
|
|
154
|
+
>>>>>> path/to/file
|
|
155
|
+
要删除的内容
|
|
156
|
+
=======
|
|
157
|
+
>>>>>>
|
|
480
158
|
<PATCH_END>
|
|
481
159
|
|
|
482
160
|
文件列表如下:
|
|
483
161
|
"""
|
|
484
162
|
for i, file in enumerate(related_files):
|
|
485
|
-
prompt
|
|
163
|
+
if len(prompt) > 30 * 1024:
|
|
164
|
+
PrettyOutput.print(f'避免上下文超限,丢弃低相关度文件:{file["file_path"]}', OutputType.WARNING)
|
|
165
|
+
continue
|
|
166
|
+
prompt += f"""{i}. {file["file_path"]}\n"""
|
|
486
167
|
prompt += f"""文件内容:\n"""
|
|
487
168
|
prompt += f"<FILE_CONTENT_START>\n"
|
|
488
169
|
prompt += f'{file["file_content"]}\n'
|
|
@@ -493,6 +174,8 @@ file2.py: 7
|
|
|
493
174
|
注意事项:
|
|
494
175
|
1、仅输出补丁内容,不要输出任何其他内容,每个补丁必须用<PATCH_START>和<PATCH_END>标记
|
|
495
176
|
2、如果在大段代码中有零星修改,生成多个补丁
|
|
177
|
+
3、要替换的内容,一定要与文件内容完全一致,不要有任何多余或者缺失的内容
|
|
178
|
+
4、每个patch不超过20行,超出20行,请生成多个patch
|
|
496
179
|
"""
|
|
497
180
|
|
|
498
181
|
success, response = self._call_model_with_retry(self.main_model, prompt)
|
|
@@ -505,7 +188,7 @@ file2.py: 7
|
|
|
505
188
|
return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
|
|
506
189
|
for patch in patches if patch.strip()]
|
|
507
190
|
except Exception as e:
|
|
508
|
-
PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.
|
|
191
|
+
PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.WARNING)
|
|
509
192
|
return []
|
|
510
193
|
|
|
511
194
|
def _apply_patch(self, related_files: List[Dict], patches: List[str]) -> Tuple[bool, str]:
|
|
@@ -544,7 +227,7 @@ file2.py: 7
|
|
|
544
227
|
return False, "\n".join(error_info)
|
|
545
228
|
|
|
546
229
|
old_content = parts[0]
|
|
547
|
-
new_content = parts[1].split("
|
|
230
|
+
new_content = parts[1].split(">>>>>>")[0]
|
|
548
231
|
|
|
549
232
|
# 处理新文件
|
|
550
233
|
if not old_content:
|
|
@@ -595,12 +278,12 @@ file2.py: 7
|
|
|
595
278
|
|
|
596
279
|
return True, ""
|
|
597
280
|
|
|
598
|
-
def _save_edit_record(self,
|
|
281
|
+
def _save_edit_record(self, commit_message: str, git_diff: str) -> None:
|
|
599
282
|
"""保存代码修改记录
|
|
600
283
|
|
|
601
284
|
Args:
|
|
602
|
-
|
|
603
|
-
|
|
285
|
+
commit_message: 提交信息
|
|
286
|
+
git_diff: git diff --cached的输出
|
|
604
287
|
"""
|
|
605
288
|
|
|
606
289
|
# 获取下一个序号
|
|
@@ -613,8 +296,8 @@ file2.py: 7
|
|
|
613
296
|
# 创建记录文件
|
|
614
297
|
record = {
|
|
615
298
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
616
|
-
"
|
|
617
|
-
"
|
|
299
|
+
"commit_message": commit_message,
|
|
300
|
+
"git_diff": git_diff
|
|
618
301
|
}
|
|
619
302
|
|
|
620
303
|
record_path = os.path.join(self.record_dir, f"{next_num:04d}.yaml")
|
|
@@ -623,28 +306,30 @@ file2.py: 7
|
|
|
623
306
|
|
|
624
307
|
PrettyOutput.print(f"已保存修改记录: {record_path}", OutputType.SUCCESS)
|
|
625
308
|
|
|
626
|
-
|
|
627
|
-
"""查找git根目录"""
|
|
628
|
-
while not os.path.exists(os.path.join(root_dir, ".git")):
|
|
629
|
-
root_dir = os.path.dirname(root_dir)
|
|
630
|
-
if root_dir == "/":
|
|
631
|
-
return None
|
|
632
|
-
return root_dir
|
|
309
|
+
|
|
633
310
|
|
|
634
311
|
|
|
635
312
|
def _prepare_execution(self) -> None:
|
|
636
313
|
"""准备执行环境"""
|
|
637
314
|
self.main_model = self._new_model()
|
|
638
|
-
self.
|
|
315
|
+
self._codebase.generate_codebase()
|
|
316
|
+
|
|
639
317
|
|
|
640
318
|
def _load_related_files(self, feature: str) -> List[Dict]:
|
|
641
319
|
"""加载相关文件内容"""
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
320
|
+
ret = []
|
|
321
|
+
# 确保索引数据库已生成
|
|
322
|
+
if not self._codebase.is_index_generated():
|
|
323
|
+
PrettyOutput.print("检测到索引数据库未生成,正在生成...", OutputType.WARNING)
|
|
324
|
+
self._codebase.generate_codebase()
|
|
325
|
+
|
|
326
|
+
related_files = self._codebase.search_similar(feature)
|
|
327
|
+
for file, score, _ in related_files:
|
|
328
|
+
PrettyOutput.print(f"相关文件: {file} 相关度: {score:.3f}", OutputType.SUCCESS)
|
|
329
|
+
with open(file, "r", encoding="utf-8") as f:
|
|
330
|
+
content = f.read()
|
|
331
|
+
ret.append({"file_path": file, "file_content": content})
|
|
332
|
+
return ret
|
|
648
333
|
|
|
649
334
|
def _handle_patch_application(self, related_files: List[Dict], patches: List[str], feature: str) -> Dict[str, Any]:
|
|
650
335
|
"""处理补丁应用流程"""
|
|
@@ -661,7 +346,7 @@ file2.py: 7
|
|
|
661
346
|
if success:
|
|
662
347
|
user_confirm = input("是否确认修改?(y/n)")
|
|
663
348
|
if user_confirm.lower() == "y":
|
|
664
|
-
self._finalize_changes(feature
|
|
349
|
+
self._finalize_changes(feature)
|
|
665
350
|
return {
|
|
666
351
|
"success": True,
|
|
667
352
|
"stdout": f"已完成功能开发{feature}",
|
|
@@ -691,13 +376,74 @@ file2.py: 7
|
|
|
691
376
|
"""
|
|
692
377
|
patches = self._remake_patch(retry_prompt)
|
|
693
378
|
|
|
694
|
-
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _generate_commit_message(self, git_diff: str, feature: str) -> str:
|
|
383
|
+
"""根据git diff和功能描述生成commit信息
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
git_diff: git diff --cached的输出
|
|
387
|
+
feature: 用户的功能描述
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
str: 生成的commit信息
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
# 生成提示词
|
|
394
|
+
prompt = f"""你是一个经验丰富的程序员,请根据以下代码变更和功能描述生成简洁明了的commit信息:
|
|
395
|
+
|
|
396
|
+
功能描述:
|
|
397
|
+
{feature}
|
|
398
|
+
|
|
399
|
+
代码变更:
|
|
400
|
+
"""
|
|
401
|
+
# 添加git diff内容
|
|
402
|
+
prompt += f"Git Diff:\n{git_diff}\n\n"
|
|
403
|
+
|
|
404
|
+
prompt += """
|
|
405
|
+
请遵循以下规则:
|
|
406
|
+
1. 使用英文编写
|
|
407
|
+
2. 采用常规的commit message格式:<type>(<scope>): <subject>
|
|
408
|
+
3. 保持简洁,不超过50个字符
|
|
409
|
+
4. 准确描述代码变更的主要内容
|
|
410
|
+
5. 优先考虑功能描述和git diff中的变更内容
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
# 使用normal模型生成commit信息
|
|
414
|
+
model = PlatformRegistry().get_global_platform_registry().create_platform(self.platform)
|
|
415
|
+
model.set_model_name(self.model)
|
|
416
|
+
model.set_suppress_output(True)
|
|
417
|
+
success, response = self._call_model_with_retry(model, prompt)
|
|
418
|
+
if not success:
|
|
419
|
+
return "Update code changes"
|
|
420
|
+
|
|
421
|
+
# 清理响应内容
|
|
422
|
+
return response.strip().split("\n")[0]
|
|
423
|
+
|
|
424
|
+
def _finalize_changes(self, feature: str) -> None:
|
|
695
425
|
"""完成修改并提交"""
|
|
696
426
|
PrettyOutput.print("修改确认成功,提交修改", OutputType.INFO)
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
427
|
+
|
|
428
|
+
# 只添加已经在 git 控制下的修改文件
|
|
429
|
+
os.system("git add -u")
|
|
430
|
+
|
|
431
|
+
# 然后获取 git diff
|
|
432
|
+
git_diff = os.popen("git diff --cached").read()
|
|
433
|
+
|
|
434
|
+
# 自动生成commit信息,传入feature
|
|
435
|
+
commit_message = self._generate_commit_message(git_diff, feature)
|
|
436
|
+
|
|
437
|
+
# 显示并确认commit信息
|
|
438
|
+
PrettyOutput.print(f"自动生成的commit信息: {commit_message}", OutputType.INFO)
|
|
439
|
+
user_confirm = input("是否使用该commit信息?(y/n) [y]: ") or "y"
|
|
440
|
+
|
|
441
|
+
if user_confirm.lower() != "y":
|
|
442
|
+
commit_message = input("请输入新的commit信息: ")
|
|
443
|
+
|
|
444
|
+
# 不需要再次 git add,因为已经添加过了
|
|
445
|
+
os.system(f"git commit -m '{commit_message}'")
|
|
446
|
+
self._save_edit_record(commit_message, git_diff)
|
|
701
447
|
|
|
702
448
|
def _revert_changes(self) -> None:
|
|
703
449
|
"""回退所有修改"""
|
|
@@ -758,8 +504,8 @@ def main():
|
|
|
758
504
|
# 循环处理需求
|
|
759
505
|
while True:
|
|
760
506
|
try:
|
|
761
|
-
#
|
|
762
|
-
feature = get_multiline_input("请输入开发需求 (输入空行退出):")
|
|
507
|
+
# 获取需求,传入项目根目录
|
|
508
|
+
feature = get_multiline_input("请输入开发需求 (输入空行退出):", tool.root_dir)
|
|
763
509
|
|
|
764
510
|
if not feature or feature == "__interrupt__":
|
|
765
511
|
break
|
|
@@ -772,9 +518,9 @@ def main():
|
|
|
772
518
|
PrettyOutput.print(result["stdout"], OutputType.SUCCESS)
|
|
773
519
|
else:
|
|
774
520
|
if result["stderr"]:
|
|
775
|
-
PrettyOutput.print(result["stderr"], OutputType.
|
|
521
|
+
PrettyOutput.print(result["stderr"], OutputType.WARNING)
|
|
776
522
|
if result["error"]:
|
|
777
|
-
PrettyOutput.print(f"错误类型: {type(result['error']).__name__}", OutputType.
|
|
523
|
+
PrettyOutput.print(f"错误类型: {type(result['error']).__name__}", OutputType.WARNING)
|
|
778
524
|
|
|
779
525
|
except KeyboardInterrupt:
|
|
780
526
|
print("\n用户中断执行")
|
|
@@ -787,3 +533,122 @@ def main():
|
|
|
787
533
|
|
|
788
534
|
if __name__ == "__main__":
|
|
789
535
|
exit(main())
|
|
536
|
+
|
|
537
|
+
class FilePathCompleter(Completer):
|
|
538
|
+
"""文件路径自动完成器"""
|
|
539
|
+
|
|
540
|
+
def __init__(self, root_dir: str):
|
|
541
|
+
self.root_dir = root_dir
|
|
542
|
+
self._file_list = None
|
|
543
|
+
|
|
544
|
+
def _get_files(self) -> List[str]:
|
|
545
|
+
"""获取git管理的文件列表"""
|
|
546
|
+
if self._file_list is None:
|
|
547
|
+
try:
|
|
548
|
+
# 切换到项目根目录
|
|
549
|
+
old_cwd = os.getcwd()
|
|
550
|
+
os.chdir(self.root_dir)
|
|
551
|
+
|
|
552
|
+
# 获取git管理的文件列表
|
|
553
|
+
self._file_list = os.popen("git ls-files").read().splitlines()
|
|
554
|
+
|
|
555
|
+
# 恢复工作目录
|
|
556
|
+
os.chdir(old_cwd)
|
|
557
|
+
except Exception as e:
|
|
558
|
+
PrettyOutput.print(f"获取文件列表失败: {str(e)}", OutputType.WARNING)
|
|
559
|
+
self._file_list = []
|
|
560
|
+
return self._file_list
|
|
561
|
+
|
|
562
|
+
def get_completions(self, document, complete_event):
|
|
563
|
+
"""获取补全建议"""
|
|
564
|
+
text_before_cursor = document.text_before_cursor
|
|
565
|
+
|
|
566
|
+
# 检查是否刚输入了@
|
|
567
|
+
if text_before_cursor.endswith('@'):
|
|
568
|
+
# 显示所有文件
|
|
569
|
+
for path in self._get_files():
|
|
570
|
+
yield Completion(path, start_position=0)
|
|
571
|
+
return
|
|
572
|
+
|
|
573
|
+
# 检查之前是否有@,并获取@后的搜索词
|
|
574
|
+
at_pos = text_before_cursor.rfind('@')
|
|
575
|
+
if at_pos == -1:
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
search = text_before_cursor[at_pos + 1:].lower().strip()
|
|
579
|
+
|
|
580
|
+
# 提供匹配的文件建议
|
|
581
|
+
for path in self._get_files():
|
|
582
|
+
path_lower = path.lower()
|
|
583
|
+
if (search in path_lower or # 直接包含
|
|
584
|
+
search in os.path.basename(path_lower) or # 文件名包含
|
|
585
|
+
any(fnmatch.fnmatch(path_lower, f'*{s}*') for s in search.split())): # 通配符匹配
|
|
586
|
+
# 计算正确的start_position
|
|
587
|
+
yield Completion(path, start_position=-(len(search)))
|
|
588
|
+
|
|
589
|
+
class SmartCompleter(Completer):
|
|
590
|
+
"""智能自动完成器,组合词语和文件路径补全"""
|
|
591
|
+
|
|
592
|
+
def __init__(self, word_completer: WordCompleter, file_completer: FilePathCompleter):
|
|
593
|
+
self.word_completer = word_completer
|
|
594
|
+
self.file_completer = file_completer
|
|
595
|
+
|
|
596
|
+
def get_completions(self, document, complete_event):
|
|
597
|
+
"""获取补全建议"""
|
|
598
|
+
# 如果当前行以@结尾,使用文件补全
|
|
599
|
+
if document.text_before_cursor.strip().endswith('@'):
|
|
600
|
+
yield from self.file_completer.get_completions(document, complete_event)
|
|
601
|
+
else:
|
|
602
|
+
# 否则使用词语补全
|
|
603
|
+
yield from self.word_completer.get_completions(document, complete_event)
|
|
604
|
+
|
|
605
|
+
def get_multiline_input(prompt_text: str, root_dir: str = None) -> str:
|
|
606
|
+
"""获取多行输入,支持文件路径自动完成功能
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
prompt_text: 提示文本
|
|
610
|
+
root_dir: 项目根目录,用于文件补全
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
str: 用户输入的文本
|
|
614
|
+
"""
|
|
615
|
+
# 创建文件补全器
|
|
616
|
+
file_completer = FilePathCompleter(root_dir or os.getcwd())
|
|
617
|
+
|
|
618
|
+
# 创建提示样式
|
|
619
|
+
style = Style.from_dict({
|
|
620
|
+
'prompt': 'ansicyan bold',
|
|
621
|
+
'input': 'ansiwhite',
|
|
622
|
+
})
|
|
623
|
+
|
|
624
|
+
# 创建会话
|
|
625
|
+
session = PromptSession(
|
|
626
|
+
completer=file_completer,
|
|
627
|
+
style=style,
|
|
628
|
+
multiline=False,
|
|
629
|
+
enable_history_search=True,
|
|
630
|
+
complete_while_typing=True
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# 显示初始提示文本
|
|
634
|
+
print(f"\n{prompt_text}")
|
|
635
|
+
|
|
636
|
+
# 创建提示符
|
|
637
|
+
prompt = FormattedText([
|
|
638
|
+
('class:prompt', ">>> ")
|
|
639
|
+
])
|
|
640
|
+
|
|
641
|
+
# 获取输入
|
|
642
|
+
lines = []
|
|
643
|
+
try:
|
|
644
|
+
while True:
|
|
645
|
+
line = session.prompt(prompt).strip()
|
|
646
|
+
if not line: # 空行表示输入结束
|
|
647
|
+
break
|
|
648
|
+
lines.append(line)
|
|
649
|
+
except KeyboardInterrupt:
|
|
650
|
+
return "__interrupt__"
|
|
651
|
+
except EOFError:
|
|
652
|
+
pass
|
|
653
|
+
|
|
654
|
+
return "\n".join(lines)
|