jarvis-ai-assistant 0.1.53__py3-none-any.whl → 0.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

@@ -0,0 +1,740 @@
1
+ import hashlib
2
+ import os
3
+ import re
4
+ import sqlite3
5
+ import time
6
+ from typing import Dict, Any, List, Optional, Tuple
7
+
8
+ import yaml
9
+ from jarvis.models.base import BasePlatform
10
+ from jarvis.utils import OutputType, PrettyOutput, get_multiline_input, load_env_from_file
11
+ from jarvis.models.registry import PlatformRegistry
12
+
13
+ class JarvisCoder:
14
+ def __init__(self, root_dir: str, language: str):
15
+ """初始化代码修改工具"""
16
+ self.main_model = None
17
+ self.db_path = ""
18
+ self.root_dir = root_dir
19
+ self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM")
20
+ self.model = os.environ.get("JARVIS_CODEGEN_MODEL")
21
+ self.language = language
22
+
23
+ self.root_dir = self._find_git_root_dir(self.root_dir)
24
+ if not self.root_dir:
25
+ self.root_dir = root_dir
26
+
27
+ PrettyOutput.print(f"Git根目录: {self.root_dir}", OutputType.INFO)
28
+
29
+ # 1. 判断代码库路径是否存在,如果不存在,创建
30
+ if not os.path.exists(self.root_dir):
31
+ PrettyOutput.print(
32
+ "Root directory does not exist, creating...", OutputType.INFO)
33
+ os.makedirs(self.root_dir)
34
+
35
+ os.chdir(self.root_dir)
36
+
37
+ self.jarvis_dir = os.path.join(self.root_dir, ".jarvis-coder")
38
+ if not os.path.exists(self.jarvis_dir):
39
+ os.makedirs(self.jarvis_dir)
40
+
41
+ self.index_db_path = os.path.join(self.jarvis_dir, "index.db")
42
+ if not os.path.exists(self.index_db_path):
43
+ self._create_index_db()
44
+
45
+ self.record_dir = os.path.join(self.jarvis_dir, "record")
46
+ if not os.path.exists(self.record_dir):
47
+ os.makedirs(self.record_dir)
48
+
49
+ # 2. 判断代码库是否是git仓库,如果不是,初始化git仓库
50
+ if not os.path.exists(os.path.join(self.root_dir, ".git")):
51
+ PrettyOutput.print(
52
+ "Git repository does not exist, initializing...", OutputType.INFO)
53
+ os.system(f"git init")
54
+ # 2.1 添加所有的文件
55
+ os.system(f"git add .")
56
+ # 2.2 提交
57
+ os.system(f"git commit -m 'Initial commit'")
58
+
59
+ # 3. 查看代码库是否有未提交的文件,如果有,提交一次
60
+ if self._has_uncommitted_files():
61
+ PrettyOutput.print("代码库有未提交的文件,提交一次", OutputType.INFO)
62
+ os.system(f"git add .")
63
+ os.system(f"git commit -m 'commit before code edit'")
64
+
65
+ def _new_model(self):
66
+ """获取大模型"""
67
+ model = PlatformRegistry().get_global_platform_registry().create_platform(self.platform)
68
+ if self.model:
69
+ model_name = self.model
70
+ model.set_model_name(model_name)
71
+ return model
72
+
73
+ def _has_uncommitted_files(self) -> bool:
74
+ """判断代码库是否有未提交的文件"""
75
+ # 获取未暂存的修改
76
+ unstaged = os.popen("git diff --name-only").read()
77
+ # 获取已暂存但未提交的修改
78
+ staged = os.popen("git diff --cached --name-only").read()
79
+ # 获取未跟踪的文件
80
+ untracked = os.popen("git ls-files --others --exclude-standard").read()
81
+
82
+ return bool(unstaged or staged or untracked)
83
+
84
+ def _call_model_with_retry(self, model: BasePlatform, prompt: str, max_retries: int = 3, initial_delay: float = 1.0) -> Tuple[bool, str]:
85
+ """调用模型并支持重试
86
+
87
+ Args:
88
+ prompt: 提示词
89
+ max_retries: 最大重试次数
90
+ initial_delay: 初始延迟时间(秒)
91
+
92
+ Returns:
93
+ Tuple[bool, str]: (是否成功, 响应内容)
94
+ """
95
+ delay = initial_delay
96
+ for attempt in range(max_retries):
97
+ try:
98
+ response = model.chat(prompt)
99
+ return True, response
100
+ except Exception as e:
101
+ if attempt == max_retries - 1: # 最后一次尝试
102
+ PrettyOutput.print(f"调用模型失败: {str(e)}", OutputType.ERROR)
103
+ return False, str(e)
104
+
105
+ PrettyOutput.print(f"调用模型失败,{delay}秒后重试: {str(e)}", OutputType.WARNING)
106
+ time.sleep(delay)
107
+ delay *= 2 # 指数退避
108
+
109
+ def _get_key_info(self, file_path: str, content: str) -> Optional[Dict[str, Any]]:
110
+ """获取文件的关键信息
111
+
112
+ Args:
113
+ file_path: 文件路径
114
+ content: 文件内容
115
+
116
+ Returns:
117
+ Optional[Dict[str, Any]]: 文件信息,包含文件描述
118
+ """
119
+ model = self._new_model() # 创建新的模型实例
120
+
121
+ prompt = f"""你是一个资深程序员,请根据文件内容,生成文件的关键信息,要求如下,除了代码,不要输出任何内容:
122
+
123
+ 1. 文件路径: {file_path}
124
+ 2. 文件内容:(<CONTENT_START>和<CONTENT_END>之间的部分)
125
+ <CONTENT_START>
126
+ {content}
127
+ <CONTENT_END>
128
+ 3. 关键信息: 请生成文件的功能描述,仅输出以下格式内容
129
+ <FILE_INFO_START>
130
+ file_description: 这个文件的主要功能和作用描述
131
+ <FILE_INFO_END>
132
+ """
133
+ try:
134
+ response = model.chat(prompt)
135
+ model.delete_chat() # 删除会话历史
136
+ old_response = response
137
+ response = response.replace("<FILE_INFO_START>", "").replace("<FILE_INFO_END>", "")
138
+ if old_response != response:
139
+ return yaml.safe_load(response)
140
+ else:
141
+ return None
142
+ except Exception as e:
143
+ PrettyOutput.print(f"解析文件信息失败: {str(e)}", OutputType.ERROR)
144
+ return None
145
+ finally:
146
+ # 确保清理模型资源
147
+ try:
148
+ model.delete_chat()
149
+ except:
150
+ pass
151
+
152
+
153
+
154
+ def _get_file_md5(self, file_path: str) -> str:
155
+ """获取文件MD5"""
156
+ return hashlib.md5(open(file_path, "rb").read()).hexdigest()
157
+
158
+ def _create_index_db(self):
159
+ """创建索引数据库"""
160
+ if not os.path.exists(self.index_db_path):
161
+ PrettyOutput.print("Index database does not exist, creating...", OutputType.INFO)
162
+ index_db = sqlite3.connect(self.index_db_path)
163
+ index_db.execute(
164
+ "CREATE TABLE files (file_path TEXT PRIMARY KEY, file_md5 TEXT, file_description TEXT)")
165
+ index_db.commit()
166
+ index_db.close()
167
+ PrettyOutput.print("Index database created", OutputType.SUCCESS)
168
+ # commit
169
+ os.chdir(self.root_dir)
170
+ os.system(f"git add .gitignore -f")
171
+ os.system(f"git commit -m 'add index database'")
172
+
173
+ def _find_file_by_md5(self, file_md5: str) -> Optional[str]:
174
+ """根据文件MD5查找文件路径"""
175
+ index_db = sqlite3.connect(self.index_db_path)
176
+ cursor = index_db.cursor()
177
+ cursor.execute(
178
+ "SELECT file_path FROM files WHERE file_md5 = ?", (file_md5,))
179
+ result = cursor.fetchone()
180
+ index_db.close()
181
+ return result[0] if result else None
182
+
183
+ def _update_file_path(self, file_path: str, file_md5: str):
184
+ """更新文件路径"""
185
+ index_db = sqlite3.connect(self.index_db_path)
186
+ cursor = index_db.cursor()
187
+ cursor.execute(
188
+ "UPDATE files SET file_path = ? WHERE file_md5 = ?", (file_path, file_md5))
189
+ index_db.commit()
190
+ index_db.close()
191
+
192
+ def _insert_info(self, file_path: str, file_md5: str, file_description: str):
193
+ """插入文件信息"""
194
+ index_db = sqlite3.connect(self.index_db_path)
195
+ cursor = index_db.cursor()
196
+ cursor.execute("DELETE FROM files WHERE file_path = ?", (file_path,))
197
+ cursor.execute("INSERT INTO files (file_path, file_md5, file_description) VALUES (?, ?, ?)",
198
+ (file_path, file_md5, file_description))
199
+ index_db.commit()
200
+ index_db.close()
201
+
202
+ def _is_text_file(self, file_path: str) -> bool:
203
+ """判断文件是否是文本文件"""
204
+ try:
205
+ with open(file_path, 'rb') as f:
206
+ # 读取文件前1024个字节
207
+ chunk = f.read(1024)
208
+ # 检查是否包含空字节
209
+ if b'\x00' in chunk:
210
+ return False
211
+ # 尝试解码为文本
212
+ chunk.decode('utf-8')
213
+ return True
214
+ except:
215
+ return False
216
+
217
+ def _index_project(self):
218
+ """建立代码库索引"""
219
+ git_files = os.popen("git ls-files").read().splitlines()
220
+
221
+ index_db = sqlite3.connect(self.index_db_path)
222
+ cursor = index_db.cursor()
223
+ cursor.execute("SELECT file_path FROM files")
224
+ db_files = [row[0] for row in cursor.fetchall()]
225
+ for db_file in db_files:
226
+ if not os.path.exists(db_file):
227
+ cursor.execute("DELETE FROM files WHERE file_path = ?", (db_file,))
228
+ PrettyOutput.print(f"删除不存在的文件记录: {db_file}", OutputType.INFO)
229
+ index_db.commit()
230
+ index_db.close()
231
+
232
+ # 3. 遍历git管理的文件
233
+ for file_path in git_files:
234
+ if self._is_text_file(file_path):
235
+ # 计算文件MD5
236
+ file_md5 = self._get_file_md5(file_path)
237
+
238
+ # 查找文件
239
+ file_path_in_db = self._find_file_by_md5(file_md5)
240
+ if file_path_in_db:
241
+ PrettyOutput.print(
242
+ f"文件 {file_path} 重复,跳过", OutputType.INFO)
243
+ if file_path_in_db != file_path:
244
+ self._update_file_path(file_path, file_md5)
245
+ PrettyOutput.print(
246
+ f"文件 {file_path} 重复,更新路径为 {file_path}", OutputType.INFO)
247
+ continue
248
+
249
+ with open(file_path, "r", encoding="utf-8") as f:
250
+ file_content = f.read()
251
+ key_info = self._get_key_info(file_path, file_content)
252
+ if not key_info:
253
+ PrettyOutput.print(
254
+ f"文件 {file_path} 索引失败", OutputType.INFO)
255
+ continue
256
+ if "file_description" in key_info:
257
+ self._insert_info(file_path, file_md5, key_info["file_description"])
258
+ PrettyOutput.print(
259
+ f"文件 {file_path} 已建立索引", OutputType.INFO)
260
+ else:
261
+ PrettyOutput.print(
262
+ f"文件 {file_path} 不是代码文件,跳过", OutputType.INFO)
263
+ PrettyOutput.print("项目索引完成", OutputType.INFO)
264
+
265
+ def _find_related_files(self, feature: str) -> List[Dict]:
266
+ """根据需求描述,查找相关文件"""
267
+ try:
268
+ # Get all files from database
269
+ index_db = sqlite3.connect(self.index_db_path)
270
+ cursor = index_db.cursor()
271
+ cursor.execute("SELECT file_path, file_description FROM files")
272
+ all_files = cursor.fetchall()
273
+ index_db.close()
274
+ except sqlite3.Error as e:
275
+ PrettyOutput.print(f"数据库操作失败: {str(e)}", OutputType.ERROR)
276
+ return []
277
+
278
+ batch_size = 100
279
+ batch_results = [] # Store results from each batch with their scores
280
+
281
+ for i in range(0, len(all_files), batch_size):
282
+ batch_files = all_files[i:i + batch_size]
283
+
284
+ prompt = """你是资深程序员,请根据需求描述,从以下文件路径中选出最相关的文件,按相关度从高到低排序。
285
+
286
+ 相关度打分标准(0-9分):
287
+ - 9分:文件名直接包含需求中的关键词,且文件功能与需求完全匹配
288
+ - 7-8分:文件名包含需求相关词,或文件功能与需求高度相关
289
+ - 5-6分:文件名暗示与需求有关,或文件功能与需求部分相关
290
+ - 3-4分:文件可能需要小幅修改以配合需求
291
+ - 1-2分:文件与需求关系较远,但可能需要少量改动
292
+ - 0分:文件与需求完全无关
293
+
294
+ 请输出yaml格式,仅输出以下格式内容:
295
+ <RELEVANT_FILES_START>
296
+ file1.py: 9
297
+ file2.py: 7
298
+ <RELEVANT_FILES_END>
299
+
300
+ 文件列表:
301
+ """
302
+ for file_path, _ in batch_files:
303
+ prompt += f"- {file_path}\n"
304
+ prompt += f"\n需求描述: {feature}\n"
305
+ prompt += "\n注意:\n1. 只输出最相关的文件,不超过5个\n2. 根据上述打分标准判断相关性\n3. 相关度必须是0-9的整数"
306
+
307
+ success, response = self._call_model_with_retry(self._new_model(), prompt)
308
+ if not success:
309
+ continue
310
+
311
+ try:
312
+ response = response.replace("<RELEVANT_FILES_START>", "").replace("<RELEVANT_FILES_END>", "")
313
+ result = yaml.safe_load(response)
314
+
315
+ # Convert results to file objects with scores
316
+ batch_files_dict = {f[0]: f[1] for f in batch_files}
317
+ for file_path, score in result.items():
318
+ if isinstance(file_path, str) and isinstance(score, int):
319
+ score = max(0, min(9, score)) # Ensure score is between 0-9
320
+ if file_path in batch_files_dict:
321
+ batch_results.append({
322
+ "file_path": file_path,
323
+ "file_description": batch_files_dict[file_path],
324
+ "score": score
325
+ })
326
+
327
+ except Exception as e:
328
+ PrettyOutput.print(f"处理批次文件失败: {str(e)}", OutputType.ERROR)
329
+ continue
330
+
331
+ # Sort all results by score
332
+ batch_results.sort(key=lambda x: x["score"], reverse=True)
333
+ top_files = batch_results[:5]
334
+
335
+ # If we don't have enough files, add more from database
336
+ if len(top_files) < 5:
337
+ remaining_files = [f for f in all_files if f[0] not in [tf["file_path"] for tf in top_files]]
338
+ top_files.extend([{
339
+ "file_path": f[0],
340
+ "file_description": f[1],
341
+ "score": 0
342
+ } for f in remaining_files[:5-len(top_files)]])
343
+
344
+ # Now do content relevance analysis on these files
345
+ score = [[], [], [], [], [], [], [], [], [], []]
346
+
347
+ prompt = """你是资深程序员,请根据需求描述,分析文件的相关性。
348
+
349
+ 相关度打分标准(0-9分):
350
+ - 9分:文件内容与需求完全匹配,是实现需求的核心文件
351
+ - 7-8分:文件内容与需求高度相关,需要较大改动
352
+ - 5-6分:文件内容与需求部分相关,需要中等改动
353
+ - 3-4分:文件内容与需求相关性较低,但需要配合修改
354
+ - 1-2分:文件内容与需求关系较远,只需极少改动
355
+ - 0分:文件内容与需求完全无关
356
+
357
+ 文件列表如下:
358
+ <FILE_LIST_START>
359
+ """
360
+ for i, file in enumerate(top_files):
361
+ prompt += f"""{i}. {file["file_path"]} : {file["file_description"]}\n"""
362
+ prompt += f"""需求描述: {feature}\n"""
363
+ prompt += "<FILE_LIST_END>\n"
364
+ prompt += """请根据需求描述和文件描述,分析文件的相关性,输出每个编号的相关性[0~9],仅输出以下格式内容(key为文件编号,value为相关性):
365
+ <FILE_RELATION_START>
366
+ "0": 5
367
+ "1": 3
368
+ <FILE_RELATION_END>"""
369
+
370
+ success, response = self._call_model_with_retry(self._new_model(), prompt)
371
+ if not success:
372
+ return top_files[:5] # Return top 5 files from filename matching if model fails
373
+
374
+ try:
375
+ response = response.replace("<FILE_RELATION_START>", "").replace("<FILE_RELATION_END>", "")
376
+ file_relation = yaml.safe_load(response)
377
+ if not file_relation:
378
+ return top_files[:5]
379
+
380
+ for file_id, relation in file_relation.items():
381
+ id = int(file_id)
382
+ relation = max(0, min(9, relation)) # 确保范围在0-9之间
383
+ score[relation].append(top_files[id])
384
+
385
+ except Exception as e:
386
+ PrettyOutput.print(f"处理文件关系失败: {str(e)}", OutputType.ERROR)
387
+ return top_files[:5]
388
+
389
+ files = []
390
+ score.reverse()
391
+ for i in score:
392
+ files.extend(i)
393
+ if len(files) >= 5: # 直接取相关性最高的5个文件
394
+ break
395
+
396
+ return files[:5]
397
+
398
+ def _remake_patch(self, prompt: str) -> List[str]:
399
+ success, response = self._call_model_with_retry(self.main_model, prompt, max_retries=5) # 增加重试次数
400
+ if not success:
401
+ return []
402
+
403
+ try:
404
+ patches = re.findall(r'<PATCH_START>.*?<PATCH_END>', response, re.DOTALL)
405
+ return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
406
+ for patch in patches if patch.strip()]
407
+ except Exception as e:
408
+ PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.ERROR)
409
+ return []
410
+
411
+ def _make_patch(self, related_files: List[Dict], feature: str) -> List[str]:
412
+ """生成修改方案"""
413
+ prompt = """你是一个资深程序员,请根据需求描述,修改文件内容。
414
+
415
+ 修改格式说明:
416
+ 1. 每个修改块格式如下:
417
+ <PATCH_START>
418
+ >>>>>> path/to/file
419
+ 要替换的内容
420
+ ======
421
+ 新的内容
422
+ <<<<<<
423
+ <PATCH_END>
424
+
425
+ 2. 如果是新文件,格式如下:
426
+ <PATCH_START>
427
+ >>>>>> path/to/new/file
428
+ ======
429
+ 新文件的完整内容
430
+ <<<<<<
431
+ <PATCH_END>
432
+
433
+ 文件列表如下:
434
+ """
435
+ for i, file in enumerate(related_files):
436
+ prompt += f"""{i}. {file["file_path"]} : {file["file_description"]}\n"""
437
+ prompt += f"""文件内容:\n"""
438
+ prompt += f"<FILE_CONTENT_START>\n"
439
+ prompt += f'{file["file_content"]}\n'
440
+ prompt += f"<FILE_CONTENT_END>\n"
441
+
442
+ prompt += f"\n需求描述: {feature}\n"
443
+ prompt += """
444
+ 注意事项:
445
+ 1、仅输出补丁内容,不要输出任何其他内容,每个补丁必须用<PATCH_START>和<PATCH_END>标记
446
+ 2、如果在大段代码中有零星修改,生成多个补丁
447
+ """
448
+
449
+ success, response = self._call_model_with_retry(self.main_model, prompt)
450
+ if not success:
451
+ return []
452
+
453
+ try:
454
+ # 使用正则表达式匹配每个patch块
455
+ patches = re.findall(r'<PATCH_START>.*?<PATCH_END>', response, re.DOTALL)
456
+ return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
457
+ for patch in patches if patch.strip()]
458
+ except Exception as e:
459
+ PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.ERROR)
460
+ return []
461
+
462
+ def _apply_patch(self, related_files: List[Dict], patches: List[str]) -> Tuple[bool, str]:
463
+ """应用补丁"""
464
+ error_info = []
465
+ modified_files = set()
466
+
467
+ # 创建文件内容映射
468
+ file_map = {file["file_path"]: file["file_content"] for file in related_files}
469
+ temp_map = file_map.copy() # 创建临时映射用于尝试应用
470
+
471
+ # 尝试应用所有补丁
472
+ for i, patch in enumerate(patches):
473
+ PrettyOutput.print(f"正在应用补丁 {i+1}/{len(patches)}", OutputType.INFO)
474
+
475
+ try:
476
+ # 解析补丁
477
+ lines = patch.split("\n")
478
+ if not lines:
479
+ continue
480
+
481
+ # 获取文件路径
482
+ file_path_match = re.search(r'>>>>>> (.*)', lines[0])
483
+ if not file_path_match:
484
+ error_info.append(f"无法解析文件路径: {lines[0]}")
485
+ return False, "\n".join(error_info)
486
+
487
+ file_path = file_path_match.group(1).strip()
488
+
489
+ # 解析补丁内容
490
+ patch_content = "\n".join(lines[1:])
491
+ parts = patch_content.split("======")
492
+
493
+ if len(parts) != 2:
494
+ error_info.append(f"补丁格式错误: {file_path}")
495
+ return False, "\n".join(error_info)
496
+
497
+ old_content = parts[0].strip()
498
+ new_content = parts[1].split("<<<<<<")[0].strip()
499
+
500
+ # 处理新文件
501
+ if not old_content:
502
+ temp_map[file_path] = new_content
503
+ modified_files.add(file_path)
504
+ continue
505
+
506
+ # 处理文件修改
507
+ if file_path not in temp_map:
508
+ error_info.append(f"文件不存在: {file_path}")
509
+ return False, "\n".join(error_info)
510
+
511
+ current_content = temp_map[file_path]
512
+
513
+ # 查找并替换代码块
514
+ if old_content not in current_content:
515
+ error_info.append(
516
+ f"补丁应用失败: {file_path}\n"
517
+ f"原因: 未找到要替换的代码\n"
518
+ f"期望找到的代码:\n{old_content}\n"
519
+ f"实际文件内容:\n{current_content[:200]}..." # 只显示前200个字符
520
+ )
521
+ return False, "\n".join(error_info)
522
+
523
+ # 应用更改
524
+ temp_map[file_path] = current_content.replace(old_content, new_content)
525
+ modified_files.add(file_path)
526
+
527
+ except Exception as e:
528
+ error_info.append(f"处理补丁时发生错误: {str(e)}")
529
+ return False, "\n".join(error_info)
530
+
531
+ # 所有补丁都应用成功,更新实际文件
532
+ for file_path in modified_files:
533
+ try:
534
+ dir_path = os.path.dirname(file_path)
535
+ if dir_path and not os.path.exists(dir_path):
536
+ os.makedirs(dir_path, exist_ok=True)
537
+
538
+ with open(file_path, "w", encoding="utf-8") as f:
539
+ f.write(temp_map[file_path])
540
+
541
+ PrettyOutput.print(f"成功修改文件: {file_path}", OutputType.SUCCESS)
542
+
543
+ except Exception as e:
544
+ error_info.append(f"写入文件失败 {file_path}: {str(e)}")
545
+ return False, "\n".join(error_info)
546
+
547
+ return True, ""
548
+
549
+ def _save_edit_record(self, feature: str, patches: List[str]) -> None:
550
+ """保存代码修改记录
551
+
552
+ Args:
553
+ feature: 需求描述
554
+ patches: 补丁列表
555
+ """
556
+
557
+ # 获取下一个序号
558
+ existing_records = [f for f in os.listdir(self.record_dir) if f.endswith('.yaml')]
559
+ next_num = 1
560
+ if existing_records:
561
+ last_num = max(int(f[:4]) for f in existing_records)
562
+ next_num = last_num + 1
563
+
564
+ # 创建记录文件
565
+ record = {
566
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
567
+ "feature": feature,
568
+ "patches": patches
569
+ }
570
+
571
+ record_path = os.path.join(self.record_dir, f"{next_num:04d}.yaml")
572
+ with open(record_path, "w", encoding="utf-8") as f:
573
+ yaml.safe_dump(record, f, allow_unicode=True)
574
+
575
+ PrettyOutput.print(f"已保存修改记录: {record_path}", OutputType.SUCCESS)
576
+
577
+ def _find_git_root_dir(self, root_dir: str) -> str:
578
+ """查找git根目录"""
579
+ while not os.path.exists(os.path.join(root_dir, ".git")):
580
+ root_dir = os.path.dirname(root_dir)
581
+ if root_dir == "/":
582
+ return None
583
+ return root_dir
584
+
585
+ def execute(self, feature: str) -> Dict[str, Any]:
586
+ """执行代码修改
587
+
588
+ Args:
589
+ args: 包含操作参数的字典
590
+ - feature: 要实现的功能描述
591
+ - root_dir: 代码库根目录
592
+ - language: 编程语言
593
+
594
+ Returns:
595
+ Dict[str, Any]: 包含执行结果的字典
596
+ - success: 是否成功
597
+ - stdout: 标准输出信息
598
+ - stderr: 错误信息
599
+ - error: 错误对象(如果有)
600
+ """
601
+ try:
602
+ self.main_model = self._new_model() # 每次执行时重新创建模型
603
+
604
+
605
+ # 4. 开始建立代码库索引
606
+ self._index_project()
607
+
608
+ # 5. 根据索引和需求,查找相关文件
609
+ related_files = self._find_related_files(feature)
610
+ for file in related_files:
611
+ PrettyOutput.print(f"Related file: {file['file_path']}", OutputType.INFO)
612
+ for file in related_files:
613
+ with open(file["file_path"], "r", encoding="utf-8") as f:
614
+ file_content = f.read()
615
+ file["file_content"] = file_content
616
+ patches = self._make_patch(related_files, feature)
617
+ while True:
618
+ # 生成修改方案
619
+ PrettyOutput.print(f"生成{len(patches)}个补丁", OutputType.INFO)
620
+
621
+ if not patches:
622
+ retry_prompt = f"""未生成补丁,请重新生成补丁"""
623
+ patches = self._remake_patch(retry_prompt)
624
+ continue
625
+
626
+ # 尝试应用补丁
627
+ success, error_info = self._apply_patch(related_files, patches)
628
+
629
+ if success:
630
+ # 用户确认修改
631
+ user_confirm = input("是否确认修改?(y/n)")
632
+ if user_confirm.lower() == "y":
633
+ PrettyOutput.print("修改确认成功,提交修改", OutputType.INFO)
634
+
635
+ os.system(f"git add .")
636
+ os.system(f"git commit -m '{feature}'")
637
+
638
+ # 保存修改记录
639
+ self._save_edit_record(feature, patches)
640
+ # 重新建立代码库索引
641
+ self._index_project()
642
+
643
+ return {
644
+ "success": True,
645
+ "stdout": f"已完成功能开发{feature}",
646
+ "stderr": "",
647
+ "error": None
648
+ }
649
+ else:
650
+ PrettyOutput.print("修改已取消,回退更改", OutputType.INFO)
651
+
652
+ os.system(f"git reset --hard") # 回退已修改的文件
653
+ os.system(f"git clean -df") # 删除新创建的文件和目录
654
+
655
+ return {
656
+ "success": False,
657
+ "stdout": "",
658
+ "stderr": "修改被用户取消,文件未发生任何变化",
659
+ "error": UserWarning("用户取消修改")
660
+ }
661
+ else:
662
+ # 补丁应用失败,让模型重新生成
663
+ PrettyOutput.print(f"补丁应用失败,请求重新生成: {error_info}", OutputType.WARNING)
664
+ retry_prompt = f"""补丁应用失败,请根据以下错误信息重新生成补丁:
665
+
666
+ 错误信息:
667
+ {error_info}
668
+
669
+ 请确保:
670
+ 1. 准确定位要修改的代码位置
671
+ 2. 正确处理代码缩进
672
+ 3. 考虑代码上下文
673
+ 4. 对新文件不要包含原始内容
674
+ """
675
+ patches = self._remake_patch(retry_prompt)
676
+ continue
677
+
678
+ except Exception as e:
679
+ return {
680
+ "success": False,
681
+ "stdout": "",
682
+ "stderr": f"执行失败: {str(e)}",
683
+ "error": e
684
+ }
685
+
686
+
687
+ def main():
688
+ """命令行入口"""
689
+ import argparse
690
+
691
+ load_env_from_file()
692
+
693
+ parser = argparse.ArgumentParser(description='代码修改工具')
694
+ parser.add_argument('-p', '--platform', help='AI平台名称', default=os.environ.get('JARVIS_CODEGEN_PLATFORM'))
695
+ parser.add_argument('-m', '--model', help='模型名称', default=os.environ.get('JARVIS_CODEGEN_MODEL'))
696
+ parser.add_argument('-d', '--dir', help='项目根目录', default=os.getcwd())
697
+ parser.add_argument('-l', '--language', help='编程语言', default="python")
698
+ args = parser.parse_args()
699
+
700
+ # 设置平台
701
+ if not args.platform:
702
+ print("错误: 未指定AI平台,请使用 -p 参数")
703
+ # 设置模型
704
+ if args.model:
705
+ os.environ['JARVIS_CODEGEN_MODEL'] = args.model
706
+
707
+ tool = JarvisCoder(args.dir, args.language)
708
+
709
+ # 循环处理需求
710
+ while True:
711
+ try:
712
+ # 获取需求
713
+ feature = get_multiline_input("请输入开发需求 (输入空行退出):")
714
+
715
+ if not feature or feature == "__interrupt__":
716
+ break
717
+
718
+ # 执行修改
719
+ result = tool.execute(feature)
720
+
721
+ # 显示结果
722
+ if result["success"]:
723
+ PrettyOutput.print(result["stdout"], OutputType.SUCCESS)
724
+ else:
725
+ if result["stderr"]:
726
+ PrettyOutput.print(result["stderr"], OutputType.ERROR)
727
+ if result["error"]:
728
+ PrettyOutput.print(f"错误类型: {type(result['error']).__name__}", OutputType.ERROR)
729
+
730
+ except KeyboardInterrupt:
731
+ print("\n用户中断执行")
732
+ break
733
+ except Exception as e:
734
+ PrettyOutput.print(f"执行出错: {str(e)}", OutputType.ERROR)
735
+ continue
736
+
737
+ return 0
738
+
739
+ if __name__ == "__main__":
740
+ exit(main())