jarvis-ai-assistant 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

@@ -1,7 +1,9 @@
1
+ from concurrent.futures import ThreadPoolExecutor, as_completed
1
2
  import hashlib
2
3
  import os
3
4
  import re
4
5
  import sqlite3
6
+ import threading
5
7
  import time
6
8
  from typing import Dict, Any, List, Optional, Tuple
7
9
 
@@ -9,16 +11,22 @@ import yaml
9
11
  from jarvis.models.base import BasePlatform
10
12
  from jarvis.utils import OutputType, PrettyOutput, get_multiline_input, load_env_from_file
11
13
  from jarvis.models.registry import PlatformRegistry
14
+ from prompt_toolkit import PromptSession
15
+ from prompt_toolkit.completion import WordCompleter, Completer, Completion
16
+ from prompt_toolkit.formatted_text import FormattedText
17
+ from prompt_toolkit.styles import Style
18
+ import fnmatch
19
+
20
+ # 全局锁对象
21
+ index_lock = threading.Lock()
12
22
 
13
23
  class JarvisCoder:
14
24
  def __init__(self, root_dir: str, language: str):
15
25
  """初始化代码修改工具"""
16
- self.main_model = None
17
- self.db_path = ""
26
+
18
27
  self.root_dir = root_dir
19
28
  self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM")
20
29
  self.model = os.environ.get("JARVIS_CODEGEN_MODEL")
21
- self.language = language
22
30
 
23
31
  self.root_dir = self._find_git_root_dir(self.root_dir)
24
32
  if not self.root_dir:
@@ -117,6 +125,7 @@ class JarvisCoder:
117
125
  Optional[Dict[str, Any]]: 文件信息,包含文件描述
118
126
  """
119
127
  model = self._new_model() # 创建新的模型实例
128
+ model.set_suppress_output(True)
120
129
 
121
130
  prompt = f"""你是一个资深程序员,请根据文件内容,生成文件的关键信息,要求如下,除了代码,不要输出任何内容:
122
131
 
@@ -125,20 +134,10 @@ class JarvisCoder:
125
134
  <CONTENT_START>
126
135
  {content}
127
136
  <CONTENT_END>
128
- 3. 关键信息: 请生成文件的功能描述,仅输出以下格式内容
129
- <FILE_INFO_START>
130
- file_description: 这个文件的主要功能和作用描述
131
- <FILE_INFO_END>
137
+ 3. 关键信息: 请生成这个文件的主要功能和作用描述,包含的特征符号(函数和类、变量等),不超过100字
132
138
  """
133
139
  try:
134
- response = model.chat(prompt)
135
- model.delete_chat() # 删除会话历史
136
- old_response = response
137
- response = response.replace("<FILE_INFO_START>", "").replace("<FILE_INFO_END>", "")
138
- if old_response != response:
139
- return yaml.safe_load(response)
140
- else:
141
- return None
140
+ return model.chat(prompt)
142
141
  except Exception as e:
143
142
  PrettyOutput.print(f"解析文件信息失败: {str(e)}", OutputType.ERROR)
144
143
  return None
@@ -155,49 +154,57 @@ file_description: 这个文件的主要功能和作用描述
155
154
  """获取文件MD5"""
156
155
  return hashlib.md5(open(file_path, "rb").read()).hexdigest()
157
156
 
157
+
158
158
  def _create_index_db(self):
159
159
  """创建索引数据库"""
160
- if not os.path.exists(self.index_db_path):
161
- PrettyOutput.print("Index database does not exist, creating...", OutputType.INFO)
162
- index_db = sqlite3.connect(self.index_db_path)
163
- index_db.execute(
164
- "CREATE TABLE files (file_path TEXT PRIMARY KEY, file_md5 TEXT, file_description TEXT)")
165
- index_db.commit()
166
- index_db.close()
167
- PrettyOutput.print("Index database created", OutputType.SUCCESS)
168
- # commit
169
- os.chdir(self.root_dir)
170
- os.system(f"git add .gitignore -f")
171
- os.system(f"git commit -m 'add index database'")
160
+ with index_lock:
161
+ if not os.path.exists(self.index_db_path):
162
+ PrettyOutput.print("Index database does not exist, creating...", OutputType.INFO)
163
+ index_db = sqlite3.connect(self.index_db_path)
164
+ index_db.execute(
165
+ "CREATE TABLE files (file_path TEXT PRIMARY KEY, file_md5 TEXT, file_description TEXT)")
166
+ index_db.commit()
167
+ index_db.close()
168
+ PrettyOutput.print("Index database created", OutputType.SUCCESS)
169
+ # commit
170
+ os.chdir(self.root_dir)
171
+ os.system(f"git add .gitignore -f")
172
+ os.system(f"git commit -m 'add index database'")
172
173
 
174
+
173
175
  def _find_file_by_md5(self, file_md5: str) -> Optional[str]:
174
176
  """根据文件MD5查找文件路径"""
175
- index_db = sqlite3.connect(self.index_db_path)
176
- cursor = index_db.cursor()
177
- cursor.execute(
178
- "SELECT file_path FROM files WHERE file_md5 = ?", (file_md5,))
179
- result = cursor.fetchone()
180
- index_db.close()
181
- return result[0] if result else None
177
+ with index_lock:
178
+ index_db = sqlite3.connect(self.index_db_path)
179
+ cursor = index_db.cursor()
180
+ cursor.execute(
181
+ "SELECT file_path FROM files WHERE file_md5 = ?", (file_md5,))
182
+ result = cursor.fetchone()
183
+ index_db.close()
184
+ return result[0] if result else None
182
185
 
186
+
183
187
  def _update_file_path(self, file_path: str, file_md5: str):
184
188
  """更新文件路径"""
185
- index_db = sqlite3.connect(self.index_db_path)
186
- cursor = index_db.cursor()
187
- cursor.execute(
188
- "UPDATE files SET file_path = ? WHERE file_md5 = ?", (file_path, file_md5))
189
- index_db.commit()
190
- index_db.close()
189
+ with index_lock:
190
+ index_db = sqlite3.connect(self.index_db_path)
191
+ cursor = index_db.cursor()
192
+ cursor.execute(
193
+ "UPDATE files SET file_path = ? WHERE file_md5 = ?", (file_path, file_md5))
194
+ index_db.commit()
195
+ index_db.close()
191
196
 
197
+
192
198
  def _insert_info(self, file_path: str, file_md5: str, file_description: str):
193
199
  """插入文件信息"""
194
- index_db = sqlite3.connect(self.index_db_path)
195
- cursor = index_db.cursor()
196
- cursor.execute("DELETE FROM files WHERE file_path = ?", (file_path,))
197
- cursor.execute("INSERT INTO files (file_path, file_md5, file_description) VALUES (?, ?, ?)",
198
- (file_path, file_md5, file_description))
199
- index_db.commit()
200
- index_db.close()
200
+ with index_lock:
201
+ index_db = sqlite3.connect(self.index_db_path)
202
+ cursor = index_db.cursor()
203
+ cursor.execute("DELETE FROM files WHERE file_path = ?", (file_path,))
204
+ cursor.execute("INSERT INTO files (file_path, file_md5, file_description) VALUES (?, ?, ?)",
205
+ (file_path, file_md5, file_description))
206
+ index_db.commit()
207
+ index_db.close()
201
208
 
202
209
  def _is_text_file(self, file_path: str) -> bool:
203
210
  """判断文件是否是文本文件"""
@@ -216,6 +223,9 @@ file_description: 这个文件的主要功能和作用描述
216
223
 
217
224
  def _index_project(self):
218
225
  """建立代码库索引"""
226
+ import threading
227
+ from concurrent.futures import ThreadPoolExecutor, as_completed
228
+
219
229
  git_files = os.popen("git ls-files").read().splitlines()
220
230
 
221
231
  index_db = sqlite3.connect(self.index_db_path)
@@ -229,59 +239,113 @@ file_description: 这个文件的主要功能和作用描述
229
239
  index_db.commit()
230
240
  index_db.close()
231
241
 
232
- # 3. 遍历git管理的文件
233
- for file_path in git_files:
234
- if self._is_text_file(file_path):
235
- # 计算文件MD5
236
- file_md5 = self._get_file_md5(file_path)
242
+ def process_file(file_path: str):
243
+ """处理单个文件的索引任务"""
244
+ if not self._is_text_file(file_path):
245
+ return
246
+
247
+ # 计算文件MD5
248
+ file_md5 = self._get_file_md5(file_path)
249
+
250
+ # 查找文件
251
+ file_path_in_db = self._find_file_by_md5(file_md5)
252
+ if file_path_in_db:
253
+ PrettyOutput.print(
254
+ f"文件 {file_path} 重复,跳过", OutputType.INFO)
255
+ if file_path_in_db != file_path:
256
+ self._update_file_path(file_path, file_md5)
257
+ PrettyOutput.print(
258
+ f"文件 {file_path} 重复,更新路径为 {file_path}", OutputType.INFO)
259
+ return
237
260
 
238
- # 查找文件
239
- file_path_in_db = self._find_file_by_md5(file_md5)
240
- if file_path_in_db:
261
+ with open(file_path, "r", encoding="utf-8") as f:
262
+ file_content = f.read()
263
+ key_info = self._get_key_info(file_path, file_content)
264
+ if not key_info:
241
265
  PrettyOutput.print(
242
- f"文件 {file_path} 重复,跳过", OutputType.INFO)
243
- if file_path_in_db != file_path:
244
- self._update_file_path(file_path, file_md5)
245
- PrettyOutput.print(
246
- f"文件 {file_path} 重复,更新路径为 {file_path}", OutputType.INFO)
247
- continue
266
+ f"文件 {file_path} 索引失败", OutputType.INFO)
267
+ return
268
+
269
+ self._insert_info(file_path, file_md5, key_info)
270
+ PrettyOutput.print(
271
+ f"文件 {file_path} 已建立索引", OutputType.INFO)
272
+
273
+
274
+ # 使用线程池处理文件索引
275
+ with ThreadPoolExecutor(max_workers=10) as executor:
276
+ futures = [executor.submit(process_file, file_path) for file_path in git_files]
277
+ for future in as_completed(futures):
278
+ try:
279
+ future.result()
280
+ except Exception as e:
281
+ PrettyOutput.print(f"处理文件时发生错误: {str(e)}", OutputType.ERROR)
248
282
 
249
- with open(file_path, "r", encoding="utf-8") as f:
250
- file_content = f.read()
251
- key_info = self._get_key_info(file_path, file_content)
252
- if not key_info:
253
- PrettyOutput.print(
254
- f"文件 {file_path} 索引失败", OutputType.INFO)
255
- continue
256
- if "file_description" in key_info:
257
- self._insert_info(file_path, file_md5, key_info["file_description"])
258
- PrettyOutput.print(
259
- f"文件 {file_path} 已建立索引", OutputType.INFO)
260
- else:
261
- PrettyOutput.print(
262
- f"文件 {file_path} 不是代码文件,跳过", OutputType.INFO)
263
283
  PrettyOutput.print("项目索引完成", OutputType.INFO)
264
284
 
265
- def _find_related_files(self, feature: str) -> List[Dict]:
266
- """根据需求描述,查找相关文件"""
285
+ def _get_files_from_db(self) -> List[Tuple[str, str]]:
286
+ """从数据库获取所有文件信息
287
+
288
+ Returns:
289
+ List[Tuple[str, str]]: [(file_path, file_description), ...]
290
+ """
267
291
  try:
268
- # Get all files from database
269
292
  index_db = sqlite3.connect(self.index_db_path)
270
293
  cursor = index_db.cursor()
271
294
  cursor.execute("SELECT file_path, file_description FROM files")
272
295
  all_files = cursor.fetchall()
273
296
  index_db.close()
297
+ return all_files
274
298
  except sqlite3.Error as e:
275
299
  PrettyOutput.print(f"数据库操作失败: {str(e)}", OutputType.ERROR)
276
300
  return []
277
301
 
278
- batch_size = 100
279
- batch_results = [] # Store results from each batch with their scores
302
+ def _analyze_files_in_batches(self, all_files: List[Tuple[str, str]], feature: str, batch_size: int = 100) -> List[Dict]:
303
+ """批量分析文件相关性
304
+
305
+ Args:
306
+ all_files: 所有文件列表
307
+ feature: 需求描述
308
+ batch_size: 批处理大小
309
+
310
+ Returns:
311
+ List[Dict]: 带评分的文件列表
312
+ """
313
+ batch_results = []
314
+
315
+ with ThreadPoolExecutor(max_workers=10) as executor:
316
+ futures = []
317
+ for i in range(0, len(all_files), batch_size):
318
+ batch_files = all_files[i:i + batch_size]
319
+ prompt = self._create_batch_analysis_prompt(batch_files, feature)
320
+ model = self._new_model()
321
+ model.set_suppress_output(True)
322
+ futures.append(executor.submit(self._call_model_with_retry, model, prompt))
323
+
324
+ for future in as_completed(futures):
325
+ success, response = future.result()
326
+ if not success:
327
+ continue
328
+
329
+ batch_start = futures.index(future) * batch_size
330
+ batch_end = min(batch_start + batch_size, len(all_files))
331
+ current_batch = all_files[batch_start:batch_end]
332
+
333
+ results = self._process_batch_response(response, current_batch)
334
+ batch_results.extend(results)
335
+
336
+ return batch_results
337
+
338
+ def _create_batch_analysis_prompt(self, batch_files: List[Tuple[str, str]], feature: str) -> str:
339
+ """创建批量分析的提示词
280
340
 
281
- for i in range(0, len(all_files), batch_size):
282
- batch_files = all_files[i:i + batch_size]
341
+ Args:
342
+ batch_files: 批次文件列表
343
+ feature: 需求描述
283
344
 
284
- prompt = """你是资深程序员,请根据需求描述,从以下文件路径中选出最相关的文件,按相关度从高到低排序。
345
+ Returns:
346
+ str: 提示词
347
+ """
348
+ prompt = """你是资深程序员,请根据需求描述,从以下文件路径中选出最相关的文件,按相关度从高到低排序。
285
349
 
286
350
  相关度打分标准(0-9分):
287
351
  - 9分:文件名直接包含需求中的关键词,且文件功能与需求完全匹配
@@ -299,102 +363,92 @@ file2.py: 7
299
363
 
300
364
  文件列表:
301
365
  """
302
- for file_path, _ in batch_files:
303
- prompt += f"- {file_path}\n"
304
- prompt += f"\n需求描述: {feature}\n"
305
- prompt += "\n注意:\n1. 只输出最相关的文件,不超过5个\n2. 根据上述打分标准判断相关性\n3. 相关度必须是0-9的整数"
306
-
307
- success, response = self._call_model_with_retry(self._new_model(), prompt)
308
- if not success:
309
- continue
310
-
311
- try:
312
- response = response.replace("<RELEVANT_FILES_START>", "").replace("<RELEVANT_FILES_END>", "")
313
- result = yaml.safe_load(response)
314
-
315
- # Convert results to file objects with scores
316
- batch_files_dict = {f[0]: f[1] for f in batch_files}
317
- for file_path, score in result.items():
318
- if isinstance(file_path, str) and isinstance(score, int):
319
- score = max(0, min(9, score)) # Ensure score is between 0-9
320
- if file_path in batch_files_dict:
321
- batch_results.append({
322
- "file_path": file_path,
323
- "file_description": batch_files_dict[file_path],
324
- "score": score
325
- })
326
-
327
- except Exception as e:
328
- PrettyOutput.print(f"处理批次文件失败: {str(e)}", OutputType.ERROR)
329
- continue
330
-
331
- # Sort all results by score
332
- batch_results.sort(key=lambda x: x["score"], reverse=True)
333
- top_files = batch_results[:5]
366
+ for file_path, _ in batch_files:
367
+ prompt += f"- {file_path}\n"
368
+ prompt += f"\n需求描述: {feature}\n"
369
+ prompt += "\n注意:\n1. 只输出最相关的文件,不超过5个\n2. 根据上述打分标准判断相关性\n3. 相关度必须是0-9的整数"
334
370
 
335
- # If we don't have enough files, add more from database
336
- if len(top_files) < 5:
337
- remaining_files = [f for f in all_files if f[0] not in [tf["file_path"] for tf in top_files]]
338
- top_files.extend([{
339
- "file_path": f[0],
340
- "file_description": f[1],
341
- "score": 0
342
- } for f in remaining_files[:5-len(top_files)]])
371
+ return prompt
343
372
 
344
- # Now do content relevance analysis on these files
345
- score = [[], [], [], [], [], [], [], [], [], []]
373
+ def _process_batch_response(self, response: str, batch_files: List[Tuple[str, str]]) -> List[Dict]:
374
+ """处理批量分析的响应
346
375
 
347
- prompt = """你是资深程序员,请根据需求描述,分析文件的相关性。
376
+ Args:
377
+ response: 模型响应
378
+ batch_files: 批次文件列表
379
+
380
+ Returns:
381
+ List[Dict]: 处理后的文件列表
382
+ """
383
+ try:
384
+ response = response.replace("<RELEVANT_FILES_START>", "").replace("<RELEVANT_FILES_END>", "")
385
+ result = yaml.safe_load(response)
386
+
387
+ batch_files_dict = {f[0]: f[1] for f in batch_files}
388
+ results = []
389
+ for file_path, score in result.items():
390
+ if isinstance(file_path, str) and isinstance(score, int):
391
+ score = max(0, min(9, score)) # Ensure score is between 0-9
392
+ if file_path in batch_files_dict:
393
+ results.append({
394
+ "file_path": file_path,
395
+ "file_description": batch_files_dict[file_path],
396
+ "score": score
397
+ })
398
+ return results
399
+ except Exception as e:
400
+ PrettyOutput.print(f"处理批次文件失败: {str(e)}", OutputType.ERROR)
401
+ return []
348
402
 
349
- 相关度打分标准(0-9分):
350
- - 9分:文件内容与需求完全匹配,是实现需求的核心文件
351
- - 7-8分:文件内容与需求高度相关,需要较大改动
352
- - 5-6分:文件内容与需求部分相关,需要中等改动
353
- - 3-4分:文件内容与需求相关性较低,但需要配合修改
354
- - 1-2分:文件内容与需求关系较远,只需极少改动
355
- - 0分:文件内容与需求完全无关
356
403
 
357
- 文件列表如下:
358
- <FILE_LIST_START>
359
- """
360
- for i, file in enumerate(top_files):
361
- prompt += f"""{i}. {file["file_path"]} : {file["file_description"]}\n"""
362
- prompt += f"""需求描述: {feature}\n"""
363
- prompt += "<FILE_LIST_END>\n"
364
- prompt += """请根据需求描述和文件描述,分析文件的相关性,输出每个编号的相关性[0~9],仅输出以下格式内容(key为文件编号,value为相关性):
365
- <FILE_RELATION_START>
366
- "0": 5
367
- "1": 3
368
- <FILE_RELATION_END>"""
369
-
370
- success, response = self._call_model_with_retry(self._new_model(), prompt)
371
- if not success:
372
- return top_files[:5] # Return top 5 files from filename matching if model fails
373
-
404
+ def _process_content_response(self, response: str, top_files: List[Dict]) -> List[Dict]:
405
+ """处理内容分析的响应"""
374
406
  try:
375
407
  response = response.replace("<FILE_RELATION_START>", "").replace("<FILE_RELATION_END>", "")
376
408
  file_relation = yaml.safe_load(response)
377
409
  if not file_relation:
378
410
  return top_files[:5]
379
411
 
412
+ score = [[] for _ in range(10)] # 创建10个空列表,对应0-9分
380
413
  for file_id, relation in file_relation.items():
381
414
  id = int(file_id)
382
415
  relation = max(0, min(9, relation)) # 确保范围在0-9之间
383
416
  score[relation].append(top_files[id])
384
417
 
418
+ files = []
419
+ for scores in reversed(score): # 从高分到低分遍历
420
+ files.extend(scores)
421
+ if len(files) >= 5: # 直接取相关性最高的5个文件
422
+ break
423
+
424
+ return files[:5]
385
425
  except Exception as e:
386
426
  PrettyOutput.print(f"处理文件关系失败: {str(e)}", OutputType.ERROR)
387
427
  return top_files[:5]
428
+
429
+ def _find_related_files(self, feature: str) -> List[Dict]:
430
+ """根据需求描述,查找相关文件
388
431
 
389
- files = []
390
- score.reverse()
391
- for i in score:
392
- files.extend(i)
393
- if len(files) >= 5: # 直接取相关性最高的5个文件
394
- break
432
+ Args:
433
+ feature: 需求描述
434
+
435
+ Returns:
436
+ List[Dict]: 相关文件列表
437
+ """
438
+ # 1. 从数据库获取所有文件
439
+ all_files = self._get_files_from_db()
440
+ if not all_files:
441
+ return []
395
442
 
396
- return files[:5]
397
-
443
+ # 2. 批量分析文件相关性
444
+ batch_results = self._analyze_files_in_batches(all_files, feature)
445
+
446
+ # 3. 排序并获取前5个文件
447
+ batch_results.sort(key=lambda x: x["score"], reverse=True)
448
+ return batch_results[:5]
449
+
450
+
451
+
398
452
  def _remake_patch(self, prompt: str) -> List[str]:
399
453
  success, response = self._call_model_with_retry(self.main_model, prompt, max_retries=5) # 增加重试次数
400
454
  if not success:
@@ -417,7 +471,7 @@ file2.py: 7
417
471
  <PATCH_START>
418
472
  >>>>>> path/to/file
419
473
  要替换的内容
420
- ======
474
+ =======
421
475
  新的内容
422
476
  <<<<<<
423
477
  <PATCH_END>
@@ -425,7 +479,7 @@ file2.py: 7
425
479
  2. 如果是新文件,格式如下:
426
480
  <PATCH_START>
427
481
  >>>>>> path/to/new/file
428
- ======
482
+ =======
429
483
  新文件的完整内容
430
484
  <<<<<<
431
485
  <PATCH_END>
@@ -444,6 +498,7 @@ file2.py: 7
444
498
  注意事项:
445
499
  1、仅输出补丁内容,不要输出任何其他内容,每个补丁必须用<PATCH_START>和<PATCH_END>标记
446
500
  2、如果在大段代码中有零星修改,生成多个补丁
501
+ 3、要替换的内容,一定要与文件内容完全一致,不要有任何多余或者缺失的内容
447
502
  """
448
503
 
449
504
  success, response = self._call_model_with_retry(self.main_model, prompt)
@@ -488,14 +543,14 @@ file2.py: 7
488
543
 
489
544
  # 解析补丁内容
490
545
  patch_content = "\n".join(lines[1:])
491
- parts = patch_content.split("======")
546
+ parts = patch_content.split("=======")
492
547
 
493
548
  if len(parts) != 2:
494
549
  error_info.append(f"补丁格式错误: {file_path}")
495
550
  return False, "\n".join(error_info)
496
551
 
497
- old_content = parts[0].strip()
498
- new_content = parts[1].split("<<<<<<")[0].strip()
552
+ old_content = parts[0]
553
+ new_content = parts[1].split("<<<<<<")[0]
499
554
 
500
555
  # 处理新文件
501
556
  if not old_content:
@@ -582,14 +637,85 @@ file2.py: 7
582
637
  return None
583
638
  return root_dir
584
639
 
640
+
641
+ def _prepare_execution(self) -> None:
642
+ """准备执行环境"""
643
+ self.main_model = self._new_model()
644
+ self._index_project()
645
+
646
+ def _load_related_files(self, feature: str) -> List[Dict]:
647
+ """加载相关文件内容"""
648
+ related_files = self._find_related_files(feature)
649
+ for file in related_files:
650
+ PrettyOutput.print(f"Related file: {file['file_path']}", OutputType.INFO)
651
+ with open(file["file_path"], "r", encoding="utf-8") as f:
652
+ file["file_content"] = f.read()
653
+ return related_files
654
+
655
+ def _handle_patch_application(self, related_files: List[Dict], patches: List[str], feature: str) -> Dict[str, Any]:
656
+ """处理补丁应用流程"""
657
+ while True:
658
+ PrettyOutput.print(f"生成{len(patches)}个补丁", OutputType.INFO)
659
+
660
+ if not patches:
661
+ retry_prompt = f"""未生成补丁,请重新生成补丁"""
662
+ patches = self._remake_patch(retry_prompt)
663
+ continue
664
+
665
+ success, error_info = self._apply_patch(related_files, patches)
666
+
667
+ if success:
668
+ user_confirm = input("是否确认修改?(y/n)")
669
+ if user_confirm.lower() == "y":
670
+ self._finalize_changes(feature, patches)
671
+ return {
672
+ "success": True,
673
+ "stdout": f"已完成功能开发{feature}",
674
+ "stderr": "",
675
+ "error": None
676
+ }
677
+ else:
678
+ self._revert_changes()
679
+ return {
680
+ "success": False,
681
+ "stdout": "",
682
+ "stderr": "修改被用户取消,文件未发生任何变化",
683
+ "error": UserWarning("用户取消修改")
684
+ }
685
+ else:
686
+ PrettyOutput.print(f"补丁应用失败,请求重新生成: {error_info}", OutputType.WARNING)
687
+ retry_prompt = f"""补丁应用失败,请根据以下错误信息重新生成补丁:
688
+
689
+ 错误信息:
690
+ {error_info}
691
+
692
+ 请确保:
693
+ 1. 准确定位要修改的代码位置
694
+ 2. 正确处理代码缩进
695
+ 3. 考虑代码上下文
696
+ 4. 对新文件不要包含原始内容
697
+ """
698
+ patches = self._remake_patch(retry_prompt)
699
+
700
+ def _finalize_changes(self, feature: str, patches: List[str]) -> None:
701
+ """完成修改并提交"""
702
+ PrettyOutput.print("修改确认成功,提交修改", OutputType.INFO)
703
+ os.system(f"git add .")
704
+ os.system(f"git commit -m '{feature}'")
705
+ self._save_edit_record(feature, patches)
706
+ self._index_project()
707
+
708
+ def _revert_changes(self) -> None:
709
+ """回退所有修改"""
710
+ PrettyOutput.print("修改已取消,回退更改", OutputType.INFO)
711
+ os.system(f"git reset --hard")
712
+ os.system(f"git clean -df")
713
+
585
714
  def execute(self, feature: str) -> Dict[str, Any]:
586
715
  """执行代码修改
587
716
 
588
717
  Args:
589
- args: 包含操作参数的字典
590
- - feature: 要实现的功能描述
591
- - root_dir: 代码库根目录
592
- - language: 编程语言
718
+ feature: 要实现的功能描述
593
719
 
594
720
  Returns:
595
721
  Dict[str, Any]: 包含执行结果的字典
@@ -599,81 +725,10 @@ file2.py: 7
599
725
  - error: 错误对象(如果有)
600
726
  """
601
727
  try:
602
- self.main_model = self._new_model() # 每次执行时重新创建模型
603
-
604
-
605
- # 4. 开始建立代码库索引
606
- self._index_project()
607
-
608
- # 5. 根据索引和需求,查找相关文件
609
- related_files = self._find_related_files(feature)
610
- for file in related_files:
611
- PrettyOutput.print(f"Related file: {file['file_path']}", OutputType.INFO)
612
- for file in related_files:
613
- with open(file["file_path"], "r", encoding="utf-8") as f:
614
- file_content = f.read()
615
- file["file_content"] = file_content
728
+ self._prepare_execution()
729
+ related_files = self._load_related_files(feature)
616
730
  patches = self._make_patch(related_files, feature)
617
- while True:
618
- # 生成修改方案
619
- PrettyOutput.print(f"生成{len(patches)}个补丁", OutputType.INFO)
620
-
621
- if not patches:
622
- retry_prompt = f"""未生成补丁,请重新生成补丁"""
623
- patches = self._remake_patch(retry_prompt)
624
- continue
625
-
626
- # 尝试应用补丁
627
- success, error_info = self._apply_patch(related_files, patches)
628
-
629
- if success:
630
- # 用户确认修改
631
- user_confirm = input("是否确认修改?(y/n)")
632
- if user_confirm.lower() == "y":
633
- PrettyOutput.print("修改确认成功,提交修改", OutputType.INFO)
634
-
635
- os.system(f"git add .")
636
- os.system(f"git commit -m '{feature}'")
637
-
638
- # 保存修改记录
639
- self._save_edit_record(feature, patches)
640
- # 重新建立代码库索引
641
- self._index_project()
642
-
643
- return {
644
- "success": True,
645
- "stdout": f"已完成功能开发{feature}",
646
- "stderr": "",
647
- "error": None
648
- }
649
- else:
650
- PrettyOutput.print("修改已取消,回退更改", OutputType.INFO)
651
-
652
- os.system(f"git reset --hard") # 回退已修改的文件
653
- os.system(f"git clean -df") # 删除新创建的文件和目录
654
-
655
- return {
656
- "success": False,
657
- "stdout": "",
658
- "stderr": "修改被用户取消,文件未发生任何变化",
659
- "error": UserWarning("用户取消修改")
660
- }
661
- else:
662
- # 补丁应用失败,让模型重新生成
663
- PrettyOutput.print(f"补丁应用失败,请求重新生成: {error_info}", OutputType.WARNING)
664
- retry_prompt = f"""补丁应用失败,请根据以下错误信息重新生成补丁:
665
-
666
- 错误信息:
667
- {error_info}
668
-
669
- 请确保:
670
- 1. 准确定位要修改的代码位置
671
- 2. 正确处理代码缩进
672
- 3. 考虑代码上下文
673
- 4. 对新文件不要包含原始内容
674
- """
675
- patches = self._remake_patch(retry_prompt)
676
- continue
731
+ return self._handle_patch_application(related_files, patches, feature)
677
732
 
678
733
  except Exception as e:
679
734
  return {
@@ -709,8 +764,8 @@ def main():
709
764
  # 循环处理需求
710
765
  while True:
711
766
  try:
712
- # 获取需求
713
- feature = get_multiline_input("请输入开发需求 (输入空行退出):")
767
+ # 获取需求,传入项目根目录
768
+ feature = get_multiline_input("请输入开发需求 (输入空行退出):", tool.root_dir)
714
769
 
715
770
  if not feature or feature == "__interrupt__":
716
771
  break
@@ -738,3 +793,122 @@ def main():
738
793
 
739
794
  if __name__ == "__main__":
740
795
  exit(main())
796
+
797
+ class FilePathCompleter(Completer):
798
+ """文件路径自动完成器"""
799
+
800
+ def __init__(self, root_dir: str):
801
+ self.root_dir = root_dir
802
+ self._file_list = None
803
+
804
+ def _get_files(self) -> List[str]:
805
+ """获取git管理的文件列表"""
806
+ if self._file_list is None:
807
+ try:
808
+ # 切换到项目根目录
809
+ old_cwd = os.getcwd()
810
+ os.chdir(self.root_dir)
811
+
812
+ # 获取git管理的文件列表
813
+ self._file_list = os.popen("git ls-files").read().splitlines()
814
+
815
+ # 恢复工作目录
816
+ os.chdir(old_cwd)
817
+ except Exception as e:
818
+ PrettyOutput.print(f"获取文件列表失败: {str(e)}", OutputType.WARNING)
819
+ self._file_list = []
820
+ return self._file_list
821
+
822
+ def get_completions(self, document, complete_event):
823
+ """获取补全建议"""
824
+ text_before_cursor = document.text_before_cursor
825
+
826
+ # 检查是否刚输入了@
827
+ if text_before_cursor.endswith('@'):
828
+ # 显示所有文件
829
+ for path in self._get_files():
830
+ yield Completion(path, start_position=0)
831
+ return
832
+
833
+ # 检查之前是否有@,并获取@后的搜索词
834
+ at_pos = text_before_cursor.rfind('@')
835
+ if at_pos == -1:
836
+ return
837
+
838
+ search = text_before_cursor[at_pos + 1:].lower().strip()
839
+
840
+ # 提供匹配的文件建议
841
+ for path in self._get_files():
842
+ path_lower = path.lower()
843
+ if (search in path_lower or # 直接包含
844
+ search in os.path.basename(path_lower) or # 文件名包含
845
+ any(fnmatch.fnmatch(path_lower, f'*{s}*') for s in search.split())): # 通配符匹配
846
+ # 计算正确的start_position
847
+ yield Completion(path, start_position=-(len(search)))
848
+
849
+ class SmartCompleter(Completer):
850
+ """智能自动完成器,组合词语和文件路径补全"""
851
+
852
+ def __init__(self, word_completer: WordCompleter, file_completer: FilePathCompleter):
853
+ self.word_completer = word_completer
854
+ self.file_completer = file_completer
855
+
856
+ def get_completions(self, document, complete_event):
857
+ """获取补全建议"""
858
+ # 如果当前行以@结尾,使用文件补全
859
+ if document.text_before_cursor.strip().endswith('@'):
860
+ yield from self.file_completer.get_completions(document, complete_event)
861
+ else:
862
+ # 否则使用词语补全
863
+ yield from self.word_completer.get_completions(document, complete_event)
864
+
865
+ def get_multiline_input(prompt_text: str, root_dir: str = None) -> str:
866
+ """获取多行输入,支持文件路径自动完成功能
867
+
868
+ Args:
869
+ prompt_text: 提示文本
870
+ root_dir: 项目根目录,用于文件补全
871
+
872
+ Returns:
873
+ str: 用户输入的文本
874
+ """
875
+ # 创建文件补全器
876
+ file_completer = FilePathCompleter(root_dir or os.getcwd())
877
+
878
+ # 创建提示样式
879
+ style = Style.from_dict({
880
+ 'prompt': 'ansicyan bold',
881
+ 'input': 'ansiwhite',
882
+ })
883
+
884
+ # 创建会话
885
+ session = PromptSession(
886
+ completer=file_completer,
887
+ style=style,
888
+ multiline=False,
889
+ enable_history_search=True,
890
+ complete_while_typing=True
891
+ )
892
+
893
+ # 显示初始提示文本
894
+ print(f"\n{prompt_text}")
895
+
896
+ # 创建提示符
897
+ prompt = FormattedText([
898
+ ('class:prompt', ">>> ")
899
+ ])
900
+
901
+ # 获取输入
902
+ lines = []
903
+ try:
904
+ while True:
905
+ line = session.prompt(prompt).strip()
906
+ if not line: # 空行表示输入结束
907
+ break
908
+ lines.append(line)
909
+ except KeyboardInterrupt:
910
+ return "__interrupt__"
911
+ except EOFError:
912
+ pass
913
+
914
+ return "\n".join(lines)