jarvis-ai-assistant 0.1.98__py3-none-any.whl → 0.1.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (40) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/agent.py +199 -157
  3. jarvis/jarvis_code_agent/__init__.py +0 -0
  4. jarvis/jarvis_code_agent/main.py +203 -0
  5. jarvis/jarvis_codebase/main.py +412 -284
  6. jarvis/jarvis_coder/file_select.py +209 -0
  7. jarvis/jarvis_coder/git_utils.py +64 -2
  8. jarvis/jarvis_coder/main.py +11 -389
  9. jarvis/jarvis_coder/patch_handler.py +84 -14
  10. jarvis/jarvis_coder/plan_generator.py +49 -7
  11. jarvis/jarvis_rag/main.py +9 -9
  12. jarvis/jarvis_smart_shell/main.py +5 -7
  13. jarvis/models/base.py +6 -1
  14. jarvis/models/ollama.py +2 -2
  15. jarvis/models/registry.py +3 -6
  16. jarvis/tools/ask_user.py +6 -6
  17. jarvis/tools/codebase_qa.py +5 -7
  18. jarvis/tools/create_code_sub_agent.py +55 -0
  19. jarvis/tools/{sub_agent.py → create_sub_agent.py} +4 -1
  20. jarvis/tools/execute_code_modification.py +72 -0
  21. jarvis/tools/{file_ops.py → file_operation.py} +13 -14
  22. jarvis/tools/find_related_files.py +86 -0
  23. jarvis/tools/methodology.py +25 -25
  24. jarvis/tools/rag.py +32 -32
  25. jarvis/tools/registry.py +72 -36
  26. jarvis/tools/search.py +1 -1
  27. jarvis/tools/select_code_files.py +64 -0
  28. jarvis/utils.py +153 -49
  29. {jarvis_ai_assistant-0.1.98.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/METADATA +1 -1
  30. jarvis_ai_assistant-0.1.99.dist-info/RECORD +52 -0
  31. {jarvis_ai_assistant-0.1.98.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/entry_points.txt +2 -1
  32. jarvis/main.py +0 -155
  33. jarvis/tools/coder.py +0 -69
  34. jarvis_ai_assistant-0.1.98.dist-info/RECORD +0 -47
  35. /jarvis/tools/{shell.py → execute_shell.py} +0 -0
  36. /jarvis/tools/{generator.py → generate_tool.py} +0 -0
  37. /jarvis/tools/{webpage.py → read_webpage.py} +0 -0
  38. {jarvis_ai_assistant-0.1.98.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/LICENSE +0 -0
  39. {jarvis_ai_assistant-0.1.98.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/WHEEL +0 -0
  40. {jarvis_ai_assistant-0.1.98.dist-info → jarvis_ai_assistant-0.1.99.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ from jarvis.models.registry import PlatformRegistry
9
9
  import concurrent.futures
10
10
  from threading import Lock
11
11
  from concurrent.futures import ThreadPoolExecutor
12
- from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_thread_count, load_embedding_model, load_rerank_model
12
+ from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_single_line_input, get_thread_count, load_embedding_model, load_rerank_model
13
13
  from jarvis.utils import load_env_from_file
14
14
  import argparse
15
15
  import pickle
@@ -28,59 +28,37 @@ class CodeBase:
28
28
 
29
29
  # 初始化数据目录
30
30
  self.data_dir = os.path.join(self.root_dir, ".jarvis-codebase")
31
- if not os.path.exists(self.data_dir):
32
- os.makedirs(self.data_dir)
31
+ self.cache_dir = os.path.join(self.data_dir, "cache")
32
+ if not os.path.exists(self.cache_dir):
33
+ os.makedirs(self.cache_dir)
33
34
 
34
- # 初始化嵌入模型,使用系统默认缓存目录
35
+ # 初始化嵌入模型
35
36
  try:
36
37
  self.embedding_model = load_embedding_model()
37
-
38
- # 强制完全加载所有模型组件
39
- test_text = """
40
- 这是一段测试文本,用于确保模型完全加载。
41
- 包含多行内容,以模拟实际使用场景。
42
- """
43
- # 预热模型,确保所有组件都被加载
38
+ test_text = """This is a test text"""
44
39
  self.embedding_model.encode([test_text],
45
40
  convert_to_tensor=True,
46
41
  normalize_embeddings=True)
47
- PrettyOutput.print("模型加载完成", output_type=OutputType.SUCCESS)
42
+ PrettyOutput.print("Model loaded successfully", output_type=OutputType.SUCCESS)
48
43
  except Exception as e:
49
- PrettyOutput.print(f"加载模型失败: {str(e)}", output_type=OutputType.ERROR)
44
+ PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
50
45
  raise
51
46
 
52
47
  self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
53
-
54
48
  self.git_file_list = self.get_git_file_list()
55
49
  self.platform_registry = PlatformRegistry.get_global_platform_registry()
56
50
 
57
51
  # 初始化缓存和索引
58
- self.cache_path = os.path.join(self.data_dir, "cache.pkl")
59
52
  self.vector_cache = {}
60
53
  self.file_paths = []
61
54
 
62
- # 加载缓存
63
- if os.path.exists(self.cache_path):
64
- try:
65
- with lzma.open(self.cache_path, 'rb') as f:
66
- cache_data = pickle.load(f)
67
- self.vector_cache = cache_data["vectors"]
68
- self.file_paths = cache_data["file_paths"]
69
- PrettyOutput.print(f"加载了 {len(self.vector_cache)} 个向量缓存",
70
- output_type=OutputType.INFO)
71
- # 从缓存重建索引
72
- self.build_index()
73
- except Exception as e:
74
- PrettyOutput.print(f"加载缓存失败: {str(e)}",
75
- output_type=OutputType.WARNING)
76
- self.vector_cache = {}
77
- self.file_paths = []
78
- self.index = None
55
+ # 加载所有缓存文件
56
+ self._load_all_cache()
79
57
 
80
58
  def get_git_file_list(self):
81
- """获取 git 仓库中的文件列表,排除 .jarvis-codebase 目录"""
59
+ """Get the list of files in the git repository, excluding the .jarvis-codebase directory"""
82
60
  files = os.popen("git ls-files").read().splitlines()
83
- # 过滤掉 .jarvis-codebase 目录下的文件
61
+ # Filter out files in the .jarvis-codebase directory
84
62
  return [f for f in files if not f.startswith(".jarvis-")]
85
63
 
86
64
  def is_text_file(self, file_path: str):
@@ -95,10 +73,11 @@ class CodeBase:
95
73
  model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
96
74
  if self.thread_count > 1:
97
75
  model.set_suppress_output(True)
76
+ else:
77
+ PrettyOutput.print(f"Make description for {file_path} ...", output_type=OutputType.PROGRESS)
98
78
  prompt = f"""Please analyze the following code file and generate a detailed description. The description should include:
99
- 1. Overall file functionality description, no more than 100 characters
100
- 2. One-sentence description (max 50 characters) for each global variable, function, type definition, class, method, and other code elements
101
- 3. 5 potential questions users might ask about this file
79
+ 1. Overall file functionality description
80
+ 2. description for each global variable, function, type definition, class, method, and other code elements
102
81
 
103
82
  Please use concise and professional language, emphasizing technical functionality to facilitate subsequent code retrieval.
104
83
  File path: {file_path}
@@ -109,42 +88,117 @@ Code content:
109
88
  return response
110
89
 
111
90
  def export(self):
112
- """导出当前索引数据到标准输出"""
91
+ """Export the current index data to standard output"""
113
92
  for file_path, data in self.vector_cache.items():
114
93
  print(f"## {file_path}")
115
94
  print(f"- path: {file_path}")
116
95
  print(f"- description: {data['description']}")
117
96
 
118
- def _save_cache(self):
119
- """保存缓存数据"""
97
+ def _get_cache_path(self, file_path: str) -> str:
98
+ """Get cache file path for a source file
99
+
100
+ Args:
101
+ file_path: Source file path
102
+
103
+ Returns:
104
+ str: Cache file path
105
+ """
106
+ # 处理文件路径:
107
+ # 1. 移除开头的 ./ 或 /
108
+ # 2. 将 / 替换为 --
109
+ # 3. 添加 .cache 后缀
110
+ clean_path = file_path.lstrip('./').lstrip('/')
111
+ cache_name = clean_path.replace('/', '--') + '.cache'
112
+ return os.path.join(self.cache_dir, cache_name)
113
+
114
+ def _load_all_cache(self):
115
+ """Load all cache files"""
120
116
  try:
121
- # 创建缓存数据的副本
122
- cache_data = {
123
- "vectors": dict(self.vector_cache), # 创建字典的副本
124
- "file_paths": list(self.file_paths) # 创建列表的副本
125
- }
117
+ # 清空现有缓存和文件路径
118
+ self.vector_cache = {}
119
+ self.file_paths = []
120
+ vectors = []
121
+
122
+ for cache_file in os.listdir(self.cache_dir):
123
+ if not cache_file.endswith('.cache'):
124
+ continue
125
+
126
+ cache_path = os.path.join(self.cache_dir, cache_file)
127
+ try:
128
+ with lzma.open(cache_path, 'rb') as f:
129
+ cache_data = pickle.load(f)
130
+ file_path = cache_data["path"]
131
+ self.vector_cache[file_path] = cache_data
132
+ self.file_paths.append(file_path)
133
+ vectors.append(cache_data["vector"])
134
+ except Exception as e:
135
+ PrettyOutput.print(f"Failed to load cache file {cache_file}: {str(e)}",
136
+ output_type=OutputType.WARNING)
137
+ continue
126
138
 
127
- # 使用 lzma 压缩存储
128
- with lzma.open(self.cache_path, 'wb') as f:
139
+ if vectors:
140
+ # 重建索引
141
+ vectors_array = np.vstack(vectors)
142
+ hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
143
+ hnsw_index.hnsw.efConstruction = 40
144
+ hnsw_index.hnsw.efSearch = 16
145
+ self.index = faiss.IndexIDMap(hnsw_index)
146
+ self.index.add_with_ids(vectors_array, np.array(range(len(vectors)))) # type: ignore
147
+
148
+ PrettyOutput.print(f"Loaded {len(self.vector_cache)} vector cache and rebuilt index",
149
+ output_type=OutputType.INFO)
150
+ else:
151
+ self.index = None
152
+ PrettyOutput.print("No valid cache files found", output_type=OutputType.WARNING)
153
+
154
+ except Exception as e:
155
+ PrettyOutput.print(f"Failed to load cache directory: {str(e)}",
156
+ output_type=OutputType.WARNING)
157
+ self.vector_cache = {}
158
+ self.file_paths = []
159
+ self.index = None
160
+
161
+ def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
162
+ """Cache the vector representation of a file"""
163
+ try:
164
+ with open(file_path, "rb") as f:
165
+ file_md5 = hashlib.md5(f.read()).hexdigest()
166
+ except Exception as e:
167
+ PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
168
+ output_type=OutputType.ERROR)
169
+ file_md5 = ""
170
+
171
+ # 准备缓存数据
172
+ cache_data = {
173
+ "path": file_path, # 保存文件路径
174
+ "md5": file_md5, # 保存文件MD5
175
+ "description": description, # 保存文件描述
176
+ "vector": vector # 保存向量
177
+ }
178
+
179
+ # 更新内存缓存
180
+ self.vector_cache[file_path] = cache_data
181
+
182
+ # 保存到单独的缓存文件
183
+ cache_path = self._get_cache_path(file_path)
184
+ try:
185
+ with lzma.open(cache_path, 'wb') as f:
129
186
  pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
130
- PrettyOutput.print(f"保存了 {len(self.vector_cache)} 个向量缓存",
131
- output_type=OutputType.INFO)
132
187
  except Exception as e:
133
- PrettyOutput.print(f"保存缓存失败: {str(e)}",
188
+ PrettyOutput.print(f"Failed to save cache for {file_path}: {str(e)}",
134
189
  output_type=OutputType.ERROR)
135
- raise # 抛出异常以便上层处理
136
190
 
137
191
  def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
138
- """从缓存获取文件的向量表示"""
192
+ """Get the vector representation of a file from the cache"""
139
193
  if file_path not in self.vector_cache:
140
194
  return None
141
195
 
142
- # 检查文件是否被修改
196
+ # Check if the file has been modified
143
197
  try:
144
198
  with open(file_path, "rb") as f:
145
199
  current_md5 = hashlib.md5(f.read()).hexdigest()
146
200
  except Exception as e:
147
- PrettyOutput.print(f"计算文件MD5失败 {file_path}: {str(e)}",
201
+ PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
148
202
  output_type=OutputType.ERROR)
149
203
  return None
150
204
 
@@ -152,63 +206,45 @@ Code content:
152
206
  if cached_data["md5"] != current_md5:
153
207
  return None
154
208
 
155
- # 检查描述是否变化
209
+ # Check if the description has changed
156
210
  if cached_data["description"] != description:
157
211
  return None
158
212
 
159
213
  return cached_data["vector"]
160
214
 
161
- def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
162
- """缓存文件的向量表示"""
163
- try:
164
- with open(file_path, "rb") as f:
165
- file_md5 = hashlib.md5(f.read()).hexdigest()
166
- except Exception as e:
167
- PrettyOutput.print(f"计算文件MD5失败 {file_path}: {str(e)}",
168
- output_type=OutputType.ERROR)
169
- file_md5 = ""
170
-
171
- # 只更新内存中的缓存
172
- self.vector_cache[file_path] = {
173
- "path": file_path, # 保存文件路径
174
- "md5": file_md5, # 保存文件MD5
175
- "description": description, # 保存文件描述
176
- "vector": vector # 保存向量
177
- }
178
-
179
215
  def get_embedding(self, text: str) -> np.ndarray:
180
- """使用 transformers 模型获取文本的向量表示"""
181
- # 对长文本进行截断
182
- max_length = 512 # 或其他合适的长度
216
+ """Use the transformers model to get the vector representation of text"""
217
+ # Truncate long text
218
+ max_length = 512 # Or other suitable length
183
219
  text = ' '.join(text.split()[:max_length])
184
220
 
185
- # 获取嵌入向量
221
+ # Get the embedding vector
186
222
  embedding = self.embedding_model.encode(text,
187
- normalize_embeddings=True, # L2归一化
223
+ normalize_embeddings=True, # L2 normalization
188
224
  show_progress_bar=False)
189
225
  vector = np.array(embedding, dtype=np.float32)
190
226
  return vector
191
227
 
192
228
  def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
193
- """将文件内容和描述向量化"""
229
+ """Vectorize the file content and description"""
194
230
  try:
195
- # 先尝试从缓存获取
231
+ # Try to get the vector from the cache first
196
232
  cached_vector = self.get_cached_vector(file_path, description)
197
233
  if cached_vector is not None:
198
234
  return cached_vector
199
235
 
200
- # 读取文件内容并组合信息
201
- content = open(file_path, "r", encoding="utf-8").read()[:self.max_context_length] # 限制文件内容长度
236
+ # Read the file content and combine information
237
+ content = open(file_path, "r", encoding="utf-8").read()[:self.max_context_length] # Limit the file content length
202
238
 
203
- # 组合文件信息,包含文件内容
239
+ # Combine file information, including file content
204
240
  combined_text = f"""
205
- {file_path}
206
- {description}
207
- {content}
241
+ File path: {file_path}
242
+ Description: {description}
243
+ Content: {content}
208
244
  """
209
245
  vector = self.get_embedding(combined_text)
210
246
 
211
- # 保存到缓存
247
+ # Save to cache
212
248
  self.cache_vector(file_path, vector, description)
213
249
  return vector
214
250
  except Exception as e:
@@ -217,36 +253,34 @@ Code content:
217
253
  return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
218
254
 
219
255
  def clean_cache(self) -> bool:
220
- """清理过期的缓存记录,返回是否有文件被删除"""
256
+ """Clean expired cache records"""
221
257
  try:
222
258
  files_to_delete = []
223
259
  for file_path in list(self.vector_cache.keys()):
224
- if file_path not in self.git_file_list:
225
- del self.vector_cache[file_path]
260
+ if not os.path.exists(file_path):
226
261
  files_to_delete.append(file_path)
227
-
228
- if files_to_delete:
229
- # 只在有文件被删除时保存缓存
230
- self._save_cache()
231
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
232
- output_type=OutputType.INFO)
233
- return True
234
- return False
262
+ cache_path = self._get_cache_path(file_path)
263
+ try:
264
+ os.remove(cache_path)
265
+ except Exception:
266
+ pass
267
+
268
+ for file_path in files_to_delete:
269
+ del self.vector_cache[file_path]
270
+ if file_path in self.file_paths:
271
+ self.file_paths.remove(file_path)
272
+
273
+ return bool(files_to_delete)
235
274
 
236
275
  except Exception as e:
237
- PrettyOutput.print(f"清理缓存失败: {str(e)}",
238
- output_type=OutputType.ERROR)
239
- # 发生异常时尝试保存当前状态
240
- try:
241
- self._save_cache()
242
- except:
243
- pass
276
+ PrettyOutput.print(f"Failed to clean cache: {str(e)}",
277
+ output_type=OutputType.ERROR)
244
278
  return False
245
279
 
246
280
  def process_file(self, file_path: str):
247
- """处理单个文件"""
281
+ """Process a single file"""
248
282
  try:
249
- # 跳过不存在的文件
283
+ # Skip non-existent files
250
284
  if not os.path.exists(file_path):
251
285
  return None
252
286
 
@@ -257,15 +291,15 @@ Code content:
257
291
 
258
292
  content = open(file_path, "r", encoding="utf-8").read()
259
293
 
260
- # 检查文件是否已经处理过且内容未变
294
+ # Check if the file has already been processed and the content has not changed
261
295
  if file_path in self.vector_cache:
262
296
  if self.vector_cache[file_path].get("md5") == md5:
263
297
  return None
264
298
 
265
- description = self.make_description(file_path, content) # 传入截取后的内容
299
+ description = self.make_description(file_path, content) # Pass the truncated content
266
300
  vector = self.vectorize_file(file_path, description)
267
301
 
268
- # 保存到缓存,使用实际文件路径作为键
302
+ # Save to cache, using the actual file path as the key
269
303
  self.vector_cache[file_path] = {
270
304
  "vector": vector,
271
305
  "description": description,
@@ -275,58 +309,94 @@ Code content:
275
309
  return file_path
276
310
 
277
311
  except Exception as e:
278
- PrettyOutput.print(f"处理文件失败 {file_path}: {str(e)}",
312
+ PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
279
313
  output_type=OutputType.ERROR)
280
314
  return None
281
315
 
282
316
  def build_index(self):
283
- """从向量缓存构建 faiss 索引"""
284
- # 创建底层 HNSW 索引
285
- hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
286
- hnsw_index.hnsw.efConstruction = 40
287
- hnsw_index.hnsw.efSearch = 16
288
-
289
- # IndexIDMap 包装 HNSW 索引
290
- self.index = faiss.IndexIDMap(hnsw_index)
291
-
292
- vectors = []
293
- ids = []
294
- self.file_paths = [] # 重置文件路径列表
295
-
296
- for i, (file_path, data) in enumerate(self.vector_cache.items()):
297
- vectors.append(data["vector"].reshape(1, -1))
298
- ids.append(i)
299
- self.file_paths.append(file_path)
300
-
301
- if vectors:
302
- vectors = np.vstack(vectors)
303
- self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
304
- else:
317
+ """Build a faiss index from the vector cache"""
318
+ try:
319
+ if not self.vector_cache:
320
+ self.index = None
321
+ return
322
+
323
+ # Create the underlying HNSW index
324
+ hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
325
+ hnsw_index.hnsw.efConstruction = 40
326
+ hnsw_index.hnsw.efSearch = 16
327
+
328
+ # Wrap the HNSW index with IndexIDMap
329
+ self.index = faiss.IndexIDMap(hnsw_index)
330
+
331
+ vectors = []
332
+ ids = []
333
+ self.file_paths = [] # Reset the file path list
334
+
335
+ for i, (file_path, data) in enumerate(self.vector_cache.items()):
336
+ if "vector" not in data:
337
+ PrettyOutput.print(f"Invalid cache data for {file_path}: missing vector",
338
+ output_type=OutputType.WARNING)
339
+ continue
340
+
341
+ vector = data["vector"]
342
+ if not isinstance(vector, np.ndarray):
343
+ PrettyOutput.print(f"Invalid vector type for {file_path}: {type(vector)}",
344
+ output_type=OutputType.WARNING)
345
+ continue
346
+
347
+ vectors.append(vector.reshape(1, -1))
348
+ ids.append(i)
349
+ self.file_paths.append(file_path)
350
+
351
+ if vectors:
352
+ vectors = np.vstack(vectors)
353
+ if len(vectors) != len(ids):
354
+ PrettyOutput.print(f"Vector count mismatch: {len(vectors)} vectors vs {len(ids)} ids",
355
+ output_type=OutputType.ERROR)
356
+ self.index = None
357
+ return
358
+
359
+ try:
360
+ self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
361
+ PrettyOutput.print(f"Successfully built index with {len(vectors)} vectors",
362
+ output_type=OutputType.SUCCESS)
363
+ except Exception as e:
364
+ PrettyOutput.print(f"Failed to add vectors to index: {str(e)}",
365
+ output_type=OutputType.ERROR)
366
+ self.index = None
367
+ else:
368
+ PrettyOutput.print("No valid vectors found, index not built",
369
+ output_type=OutputType.WARNING)
370
+ self.index = None
371
+
372
+ except Exception as e:
373
+ PrettyOutput.print(f"Failed to build index: {str(e)}",
374
+ output_type=OutputType.ERROR)
305
375
  self.index = None
306
376
 
307
377
  def gen_vector_db_from_cache(self):
308
- """从缓存生成向量数据库"""
378
+ """Generate a vector database from the cache"""
309
379
  self.build_index()
310
- self._save_cache()
380
+ self._load_all_cache()
311
381
 
312
382
 
313
383
  def generate_codebase(self, force: bool = False):
314
- """生成代码库索引
384
+ """Generate the codebase index
315
385
  Args:
316
- force: 是否强制重建索引,不询问用户
386
+ force: Whether to force rebuild the index, without asking the user
317
387
  """
318
388
  try:
319
- # 更新 git 文件列表
389
+ # Update the git file list
320
390
  self.git_file_list = self.get_git_file_list()
321
391
 
322
- # 检查文件变化
323
- PrettyOutput.print("\n检查文件变化...", output_type=OutputType.INFO)
392
+ # Check file changes
393
+ PrettyOutput.print("\nCheck file changes...", output_type=OutputType.INFO)
324
394
  changes_detected = False
325
395
  new_files = []
326
396
  modified_files = []
327
397
  deleted_files = []
328
398
 
329
- # 检查删除的文件
399
+ # Check deleted files
330
400
  files_to_delete = []
331
401
  for file_path in list(self.vector_cache.keys()):
332
402
  if file_path not in self.git_file_list:
@@ -334,8 +404,8 @@ Code content:
334
404
  files_to_delete.append(file_path)
335
405
  changes_detected = True
336
406
 
337
- # 检查新增和修改的文件
338
- with tqdm(total=len(self.git_file_list), desc="检查文件状态") as pbar:
407
+ # Check new and modified files
408
+ with tqdm(total=len(self.git_file_list), desc="Check file status") as pbar:
339
409
  for file_path in self.git_file_list:
340
410
  if not os.path.exists(file_path) or not self.is_text_file(file_path):
341
411
  pbar.update(1)
@@ -351,60 +421,60 @@ Code content:
351
421
  modified_files.append(file_path)
352
422
  changes_detected = True
353
423
  except Exception as e:
354
- PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
424
+ PrettyOutput.print(f"Failed to check file {file_path}: {str(e)}",
355
425
  output_type=OutputType.ERROR)
356
426
  pbar.update(1)
357
427
 
358
- # 如果检测到变化,显示变化并询问用户
428
+ # If changes are detected, display changes and ask the user
359
429
  if changes_detected:
360
- PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
430
+ PrettyOutput.print("\nDetected the following changes:", output_type=OutputType.WARNING)
361
431
  if new_files:
362
- PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
432
+ PrettyOutput.print("\nNew files:", output_type=OutputType.INFO)
363
433
  for f in new_files:
364
434
  PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
365
435
  if modified_files:
366
- PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
436
+ PrettyOutput.print("\nModified files:", output_type=OutputType.INFO)
367
437
  for f in modified_files:
368
438
  PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
369
439
  if deleted_files:
370
- PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
440
+ PrettyOutput.print("\nDeleted files:", output_type=OutputType.INFO)
371
441
  for f in deleted_files:
372
442
  PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
373
443
 
374
- # 如果forceTrue,直接继续
444
+ # If force is True, continue directly
375
445
  if not force:
376
- # 询问用户是否继续
446
+ # Ask the user whether to continue
377
447
  while True:
378
- response = input("\n是否重建索引?[y/N] ").lower().strip()
448
+ response = get_single_line_input("\nRebuild the index? [y/N]").lower().strip()
379
449
  if response in ['y', 'yes']:
380
450
  break
381
451
  elif response in ['', 'n', 'no']:
382
- PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
452
+ PrettyOutput.print("Cancel rebuilding the index", output_type=OutputType.INFO)
383
453
  return
384
454
  else:
385
- PrettyOutput.print("请输入 y n", output_type=OutputType.WARNING)
455
+ PrettyOutput.print("Please input y or n", output_type=OutputType.WARNING)
386
456
 
387
- # 清理已删除的文件
457
+ # Clean deleted files
388
458
  for file_path in files_to_delete:
389
459
  del self.vector_cache[file_path]
390
460
  if files_to_delete:
391
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
461
+ PrettyOutput.print(f"Cleaned the cache of {len(files_to_delete)} files",
392
462
  output_type=OutputType.INFO)
393
463
 
394
- # 处理新文件和修改的文件
464
+ # Process new and modified files
395
465
  files_to_process = new_files + modified_files
396
466
  processed_files = []
397
467
 
398
- with tqdm(total=len(files_to_process), desc="处理文件") as pbar:
399
- # 使用线程池处理文件
468
+ with tqdm(total=len(files_to_process), desc="Processing files") as pbar:
469
+ # Use a thread pool to process files
400
470
  with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
401
- # 提交所有任务
471
+ # Submit all tasks
402
472
  future_to_file = {
403
473
  executor.submit(self.process_file, file): file
404
474
  for file in files_to_process
405
475
  }
406
476
 
407
- # 处理完成的任务
477
+ # Process completed tasks
408
478
  for future in concurrent.futures.as_completed(future_to_file):
409
479
  file = future_to_file[future]
410
480
  try:
@@ -412,37 +482,37 @@ Code content:
412
482
  if result:
413
483
  processed_files.append(result)
414
484
  except Exception as e:
415
- PrettyOutput.print(f"处理文件失败 {file}: {str(e)}",
485
+ PrettyOutput.print(f"Failed to process file {file}: {str(e)}",
416
486
  output_type=OutputType.ERROR)
417
487
  pbar.update(1)
418
488
 
419
489
  if processed_files:
420
- PrettyOutput.print("\n重新生成向量数据库...", output_type=OutputType.INFO)
490
+ PrettyOutput.print("\nRebuilding the vector database...", output_type=OutputType.INFO)
421
491
  self.gen_vector_db_from_cache()
422
- PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引",
492
+ PrettyOutput.print(f"Successfully generated the index for {len(processed_files)} files",
423
493
  output_type=OutputType.SUCCESS)
424
494
  else:
425
- PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
495
+ PrettyOutput.print("No file changes detected, no need to rebuild the index", output_type=OutputType.INFO)
426
496
 
427
497
  except Exception as e:
428
- # 发生异常时尝试保存缓存
498
+ # Try to save the cache when an exception occurs
429
499
  try:
430
- self._save_cache()
500
+ self._load_all_cache()
431
501
  except Exception as save_error:
432
- PrettyOutput.print(f"保存缓存失败: {str(save_error)}",
502
+ PrettyOutput.print(f"Failed to save cache: {str(save_error)}",
433
503
  output_type=OutputType.ERROR)
434
- raise e # 重新抛出原始异常
504
+ raise e # Re-raise the original exception
435
505
 
436
506
 
437
507
  def _text_search_score(self, content: str, keywords: List[str]) -> float:
438
- """计算文本内容与关键词的匹配分数
508
+ """Calculate the matching score between the text content and the keywords
439
509
 
440
510
  Args:
441
- content: 文本内容
442
- keywords: 关键词列表
511
+ content: Text content
512
+ keywords: List of keywords
443
513
 
444
514
  Returns:
445
- float: 匹配分数 (0-1)
515
+ float: Matching score (0-1)
446
516
  """
447
517
  if not keywords:
448
518
  return 0.0
@@ -455,89 +525,128 @@ Code content:
455
525
  if keyword in content:
456
526
  matched_keywords.add(keyword)
457
527
 
458
- # 计算匹配分数
528
+ # Calculate the matching score
459
529
  score = len(matched_keywords) / len(keywords)
460
530
  return score
461
531
 
462
- def rerank_results(self, query: str, initial_results: List[Tuple[str, float, str]]) -> List[Tuple[str, float]]:
463
- """使用多种策略对搜索结果重新排序"""
532
+ def pick_results(self, query: str, initial_results: List[str]) -> List[str]:
533
+ """Use a large model to pick the search results
534
+
535
+ Args:
536
+ query: Search query
537
+ initial_results: Initial results list of file paths
538
+
539
+ Returns:
540
+ List[str]: The picked results list, each item is a file path
541
+ """
464
542
  if not initial_results:
465
543
  return []
466
544
 
467
545
  try:
468
- import torch
546
+ PrettyOutput.print(f"Picking results for query: {query}", output_type=OutputType.INFO)
469
547
 
470
- # 加载模型和分词器
471
- model, tokenizer = load_rerank_model()
548
+ # Maximum content length per batch
549
+ max_batch_length = self.max_context_length - 1000 # Reserve space for prompt
550
+ max_file_length = max_batch_length // 3 # Limit individual file size
472
551
 
473
- # 准备数据
474
- pairs = []
552
+ # Process files in batches
553
+ all_selected_files = set()
554
+ current_batch = []
555
+ current_length = 0
475
556
 
476
- for path, _, desc in initial_results:
557
+ for path in initial_results:
477
558
  try:
478
- content = open(path, "r", encoding="utf-8").read()[:512] # 限制内容长度
559
+ content = open(path, "r", encoding="utf-8").read()
560
+ # Truncate large files
561
+ if len(content) > max_file_length:
562
+ PrettyOutput.print(f"Truncating large file: {path}", OutputType.WARNING)
563
+ content = content[:max_file_length] + "\n... (content truncated)"
479
564
 
480
- # 组合文件信息
481
- doc_content = f"File path: {path}\nDescription: {desc}\nContent: {content}"
482
- pairs.append([query, doc_content])
483
- except Exception as e:
484
- PrettyOutput.print(f"读取文件失败 {path}: {str(e)}",
485
- output_type=OutputType.ERROR)
486
- doc_content = f"File path: {path}\nDescription: {desc}"
487
- pairs.append([query, doc_content])
488
-
489
- # 使用更大的batch size提高处理速度
490
- batch_size = 16 # 根据GPU显存调整
491
- batch_scores = []
492
-
493
- with torch.no_grad():
494
- for i in range(0, len(pairs), batch_size):
495
- batch_pairs = pairs[i:i + batch_size]
496
- encoded = tokenizer(
497
- batch_pairs,
498
- padding=True,
499
- truncation=True,
500
- max_length=512,
501
- return_tensors='pt'
502
- )
565
+ file_info = f"File: {path}\nContent: {content}\n\n"
566
+ file_length = len(file_info)
503
567
 
504
- if torch.cuda.is_available():
505
- encoded = {k: v.cuda() for k, v in encoded.items()}
568
+ # If adding this file would exceed batch limit
569
+ if current_length + file_length > max_batch_length:
570
+ # Process current batch
571
+ if current_batch:
572
+ selected = self._process_batch(query, current_batch)
573
+ all_selected_files.update(selected)
574
+ # Start new batch
575
+ current_batch = [file_info]
576
+ current_length = file_length
577
+ else:
578
+ current_batch.append(file_info)
579
+ current_length += file_length
506
580
 
507
- outputs = model(**encoded)
508
- batch_scores.extend(outputs.logits.squeeze(-1).cpu().numpy())
509
-
510
- # 归一化分数到 0-1 范围
511
- if batch_scores:
512
- min_score = min(batch_scores)
513
- max_score = max(batch_scores)
514
- if max_score > min_score:
515
- batch_scores = [(s - min_score) / (max_score - min_score) for s in batch_scores]
516
-
517
- # 将重排序分数与原始分数结合
518
- scored_results = []
519
- for (path,_, desc), rerank_score in zip(initial_results, batch_scores):
520
- if rerank_score >= 0.5: # 只保留相关度较高的结果
521
- scored_results.append((path, rerank_score))
522
-
523
- # 按综合分数降序排序
524
- scored_results.sort(key=lambda x: x[1], reverse=True)
581
+ except Exception as e:
582
+ PrettyOutput.print(f"Failed to read file {path}: {str(e)}", OutputType.ERROR)
583
+ continue
525
584
 
526
- return scored_results
585
+ # Process final batch
586
+ if current_batch:
587
+ selected = self._process_batch(query, current_batch)
588
+ all_selected_files.update(selected)
527
589
 
590
+ # Convert set to list and maintain original order
591
+ final_results = [path for path in initial_results if path in all_selected_files]
592
+ return final_results
593
+
528
594
  except Exception as e:
529
- PrettyOutput.print(f"重排序失败: {str(e)}",
530
- output_type=OutputType.ERROR)
531
- return [(path, score) for path, score, _ in initial_results] # 发生错误时返回原始结果
595
+ PrettyOutput.print(f"Failed to pick: {str(e)}", OutputType.ERROR)
596
+ return initial_results
597
+
598
+ def _process_batch(self, query: str, files_info: List[str]) -> List[str]:
599
+ """Process a batch of files
600
+
601
+ Args:
602
+ query: Search query
603
+ files_info: List of file information strings
604
+
605
+ Returns:
606
+ List[str]: Selected file paths from this batch
607
+ """
608
+ prompt = f"""Please analyze the following code files and determine which files are most relevant to the given query. Consider file paths and code content to make your judgment.
609
+
610
+ Query: {query}
611
+
612
+ Available files:
613
+ {''.join(files_info)}
614
+
615
+ Please output a YAML list of relevant file paths, ordered by relevance (most relevant first). Only include files that are truly relevant to the query.
616
+ Output format:
617
+ <FILES>
618
+ - path/to/file1.py
619
+ - path/to/file2.py
620
+ </FILES>
621
+
622
+ Note: Only include files that have a strong connection to the query."""
623
+
624
+ # Use a large model to evaluate
625
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
626
+ response = model.chat_until_success(prompt)
627
+
628
+ # Parse the response
629
+ import yaml
630
+ files_match = re.search(r'<FILES>\n(.*?)</FILES>', response, re.DOTALL)
631
+ if not files_match:
632
+ return []
633
+
634
+ # Extract the file list
635
+ try:
636
+ selected_files = yaml.safe_load(files_match.group(1))
637
+ return selected_files if selected_files else []
638
+ except Exception as e:
639
+ PrettyOutput.print(f"Failed to parse response: {str(e)}", OutputType.ERROR)
640
+ return []
532
641
 
533
642
  def _generate_query_variants(self, query: str) -> List[str]:
534
- """生成查询的不同表述变体
643
+ """Generate different expressions of the query
535
644
 
536
645
  Args:
537
- query: 原始查询
646
+ query: Original query
538
647
 
539
648
  Returns:
540
- List[str]: 查询变体列表
649
+ List[str]: The query variants list
541
650
  """
542
651
  model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
543
652
  prompt = f"""Please generate 3 different expressions based on the following query, each expression should fully convey the meaning of the original query. These expressions will be used for code search, maintain professionalism and accuracy.
@@ -546,18 +655,18 @@ Original query: {query}
546
655
  Please output 3 expressions directly, separated by two line breaks, without numbering or other markers.
547
656
  """
548
657
  variants = model.chat_until_success(prompt).strip().split('\n\n')
549
- variants.append(query) # 添加原始查询
658
+ variants.append(query) # Add the original query
550
659
  return variants
551
660
 
552
661
  def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
553
- """使用向量搜索查找相关文件
662
+ """Use vector search to find related files
554
663
 
555
664
  Args:
556
- query_variants: 查询变体列表
557
- top_k: 返回结果数量
665
+ query_variants: The query variants list
666
+ top_k: The number of results to return
558
667
 
559
668
  Returns:
560
- Dict[str, Tuple[str, float, str]]: 文件路径到(路径,分数,描述)的映射
669
+ Dict[str, Tuple[str, float, str]]: The mapping from file path to (file path, score, description)
561
670
  """
562
671
  results = {}
563
672
  for query in query_variants:
@@ -571,75 +680,78 @@ Please output 3 expressions directly, separated by two line breaks, without numb
571
680
  continue
572
681
 
573
682
  similarity = 1.0 / (1.0 + float(distance))
574
- if similarity >= 0.5:
575
- file_path = self.file_paths[i]
576
- # 使用最高的相似度分数
577
- if file_path not in results or similarity > results[file_path][1]:
683
+ file_path = self.file_paths[i]
684
+ # Use the highest similarity score
685
+ if file_path not in results:
686
+ if similarity > 0.5:
578
687
  data = self.vector_cache[file_path]
579
688
  results[file_path] = (file_path, similarity, data["description"])
580
689
 
581
690
  return results
582
691
 
583
692
 
584
- def search_similar(self, query: str, top_k: int = 30) -> List[Tuple[str, float]]:
585
- """搜索关联文件"""
693
+ def search_similar(self, query: str, top_k: int = 30) -> List[str]:
694
+ """Search related files"""
586
695
  try:
587
696
  if self.index is None:
588
697
  return []
589
- # 生成查询变体
698
+ # Generate the query variants
590
699
  query_variants = self._generate_query_variants(query)
591
700
 
592
- # 进行向量搜索
701
+ # Perform vector search
593
702
  vector_results = self._vector_search(query_variants, top_k)
594
703
 
595
704
  results = list(vector_results.values())
596
705
  results.sort(key=lambda x: x[1], reverse=True)
597
706
 
598
- # 取前 top_k 个结果进行重排序
707
+ # Take the top top_k results for reordering
599
708
  initial_results = results[:top_k]
600
709
 
601
- # 如果没有找到结果,直接返回
710
+ # If no results are found, return directly
602
711
  if not initial_results:
603
712
  return []
604
713
 
605
- # 过滤低分结果
714
+ # Filter low-scoring results
606
715
  initial_results = [(path, score, desc) for path, score, desc in initial_results if score >= 0.5]
716
+
717
+ for path, score, desc in initial_results:
718
+ PrettyOutput.print(f"File: {path} Similarity: {score:.3f}", output_type=OutputType.INFO)
607
719
 
608
- # 对初步结果进行重排序
609
- return self.rerank_results(query, initial_results)
720
+ # Reorder the preliminary results
721
+ return self.pick_results(query, [path for path, _, _ in initial_results])
610
722
 
611
723
  except Exception as e:
612
- PrettyOutput.print(f"搜索失败: {str(e)}", output_type=OutputType.ERROR)
724
+ PrettyOutput.print(f"Failed to search: {str(e)}", output_type=OutputType.ERROR)
613
725
  return []
614
726
 
615
727
  def ask_codebase(self, query: str, top_k: int=20) -> str:
616
- """查询代码库"""
728
+ """Query the codebase"""
617
729
  results = self.search_similar(query, top_k)
618
730
  if not results:
619
- PrettyOutput.print("没有找到关联的文件", output_type=OutputType.WARNING)
731
+ PrettyOutput.print("No related files found", output_type=OutputType.WARNING)
620
732
  return ""
621
733
 
622
- PrettyOutput.print(f"找到的关联文件: ", output_type=OutputType.SUCCESS)
623
- for path, score in results:
624
- PrettyOutput.print(f"文件: {path} 关联度: {score:.3f}",
734
+ PrettyOutput.print(f"Found related files: ", output_type=OutputType.SUCCESS)
735
+ for path in results:
736
+ PrettyOutput.print(f"File: {path}",
625
737
  output_type=OutputType.INFO)
626
738
 
627
- prompt = f"""你是一个代码专家,请根据以下文件信息回答用户的问题:
739
+ prompt = f"""You are a code expert, please answer the user's question based on the following file information:
628
740
  """
629
- for path, _ in results:
741
+ for path in results:
630
742
  try:
631
743
  if len(prompt) > self.max_context_length:
632
- PrettyOutput.print(f"避免上下文超限,丢弃低相关度文件:{path}", OutputType.WARNING)
744
+ PrettyOutput.print(f"Avoid context overflow, discard low-related file: {path}", OutputType.WARNING)
633
745
  continue
634
746
  content = open(path, "r", encoding="utf-8").read()
635
747
  prompt += f"""
636
- File path: {path}prompt
748
+ File path: {path}
637
749
  File content:
638
750
  {content}
639
751
  ========================================
640
752
  """
641
753
  except Exception as e:
642
- PrettyOutput.print(f"读取文件失败 {path}: {str(e)}",
754
+ PrettyOutput.print(f"Failed to read file {path}: {str(e)}",
643
755
  output_type=OutputType.ERROR)
644
756
  continue
645
757
 
@@ -653,29 +765,46 @@ Please answer the user's question in Chinese using professional language. If the
653
765
  return response
654
766
 
655
767
  def is_index_generated(self) -> bool:
656
- """检查索引是否已经生成"""
657
- # 检查缓存文件是否存在
658
- if not os.path.exists(self.cache_path):
659
- return False
660
-
661
- # 检查缓存是否有效
768
+ """Check if the index has been generated"""
662
769
  try:
663
- with lzma.open(self.cache_path, 'rb') as f:
664
- cache_data = pickle.load(f)
665
- if not cache_data.get("vectors") or not cache_data.get("file_paths"):
770
+ # 1. 检查基本条件
771
+ if not self.vector_cache or not self.file_paths:
772
+ return False
773
+
774
+ if not hasattr(self, 'index') or self.index is None:
775
+ return False
776
+
777
+ # 2. 检查索引是否可用
778
+ # 创建测试向量
779
+ test_vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
780
+ try:
781
+ self.index.search(test_vector, 1) # type: ignore
782
+ except Exception:
783
+ return False
784
+
785
+ # 3. 验证向量缓存和文件路径的一致性
786
+ if len(self.vector_cache) != len(self.file_paths):
787
+ return False
788
+
789
+ # 4. 验证所有缓存文件
790
+ for file_path in self.file_paths:
791
+ if file_path not in self.vector_cache:
666
792
  return False
667
- except Exception:
668
- return False
669
-
670
- # 检查索引是否已构建
671
- if not hasattr(self, 'index') or self.index is None:
672
- return False
673
-
674
- # 检查向量缓存和文件路径列表是否非空
675
- if not self.vector_cache or not self.file_paths:
793
+
794
+ cache_path = self._get_cache_path(file_path)
795
+ if not os.path.exists(cache_path):
796
+ return False
797
+
798
+ cache_data = self.vector_cache[file_path]
799
+ if not isinstance(cache_data.get("vector"), np.ndarray):
800
+ return False
801
+
802
+ return True
803
+
804
+ except Exception as e:
805
+ PrettyOutput.print(f"Error checking index status: {str(e)}",
806
+ output_type=OutputType.ERROR)
676
807
  return False
677
-
678
- return True
679
808
 
680
809
 
681
810
 
@@ -729,10 +858,9 @@ def main():
729
858
  return
730
859
 
731
860
  PrettyOutput.print("\nSearch Results:", output_type=OutputType.INFO)
732
- for path, score in results:
861
+ for path in results:
733
862
  PrettyOutput.print("\n" + "="*50, output_type=OutputType.INFO)
734
863
  PrettyOutput.print(f"File: {path}", output_type=OutputType.INFO)
735
- PrettyOutput.print(f"Similarity: {score:.3f}", output_type=OutputType.INFO)
736
864
 
737
865
  elif args.command == 'ask':
738
866
  response = codebase.ask_codebase(args.question, args.top_k)