jarvis-ai-assistant 0.1.91__py3-none-any.whl → 0.1.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

jarvis/jarvis_rag/main.py CHANGED
@@ -1,13 +1,10 @@
1
1
  import os
2
- import hashlib
3
2
  import numpy as np
4
3
  import faiss
5
4
  from typing import List, Tuple, Optional, Dict
6
- from sentence_transformers import SentenceTransformer
7
5
  import pickle
8
- from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_max_context_length, load_embedding_model, load_rerank_model
6
+ from jarvis.utils import OutputType, PrettyOutput, get_file_md5, get_max_context_length, load_embedding_model, load_rerank_model
9
7
  from jarvis.utils import load_env_from_file
10
- import tiktoken
11
8
  from dataclasses import dataclass
12
9
  from tqdm import tqdm
13
10
  import fitz # PyMuPDF for PDF files
@@ -16,12 +13,16 @@ from pathlib import Path
16
13
  from jarvis.models.registry import PlatformRegistry
17
14
  import shutil
18
15
  from datetime import datetime
16
+ import lzma # 添加 lzma 导入
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from threading import Lock
19
19
 
20
20
  @dataclass
21
21
  class Document:
22
22
  """文档类,用于存储文档内容和元数据"""
23
23
  content: str # 文档内容
24
24
  metadata: Dict # 元数据(文件路径、位置等)
25
+ md5: str = "" # 文件MD5值,用于增量更新检测
25
26
 
26
27
  class FileProcessor:
27
28
  """文件处理器基类"""
@@ -163,7 +164,9 @@ class RAGTool:
163
164
  # 初始化缓存和索引
164
165
  self.cache_path = os.path.join(self.data_dir, "cache.pkl")
165
166
  self.documents: List[Document] = []
166
- self.index = None
167
+ self.index = None # 用于搜索的IVF索引
168
+ self.flat_index = None # 用于存储原始向量
169
+ self.file_md5_cache = {} # 用于存储文件的MD5值
167
170
 
168
171
  # 加载缓存
169
172
  self._load_cache()
@@ -175,17 +178,23 @@ class RAGTool:
175
178
  DocxProcessor()
176
179
  ]
177
180
 
181
+ # 添加线程相关配置
182
+ self.thread_count = int(os.environ.get("JARVIS_THREAD_COUNT", os.cpu_count() or 4))
183
+ self.vector_lock = Lock() # 用于保护向量列表的并发访问
184
+
178
185
  def _load_cache(self):
179
186
  """加载缓存数据"""
180
187
  if os.path.exists(self.cache_path):
181
188
  try:
182
- with open(self.cache_path, 'rb') as f:
189
+ with lzma.open(self.cache_path, 'rb') as f:
183
190
  cache_data = pickle.load(f)
184
191
  self.documents = cache_data["documents"]
185
192
  vectors = cache_data["vectors"]
193
+ self.file_md5_cache = cache_data.get("file_md5_cache", {}) # 加载MD5缓存
186
194
 
187
195
  # 重建索引
188
- self._build_index(vectors)
196
+ if vectors is not None:
197
+ self._build_index(vectors)
189
198
  PrettyOutput.print(f"加载了 {len(self.documents)} 个文档片段",
190
199
  output_type=OutputType.INFO)
191
200
  except Exception as e:
@@ -193,16 +202,18 @@ class RAGTool:
193
202
  output_type=OutputType.WARNING)
194
203
  self.documents = []
195
204
  self.index = None
205
+ self.flat_index = None
206
+ self.file_md5_cache = {}
196
207
 
197
208
  def _save_cache(self, vectors: np.ndarray):
198
209
  """优化缓存保存"""
199
210
  try:
200
- # 添加版本号和时间戳
201
211
  cache_data = {
202
212
  "version": "1.0",
203
213
  "timestamp": datetime.now().isoformat(),
204
214
  "documents": self.documents,
205
- "vectors": vectors,
215
+ "vectors": vectors.copy() if vectors is not None else None, # 创建数组的副本
216
+ "file_md5_cache": dict(self.file_md5_cache), # 创建字典的副本
206
217
  "metadata": {
207
218
  "vector_dim": self.vector_dim,
208
219
  "total_docs": len(self.documents),
@@ -210,9 +221,12 @@ class RAGTool:
210
221
  }
211
222
  }
212
223
 
213
- # 使用压缩存储
214
- with open(self.cache_path, 'wb') as f:
215
- pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
224
+ # 先将数据序列化为字节流
225
+ data = pickle.dumps(cache_data, protocol=pickle.HIGHEST_PROTOCOL)
226
+
227
+ # 然后使用 LZMA 压缩字节流
228
+ with lzma.open(self.cache_path, 'wb') as f:
229
+ f.write(data)
216
230
 
217
231
  # 创建备份
218
232
  backup_path = f"{self.cache_path}.backup"
@@ -223,22 +237,29 @@ class RAGTool:
223
237
  except Exception as e:
224
238
  PrettyOutput.print(f"保存缓存失败: {str(e)}",
225
239
  output_type=OutputType.ERROR)
240
+ raise
226
241
 
227
242
  def _build_index(self, vectors: np.ndarray):
228
243
  """构建FAISS索引"""
229
- # 添加IVF索引以提高大规模检索性能
244
+ if vectors.shape[0] == 0:
245
+ self.index = None
246
+ self.flat_index = None
247
+ return
248
+
249
+ # 创建扁平索引存储原始向量,用于重建
250
+ self.flat_index = faiss.IndexFlatIP(self.vector_dim)
251
+ self.flat_index.add(vectors)
252
+
253
+ # 创建IVF索引用于快速搜索
230
254
  nlist = max(4, int(vectors.shape[0] / 1000)) # 每1000个向量一个聚类中心
231
255
  quantizer = faiss.IndexFlatIP(self.vector_dim)
232
256
  self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
233
257
 
234
- if vectors.shape[0] > 0:
235
- # 训练IVF索引
236
- self.index.train(vectors)
237
- self.index.add(vectors)
238
- # 设置搜索时探测的聚类数
239
- self.index.nprobe = min(nlist, 10)
240
- else:
241
- self.index = None
258
+ # 训练并添加向量
259
+ self.index.train(vectors)
260
+ self.index.add(vectors)
261
+ # 设置搜索时探测的聚类数
262
+ self.index.nprobe = min(nlist, 10)
242
263
 
243
264
  def _split_text(self, text: str) -> List[str]:
244
265
  """使用更智能的分块策略"""
@@ -302,16 +323,58 @@ class RAGTool:
302
323
  show_progress_bar=False)
303
324
  return np.array(embedding, dtype=np.float32)
304
325
 
305
- def _process_file(self, file_path: str) -> List[Document]:
306
- """处理单个文件
326
+ def _get_embedding_batch(self, texts: List[str]) -> np.ndarray:
327
+ """批量获取文本的向量表示
307
328
 
308
329
  Args:
309
- file_path: 文件路径
330
+ texts: 文本列表
310
331
 
311
332
  Returns:
312
- 文档对象列表
333
+ np.ndarray: 向量表示数组
313
334
  """
314
335
  try:
336
+ embeddings = self.embedding_model.encode(texts,
337
+ normalize_embeddings=True,
338
+ show_progress_bar=False,
339
+ batch_size=32) # 使用批处理提高效率
340
+ return np.array(embeddings, dtype=np.float32)
341
+ except Exception as e:
342
+ PrettyOutput.print(f"获取向量表示失败: {str(e)}",
343
+ output_type=OutputType.ERROR)
344
+ return np.zeros((len(texts), self.vector_dim), dtype=np.float32)
345
+
346
+ def _process_document_batch(self, documents: List[Document]) -> List[np.ndarray]:
347
+ """处理一批文档的向量化
348
+
349
+ Args:
350
+ documents: 文档列表
351
+
352
+ Returns:
353
+ List[np.ndarray]: 向量列表
354
+ """
355
+ texts = []
356
+ for doc in documents:
357
+ # 组合文档信息
358
+ combined_text = f"""
359
+ 文件: {doc.metadata['file_path']}
360
+ 内容: {doc.content}
361
+ """
362
+ texts.append(combined_text)
363
+
364
+ return self._get_embedding_batch(texts)
365
+
366
+ def _process_file(self, file_path: str) -> List[Document]:
367
+ """处理单个文件"""
368
+ try:
369
+ # 计算文件MD5
370
+ current_md5 = get_file_md5(file_path)
371
+ if not current_md5:
372
+ return []
373
+
374
+ # 检查文件是否需要重新处理
375
+ if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
376
+ return []
377
+
315
378
  # 查找合适的处理器
316
379
  processor = None
317
380
  for p in self.file_processors:
@@ -320,18 +383,14 @@ class RAGTool:
320
383
  break
321
384
 
322
385
  if not processor:
323
- PrettyOutput.print(f"跳过不支持的文件: {file_path}",
324
- output_type=OutputType.WARNING)
386
+ # 如果找不到合适的处理器,则返回一个空的文档
325
387
  return []
326
388
 
327
389
  # 提取文本内容
328
390
  content = processor.extract_text(file_path)
329
391
  if not content.strip():
330
- PrettyOutput.print(f"文件内容为空: {file_path}",
331
- output_type=OutputType.WARNING)
332
392
  return []
333
393
 
334
-
335
394
  # 分割文本
336
395
  chunks = self._split_text(content)
337
396
 
@@ -345,10 +404,13 @@ class RAGTool:
345
404
  "file_type": Path(file_path).suffix.lower(),
346
405
  "chunk_index": i,
347
406
  "total_chunks": len(chunks)
348
- }
407
+ },
408
+ md5=current_md5
349
409
  )
350
410
  documents.append(doc)
351
-
411
+
412
+ # 更新MD5缓存
413
+ self.file_md5_cache[file_path] = current_md5
352
414
  return documents
353
415
 
354
416
  except Exception as e:
@@ -361,43 +423,117 @@ class RAGTool:
361
423
  # 获取所有文件
362
424
  all_files = []
363
425
  for root, _, files in os.walk(dir):
364
- # 忽略特定目录
365
426
  if any(ignored in root for ignored in ['.git', '__pycache__', 'node_modules']) or \
366
427
  any(part.startswith('.jarvis-') for part in root.split(os.sep)):
367
428
  continue
368
429
  for file in files:
369
- # 忽略 .jarvis- 开头的文件
370
430
  if file.startswith('.jarvis-'):
371
431
  continue
372
432
 
373
433
  file_path = os.path.join(root, file)
374
- # 跳过大文件
375
434
  if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
376
435
  PrettyOutput.print(f"跳过大文件: {file_path}",
377
436
  output_type=OutputType.WARNING)
378
437
  continue
379
438
  all_files.append(file_path)
380
439
 
381
- # 处理所有文件
382
- self.documents = []
383
- for file_path in tqdm(all_files, desc="处理文件"):
384
- docs = self._process_file(file_path)
385
- self.documents.extend(docs)
440
+ # 清理已删除文件的缓存
441
+ deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
442
+ for file_path in deleted_files:
443
+ del self.file_md5_cache[file_path]
444
+ # 移除相关的文档
445
+ self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
386
446
 
387
- # 获取所有文档的向量表示
388
- vectors = []
389
- for doc in tqdm(self.documents, desc="生成向量"):
390
- vector = self._get_embedding(doc.content)
391
- vectors.append(vector)
447
+ # 检查文件变化
448
+ files_to_process = []
449
+ unchanged_files = []
450
+
451
+ with tqdm(total=len(all_files), desc="检查文件状态") as pbar:
452
+ for file_path in all_files:
453
+ current_md5 = get_file_md5(file_path)
454
+ if current_md5: # 只处理能成功计算MD5的文件
455
+ if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
456
+ # 文件未变化,记录但不重新处理
457
+ unchanged_files.append(file_path)
458
+ else:
459
+ # 新文件或已修改的文件
460
+ files_to_process.append(file_path)
461
+ pbar.update(1)
462
+
463
+ # 保留未变化文件的文档
464
+ unchanged_documents = [doc for doc in self.documents
465
+ if doc.metadata['file_path'] in unchanged_files]
466
+
467
+ # 处理新文件和修改的文件
468
+ new_documents = []
469
+ if files_to_process:
470
+ with tqdm(total=len(files_to_process), desc="处理文件") as pbar:
471
+ for file_path in files_to_process:
472
+ try:
473
+ docs = self._process_file(file_path)
474
+ if len(docs) > 0:
475
+ new_documents.extend(docs)
476
+ except Exception as e:
477
+ PrettyOutput.print(f"处理文件失败 {file_path}: {str(e)}",
478
+ output_type=OutputType.ERROR)
479
+ pbar.update(1)
480
+
481
+ # 更新文档列表
482
+ self.documents = unchanged_documents + new_documents
483
+
484
+ if not self.documents:
485
+ PrettyOutput.print("没有需要处理的文档", output_type=OutputType.WARNING)
486
+ return
487
+
488
+ # 只对新文档进行向量化
489
+ if new_documents:
490
+ PrettyOutput.print(f"开始处理 {len(new_documents)} 个新文档",
491
+ output_type=OutputType.INFO)
492
+
493
+ # 使用线程池并发处理向量化
494
+ batch_size = 32
495
+ new_vectors = []
496
+
497
+ with tqdm(total=len(new_documents), desc="生成向量") as pbar:
498
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
499
+ for i in range(0, len(new_documents), batch_size):
500
+ batch = new_documents[i:i + batch_size]
501
+ future = executor.submit(self._process_document_batch, batch)
502
+ batch_vectors = future.result()
503
+
504
+ with self.vector_lock:
505
+ new_vectors.extend(batch_vectors)
506
+
507
+ pbar.update(len(batch))
508
+
509
+ # 合并新旧向量
510
+ if self.flat_index is not None:
511
+ # 获取未变化文档的向量
512
+ unchanged_vectors = []
513
+ for doc in unchanged_documents:
514
+ # 从现有索引中提取向量
515
+ doc_idx = next((i for i, d in enumerate(self.documents)
516
+ if d.metadata['file_path'] == doc.metadata['file_path']), None)
517
+ if doc_idx is not None:
518
+ # 从扁平索引中重建向量
519
+ vector = np.zeros((1, self.vector_dim), dtype=np.float32)
520
+ self.flat_index.reconstruct(doc_idx, vector.ravel())
521
+ unchanged_vectors.append(vector)
522
+
523
+ if unchanged_vectors:
524
+ unchanged_vectors = np.vstack(unchanged_vectors)
525
+ vectors = np.vstack([unchanged_vectors, np.vstack(new_vectors)])
526
+ else:
527
+ vectors = np.vstack(new_vectors)
528
+ else:
529
+ vectors = np.vstack(new_vectors)
392
530
 
393
- if vectors:
394
- vectors = np.vstack(vectors)
395
531
  # 构建索引
396
532
  self._build_index(vectors)
397
533
  # 保存缓存
398
534
  self._save_cache(vectors)
399
-
400
- PrettyOutput.print(f"成功索引了 {len(self.documents)} 个文档片段",
535
+
536
+ PrettyOutput.print(f"成功索引了 {len(self.documents)} 个文档片段 (新增/修改: {len(new_documents)}, 未变化: {len(unchanged_documents)})",
401
537
  output_type=OutputType.SUCCESS)
402
538
 
403
539
  def search(self, query: str, top_k: int = 30) -> List[Tuple[Document, float]]:
@@ -4,6 +4,8 @@ import os
4
4
  import sys
5
5
  import readline
6
6
  from typing import Optional
7
+ from yaspin import yaspin
8
+ from yaspin.spinners import Spinners
7
9
 
8
10
  from jarvis.models.registry import PlatformRegistry
9
11
  from jarvis.utils import PrettyOutput, OutputType, load_env_from_file
@@ -11,7 +13,7 @@ from jarvis.utils import PrettyOutput, OutputType, load_env_from_file
11
13
  def execute_command(command: str) -> None:
12
14
  """显示命令并允许用户编辑,回车执行,Ctrl+C取消"""
13
15
  try:
14
- print("生成的命令 (可以编辑,回车执行,Ctrl+C取消):")
16
+ print("\n生成的命令 (可以编辑,回车执行,Ctrl+C取消):")
15
17
  # 预填充输入行
16
18
  readline.set_startup_hook(lambda: readline.insert_text(command))
17
19
  try:
@@ -68,14 +70,19 @@ find . -name "*.py"
68
70
  prefix = f"当前路径: {current_path}\n"
69
71
  prefix += f"当前shell: {shell}\n"
70
72
 
71
- # 处理请求
72
- result = model.chat(prefix + request)
73
-
74
- # 提取命令 - 简化处理逻辑,因为现在应该只返回纯命令
75
- if result and isinstance(result, str):
76
- return result.strip()
77
-
78
- return None
73
+ # 使用yaspin显示Thinking状态
74
+ with yaspin(Spinners.dots, text="Thinking", color="yellow") as spinner:
75
+ # 处理请求
76
+ result = model.chat(prefix + request)
77
+
78
+ # 提取命令
79
+ if result and isinstance(result, str):
80
+ command = result.strip()
81
+ spinner.ok("✓")
82
+ return command
83
+
84
+ spinner.fail("✗")
85
+ return None
79
86
 
80
87
  except Exception as e:
81
88
  PrettyOutput.print(f"处理请求时发生错误: {str(e)}", OutputType.ERROR)
jarvis/main.py CHANGED
@@ -54,6 +54,15 @@ def load_tasks() -> dict:
54
54
  PrettyOutput.print("Warning: .jarvis file should contain a dictionary of task_name: task_description", OutputType.ERROR)
55
55
  except Exception as e:
56
56
  PrettyOutput.print(f"Error loading .jarvis file: {str(e)}", OutputType.ERROR)
57
+
58
+ # 读取方法论
59
+ method_path = os.path.expanduser("~/.jarvis_methodology")
60
+ if os.path.exists(method_path):
61
+ with open(method_path, "r", encoding="utf-8") as f:
62
+ methodology = yaml.safe_load(f)
63
+ if isinstance(methodology, dict):
64
+ for name, desc in methodology.items():
65
+ tasks[f"执行方法论:{str(name)}\n {str(desc)}" ] = str(desc)
57
66
 
58
67
  return tasks
59
68
 
jarvis/models/ai8.py CHANGED
@@ -20,7 +20,7 @@ class AI8Model(BasePlatform):
20
20
  """Initialize model"""
21
21
  super().__init__()
22
22
  self.system_message = ""
23
- self.conversation = None
23
+ self.conversation = {}
24
24
  self.files = []
25
25
  self.models = {} # 存储模型信息
26
26
 
@@ -112,6 +112,7 @@ class AI8Model(BasePlatform):
112
112
  "name": name,
113
113
  "data": f"data:image/png;base64,{base64_data}"
114
114
  })
115
+ return self.files
115
116
 
116
117
  def set_system_message(self, message: str):
117
118
  """Set system message"""
@@ -138,7 +139,7 @@ class AI8Model(BasePlatform):
138
139
 
139
140
  payload = {
140
141
  "text": message,
141
- "sessionId": self.conversation['id'],
142
+ "sessionId": self.conversation['id'] if self.conversation else None,
142
143
  "files": []
143
144
  }
144
145
 
@@ -307,6 +308,6 @@ class AI8Model(BasePlatform):
307
308
  return list(self.models.keys())
308
309
 
309
310
  except Exception as e:
310
- PrettyOutput.print(f"获取模型列表异常: {str(e)}", OutputType.ERROR)
311
+ PrettyOutput.print(f"获取模型列表异常: {str(e)}", OutputType.WARNING)
311
312
  return []
312
313
 
jarvis/models/ollama.py CHANGED
@@ -29,15 +29,15 @@ class OllamaPlatform(BasePlatform):
29
29
  PrettyOutput.print("1. 安装 Ollama: https://ollama.ai", OutputType.INFO)
30
30
  PrettyOutput.print("2. 下载模型:", OutputType.INFO)
31
31
  PrettyOutput.print(f" ollama pull {self.model_name}", OutputType.INFO)
32
- raise Exception("No available models found")
32
+ PrettyOutput.print("Ollama没有可用的模型", OutputType.WARNING)
33
33
 
34
34
  except requests.exceptions.ConnectionError:
35
- PrettyOutput.print("\nOllama 服务未启动或无法连接", OutputType.ERROR)
35
+ PrettyOutput.print("\nOllama 服务未启动或无法连接", OutputType.WARNING)
36
36
  PrettyOutput.print("请确保已经:", OutputType.INFO)
37
37
  PrettyOutput.print("1. 安装了 Ollama: https://ollama.ai", OutputType.INFO)
38
38
  PrettyOutput.print("2. 启动了 Ollama 服务", OutputType.INFO)
39
39
  PrettyOutput.print("3. 服务地址配置正确 (默认: http://localhost:11434)", OutputType.INFO)
40
- raise Exception("Ollama service is not available")
40
+
41
41
 
42
42
  self.messages = []
43
43
  self.system_message = ""
jarvis/models/openai.py CHANGED
@@ -69,9 +69,9 @@ class OpenAIModel(BasePlatform):
69
69
 
70
70
  response = self.client.chat.completions.create(
71
71
  model=self.model_name, # 使用配置的模型名称
72
- messages=self.messages,
72
+ messages=self.messages, # type: ignore
73
73
  stream=True
74
- )
74
+ ) # type: ignore
75
75
 
76
76
  full_response = ""
77
77
 
jarvis/models/oyi.py CHANGED
@@ -23,7 +23,7 @@ class OyiModel(BasePlatform):
23
23
  self.messages = []
24
24
  self.system_message = ""
25
25
  self.conversation = None
26
- self.upload_files = []
26
+ self.files = []
27
27
  self.first_chat = True
28
28
 
29
29
  self.token = os.getenv("OYI_API_KEY")
@@ -122,7 +122,7 @@ class OyiModel(BasePlatform):
122
122
  }
123
123
 
124
124
  payload = {
125
- "topicId": self.conversation['result']['id'],
125
+ "topicId": self.conversation['result']['id'] if self.conversation else None,
126
126
  "messages": self.messages,
127
127
  "content": message,
128
128
  "contentFiles": []
@@ -130,8 +130,8 @@ class OyiModel(BasePlatform):
130
130
 
131
131
  # 如果有上传的文件,添加到请求中
132
132
  if self.first_chat:
133
- if self.upload_files:
134
- for file_data in self.upload_files:
133
+ if self.files:
134
+ for file_data in self.files:
135
135
  file_info = {
136
136
  "contentType": 1, # 1 表示图片
137
137
  "fileUrl": file_data['result']['url'],
@@ -140,7 +140,7 @@ class OyiModel(BasePlatform):
140
140
  }
141
141
  payload["contentFiles"].append(file_info)
142
142
  # 清空已使用的文件列表
143
- self.upload_files = []
143
+ self.files = []
144
144
  message = self.system_message + "\n" + message
145
145
  payload["content"] = message
146
146
  self.first_chat = False
@@ -195,7 +195,7 @@ class OyiModel(BasePlatform):
195
195
  """Reset model state"""
196
196
  self.messages = []
197
197
  self.conversation = None
198
- self.upload_files = []
198
+ self.files = []
199
199
  self.first_chat = True
200
200
 
201
201
  def delete_chat(self) -> bool:
@@ -251,7 +251,7 @@ class OyiModel(BasePlatform):
251
251
  model_info = self.models.get(self.model_name)
252
252
  if not model_info or not model_info.get('uploadFile', False):
253
253
  PrettyOutput.print(f"当前模型 {self.model_name} 不支持文件上传", OutputType.WARNING)
254
- return None
254
+ return []
255
255
 
256
256
  headers = {
257
257
  'Authorization': f'Bearer {self.token}',
@@ -283,18 +283,18 @@ class OyiModel(BasePlatform):
283
283
  if response.status_code == 200:
284
284
  data = response.json()
285
285
  if data.get('code') == 200:
286
- self.upload_files.append(data)
287
- return data
286
+ self.files.append(data)
288
287
  else:
289
288
  PrettyOutput.print(f"文件上传失败: {data.get('message')}", OutputType.ERROR)
290
- return None
289
+ return []
291
290
  else:
292
291
  PrettyOutput.print(f"文件上传失败: {response.status_code}", OutputType.ERROR)
293
- return None
292
+ return []
294
293
 
294
+ return self.files
295
295
  except Exception as e:
296
296
  PrettyOutput.print(f"文件上传异常: {str(e)}", OutputType.ERROR)
297
- return None
297
+ return []
298
298
 
299
299
  def get_available_models(self) -> List[str]:
300
300
  """获取可用的模型列表
@@ -364,5 +364,5 @@ class OyiModel(BasePlatform):
364
364
  return sorted(models)
365
365
 
366
366
  except Exception as e:
367
- PrettyOutput.print(f"获取模型列表异常: {str(e)}", OutputType.ERROR)
367
+ PrettyOutput.print(f"获取模型列表异常: {str(e)}", OutputType.WARNING)
368
368
  return []
jarvis/tools/ask_user.py CHANGED
@@ -34,8 +34,7 @@ class AskUserTool:
34
34
  PrettyOutput.print(question, OutputType.SYSTEM)
35
35
 
36
36
  # 获取用户输入
37
- PrettyOutput.print("\n请输入您的回答(输入空行结束):", OutputType.INPUT)
38
- user_response = get_multiline_input()
37
+ user_response = get_multiline_input("请输入您的回答(输入空行结束)")
39
38
 
40
39
  if user_response == "__interrupt__":
41
40
  return {
jarvis/tools/coder.py ADDED
@@ -0,0 +1,69 @@
1
+ import os
2
+ from typing import Dict, Any, Optional
3
+ from jarvis.jarvis_coder.main import JarvisCoder
4
+ from jarvis.utils import PrettyOutput, OutputType
5
+
6
+ class CoderTool:
7
+ """代码修改工具"""
8
+
9
+ name = "coder"
10
+ description = "分析并修改现有代码,用于实现新功能、修复bug、重构代码等。能理解代码上下文并进行精确的代码编辑。"
11
+ parameters = {
12
+ "feature": {
13
+ "type": "string",
14
+ "description": "要实现的功能描述或需要修改的内容,例如:'添加日志功能'、'修复内存泄漏'、'优化性能'等",
15
+ "required": True
16
+ },
17
+ "dir": {
18
+ "type": "string",
19
+ "description": "项目根目录,默认为当前目录",
20
+ "required": False
21
+ },
22
+ "language": {
23
+ "type": "string",
24
+ "description": "项目的主要编程语言,默认为python",
25
+ "required": False
26
+ }
27
+ }
28
+
29
+ def __init__(self):
30
+ self._coder = None
31
+
32
+
33
+ def _init_coder(self, dir: Optional[str] = None, language: Optional[str] = "python") -> None:
34
+ """初始化JarvisCoder实例"""
35
+ if not self._coder:
36
+ import os
37
+ work_dir = dir or os.getcwd()
38
+ self._coder = JarvisCoder(work_dir, language)
39
+
40
+ def execute(self, args: Dict) -> Dict[str, Any]:
41
+ """执行代码修改
42
+
43
+ Args:
44
+ feature: 要实现的功能描述
45
+ dir: 可选,项目根目录
46
+ language: 可选,编程语言
47
+
48
+ Returns:
49
+ Dict[str, Any]: 执行结果
50
+ """
51
+ feature = args.get("feature")
52
+ dir = args.get("dir")
53
+ language = args.get("language", "python")
54
+
55
+ try:
56
+ self.current_dir = os.getcwd()
57
+ self._init_coder(dir, language)
58
+ result = self._coder.execute(str(feature)) # type: ignore
59
+ return result
60
+ except Exception as e:
61
+ PrettyOutput.print(f"代码修改失败: {str(e)}", OutputType.ERROR)
62
+ return {
63
+ "success": False,
64
+ "stdout": "",
65
+ "stderr": f"执行失败: {str(e)}",
66
+ "error": e
67
+ }
68
+ finally:
69
+ os.chdir(self.current_dir)