jarvis-ai-assistant 0.1.63__py3-none-any.whl → 0.1.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

@@ -1,16 +1,14 @@
1
- from concurrent.futures import ThreadPoolExecutor, as_completed
2
- import hashlib
3
1
  import os
4
2
  import re
5
- import sqlite3
6
3
  import threading
7
4
  import time
8
- from typing import Dict, Any, List, Optional, Tuple
5
+ from typing import Dict, Any, List, Tuple
9
6
 
10
7
  import yaml
11
8
  from jarvis.models.base import BasePlatform
12
- from jarvis.utils import OutputType, PrettyOutput, get_multiline_input, load_env_from_file
9
+ from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_multiline_input, load_env_from_file
13
10
  from jarvis.models.registry import PlatformRegistry
11
+ from jarvis.jarvis_codebase.main import CodeBase
14
12
  from prompt_toolkit import PromptSession
15
13
  from prompt_toolkit.completion import WordCompleter, Completer, Completion
16
14
  from prompt_toolkit.formatted_text import FormattedText
@@ -23,12 +21,15 @@ index_lock = threading.Lock()
23
21
  class JarvisCoder:
24
22
  def __init__(self, root_dir: str, language: str):
25
23
  """初始化代码修改工具"""
24
+
25
+ self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM") or os.environ.get("JARVIS_PLATFORM")
26
+ self.model = os.environ.get("JARVIS_CODEGEN_MODEL") or os.environ.get("JARVIS_MODEL")
26
27
 
27
- self.root_dir = root_dir
28
- self.platform = os.environ.get("JARVIS_CODEGEN_PLATFORM")
29
- self.model = os.environ.get("JARVIS_CODEGEN_MODEL")
30
28
 
31
- self.root_dir = self._find_git_root_dir(self.root_dir)
29
+ if not self.platform or not self.model:
30
+ raise ValueError("JARVIS_CODEGEN_PLATFORM or JARVIS_CODEGEN_MODEL is not set")
31
+
32
+ self.root_dir = find_git_root(root_dir)
32
33
  if not self.root_dir:
33
34
  self.root_dir = root_dir
34
35
 
@@ -46,10 +47,6 @@ class JarvisCoder:
46
47
  if not os.path.exists(self.jarvis_dir):
47
48
  os.makedirs(self.jarvis_dir)
48
49
 
49
- self.index_db_path = os.path.join(self.jarvis_dir, "index.db")
50
- if not os.path.exists(self.index_db_path):
51
- self._create_index_db()
52
-
53
50
  self.record_dir = os.path.join(self.jarvis_dir, "record")
54
51
  if not os.path.exists(self.record_dir):
55
52
  os.makedirs(self.record_dir)
@@ -70,6 +67,9 @@ class JarvisCoder:
70
67
  os.system(f"git add .")
71
68
  os.system(f"git commit -m 'commit before code edit'")
72
69
 
70
+ # 4. 初始化代码库
71
+ self._codebase = CodeBase(self.root_dir)
72
+
73
73
  def _new_model(self):
74
74
  """获取大模型"""
75
75
  model = PlatformRegistry().get_global_platform_registry().create_platform(self.platform)
@@ -114,341 +114,6 @@ class JarvisCoder:
114
114
  time.sleep(delay)
115
115
  delay *= 2 # 指数退避
116
116
 
117
- def _get_key_info(self, file_path: str, content: str) -> Optional[Dict[str, Any]]:
118
- """获取文件的关键信息
119
-
120
- Args:
121
- file_path: 文件路径
122
- content: 文件内容
123
-
124
- Returns:
125
- Optional[Dict[str, Any]]: 文件信息,包含文件描述
126
- """
127
- model = self._new_model() # 创建新的模型实例
128
- model.set_suppress_output(True)
129
-
130
- prompt = f"""你是一个资深程序员,请根据文件内容,生成文件的关键信息,要求如下,除了代码,不要输出任何内容:
131
-
132
- 1. 文件路径: {file_path}
133
- 2. 文件内容:(<CONTENT_START>和<CONTENT_END>之间的部分)
134
- <CONTENT_START>
135
- {content}
136
- <CONTENT_END>
137
- 3. 关键信息: 请生成这个文件的主要功能和作用描述,包含的特征符号(函数和类、变量等),不超过100字
138
- """
139
- try:
140
- return model.chat(prompt)
141
- except Exception as e:
142
- PrettyOutput.print(f"解析文件信息失败: {str(e)}", OutputType.ERROR)
143
- return None
144
- finally:
145
- # 确保清理模型资源
146
- try:
147
- model.delete_chat()
148
- except:
149
- pass
150
-
151
-
152
-
153
- def _get_file_md5(self, file_path: str) -> str:
154
- """获取文件MD5"""
155
- return hashlib.md5(open(file_path, "rb").read()).hexdigest()
156
-
157
-
158
- def _create_index_db(self):
159
- """创建索引数据库"""
160
- with index_lock:
161
- if not os.path.exists(self.index_db_path):
162
- PrettyOutput.print("Index database does not exist, creating...", OutputType.INFO)
163
- index_db = sqlite3.connect(self.index_db_path)
164
- index_db.execute(
165
- "CREATE TABLE files (file_path TEXT PRIMARY KEY, file_md5 TEXT, file_description TEXT)")
166
- index_db.commit()
167
- index_db.close()
168
- PrettyOutput.print("Index database created", OutputType.SUCCESS)
169
- # commit
170
- os.chdir(self.root_dir)
171
- os.system(f"git add .gitignore -f")
172
- os.system(f"git commit -m 'add index database'")
173
-
174
-
175
- def _find_file_by_md5(self, file_md5: str) -> Optional[str]:
176
- """根据文件MD5查找文件路径"""
177
- with index_lock:
178
- index_db = sqlite3.connect(self.index_db_path)
179
- cursor = index_db.cursor()
180
- cursor.execute(
181
- "SELECT file_path FROM files WHERE file_md5 = ?", (file_md5,))
182
- result = cursor.fetchone()
183
- index_db.close()
184
- return result[0] if result else None
185
-
186
-
187
- def _update_file_path(self, file_path: str, file_md5: str):
188
- """更新文件路径"""
189
- with index_lock:
190
- index_db = sqlite3.connect(self.index_db_path)
191
- cursor = index_db.cursor()
192
- cursor.execute(
193
- "UPDATE files SET file_path = ? WHERE file_md5 = ?", (file_path, file_md5))
194
- index_db.commit()
195
- index_db.close()
196
-
197
-
198
- def _insert_info(self, file_path: str, file_md5: str, file_description: str):
199
- """插入文件信息"""
200
- with index_lock:
201
- index_db = sqlite3.connect(self.index_db_path)
202
- cursor = index_db.cursor()
203
- cursor.execute("DELETE FROM files WHERE file_path = ?", (file_path,))
204
- cursor.execute("INSERT INTO files (file_path, file_md5, file_description) VALUES (?, ?, ?)",
205
- (file_path, file_md5, file_description))
206
- index_db.commit()
207
- index_db.close()
208
-
209
- def _is_text_file(self, file_path: str) -> bool:
210
- """判断文件是否是文本文件"""
211
- try:
212
- with open(file_path, 'rb') as f:
213
- # 读取文件前1024个字节
214
- chunk = f.read(1024)
215
- # 检查是否包含空字节
216
- if b'\x00' in chunk:
217
- return False
218
- # 尝试解码为文本
219
- chunk.decode('utf-8')
220
- return True
221
- except:
222
- return False
223
-
224
- def _index_project(self):
225
- """建立代码库索引"""
226
- import threading
227
- from concurrent.futures import ThreadPoolExecutor, as_completed
228
-
229
- git_files = os.popen("git ls-files").read().splitlines()
230
-
231
- index_db = sqlite3.connect(self.index_db_path)
232
- cursor = index_db.cursor()
233
- cursor.execute("SELECT file_path FROM files")
234
- db_files = [row[0] for row in cursor.fetchall()]
235
- for db_file in db_files:
236
- if not os.path.exists(db_file):
237
- cursor.execute("DELETE FROM files WHERE file_path = ?", (db_file,))
238
- PrettyOutput.print(f"删除不存在的文件记录: {db_file}", OutputType.INFO)
239
- index_db.commit()
240
- index_db.close()
241
-
242
- def process_file(file_path: str):
243
- """处理单个文件的索引任务"""
244
- if not self._is_text_file(file_path):
245
- return
246
-
247
- # 计算文件MD5
248
- file_md5 = self._get_file_md5(file_path)
249
-
250
- # 查找文件
251
- file_path_in_db = self._find_file_by_md5(file_md5)
252
- if file_path_in_db:
253
- PrettyOutput.print(
254
- f"文件 {file_path} 重复,跳过", OutputType.INFO)
255
- if file_path_in_db != file_path:
256
- self._update_file_path(file_path, file_md5)
257
- PrettyOutput.print(
258
- f"文件 {file_path} 重复,更新路径为 {file_path}", OutputType.INFO)
259
- return
260
-
261
- with open(file_path, "r", encoding="utf-8") as f:
262
- file_content = f.read()
263
- key_info = self._get_key_info(file_path, file_content)
264
- if not key_info:
265
- PrettyOutput.print(
266
- f"文件 {file_path} 索引失败", OutputType.INFO)
267
- return
268
-
269
- self._insert_info(file_path, file_md5, key_info)
270
- PrettyOutput.print(
271
- f"文件 {file_path} 已建立索引", OutputType.INFO)
272
-
273
-
274
- # 使用线程池处理文件索引
275
- with ThreadPoolExecutor(max_workers=10) as executor:
276
- futures = [executor.submit(process_file, file_path) for file_path in git_files]
277
- for future in as_completed(futures):
278
- try:
279
- future.result()
280
- except Exception as e:
281
- PrettyOutput.print(f"处理文件时发生错误: {str(e)}", OutputType.ERROR)
282
-
283
- PrettyOutput.print("项目索引完成", OutputType.INFO)
284
-
285
- def _get_files_from_db(self) -> List[Tuple[str, str]]:
286
- """从数据库获取所有文件信息
287
-
288
- Returns:
289
- List[Tuple[str, str]]: [(file_path, file_description), ...]
290
- """
291
- try:
292
- index_db = sqlite3.connect(self.index_db_path)
293
- cursor = index_db.cursor()
294
- cursor.execute("SELECT file_path, file_description FROM files")
295
- all_files = cursor.fetchall()
296
- index_db.close()
297
- return all_files
298
- except sqlite3.Error as e:
299
- PrettyOutput.print(f"数据库操作失败: {str(e)}", OutputType.ERROR)
300
- return []
301
-
302
- def _analyze_files_in_batches(self, all_files: List[Tuple[str, str]], feature: str, batch_size: int = 100) -> List[Dict]:
303
- """批量分析文件相关性
304
-
305
- Args:
306
- all_files: 所有文件列表
307
- feature: 需求描述
308
- batch_size: 批处理大小
309
-
310
- Returns:
311
- List[Dict]: 带评分的文件列表
312
- """
313
- batch_results = []
314
-
315
- with ThreadPoolExecutor(max_workers=10) as executor:
316
- futures = []
317
- for i in range(0, len(all_files), batch_size):
318
- batch_files = all_files[i:i + batch_size]
319
- prompt = self._create_batch_analysis_prompt(batch_files, feature)
320
- model = self._new_model()
321
- model.set_suppress_output(True)
322
- futures.append(executor.submit(self._call_model_with_retry, model, prompt))
323
-
324
- for future in as_completed(futures):
325
- success, response = future.result()
326
- if not success:
327
- continue
328
-
329
- batch_start = futures.index(future) * batch_size
330
- batch_end = min(batch_start + batch_size, len(all_files))
331
- current_batch = all_files[batch_start:batch_end]
332
-
333
- results = self._process_batch_response(response, current_batch)
334
- batch_results.extend(results)
335
-
336
- return batch_results
337
-
338
- def _create_batch_analysis_prompt(self, batch_files: List[Tuple[str, str]], feature: str) -> str:
339
- """创建批量分析的提示词
340
-
341
- Args:
342
- batch_files: 批次文件列表
343
- feature: 需求描述
344
-
345
- Returns:
346
- str: 提示词
347
- """
348
- prompt = """你是资深程序员,请根据需求描述,从以下文件路径中选出最相关的文件,按相关度从高到低排序。
349
-
350
- 相关度打分标准(0-9分):
351
- - 9分:文件名直接包含需求中的关键词,且文件功能与需求完全匹配
352
- - 7-8分:文件名包含需求相关词,或文件功能与需求高度相关
353
- - 5-6分:文件名暗示与需求有关,或文件功能与需求部分相关
354
- - 3-4分:文件可能需要小幅修改以配合需求
355
- - 1-2分:文件与需求关系较远,但可能需要少量改动
356
- - 0分:文件与需求完全无关
357
-
358
- 请输出yaml格式,仅输出以下格式内容:
359
- <RELEVANT_FILES_START>
360
- file1.py: 9
361
- file2.py: 7
362
- <RELEVANT_FILES_END>
363
-
364
- 文件列表:
365
- """
366
- for file_path, _ in batch_files:
367
- prompt += f"- {file_path}\n"
368
- prompt += f"\n需求描述: {feature}\n"
369
- prompt += "\n注意:\n1. 只输出最相关的文件,不超过5个\n2. 根据上述打分标准判断相关性\n3. 相关度必须是0-9的整数"
370
-
371
- return prompt
372
-
373
- def _process_batch_response(self, response: str, batch_files: List[Tuple[str, str]]) -> List[Dict]:
374
- """处理批量分析的响应
375
-
376
- Args:
377
- response: 模型响应
378
- batch_files: 批次文件列表
379
-
380
- Returns:
381
- List[Dict]: 处理后的文件列表
382
- """
383
- try:
384
- response = response.replace("<RELEVANT_FILES_START>", "").replace("<RELEVANT_FILES_END>", "")
385
- result = yaml.safe_load(response)
386
-
387
- batch_files_dict = {f[0]: f[1] for f in batch_files}
388
- results = []
389
- for file_path, score in result.items():
390
- if isinstance(file_path, str) and isinstance(score, int):
391
- score = max(0, min(9, score)) # Ensure score is between 0-9
392
- if file_path in batch_files_dict:
393
- results.append({
394
- "file_path": file_path,
395
- "file_description": batch_files_dict[file_path],
396
- "score": score
397
- })
398
- return results
399
- except Exception as e:
400
- PrettyOutput.print(f"处理批次文件失败: {str(e)}", OutputType.ERROR)
401
- return []
402
-
403
-
404
- def _process_content_response(self, response: str, top_files: List[Dict]) -> List[Dict]:
405
- """处理内容分析的响应"""
406
- try:
407
- response = response.replace("<FILE_RELATION_START>", "").replace("<FILE_RELATION_END>", "")
408
- file_relation = yaml.safe_load(response)
409
- if not file_relation:
410
- return top_files[:5]
411
-
412
- score = [[] for _ in range(10)] # 创建10个空列表,对应0-9分
413
- for file_id, relation in file_relation.items():
414
- id = int(file_id)
415
- relation = max(0, min(9, relation)) # 确保范围在0-9之间
416
- score[relation].append(top_files[id])
417
-
418
- files = []
419
- for scores in reversed(score): # 从高分到低分遍历
420
- files.extend(scores)
421
- if len(files) >= 5: # 直接取相关性最高的5个文件
422
- break
423
-
424
- return files[:5]
425
- except Exception as e:
426
- PrettyOutput.print(f"处理文件关系失败: {str(e)}", OutputType.ERROR)
427
- return top_files[:5]
428
-
429
- def _find_related_files(self, feature: str) -> List[Dict]:
430
- """根据需求描述,查找相关文件
431
-
432
- Args:
433
- feature: 需求描述
434
-
435
- Returns:
436
- List[Dict]: 相关文件列表
437
- """
438
- # 1. 从数据库获取所有文件
439
- all_files = self._get_files_from_db()
440
- if not all_files:
441
- return []
442
-
443
- # 2. 批量分析文件相关性
444
- batch_results = self._analyze_files_in_batches(all_files, feature)
445
-
446
- # 3. 排序并获取前5个文件
447
- batch_results.sort(key=lambda x: x["score"], reverse=True)
448
- return batch_results[:5]
449
-
450
-
451
-
452
117
  def _remake_patch(self, prompt: str) -> List[str]:
453
118
  success, response = self._call_model_with_retry(self.main_model, prompt, max_retries=5) # 增加重试次数
454
119
  if not success:
@@ -459,7 +124,7 @@ file2.py: 7
459
124
  return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
460
125
  for patch in patches if patch.strip()]
461
126
  except Exception as e:
462
- PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.ERROR)
127
+ PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.WARNING)
463
128
  return []
464
129
 
465
130
  def _make_patch(self, related_files: List[Dict], feature: str) -> List[str]:
@@ -473,21 +138,32 @@ file2.py: 7
473
138
  要替换的内容
474
139
  =======
475
140
  新的内容
476
- <<<<<<
141
+ >>>>>>
477
142
  <PATCH_END>
478
143
 
479
- 2. 如果是新文件,格式如下:
144
+ 2. 如果是新文件或者替换整个文件内容,格式如下:
480
145
  <PATCH_START>
481
146
  >>>>>> path/to/new/file
482
147
  =======
483
148
  新文件的完整内容
484
- <<<<<<
149
+ >>>>>>
150
+ <PATCH_END>
151
+
152
+ 3. 如果要删除文件中的某一段,格式如下:
153
+ <PATCH_START>
154
+ >>>>>> path/to/file
155
+ 要删除的内容
156
+ =======
157
+ >>>>>>
485
158
  <PATCH_END>
486
159
 
487
160
  文件列表如下:
488
161
  """
489
162
  for i, file in enumerate(related_files):
490
- prompt += f"""{i}. {file["file_path"]} : {file["file_description"]}\n"""
163
+ if len(prompt) > 30 * 1024:
164
+ PrettyOutput.print(f'避免上下文超限,丢弃低相关度文件:{file["file_path"]}', OutputType.WARNING)
165
+ continue
166
+ prompt += f"""{i}. {file["file_path"]}\n"""
491
167
  prompt += f"""文件内容:\n"""
492
168
  prompt += f"<FILE_CONTENT_START>\n"
493
169
  prompt += f'{file["file_content"]}\n'
@@ -499,6 +175,7 @@ file2.py: 7
499
175
  1、仅输出补丁内容,不要输出任何其他内容,每个补丁必须用<PATCH_START>和<PATCH_END>标记
500
176
  2、如果在大段代码中有零星修改,生成多个补丁
501
177
  3、要替换的内容,一定要与文件内容完全一致,不要有任何多余或者缺失的内容
178
+ 4、每个patch不超过20行,超出20行,请生成多个patch
502
179
  """
503
180
 
504
181
  success, response = self._call_model_with_retry(self.main_model, prompt)
@@ -511,7 +188,7 @@ file2.py: 7
511
188
  return [patch.replace('<PATCH_START>', '').replace('<PATCH_END>', '').strip()
512
189
  for patch in patches if patch.strip()]
513
190
  except Exception as e:
514
- PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.ERROR)
191
+ PrettyOutput.print(f"解析patch失败: {str(e)}", OutputType.WARNING)
515
192
  return []
516
193
 
517
194
  def _apply_patch(self, related_files: List[Dict], patches: List[str]) -> Tuple[bool, str]:
@@ -550,7 +227,7 @@ file2.py: 7
550
227
  return False, "\n".join(error_info)
551
228
 
552
229
  old_content = parts[0]
553
- new_content = parts[1].split("<<<<<<")[0]
230
+ new_content = parts[1].split(">>>>>>")[0]
554
231
 
555
232
  # 处理新文件
556
233
  if not old_content:
@@ -601,12 +278,12 @@ file2.py: 7
601
278
 
602
279
  return True, ""
603
280
 
604
- def _save_edit_record(self, feature: str, patches: List[str]) -> None:
281
+ def _save_edit_record(self, commit_message: str, git_diff: str) -> None:
605
282
  """保存代码修改记录
606
283
 
607
284
  Args:
608
- feature: 需求描述
609
- patches: 补丁列表
285
+ commit_message: 提交信息
286
+ git_diff: git diff --cached的输出
610
287
  """
611
288
 
612
289
  # 获取下一个序号
@@ -619,8 +296,8 @@ file2.py: 7
619
296
  # 创建记录文件
620
297
  record = {
621
298
  "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
622
- "feature": feature,
623
- "patches": patches
299
+ "commit_message": commit_message,
300
+ "git_diff": git_diff
624
301
  }
625
302
 
626
303
  record_path = os.path.join(self.record_dir, f"{next_num:04d}.yaml")
@@ -629,28 +306,30 @@ file2.py: 7
629
306
 
630
307
  PrettyOutput.print(f"已保存修改记录: {record_path}", OutputType.SUCCESS)
631
308
 
632
- def _find_git_root_dir(self, root_dir: str) -> str:
633
- """查找git根目录"""
634
- while not os.path.exists(os.path.join(root_dir, ".git")):
635
- root_dir = os.path.dirname(root_dir)
636
- if root_dir == "/":
637
- return None
638
- return root_dir
309
+
639
310
 
640
311
 
641
312
  def _prepare_execution(self) -> None:
642
313
  """准备执行环境"""
643
314
  self.main_model = self._new_model()
644
- self._index_project()
315
+ self._codebase.generate_codebase()
316
+
645
317
 
646
318
  def _load_related_files(self, feature: str) -> List[Dict]:
647
319
  """加载相关文件内容"""
648
- related_files = self._find_related_files(feature)
649
- for file in related_files:
650
- PrettyOutput.print(f"Related file: {file['file_path']}", OutputType.INFO)
651
- with open(file["file_path"], "r", encoding="utf-8") as f:
652
- file["file_content"] = f.read()
653
- return related_files
320
+ ret = []
321
+ # 确保索引数据库已生成
322
+ if not self._codebase.is_index_generated():
323
+ PrettyOutput.print("检测到索引数据库未生成,正在生成...", OutputType.WARNING)
324
+ self._codebase.generate_codebase()
325
+
326
+ related_files = self._codebase.search_similar(feature)
327
+ for file, score, _ in related_files:
328
+ PrettyOutput.print(f"相关文件: {file} 相关度: {score:.3f}", OutputType.SUCCESS)
329
+ with open(file, "r", encoding="utf-8") as f:
330
+ content = f.read()
331
+ ret.append({"file_path": file, "file_content": content})
332
+ return ret
654
333
 
655
334
  def _handle_patch_application(self, related_files: List[Dict], patches: List[str], feature: str) -> Dict[str, Any]:
656
335
  """处理补丁应用流程"""
@@ -667,7 +346,7 @@ file2.py: 7
667
346
  if success:
668
347
  user_confirm = input("是否确认修改?(y/n)")
669
348
  if user_confirm.lower() == "y":
670
- self._finalize_changes(feature, patches)
349
+ self._finalize_changes(feature)
671
350
  return {
672
351
  "success": True,
673
352
  "stdout": f"已完成功能开发{feature}",
@@ -697,13 +376,74 @@ file2.py: 7
697
376
  """
698
377
  patches = self._remake_patch(retry_prompt)
699
378
 
700
- def _finalize_changes(self, feature: str, patches: List[str]) -> None:
379
+
380
+
381
+
382
+ def _generate_commit_message(self, git_diff: str, feature: str) -> str:
383
+ """根据git diff和功能描述生成commit信息
384
+
385
+ Args:
386
+ git_diff: git diff --cached的输出
387
+ feature: 用户的功能描述
388
+
389
+ Returns:
390
+ str: 生成的commit信息
391
+ """
392
+
393
+ # 生成提示词
394
+ prompt = f"""你是一个经验丰富的程序员,请根据以下代码变更和功能描述生成简洁明了的commit信息:
395
+
396
+ 功能描述:
397
+ {feature}
398
+
399
+ 代码变更:
400
+ """
401
+ # 添加git diff内容
402
+ prompt += f"Git Diff:\n{git_diff}\n\n"
403
+
404
+ prompt += """
405
+ 请遵循以下规则:
406
+ 1. 使用英文编写
407
+ 2. 采用常规的commit message格式:<type>(<scope>): <subject>
408
+ 3. 保持简洁,不超过50个字符
409
+ 4. 准确描述代码变更的主要内容
410
+ 5. 优先考虑功能描述和git diff中的变更内容
411
+ """
412
+
413
+ # 使用normal模型生成commit信息
414
+ model = PlatformRegistry().get_global_platform_registry().create_platform(self.platform)
415
+ model.set_model_name(self.model)
416
+ model.set_suppress_output(True)
417
+ success, response = self._call_model_with_retry(model, prompt)
418
+ if not success:
419
+ return "Update code changes"
420
+
421
+ # 清理响应内容
422
+ return response.strip().split("\n")[0]
423
+
424
+ def _finalize_changes(self, feature: str) -> None:
701
425
  """完成修改并提交"""
702
426
  PrettyOutput.print("修改确认成功,提交修改", OutputType.INFO)
703
- os.system(f"git add .")
704
- os.system(f"git commit -m '{feature}'")
705
- self._save_edit_record(feature, patches)
706
- self._index_project()
427
+
428
+ # 只添加已经在 git 控制下的修改文件
429
+ os.system("git add -u")
430
+
431
+ # 然后获取 git diff
432
+ git_diff = os.popen("git diff --cached").read()
433
+
434
+ # 自动生成commit信息,传入feature
435
+ commit_message = self._generate_commit_message(git_diff, feature)
436
+
437
+ # 显示并确认commit信息
438
+ PrettyOutput.print(f"自动生成的commit信息: {commit_message}", OutputType.INFO)
439
+ user_confirm = input("是否使用该commit信息?(y/n) [y]: ") or "y"
440
+
441
+ if user_confirm.lower() != "y":
442
+ commit_message = input("请输入新的commit信息: ")
443
+
444
+ # 不需要再次 git add,因为已经添加过了
445
+ os.system(f"git commit -m '{commit_message}'")
446
+ self._save_edit_record(commit_message, git_diff)
707
447
 
708
448
  def _revert_changes(self) -> None:
709
449
  """回退所有修改"""
@@ -778,9 +518,9 @@ def main():
778
518
  PrettyOutput.print(result["stdout"], OutputType.SUCCESS)
779
519
  else:
780
520
  if result["stderr"]:
781
- PrettyOutput.print(result["stderr"], OutputType.ERROR)
521
+ PrettyOutput.print(result["stderr"], OutputType.WARNING)
782
522
  if result["error"]:
783
- PrettyOutput.print(f"错误类型: {type(result['error']).__name__}", OutputType.ERROR)
523
+ PrettyOutput.print(f"错误类型: {type(result['error']).__name__}", OutputType.WARNING)
784
524
 
785
525
  except KeyboardInterrupt:
786
526
  print("\n用户中断执行")
jarvis/main.py CHANGED
@@ -117,7 +117,6 @@ def main():
117
117
  PlatformRegistry.get_global_platform_registry().set_global_platform_name(platform)
118
118
 
119
119
  if args.model:
120
- PrettyOutput.print(f"用户传入了模型参数,更换模型: {args.model}", OutputType.USER)
121
120
  os.environ["JARVIS_MODEL"] = args.model
122
121
 
123
122
  try:
@@ -126,7 +125,6 @@ def main():
126
125
 
127
126
  # 如果用户传入了模型参数,则更换当前模型为用户指定的模型
128
127
  if args.model:
129
- PrettyOutput.print(f"用户传入了模型参数,更换模型: {args.model}", OutputType.USER)
130
128
  agent.model.set_model_name(args.model)
131
129
 
132
130
  # 欢迎信息
jarvis/models/ai8.py CHANGED
@@ -64,11 +64,10 @@ class AI8Model(BasePlatform):
64
64
 
65
65
  PrettyOutput.print("使用AI8_MODEL环境变量配置模型", OutputType.SUCCESS)
66
66
 
67
- self.model_name = os.getenv("AI8_MODEL") or os.getenv("JARVIS_MODEL") or "deepseek-chat"
67
+ self.model_name = os.getenv("JARVIS_MODEL") or "deepseek-chat"
68
68
  if self.model_name not in self.models:
69
69
  PrettyOutput.print(f"警告: 当前选择的模型 {self.model_name} 不在可用列表中", OutputType.WARNING)
70
70
 
71
- PrettyOutput.print(f"当前使用模型: {self.model_name}", OutputType.SYSTEM)
72
71
 
73
72
  def set_model_name(self, model_name: str):
74
73
  """设置模型名称"""
jarvis/models/openai.py CHANGED
@@ -33,7 +33,6 @@ class OpenAIModel(BasePlatform):
33
33
  self.base_url = os.getenv("OPENAI_API_BASE", "https://api.deepseek.com")
34
34
  self.model_name = os.getenv("OPENAI_MODEL_NAME") or os.getenv("JARVIS_MODEL") or "deepseek-chat"
35
35
 
36
- PrettyOutput.print(f"当前使用模型: {self.model_name}", OutputType.SYSTEM)
37
36
 
38
37
  self.client = OpenAI(
39
38
  api_key=self.api_key,