jarvis-ai-assistant 0.1.102__py3-none-any.whl → 0.1.104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (55) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/agent.py +138 -117
  3. jarvis/jarvis_code_agent/code_agent.py +234 -0
  4. jarvis/{jarvis_coder → jarvis_code_agent}/file_select.py +19 -22
  5. jarvis/jarvis_code_agent/patch.py +120 -0
  6. jarvis/jarvis_code_agent/relevant_files.py +97 -0
  7. jarvis/jarvis_codebase/main.py +871 -0
  8. jarvis/jarvis_platform/main.py +5 -3
  9. jarvis/jarvis_rag/main.py +818 -0
  10. jarvis/jarvis_smart_shell/main.py +2 -2
  11. jarvis/models/ai8.py +3 -1
  12. jarvis/models/kimi.py +36 -30
  13. jarvis/models/ollama.py +17 -11
  14. jarvis/models/openai.py +15 -12
  15. jarvis/models/oyi.py +24 -7
  16. jarvis/models/registry.py +1 -25
  17. jarvis/tools/__init__.py +0 -6
  18. jarvis/tools/ask_codebase.py +96 -0
  19. jarvis/tools/ask_user.py +1 -9
  20. jarvis/tools/chdir.py +2 -37
  21. jarvis/tools/code_review.py +210 -0
  22. jarvis/tools/create_code_test_agent.py +115 -0
  23. jarvis/tools/create_ctags_agent.py +164 -0
  24. jarvis/tools/create_sub_agent.py +2 -2
  25. jarvis/tools/execute_shell.py +2 -2
  26. jarvis/tools/file_operation.py +2 -2
  27. jarvis/tools/find_in_codebase.py +78 -0
  28. jarvis/tools/git_commiter.py +68 -0
  29. jarvis/tools/methodology.py +3 -3
  30. jarvis/tools/rag.py +141 -0
  31. jarvis/tools/read_code.py +116 -0
  32. jarvis/tools/read_webpage.py +1 -1
  33. jarvis/tools/registry.py +47 -31
  34. jarvis/tools/search.py +8 -6
  35. jarvis/tools/select_code_files.py +4 -4
  36. jarvis/utils.py +375 -85
  37. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/METADATA +107 -32
  38. jarvis_ai_assistant-0.1.104.dist-info/RECORD +50 -0
  39. jarvis_ai_assistant-0.1.104.dist-info/entry_points.txt +11 -0
  40. jarvis/jarvis_code_agent/main.py +0 -200
  41. jarvis/jarvis_coder/git_utils.py +0 -123
  42. jarvis/jarvis_coder/patch_handler.py +0 -340
  43. jarvis/jarvis_github/main.py +0 -232
  44. jarvis/tools/create_code_sub_agent.py +0 -56
  45. jarvis/tools/execute_code_modification.py +0 -70
  46. jarvis/tools/find_files.py +0 -119
  47. jarvis/tools/generate_tool.py +0 -174
  48. jarvis/tools/thinker.py +0 -151
  49. jarvis_ai_assistant-0.1.102.dist-info/RECORD +0 -46
  50. jarvis_ai_assistant-0.1.102.dist-info/entry_points.txt +0 -6
  51. /jarvis/{jarvis_coder → jarvis_codebase}/__init__.py +0 -0
  52. /jarvis/{jarvis_github → jarvis_rag}/__init__.py +0 -0
  53. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/LICENSE +0 -0
  54. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/WHEEL +0 -0
  55. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,871 @@
1
+ import hashlib
2
+ import os
3
+ import numpy as np
4
+ import faiss
5
+ from typing import List, Tuple, Optional, Dict
6
+
7
+ import yaml
8
+ from jarvis.models.registry import PlatformRegistry
9
+ import concurrent.futures
10
+ from threading import Lock
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_single_line_input, get_thread_count, load_embedding_model, load_rerank_model, user_confirm
13
+ from jarvis.utils import init_env
14
+ import argparse
15
+ import pickle
16
+ import lzma # 添加 lzma 导入
17
+ from tqdm import tqdm
18
+ import re
19
+
20
+ class CodeBase:
21
+ def __init__(self, root_dir: str):
22
+ init_env()
23
+ self.root_dir = root_dir
24
+ os.chdir(self.root_dir)
25
+ self.thread_count = get_thread_count()
26
+ self.max_context_length = get_max_context_length()
27
+ self.index = None
28
+
29
+ # 初始化数据目录
30
+ self.data_dir = os.path.join(self.root_dir, ".jarvis/codebase")
31
+ self.cache_dir = os.path.join(self.data_dir, "cache")
32
+ if not os.path.exists(self.cache_dir):
33
+ os.makedirs(self.cache_dir)
34
+
35
+ # 初始化嵌入模型
36
+ try:
37
+ self.embedding_model = load_embedding_model()
38
+ test_text = """This is a test text"""
39
+ self.embedding_model.encode([test_text],
40
+ convert_to_tensor=True,
41
+ normalize_embeddings=True)
42
+ PrettyOutput.print("Model loaded successfully", output_type=OutputType.SUCCESS)
43
+ except Exception as e:
44
+ PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
45
+ raise
46
+
47
+ self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
48
+ self.git_file_list = self.get_git_file_list()
49
+ self.platform_registry = PlatformRegistry.get_global_platform_registry()
50
+
51
+ # 初始化缓存和索引
52
+ self.vector_cache = {}
53
+ self.file_paths = []
54
+
55
+ # 加载所有缓存文件
56
+ self._load_all_cache()
57
+
58
+ def get_git_file_list(self):
59
+ """Get the list of files in the git repository, excluding the .jarvis-codebase directory"""
60
+ files = os.popen("git ls-files").read().splitlines()
61
+ # Filter out files in the .jarvis-codebase directory
62
+ return [f for f in files if not f.startswith(".jarvis")]
63
+
64
+ def is_text_file(self, file_path: str):
65
+ with open(file_path, "r", encoding="utf-8") as f:
66
+ try:
67
+ f.read()
68
+ return True
69
+ except UnicodeDecodeError:
70
+ return False
71
+
72
+ def make_description(self, file_path: str, content: str) -> str:
73
+ model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
74
+ if self.thread_count > 1:
75
+ model.set_suppress_output(True)
76
+ else:
77
+ PrettyOutput.print(f"Make description for {file_path} ...", output_type=OutputType.PROGRESS)
78
+ prompt = f"""Please analyze the following code file and generate a detailed description. The description should include:
79
+ 1. Overall file functionality description
80
+ 2. description for each global variable, function, type definition, class, method, and other code elements
81
+
82
+ Please use concise and professional language, emphasizing technical functionality to facilitate subsequent code retrieval.
83
+ File path: {file_path}
84
+ Code content:
85
+ {content}
86
+ """
87
+ response = model.chat_until_success(prompt)
88
+ return response
89
+
90
+ def export(self):
91
+ """Export the current index data to standard output"""
92
+ for file_path, data in self.vector_cache.items():
93
+ print(f"## {file_path}")
94
+ print(f"- path: {file_path}")
95
+ print(f"- description: {data['description']}")
96
+
97
+ def _get_cache_path(self, file_path: str) -> str:
98
+ """Get cache file path for a source file
99
+
100
+ Args:
101
+ file_path: Source file path
102
+
103
+ Returns:
104
+ str: Cache file path
105
+ """
106
+ # 处理文件路径:
107
+ # 1. 移除开头的 ./ 或 /
108
+ # 2. 将 / 替换为 --
109
+ # 3. 添加 .cache 后缀
110
+ clean_path = file_path.lstrip('./').lstrip('/')
111
+ cache_name = clean_path.replace('/', '--') + '.cache'
112
+ return os.path.join(self.cache_dir, cache_name)
113
+
114
+ def _load_all_cache(self):
115
+ """Load all cache files"""
116
+ try:
117
+ # 清空现有缓存和文件路径
118
+ self.vector_cache = {}
119
+ self.file_paths = []
120
+ vectors = []
121
+
122
+ for cache_file in os.listdir(self.cache_dir):
123
+ if not cache_file.endswith('.cache'):
124
+ continue
125
+
126
+ cache_path = os.path.join(self.cache_dir, cache_file)
127
+ try:
128
+ with lzma.open(cache_path, 'rb') as f:
129
+ cache_data = pickle.load(f)
130
+ file_path = cache_data["path"]
131
+ self.vector_cache[file_path] = cache_data
132
+ self.file_paths.append(file_path)
133
+ vectors.append(cache_data["vector"])
134
+ except Exception as e:
135
+ PrettyOutput.print(f"Failed to load cache file {cache_file}: {str(e)}",
136
+ output_type=OutputType.WARNING)
137
+ continue
138
+
139
+ if vectors:
140
+ # 重建索引
141
+ vectors_array = np.vstack(vectors)
142
+ hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
143
+ hnsw_index.hnsw.efConstruction = 40
144
+ hnsw_index.hnsw.efSearch = 16
145
+ self.index = faiss.IndexIDMap(hnsw_index)
146
+ self.index.add_with_ids(vectors_array, np.array(range(len(vectors)))) # type: ignore
147
+
148
+ PrettyOutput.print(f"Loaded {len(self.vector_cache)} vector cache and rebuilt index",
149
+ output_type=OutputType.INFO)
150
+ else:
151
+ self.index = None
152
+ PrettyOutput.print("No valid cache files found", output_type=OutputType.WARNING)
153
+
154
+ except Exception as e:
155
+ PrettyOutput.print(f"Failed to load cache directory: {str(e)}",
156
+ output_type=OutputType.WARNING)
157
+ self.vector_cache = {}
158
+ self.file_paths = []
159
+ self.index = None
160
+
161
+ def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
162
+ """Cache the vector representation of a file"""
163
+ try:
164
+ with open(file_path, "rb") as f:
165
+ file_md5 = hashlib.md5(f.read()).hexdigest()
166
+ except Exception as e:
167
+ PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
168
+ output_type=OutputType.ERROR)
169
+ file_md5 = ""
170
+
171
+ # 准备缓存数据
172
+ cache_data = {
173
+ "path": file_path, # 保存文件路径
174
+ "md5": file_md5, # 保存文件MD5
175
+ "description": description, # 保存文件描述
176
+ "vector": vector # 保存向量
177
+ }
178
+
179
+ # 更新内存缓存
180
+ self.vector_cache[file_path] = cache_data
181
+
182
+ # 保存到单独的缓存文件
183
+ cache_path = self._get_cache_path(file_path)
184
+ try:
185
+ with lzma.open(cache_path, 'wb') as f:
186
+ pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
187
+ except Exception as e:
188
+ PrettyOutput.print(f"Failed to save cache for {file_path}: {str(e)}",
189
+ output_type=OutputType.ERROR)
190
+
191
+ def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
192
+ """Get the vector representation of a file from the cache"""
193
+ if file_path not in self.vector_cache:
194
+ return None
195
+
196
+ # Check if the file has been modified
197
+ try:
198
+ with open(file_path, "rb") as f:
199
+ current_md5 = hashlib.md5(f.read()).hexdigest()
200
+ except Exception as e:
201
+ PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
202
+ output_type=OutputType.ERROR)
203
+ return None
204
+
205
+ cached_data = self.vector_cache[file_path]
206
+ if cached_data["md5"] != current_md5:
207
+ return None
208
+
209
+ # Check if the description has changed
210
+ if cached_data["description"] != description:
211
+ return None
212
+
213
+ return cached_data["vector"]
214
+
215
+ def get_embedding(self, text: str) -> np.ndarray:
216
+ """Use the transformers model to get the vector representation of text"""
217
+ # Truncate long text
218
+ max_length = 512 # Or other suitable length
219
+ text = ' '.join(text.split()[:max_length])
220
+
221
+ # Get the embedding vector
222
+ embedding = self.embedding_model.encode(text,
223
+ normalize_embeddings=True, # L2 normalization
224
+ show_progress_bar=False)
225
+ vector = np.array(embedding, dtype=np.float32)
226
+ return vector
227
+
228
+ def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
229
+ """Vectorize the file content and description"""
230
+ try:
231
+ # Try to get the vector from the cache first
232
+ cached_vector = self.get_cached_vector(file_path, description)
233
+ if cached_vector is not None:
234
+ return cached_vector
235
+
236
+ # Read the file content and combine information
237
+ content = open(file_path, "r", encoding="utf-8").read()[:self.max_context_length] # Limit the file content length
238
+
239
+ # Combine file information, including file content
240
+ combined_text = f"""
241
+ File path: {file_path}
242
+ Description: {description}
243
+ Content: {content}
244
+ """
245
+ vector = self.get_embedding(combined_text)
246
+
247
+ # Save to cache
248
+ self.cache_vector(file_path, vector, description)
249
+ return vector
250
+ except Exception as e:
251
+ PrettyOutput.print(f"Error vectorizing file {file_path}: {str(e)}",
252
+ output_type=OutputType.ERROR)
253
+ return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
254
+
255
+ def clean_cache(self) -> bool:
256
+ """Clean expired cache records"""
257
+ try:
258
+ files_to_delete = []
259
+ for file_path in list(self.vector_cache.keys()):
260
+ if not os.path.exists(file_path):
261
+ files_to_delete.append(file_path)
262
+ cache_path = self._get_cache_path(file_path)
263
+ try:
264
+ os.remove(cache_path)
265
+ except Exception:
266
+ pass
267
+
268
+ for file_path in files_to_delete:
269
+ del self.vector_cache[file_path]
270
+ if file_path in self.file_paths:
271
+ self.file_paths.remove(file_path)
272
+
273
+ return bool(files_to_delete)
274
+
275
+ except Exception as e:
276
+ PrettyOutput.print(f"Failed to clean cache: {str(e)}",
277
+ output_type=OutputType.ERROR)
278
+ return False
279
+
280
+ def process_file(self, file_path: str):
281
+ """Process a single file"""
282
+ try:
283
+ # Skip non-existent files
284
+ if not os.path.exists(file_path):
285
+ return None
286
+
287
+ if not self.is_text_file(file_path):
288
+ return None
289
+
290
+ md5 = get_file_md5(file_path)
291
+
292
+ content = open(file_path, "r", encoding="utf-8").read()
293
+
294
+ # Check if the file has already been processed and the content has not changed
295
+ if file_path in self.vector_cache:
296
+ if self.vector_cache[file_path].get("md5") == md5:
297
+ return None
298
+
299
+ description = self.make_description(file_path, content) # Pass the truncated content
300
+ vector = self.vectorize_file(file_path, description)
301
+
302
+ # Save to cache, using the actual file path as the key
303
+ self.vector_cache[file_path] = {
304
+ "vector": vector,
305
+ "description": description,
306
+ "md5": md5
307
+ }
308
+
309
+ return file_path
310
+
311
+ except Exception as e:
312
+ PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
313
+ output_type=OutputType.ERROR)
314
+ return None
315
+
316
+ def build_index(self):
317
+ """Build a faiss index from the vector cache"""
318
+ try:
319
+ if not self.vector_cache:
320
+ self.index = None
321
+ return
322
+
323
+ # Create the underlying HNSW index
324
+ hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
325
+ hnsw_index.hnsw.efConstruction = 40
326
+ hnsw_index.hnsw.efSearch = 16
327
+
328
+ # Wrap the HNSW index with IndexIDMap
329
+ self.index = faiss.IndexIDMap(hnsw_index)
330
+
331
+ vectors = []
332
+ ids = []
333
+ self.file_paths = [] # Reset the file path list
334
+
335
+ for i, (file_path, data) in enumerate(self.vector_cache.items()):
336
+ if "vector" not in data:
337
+ PrettyOutput.print(f"Invalid cache data for {file_path}: missing vector",
338
+ output_type=OutputType.WARNING)
339
+ continue
340
+
341
+ vector = data["vector"]
342
+ if not isinstance(vector, np.ndarray):
343
+ PrettyOutput.print(f"Invalid vector type for {file_path}: {type(vector)}",
344
+ output_type=OutputType.WARNING)
345
+ continue
346
+
347
+ vectors.append(vector.reshape(1, -1))
348
+ ids.append(i)
349
+ self.file_paths.append(file_path)
350
+
351
+ if vectors:
352
+ vectors = np.vstack(vectors)
353
+ if len(vectors) != len(ids):
354
+ PrettyOutput.print(f"Vector count mismatch: {len(vectors)} vectors vs {len(ids)} ids",
355
+ output_type=OutputType.ERROR)
356
+ self.index = None
357
+ return
358
+
359
+ try:
360
+ self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
361
+ PrettyOutput.print(f"Successfully built index with {len(vectors)} vectors",
362
+ output_type=OutputType.SUCCESS)
363
+ except Exception as e:
364
+ PrettyOutput.print(f"Failed to add vectors to index: {str(e)}",
365
+ output_type=OutputType.ERROR)
366
+ self.index = None
367
+ else:
368
+ PrettyOutput.print("No valid vectors found, index not built",
369
+ output_type=OutputType.WARNING)
370
+ self.index = None
371
+
372
+ except Exception as e:
373
+ PrettyOutput.print(f"Failed to build index: {str(e)}",
374
+ output_type=OutputType.ERROR)
375
+ self.index = None
376
+
377
+ def gen_vector_db_from_cache(self):
378
+ """Generate a vector database from the cache"""
379
+ self.build_index()
380
+ self._load_all_cache()
381
+
382
+
383
+ def generate_codebase(self, force: bool = False):
384
+ """Generate the codebase index
385
+ Args:
386
+ force: Whether to force rebuild the index, without asking the user
387
+ """
388
+ try:
389
+ # Update the git file list
390
+ self.git_file_list = self.get_git_file_list()
391
+
392
+ # Check file changes
393
+ PrettyOutput.print("Check file changes...", output_type=OutputType.INFO)
394
+ changes_detected = False
395
+ new_files = []
396
+ modified_files = []
397
+ deleted_files = []
398
+
399
+ # Check deleted files
400
+ files_to_delete = []
401
+ for file_path in list(self.vector_cache.keys()):
402
+ if file_path not in self.git_file_list:
403
+ deleted_files.append(file_path)
404
+ files_to_delete.append(file_path)
405
+ changes_detected = True
406
+ # Check new and modified files
407
+ from rich.progress import Progress
408
+ with Progress() as progress:
409
+ task = progress.add_task("Check file status", total=len(self.git_file_list))
410
+ for file_path in self.git_file_list:
411
+ if not os.path.exists(file_path) or not self.is_text_file(file_path):
412
+ progress.advance(task)
413
+ continue
414
+
415
+ try:
416
+ current_md5 = get_file_md5(file_path)
417
+
418
+ if file_path not in self.vector_cache:
419
+ new_files.append(file_path)
420
+ changes_detected = True
421
+ elif self.vector_cache[file_path].get("md5") != current_md5:
422
+ modified_files.append(file_path)
423
+ changes_detected = True
424
+ except Exception as e:
425
+ PrettyOutput.print(f"Failed to check file {file_path}: {str(e)}",
426
+ output_type=OutputType.ERROR)
427
+ progress.advance(task)
428
+
429
+ # If changes are detected, display changes and ask the user
430
+ if changes_detected:
431
+ output_lines = ["Detected the following changes:"]
432
+ if new_files:
433
+ output_lines.append("New files:")
434
+ output_lines.extend(f" {f}" for f in new_files)
435
+ if modified_files:
436
+ output_lines.append("Modified files:")
437
+ output_lines.extend(f" {f}" for f in modified_files)
438
+ if deleted_files:
439
+ output_lines.append("Deleted files:")
440
+ output_lines.extend(f" {f}" for f in deleted_files)
441
+
442
+ PrettyOutput.print("\n".join(output_lines), output_type=OutputType.WARNING)
443
+
444
+ # If force is True, continue directly
445
+ if not force:
446
+ if not user_confirm("Rebuild the index?", False):
447
+ PrettyOutput.print("Cancel rebuilding the index", output_type=OutputType.INFO)
448
+ return
449
+
450
+ # Clean deleted files
451
+ for file_path in files_to_delete:
452
+ del self.vector_cache[file_path]
453
+ if files_to_delete:
454
+ PrettyOutput.print(f"Cleaned the cache of {len(files_to_delete)} files",
455
+ output_type=OutputType.INFO)
456
+
457
+ # Process new and modified files
458
+ files_to_process = new_files + modified_files
459
+ processed_files = []
460
+
461
+ with tqdm(total=len(files_to_process), desc="Processing files") as pbar:
462
+ # Use a thread pool to process files
463
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
464
+ # Submit all tasks
465
+ future_to_file = {
466
+ executor.submit(self.process_file, file): file
467
+ for file in files_to_process
468
+ }
469
+
470
+ # Process completed tasks
471
+ for future in concurrent.futures.as_completed(future_to_file):
472
+ file = future_to_file[future]
473
+ try:
474
+ result = future.result()
475
+ if result:
476
+ processed_files.append(result)
477
+ except Exception as e:
478
+ PrettyOutput.print(f"Failed to process file {file}: {str(e)}",
479
+ output_type=OutputType.ERROR)
480
+ pbar.update(1)
481
+
482
+ if processed_files:
483
+ PrettyOutput.print("Rebuilding the vector database...", output_type=OutputType.INFO)
484
+ self.gen_vector_db_from_cache()
485
+ PrettyOutput.print(f"Successfully generated the index for {len(processed_files)} files",
486
+ output_type=OutputType.SUCCESS)
487
+ else:
488
+ PrettyOutput.print("No file changes detected, no need to rebuild the index", output_type=OutputType.INFO)
489
+
490
+ except Exception as e:
491
+ # Try to save the cache when an exception occurs
492
+ try:
493
+ self._load_all_cache()
494
+ except Exception as save_error:
495
+ PrettyOutput.print(f"Failed to save cache: {str(save_error)}",
496
+ output_type=OutputType.ERROR)
497
+ raise e # Re-raise the original exception
498
+
499
+
500
+ def _text_search_score(self, content: str, keywords: List[str]) -> float:
501
+ """Calculate the matching score between the text content and the keywords
502
+
503
+ Args:
504
+ content: Text content
505
+ keywords: List of keywords
506
+
507
+ Returns:
508
+ float: Matching score (0-1)
509
+ """
510
+ if not keywords:
511
+ return 0.0
512
+
513
+ content = content.lower()
514
+ matched_keywords = set()
515
+
516
+ for keyword in keywords:
517
+ keyword = keyword.lower()
518
+ if keyword in content:
519
+ matched_keywords.add(keyword)
520
+
521
+ # Calculate the matching score
522
+ score = len(matched_keywords) / len(keywords)
523
+ return score
524
+
525
+ def pick_results(self, query: str, initial_results: List[str]) -> List[str]:
526
+ """Use a large model to pick the search results
527
+
528
+ Args:
529
+ query: Search query
530
+ initial_results: Initial results list of file paths
531
+
532
+ Returns:
533
+ List[str]: The picked results list, each item is a file path
534
+ """
535
+ if not initial_results:
536
+ return []
537
+
538
+ try:
539
+ PrettyOutput.print(f"Picking results for query: {query}", output_type=OutputType.INFO)
540
+
541
+ # Maximum content length per batch
542
+ max_batch_length = self.max_context_length - 1000 # Reserve space for prompt
543
+ max_file_length = max_batch_length // 3 # Limit individual file size
544
+
545
+ # Process files in batches
546
+ all_selected_files = set()
547
+ current_batch = []
548
+ current_length = 0
549
+
550
+ for path in initial_results:
551
+ try:
552
+ content = open(path, "r", encoding="utf-8").read()
553
+ # Truncate large files
554
+ if len(content) > max_file_length:
555
+ PrettyOutput.print(f"Truncating large file: {path}", OutputType.WARNING)
556
+ content = content[:max_file_length] + "\n... (content truncated)"
557
+
558
+ file_info = f"File: {path}\nContent: {content}\n\n"
559
+ file_length = len(file_info)
560
+
561
+ # If adding this file would exceed batch limit
562
+ if current_length + file_length > max_batch_length:
563
+ # Process current batch
564
+ if current_batch:
565
+ selected = self._process_batch(query, current_batch)
566
+ all_selected_files.update(selected)
567
+ # Start new batch
568
+ current_batch = [file_info]
569
+ current_length = file_length
570
+ else:
571
+ current_batch.append(file_info)
572
+ current_length += file_length
573
+
574
+ except Exception as e:
575
+ PrettyOutput.print(f"Failed to read file {path}: {str(e)}", OutputType.ERROR)
576
+ continue
577
+
578
+ # Process final batch
579
+ if current_batch:
580
+ selected = self._process_batch(query, current_batch)
581
+ all_selected_files.update(selected)
582
+
583
+ # Convert set to list and maintain original order
584
+ final_results = [path for path in initial_results if path in all_selected_files]
585
+ return final_results
586
+
587
+ except Exception as e:
588
+ PrettyOutput.print(f"Failed to pick: {str(e)}", OutputType.ERROR)
589
+ return initial_results
590
+
591
+ def _process_batch(self, query: str, files_info: List[str]) -> List[str]:
592
+ """Process a batch of files
593
+
594
+ Args:
595
+ query: Search query
596
+ files_info: List of file information strings
597
+
598
+ Returns:
599
+ List[str]: Selected file paths from this batch
600
+ """
601
+ prompt = f"""Please analyze the following code files and determine which files are most relevant to the given query. Consider file paths and code content to make your judgment.
602
+
603
+ Query: {query}
604
+
605
+ Available files:
606
+ {''.join(files_info)}
607
+
608
+ Please output a YAML list of relevant file paths, ordered by relevance (most relevant first). Only include files that are truly relevant to the query.
609
+ Output format:
610
+ <FILES>
611
+ - path/to/file1.py
612
+ - path/to/file2.py
613
+ </FILES>
614
+
615
+ Note: Only include files that have a strong connection to the query."""
616
+
617
+ # Use a large model to evaluate
618
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
619
+ response = model.chat_until_success(prompt)
620
+
621
+ # Parse the response
622
+ import yaml
623
+ files_match = re.search(r'<FILES>\n(.*?)</FILES>', response, re.DOTALL)
624
+ if not files_match:
625
+ return []
626
+
627
+ # Extract the file list
628
+ try:
629
+ selected_files = yaml.safe_load(files_match.group(1))
630
+ return selected_files if selected_files else []
631
+ except Exception as e:
632
+ PrettyOutput.print(f"Failed to parse response: {str(e)}", OutputType.ERROR)
633
+ return []
634
+
635
+ def _generate_query_variants(self, query: str) -> List[str]:
636
+ """Generate different expressions of the query
637
+
638
+ Args:
639
+ query: Original query
640
+
641
+ Returns:
642
+ List[str]: The query variants list
643
+ """
644
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
645
+ prompt = f"""Please generate 3 different expressions based on the following query, each expression should fully convey the meaning of the original query. These expressions will be used for code search, maintain professionalism and accuracy.
646
+ Original query: {query}
647
+
648
+ Please output 3 expressions directly, separated by two line breaks, without numbering or other markers.
649
+ """
650
+ variants = model.chat_until_success(prompt).strip().split('\n\n')
651
+ variants.append(query) # Add the original query
652
+ return variants
653
+
654
+ def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
655
+ """Use vector search to find related files
656
+
657
+ Args:
658
+ query_variants: The query variants list
659
+ top_k: The number of results to return
660
+
661
+ Returns:
662
+ Dict[str, Tuple[str, float, str]]: The mapping from file path to (file path, score, description)
663
+ """
664
+ results = {}
665
+ for query in query_variants:
666
+ query_vector = self.get_embedding(query)
667
+ query_vector = query_vector.reshape(1, -1)
668
+
669
+ distances, indices = self.index.search(query_vector, top_k) # type: ignore
670
+
671
+ for i, distance in zip(indices[0], distances[0]):
672
+ if i == -1:
673
+ continue
674
+
675
+ similarity = 1.0 / (1.0 + float(distance))
676
+ file_path = self.file_paths[i]
677
+ # Use the highest similarity score
678
+ if file_path not in results:
679
+ if similarity > 0.5:
680
+ data = self.vector_cache[file_path]
681
+ results[file_path] = (file_path, similarity, data["description"])
682
+
683
+ return results
684
+
685
+
686
+ def search_similar(self, query: str, top_k: int = 30) -> List[str]:
687
+ """Search related files"""
688
+ try:
689
+ self.generate_codebase()
690
+ if self.index is None:
691
+ return []
692
+ # Generate the query variants
693
+ query_variants = self._generate_query_variants(query)
694
+
695
+ # Perform vector search
696
+ vector_results = self._vector_search(query_variants, top_k)
697
+
698
+ results = list(vector_results.values())
699
+ results.sort(key=lambda x: x[1], reverse=True)
700
+
701
+ # Take the top top_k results for reordering
702
+ initial_results = results[:top_k]
703
+
704
+ # If no results are found, return directly
705
+ if not initial_results:
706
+ return []
707
+
708
+ # Filter low-scoring results
709
+ initial_results = [(path, score, desc) for path, score, desc in initial_results if score >= 0.5]
710
+
711
+ message = "Found related files:\n"
712
+ for path, score, _ in initial_results:
713
+ message += f"File: {path} Similarity: {score:.3f}\n"
714
+ PrettyOutput.print(message.rstrip(), output_type=OutputType.INFO, lang="markdown")
715
+
716
+ # Reorder the preliminary results
717
+ return self.pick_results(query, [path for path, _, _ in initial_results])
718
+
719
+ except Exception as e:
720
+ PrettyOutput.print(f"Failed to search: {str(e)}", output_type=OutputType.ERROR)
721
+ return []
722
+
723
+ def ask_codebase(self, query: str, top_k: int=20) -> str:
724
+ """Query the codebase"""
725
+ results = self.search_similar(query, top_k)
726
+ if not results:
727
+ PrettyOutput.print("No related files found", output_type=OutputType.WARNING)
728
+ return ""
729
+
730
+ message = "Found related files:\n"
731
+ for path in results:
732
+ message += f"File: {path}\n"
733
+ PrettyOutput.print(message.rstrip(), output_type=OutputType.SUCCESS, lang="markdown")
734
+
735
+ prompt = f"""You are a code expert, please answer the user's question based on the following file information:
736
+ """
737
+ for path in results:
738
+ try:
739
+ if len(prompt) > self.max_context_length:
740
+ PrettyOutput.print(f"Avoid context overflow, discard low-related file: {path}", OutputType.WARNING)
741
+ continue
742
+ content = open(path, "r", encoding="utf-8").read()
743
+ prompt += f"""
744
+ File path: {path}
745
+ File content:
746
+ {content}
747
+ ========================================
748
+ """
749
+ except Exception as e:
750
+ PrettyOutput.print(f"Failed to read file {path}: {str(e)}",
751
+ output_type=OutputType.ERROR)
752
+ continue
753
+
754
+ prompt += f"""
755
+ User question: {query}
756
+
757
+ Please answer the user's question in Chinese using professional language. If the provided file content is insufficient to answer the user's question, please inform the user. Never make up information.
758
+ """
759
+ model = PlatformRegistry.get_global_platform_registry().get_codegen_platform()
760
+ response = model.chat_until_success(prompt)
761
+ return response
762
+
763
+ def is_index_generated(self) -> bool:
764
+ """Check if the index has been generated"""
765
+ try:
766
+ # 1. 检查基本条件
767
+ if not self.vector_cache or not self.file_paths:
768
+ return False
769
+
770
+ if not hasattr(self, 'index') or self.index is None:
771
+ return False
772
+
773
+ # 2. 检查索引是否可用
774
+ # 创建测试向量
775
+ test_vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
776
+ try:
777
+ self.index.search(test_vector, 1) # type: ignore
778
+ except Exception:
779
+ return False
780
+
781
+ # 3. 验证向量缓存和文件路径的一致性
782
+ if len(self.vector_cache) != len(self.file_paths):
783
+ return False
784
+
785
+ # 4. 验证所有缓存文件
786
+ for file_path in self.file_paths:
787
+ if file_path not in self.vector_cache:
788
+ return False
789
+
790
+ cache_path = self._get_cache_path(file_path)
791
+ if not os.path.exists(cache_path):
792
+ return False
793
+
794
+ cache_data = self.vector_cache[file_path]
795
+ if not isinstance(cache_data.get("vector"), np.ndarray):
796
+ return False
797
+
798
+ return True
799
+
800
+ except Exception as e:
801
+ PrettyOutput.print(f"Error checking index status: {str(e)}",
802
+ output_type=OutputType.ERROR)
803
+ return False
804
+
805
+
806
+
807
+
808
+
809
+ def main():
810
+
811
+ parser = argparse.ArgumentParser(description='Codebase management and search tool')
812
+ subparsers = parser.add_subparsers(dest='command', help='Available commands')
813
+
814
+ # Generate command
815
+ generate_parser = subparsers.add_parser('generate', help='Generate codebase index')
816
+ generate_parser.add_argument('--force', action='store_true', help='Force rebuild index')
817
+
818
+ # Search command
819
+ search_parser = subparsers.add_parser('search', help='Search similar code files')
820
+ search_parser.add_argument('query', type=str, help='Search query')
821
+ search_parser.add_argument('--top-k', type=int, default=20, help='Number of results to return (default: 20)')
822
+
823
+ # Ask command
824
+ ask_parser = subparsers.add_parser('ask', help='Ask a question about the codebase')
825
+ ask_parser.add_argument('question', type=str, help='Question to ask')
826
+ ask_parser.add_argument('--top-k', type=int, default=20, help='Number of results to use (default: 20)')
827
+
828
+ export_parser = subparsers.add_parser('export', help='Export current index data')
829
+ args = parser.parse_args()
830
+
831
+ current_dir = find_git_root()
832
+ codebase = CodeBase(current_dir)
833
+
834
+ if args.command == 'export':
835
+ codebase.export()
836
+ return
837
+
838
+ # 如果没有生成索引,且不是生成命令,提示用户先生成索引
839
+ if not codebase.is_index_generated() and args.command != 'generate':
840
+ PrettyOutput.print("索引尚未生成,请先运行 'generate' 命令生成索引", output_type=OutputType.WARNING)
841
+ return
842
+
843
+ if args.command == 'generate':
844
+ try:
845
+ codebase.generate_codebase(force=args.force)
846
+ PrettyOutput.print("Codebase generation completed", output_type=OutputType.SUCCESS)
847
+ except Exception as e:
848
+ PrettyOutput.print(f"Error during codebase generation: {str(e)}", output_type=OutputType.ERROR)
849
+
850
+ elif args.command == 'search':
851
+ results = codebase.search_similar(args.query, args.top_k)
852
+ if not results:
853
+ PrettyOutput.print("No similar files found", output_type=OutputType.WARNING)
854
+ return
855
+
856
+ output = "Search Results:\n"
857
+ for path in results:
858
+ output += f"""- {path}\n"""
859
+ PrettyOutput.print(output, output_type=OutputType.INFO, lang="markdown")
860
+
861
+ elif args.command == 'ask':
862
+ response = codebase.ask_codebase(args.question, args.top_k)
863
+ output = f"""Answer:\n{response}"""
864
+ PrettyOutput.print(output, output_type=OutputType.INFO)
865
+
866
+ else:
867
+ parser.print_help()
868
+
869
+
870
+ if __name__ == "__main__":
871
+ exit(main())