jarvis-ai-assistant 0.1.58__py3-none-any.whl → 0.1.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

@@ -1,16 +1,19 @@
1
- from concurrent.futures import ThreadPoolExecutor, as_completed
2
- import hashlib
3
1
  import os
4
2
  import re
5
- import sqlite3
6
3
  import threading
7
4
  import time
8
- from typing import Dict, Any, List, Optional, Tuple
5
+ from typing import Dict, Any, List, Tuple
9
6
 
10
7
  import yaml
11
8
  from jarvis.models.base import BasePlatform
12
- from jarvis.utils import OutputType, PrettyOutput, get_multiline_input, load_env_from_file
9
+ from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_multiline_input, load_env_from_file
13
10
  from jarvis.models.registry import PlatformRegistry
11
+ from jarvis.jarvis_codebase.main import CodeBase
12
+ from prompt_toolkit import PromptSession
13
+ from prompt_toolkit.completion import WordCompleter, Completer, Completion
14
+ from prompt_toolkit.formatted_text import FormattedText
15
+ from prompt_toolkit.styles import Style
16
+ import fnmatch
14
17
 
15
18
  # 全局锁对象
16
19
  index_lock = threading.Lock()
@@ -18,12 +21,15 @@ index_lock = threading.Lock()
18
21
  class JarvisCoder:
19
22
  def __init__(self, root_dir: str, language: str):
20
23
  """初始化代码修改工具"""
24
+
25
+ self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM") or os.environ.get("JARVIS_PLATFORM")
26
+ self.model = os.environ.get("JARVIS_CODEGEN_MODEL") or os.environ.get("JARVIS_MODEL")
21
27
 
22
- self.root_dir = root_dir
23
- self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM")
24
- self.model = os.environ.get("JARVIS_CODEGEN_MODEL")
25
28
 
26
- self.root_dir = self._find_git_root_dir(self.root_dir)
29
+ if not self.platform or not self.model:
30
+ raise ValueError("JARVIS_CODEGEN_PLATFORM or JARVIS_CODEGEN_MODEL is not set")
31
+
32
+ self.root_dir = find_git_root(root_dir)
27
33
  if not self.root_dir:
28
34
  self.root_dir = root_dir
29
35
 
@@ -41,10 +47,6 @@ class JarvisCoder:
41
47
  if not os.path.exists(self.jarvis_dir):
42
48
  os.makedirs(self.jarvis_dir)
43
49
 
44
- self.index_db_path = os.path.join(self.jarvis_dir, "index.db")
45
- if not os.path.exists(self.index_db_path):
46
- self._create_index_db()
47
-
48
50
  self.record_dir = os.path.join(self.jarvis_dir, "record")
49
51
  if not os.path.exists(self.record_dir):
50
52
  os.makedirs(self.record_dir)
@@ -65,6 +67,9 @@ class JarvisCoder:
65
67
  os.system(f"git add .")
66
68
  os.system(f"git commit -m 'commit before code edit'")
67
69
 
70
+ # 4. 初始化代码库
71
+ self._codebase = CodeBase(self.root_dir)
72
+
68
73
  def _new_model(self):
69
74
  """获取大模型"""
70
75
  model = PlatformRegistry().get_global_platform_registry().create_platform(self.platform)
@@ -109,341 +114,6 @@ class JarvisCoder:
109
114
  time.sleep(delay)
110
115
  delay *= 2 # 指数退避
111
116
 
112
- def _get_key_info(self, file_path: str, content: str) -> Optional[Dict[str, Any]]:
113
- """获取文件的关键信息
114
-
115
- Args:
116
- file_path: 文件路径
117
- content: 文件内容
118
-
119
- Returns:
120
- Optional[Dict[str, Any]]: 文件信息,包含文件描述
121
- """
122
- model = self._new_model() # 创建新的模型实例
123
- model.set_suppress_output(True)
124
-
125
- prompt = f"""你是一个资深程序员,请根据文件内容,生成文件的关键信息,要求如下,除了代码,不要输出任何内容:
126
-
127
- 1. 文件路径: {file_path}
128
- 2. 文件内容:(<CONTENT_START>和<CONTENT_END>之间的部分)
129
- <CONTENT_START>
130
- {content}
131
- <CONTENT_END>
132
- 3. 关键信息: 请生成这个文件的主要功能和作用描述,包含的特征符号(函数和类、变量等),不超过100字
133
- """
134
- try:
135
- return model.chat(prompt)
136
- except Exception as e:
137
- PrettyOutput.print(f"解析文件信息失败: {str(e)}", OutputType.ERROR)
138
- return None
139
- finally:
140
- # 确保清理模型资源
141
- try:
142
- model.delete_chat()
143
- except:
144
- pass
145
-
146
-
147
-
148
- def _get_file_md5(self, file_path: str) -> str:
149
- """获取文件MD5"""
150
- return hashlib.md5(open(file_path, "rb").read()).hexdigest()
151
-
152
-
153
- def _create_index_db(self):
154
- """创建索引数据库"""
155
- with index_lock:
156
- if not os.path.exists(self.index_db_path):
157
- PrettyOutput.print("Index database does not exist, creating...", OutputType.INFO)
158
- index_db = sqlite3.connect(self.index_db_path)
159
- index_db.execute(
160
- "CREATE TABLE files (file_path TEXT PRIMARY KEY, file_md5 TEXT, file_description TEXT)")
161
- index_db.commit()
162
- index_db.close()
163
- PrettyOutput.print("Index database created", OutputType.SUCCESS)
164
- # commit
165
- os.chdir(self.root_dir)
166
- os.system(f"git add .gitignore -f")
167
- os.system(f"git commit -m 'add index database'")
168
-
169
-
170
- def _find_file_by_md5(self, file_md5: str) -> Optional[str]:
171
- """根据文件MD5查找文件路径"""
172
- with index_lock:
173
- index_db = sqlite3.connect(self.index_db_path)
174
- cursor = index_db.cursor()
175
- cursor.execute(
176
- "SELECT file_path FROM files WHERE file_md5 = ?", (file_md5,))
177
- result = cursor.fetchone()
178
- index_db.close()
179
- return result[0] if result else None
180
-
181
-
182
- def _update_file_path(self, file_path: str, file_md5: str):
183
- """更新文件路径"""
184
- with index_lock:
185
- index_db = sqlite3.connect(self.index_db_path)
186
- cursor = index_db.cursor()
187
- cursor.execute(
188
- "UPDATE files SET file_path = ? WHERE file_md5 = ?", (file_path, file_md5))
189
- index_db.commit()
190
- index_db.close()
191
-
192
-
193
- def _insert_info(self, file_path: str, file_md5: str, file_description: str):
194
- """插入文件信息"""
195
- with index_lock:
196
- index_db = sqlite3.connect(self.index_db_path)
197
- cursor = index_db.cursor()
198
- cursor.execute("DELETE FROM files WHERE file_path = ?", (file_path,))
199
- cursor.execute("INSERT INTO files (file_path, file_md5, file_description) VALUES (?, ?, ?)",
200
- (file_path, file_md5, file_description))
201
- index_db.commit()
202
- index_db.close()
203
-
204
- def _is_text_file(self, file_path: str) -> bool:
205
- """判断文件是否是文本文件"""
206
- try:
207
- with open(file_path, 'rb') as f:
208
- # 读取文件前1024个字节
209
- chunk = f.read(1024)
210
- # 检查是否包含空字节
211
- if b'\x00' in chunk:
212
- return False
213
- # 尝试解码为文本
214
- chunk.decode('utf-8')
215
- return True
216
- except:
217
- return False
218
-
219
- def _index_project(self):
220
- """建立代码库索引"""
221
- import threading
222
- from concurrent.futures import ThreadPoolExecutor, as_completed
223
-
224
- git_files = os.popen("git ls-files").read().splitlines()
225
-
226
- index_db = sqlite3.connect(self.index_db_path)
227
- cursor = index_db.cursor()
228
- cursor.execute("SELECT file_path FROM files")
229
- db_files = [row[0] for row in cursor.fetchall()]
230
- for db_file in db_files:
231
- if not os.path.exists(db_file):
232
- cursor.execute("DELETE FROM files WHERE file_path = ?", (db_file,))
233
- PrettyOutput.print(f"删除不存在的文件记录: {db_file}", OutputType.INFO)
234
- index_db.commit()
235
- index_db.close()
236
-
237
- def process_file(file_path: str):
238
- """处理单个文件的索引任务"""
239
- if not self._is_text_file(file_path):
240
- return
241
-
242
- # 计算文件MD5
243
- file_md5 = self._get_file_md5(file_path)
244
-
245
- # 查找文件
246
- file_path_in_db = self._find_file_by_md5(file_md5)
247
- if file_path_in_db:
248
- PrettyOutput.print(
249
- f"文件 {file_path} 重复,跳过", OutputType.INFO)
250
- if file_path_in_db != file_path:
251
- self._update_file_path(file_path, file_md5)
252
- PrettyOutput.print(
253
- f"文件 {file_path} 重复,更新路径为 {file_path}", OutputType.INFO)
254
- return
255
-
256
- with open(file_path, "r", encoding="utf-8") as f:
257
- file_content = f.read()
258
- key_info = self._get_key_info(file_path, file_content)
259
- if not key_info:
260
- PrettyOutput.print(
261
- f"文件 {file_path} 索引失败", OutputType.INFO)
262
- return
263
-
264
- self._insert_info(file_path, file_md5, key_info)
265
- PrettyOutput.print(
266
- f"文件 {file_path} 已建立索引", OutputType.INFO)
267
-
268
-
269
- # 使用线程池处理文件索引
270
- with ThreadPoolExecutor(max_workers=10) as executor:
271
- futures = [executor.submit(process_file, file_path) for file_path in git_files]
272
- for future in as_completed(futures):
273
- try:
274
- future.result()
275
- except Exception as e:
276
- PrettyOutput.print(f"处理文件时发生错误: {str(e)}", OutputType.ERROR)
277
-
278
- PrettyOutput.print("项目索引完成", OutputType.INFO)
279
-
280
- def _get_files_from_db(self) -> List[Tuple[str, str]]:
281
- """从数据库获取所有文件信息
282
-
283
- Returns:
284
- List[Tuple[str, str]]: [(file_path, file_description), ...]
285
- """
286
- try:
287
- index_db = sqlite3.connect(self.index_db_path)
288
- cursor = index_db.cursor()
289
- cursor.execute("SELECT file_path, file_description FROM files")
290
- all_files = cursor.fetchall()
291
- index_db.close()
292
- return all_files
293
- except sqlite3.Error as e:
294
- PrettyOutput.print(f"数据库操作失败: {str(e)}", OutputType.ERROR)
295
- return []
296
-
297
- def _analyze_files_in_batches(self, all_files: List[Tuple[str, str]], feature: str, batch_size: int = 100) -> List[Dict]:
298
- """批量分析文件相关性
299
-
300
- Args:
301
- all_files: 所有文件列表
302
- feature: 需求描述
303
- batch_size: 批处理大小
304
-
305
- Returns:
306
- List[Dict]: 带评分的文件列表
307
- """
308
- batch_results = []
309
-
310
- with ThreadPoolExecutor(max_workers=10) as executor:
311
- futures = []
312
- for i in range(0, len(all_files), batch_size):
313
- batch_files = all_files[i:i + batch_size]
314
- prompt = self._create_batch_analysis_prompt(batch_files, feature)
315
- model = self._new_model()
316
- model.set_suppress_output(True)
317
- futures.append(executor.submit(self._call_model_with_retry, model, prompt))
318
-
319
- for future in as_completed(futures):
320
- success, response = future.result()
321
- if not success:
322
- continue
323
-
324
- batch_start = futures.index(future) * batch_size
325
- batch_end = min(batch_start + batch_size, len(all_files))
326
- current_batch = all_files[batch_start:batch_end]
327
-
328
- results = self._process_batch_response(response, current_batch)
329
- batch_results.extend(results)
330
-
331
- return batch_results
332
-
333
- def _create_batch_analysis_prompt(self, batch_files: List[Tuple[str, str]], feature: str) -> str:
334
- """创建批量分析的提示词
335
-
336
- Args:
337
- batch_files: 批次文件列表
338
- feature: 需求描述
339
-
340
- Returns:
341
- str: 提示词
342
- """
343
- prompt = """你是资深程序员,请根据需求描述,从以下文件路径中选出最相关的文件,按相关度从高到低排序。
344
-
345
- 相关度打分标准(0-9分):
346
- - 9分:文件名直接包含需求中的关键词,且文件功能与需求完全匹配
347
- - 7-8分:文件名包含需求相关词,或文件功能与需求高度相关
348
- - 5-6分:文件名暗示与需求有关,或文件功能与需求部分相关
349
- - 3-4分:文件可能需要小幅修改以配合需求
350
- - 1-2分:文件与需求关系较远,但可能需要少量改动
351
- - 0分:文件与需求完全无关
352
-
353
- 请输出yaml格式,仅输出以下格式内容:
354
- <RELEVANT_FILES_START>
355
- file1.py: 9
356
- file2.py: 7
357
- <RELEVANT_FILES_END>
358
-
359
- 文件列表:
360
- """
361
- for file_path, _ in batch_files:
362
- prompt += f"- {file_path}\n"
363
- prompt += f"\n需求描述: {feature}\n"
364
- prompt += "\n注意:\n1. 只输出最相关的文件,不超过5个\n2. 根据上述打分标准判断相关性\n3. 相关度必须是0-9的整数"
365
-
366
- return prompt
367
-
368
- def _process_batch_response(self, response: str, batch_files: List[Tuple[str, str]]) -> List[Dict]:
369
- """处理批量分析的响应
370
-
371
- Args:
372
- response: 模型响应
373
- batch_files: 批次文件列表
374
-
375
- Returns:
376
- List[Dict]: 处理后的文件列表
377
- """
378
- try:
379
- response = response.replace("<RELEVANT_FILES_START>", "").replace("<RELEVANT_FILES_END>", "")
380
- result = yaml.safe_load(response)
381
-
382
- batch_files_dict = {f[0]: f[1] for f in batch_files}
383
- results = []
384
- for file_path, score in result.items():
385
- if isinstance(file_path, str) and isinstance(score, int):
386
- score = max(0, min(9, score)) # Ensure score is between 0-9
387
- if file_path in batch_files_dict:
388
- results.append({
389
- "file_path": file_path,
390
- "file_description": batch_files_dict[file_path],
391
- "score": score
392
- })
393
- return results
394
- except Exception as e:
395
- PrettyOutput.print(f"处理批次文件失败: {str(e)}", OutputType.ERROR)
396
- return []
397
-
398
-
399
- def _process_content_response(self, response: str, top_files: List[Dict]) -> List[Dict]:
400
- """处理内容分析的响应"""
401
- try:
402
- response = response.replace("<FILE_RELATION_START>", "").replace("<FILE_RELATION_END>", "")
403
- file_relation = yaml.safe_load(response)
404
- if not file_relation:
405
- return top_files[:5]
406
-
407
- score = [[] for _ in range(10)] # 创建10个空列表,对应0-9分
408
- for file_id, relation in file_relation.items():
409
- id = int(file_id)
410
- relation = max(0, min(9, relation)) # 确保范围在0-9之间
411
- score[relation].append(top_files[id])
412
-
413
- files = []
414
- for scores in reversed(score): # 从高分到低分遍历
415
- files.extend(scores)
416
- if len(files) >= 5: # 直接取相关性最高的5个文件
417
- break
418
-
419
- return files[:5]
420
- except Exception as e:
421
- PrettyOutput.print(f"处理文件关系失败: {str(e)}", OutputType.ERROR)
422
- return top_files[:5]
423
-
424
- def _find_related_files(self, feature: str) -> List[Dict]:
425
- """根据需求描述,查找相关文件
426
-
427
- Args:
428
- feature: 需求描述
429
-
430
- Returns:
431
- List[Dict]: 相关文件列表
432
- """
433
- # 1. 从数据库获取所有文件
434
- all_files = self._get_files_from_db()
435
- if not all_files:
436
- return []
437
-
438
- # 2. 批量分析文件相关性
439
- batch_results = self._analyze_files_in_batches(all_files, feature)
440
-
441
- # 3. 排序并获取前5个文件
442
- batch_results.sort(key=lambda x: x["score"], reverse=True)
443
- return batch_results[:5]
444
-
445
-
446
-
447
117
  def _remake_patch(self, prompt: str) -> List[str]:
448
118
  success, response = self._call_model_with_retry(self.main_model, prompt, max_retries=5) # 增加重试次数
449
119
  if not success:
@@ -454,7 +124,7 @@ file2.py: 7
454
124
  return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
455
125
  for patch in patches if patch.strip()]
456
126
  except Exception as e:
457
- PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.ERROR)
127
+ PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.WARNING)
458
128
  return []
459
129
 
460
130
  def _make_patch(self, related_files: List[Dict], feature: str) -> List[str]:
@@ -468,21 +138,32 @@ file2.py: 7
468
138
  要替换的内容
469
139
  =======
470
140
  新的内容
471
- <<<<<<
141
+ >>>>>>
472
142
  <PATCH_END>
473
143
 
474
- 2. 如果是新文件,格式如下:
144
+ 2. 如果是新文件或者替换整个文件内容,格式如下:
475
145
  <PATCH_START>
476
146
  >>>>>> path/to/new/file
477
147
  =======
478
148
  新文件的完整内容
479
- <<<<<<
149
+ >>>>>>
150
+ <PATCH_END>
151
+
152
+ 3. 如果要删除文件中的某一段,格式如下:
153
+ <PATCH_START>
154
+ >>>>>> path/to/file
155
+ 要删除的内容
156
+ =======
157
+ >>>>>>
480
158
  <PATCH_END>
481
159
 
482
160
  文件列表如下:
483
161
  """
484
162
  for i, file in enumerate(related_files):
485
- prompt += f"""{i}. {file["file_path"]} : {file["file_description"]}\n"""
163
+ if len(prompt) > 30 * 1024:
164
+ PrettyOutput.print(f'避免上下文超限,丢弃低相关度文件:{file["file_path"]}', OutputType.WARNING)
165
+ continue
166
+ prompt += f"""{i}. {file["file_path"]}\n"""
486
167
  prompt += f"""文件内容:\n"""
487
168
  prompt += f"<FILE_CONTENT_START>\n"
488
169
  prompt += f'{file["file_content"]}\n'
@@ -493,6 +174,8 @@ file2.py: 7
493
174
  注意事项:
494
175
  1、仅输出补丁内容,不要输出任何其他内容,每个补丁必须用<PATCH_START>和<PATCH_END>标记
495
176
  2、如果在大段代码中有零星修改,生成多个补丁
177
+ 3、要替换的内容,一定要与文件内容完全一致,不要有任何多余或者缺失的内容
178
+ 4、每个patch不超过20行,超出20行,请生成多个patch
496
179
  """
497
180
 
498
181
  success, response = self._call_model_with_retry(self.main_model, prompt)
@@ -505,7 +188,7 @@ file2.py: 7
505
188
  return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
506
189
  for patch in patches if patch.strip()]
507
190
  except Exception as e:
508
- PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.ERROR)
191
+ PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.WARNING)
509
192
  return []
510
193
 
511
194
  def _apply_patch(self, related_files: List[Dict], patches: List[str]) -> Tuple[bool, str]:
@@ -544,7 +227,7 @@ file2.py: 7
544
227
  return False, "\n".join(error_info)
545
228
 
546
229
  old_content = parts[0]
547
- new_content = parts[1].split("<<<<<<")[0]
230
+ new_content = parts[1].split(">>>>>>")[0]
548
231
 
549
232
  # 处理新文件
550
233
  if not old_content:
@@ -595,12 +278,12 @@ file2.py: 7
595
278
 
596
279
  return True, ""
597
280
 
598
- def _save_edit_record(self, feature: str, patches: List[str]) -> None:
281
+ def _save_edit_record(self, commit_message: str, git_diff: str) -> None:
599
282
  """保存代码修改记录
600
283
 
601
284
  Args:
602
- feature: 需求描述
603
- patches: 补丁列表
285
+ commit_message: 提交信息
286
+ git_diff: git diff --cached的输出
604
287
  """
605
288
 
606
289
  # 获取下一个序号
@@ -613,8 +296,8 @@ file2.py: 7
613
296
  # 创建记录文件
614
297
  record = {
615
298
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
616
- "feature": feature,
617
- "patches": patches
299
+ "commit_message": commit_message,
300
+ "git_diff": git_diff
618
301
  }
619
302
 
620
303
  record_path = os.path.join(self.record_dir, f"{next_num:04d}.yaml")
@@ -623,28 +306,30 @@ file2.py: 7
623
306
 
624
307
  PrettyOutput.print(f"已保存修改记录: {record_path}", OutputType.SUCCESS)
625
308
 
626
- def _find_git_root_dir(self, root_dir: str) -> str:
627
- """查找git根目录"""
628
- while not os.path.exists(os.path.join(root_dir, ".git")):
629
- root_dir = os.path.dirname(root_dir)
630
- if root_dir == "/":
631
- return None
632
- return root_dir
309
+
633
310
 
634
311
 
635
312
  def _prepare_execution(self) -> None:
636
313
  """准备执行环境"""
637
314
  self.main_model = self._new_model()
638
- self._index_project()
315
+ self._codebase.generate_codebase()
316
+
639
317
 
640
318
  def _load_related_files(self, feature: str) -> List[Dict]:
641
319
  """加载相关文件内容"""
642
- related_files = self._find_related_files(feature)
643
- for file in related_files:
644
- PrettyOutput.print(f"Related file: {file['file_path']}", OutputType.INFO)
645
- with open(file["file_path"], "r", encoding="utf-8") as f:
646
- file["file_content"] = f.read()
647
- return related_files
320
+ ret = []
321
+ # 确保索引数据库已生成
322
+ if not self._codebase.is_index_generated():
323
+ PrettyOutput.print("检测到索引数据库未生成,正在生成...", OutputType.WARNING)
324
+ self._codebase.generate_codebase()
325
+
326
+ related_files = self._codebase.search_similar(feature)
327
+ for file, score, _ in related_files:
328
+ PrettyOutput.print(f"相关文件: {file} 相关度: {score:.3f}", OutputType.SUCCESS)
329
+ with open(file, "r", encoding="utf-8") as f:
330
+ content = f.read()
331
+ ret.append({"file_path": file, "file_content": content})
332
+ return ret
648
333
 
649
334
  def _handle_patch_application(self, related_files: List[Dict], patches: List[str], feature: str) -> Dict[str, Any]:
650
335
  """处理补丁应用流程"""
@@ -661,7 +346,7 @@ file2.py: 7
661
346
  if success:
662
347
  user_confirm = input("是否确认修改?(y/n)")
663
348
  if user_confirm.lower() == "y":
664
- self._finalize_changes(feature, patches)
349
+ self._finalize_changes(feature)
665
350
  return {
666
351
  "success": True,
667
352
  "stdout": f"已完成功能开发{feature}",
@@ -691,13 +376,74 @@ file2.py: 7
691
376
  """
692
377
  patches = self._remake_patch(retry_prompt)
693
378
 
694
- def _finalize_changes(self, feature: str, patches: List[str]) -> None:
379
+
380
+
381
+
382
+ def _generate_commit_message(self, git_diff: str, feature: str) -> str:
383
+ """根据git diff和功能描述生成commit信息
384
+
385
+ Args:
386
+ git_diff: git diff --cached的输出
387
+ feature: 用户的功能描述
388
+
389
+ Returns:
390
+ str: 生成的commit信息
391
+ """
392
+
393
+ # 生成提示词
394
+ prompt = f"""你是一个经验丰富的程序员,请根据以下代码变更和功能描述生成简洁明了的commit信息:
395
+
396
+ 功能描述:
397
+ {feature}
398
+
399
+ 代码变更:
400
+ """
401
+ # 添加git diff内容
402
+ prompt += f"Git Diff:\n{git_diff}\n\n"
403
+
404
+ prompt += """
405
+ 请遵循以下规则:
406
+ 1. 使用英文编写
407
+ 2. 采用常规的commit message格式:<type>(<scope>): <subject>
408
+ 3. 保持简洁,不超过50个字符
409
+ 4. 准确描述代码变更的主要内容
410
+ 5. 优先考虑功能描述和git diff中的变更内容
411
+ """
412
+
413
+ # 使用normal模型生成commit信息
414
+ model = PlatformRegistry().get_global_platform_registry().create_platform(self.platform)
415
+ model.set_model_name(self.model)
416
+ model.set_suppress_output(True)
417
+ success, response = self._call_model_with_retry(model, prompt)
418
+ if not success:
419
+ return "Update code changes"
420
+
421
+ # 清理响应内容
422
+ return response.strip().split("\n")[0]
423
+
424
+ def _finalize_changes(self, feature: str) -> None:
695
425
  """完成修改并提交"""
696
426
  PrettyOutput.print("修改确认成功,提交修改", OutputType.INFO)
697
- os.system(f"git add .")
698
- os.system(f"git commit -m '{feature}'")
699
- self._save_edit_record(feature, patches)
700
- self._index_project()
427
+
428
+ # 只添加已经在 git 控制下的修改文件
429
+ os.system("git add -u")
430
+
431
+ # 然后获取 git diff
432
+ git_diff = os.popen("git diff --cached").read()
433
+
434
+ # 自动生成commit信息,传入feature
435
+ commit_message = self._generate_commit_message(git_diff, feature)
436
+
437
+ # 显示并确认commit信息
438
+ PrettyOutput.print(f"自动生成的commit信息: {commit_message}", OutputType.INFO)
439
+ user_confirm = input("是否使用该commit信息?(y/n) [y]: ") or "y"
440
+
441
+ if user_confirm.lower() != "y":
442
+ commit_message = input("请输入新的commit信息: ")
443
+
444
+ # 不需要再次 git add,因为已经添加过了
445
+ os.system(f"git commit -m '{commit_message}'")
446
+ self._save_edit_record(commit_message, git_diff)
701
447
 
702
448
  def _revert_changes(self) -> None:
703
449
  """回退所有修改"""
@@ -758,8 +504,8 @@ def main():
758
504
  # 循环处理需求
759
505
  while True:
760
506
  try:
761
- # 获取需求
762
- feature = get_multiline_input("请输入开发需求 (输入空行退出):")
507
+ # 获取需求,传入项目根目录
508
+ feature = get_multiline_input("请输入开发需求 (输入空行退出):", tool.root_dir)
763
509
 
764
510
  if not feature or feature == "__interrupt__":
765
511
  break
@@ -772,9 +518,9 @@ def main():
772
518
  PrettyOutput.print(result["stdout"], OutputType.SUCCESS)
773
519
  else:
774
520
  if result["stderr"]:
775
- PrettyOutput.print(result["stderr"], OutputType.ERROR)
521
+ PrettyOutput.print(result["stderr"], OutputType.WARNING)
776
522
  if result["error"]:
777
- PrettyOutput.print(f"错误类型: {type(result['error']).__name__}", OutputType.ERROR)
523
+ PrettyOutput.print(f"错误类型: {type(result['error']).__name__}", OutputType.WARNING)
778
524
 
779
525
  except KeyboardInterrupt:
780
526
  print("\n用户中断执行")
@@ -787,3 +533,122 @@ def main():
787
533
 
788
534
  if __name__ == "__main__":
789
535
  exit(main())
536
+
537
+ class FilePathCompleter(Completer):
538
+ """文件路径自动完成器"""
539
+
540
+ def __init__(self, root_dir: str):
541
+ self.root_dir = root_dir
542
+ self._file_list = None
543
+
544
+ def _get_files(self) -> List[str]:
545
+ """获取git管理的文件列表"""
546
+ if self._file_list is None:
547
+ try:
548
+ # 切换到项目根目录
549
+ old_cwd = os.getcwd()
550
+ os.chdir(self.root_dir)
551
+
552
+ # 获取git管理的文件列表
553
+ self._file_list = os.popen("git ls-files").read().splitlines()
554
+
555
+ # 恢复工作目录
556
+ os.chdir(old_cwd)
557
+ except Exception as e:
558
+ PrettyOutput.print(f"获取文件列表失败: {str(e)}", OutputType.WARNING)
559
+ self._file_list = []
560
+ return self._file_list
561
+
562
+ def get_completions(self, document, complete_event):
563
+ """获取补全建议"""
564
+ text_before_cursor = document.text_before_cursor
565
+
566
+ # 检查是否刚输入了@
567
+ if text_before_cursor.endswith('@'):
568
+ # 显示所有文件
569
+ for path in self._get_files():
570
+ yield Completion(path, start_position=0)
571
+ return
572
+
573
+ # 检查之前是否有@,并获取@后的搜索词
574
+ at_pos = text_before_cursor.rfind('@')
575
+ if at_pos == -1:
576
+ return
577
+
578
+ search = text_before_cursor[at_pos + 1:].lower().strip()
579
+
580
+ # 提供匹配的文件建议
581
+ for path in self._get_files():
582
+ path_lower = path.lower()
583
+ if (search in path_lower or # 直接包含
584
+ search in os.path.basename(path_lower) or # 文件名包含
585
+ any(fnmatch.fnmatch(path_lower, f'*{s}*') for s in search.split())): # 通配符匹配
586
+ # 计算正确的start_position
587
+ yield Completion(path, start_position=-(len(search)))
588
+
589
+ class SmartCompleter(Completer):
590
+ """智能自动完成器,组合词语和文件路径补全"""
591
+
592
+ def __init__(self, word_completer: WordCompleter, file_completer: FilePathCompleter):
593
+ self.word_completer = word_completer
594
+ self.file_completer = file_completer
595
+
596
+ def get_completions(self, document, complete_event):
597
+ """获取补全建议"""
598
+ # 如果当前行以@结尾,使用文件补全
599
+ if document.text_before_cursor.strip().endswith('@'):
600
+ yield from self.file_completer.get_completions(document, complete_event)
601
+ else:
602
+ # 否则使用词语补全
603
+ yield from self.word_completer.get_completions(document, complete_event)
604
+
605
+ def get_multiline_input(prompt_text: str, root_dir: str = None) -> str:
606
+ """获取多行输入,支持文件路径自动完成功能
607
+
608
+ Args:
609
+ prompt_text: 提示文本
610
+ root_dir: 项目根目录,用于文件补全
611
+
612
+ Returns:
613
+ str: 用户输入的文本
614
+ """
615
+ # 创建文件补全器
616
+ file_completer = FilePathCompleter(root_dir or os.getcwd())
617
+
618
+ # 创建提示样式
619
+ style = Style.from_dict({
620
+ 'prompt': 'ansicyan bold',
621
+ 'input': 'ansiwhite',
622
+ })
623
+
624
+ # 创建会话
625
+ session = PromptSession(
626
+ completer=file_completer,
627
+ style=style,
628
+ multiline=False,
629
+ enable_history_search=True,
630
+ complete_while_typing=True
631
+ )
632
+
633
+ # 显示初始提示文本
634
+ print(f"\n{prompt_text}")
635
+
636
+ # 创建提示符
637
+ prompt = FormattedText([
638
+ ('class:prompt', ">>> ")
639
+ ])
640
+
641
+ # 获取输入
642
+ lines = []
643
+ try:
644
+ while True:
645
+ line = session.prompt(prompt).strip()
646
+ if not line: # 空行表示输入结束
647
+ break
648
+ lines.append(line)
649
+ except KeyboardInterrupt:
650
+ return "__interrupt__"
651
+ except EOFError:
652
+ pass
653
+
654
+ return "\n".join(lines)