jarvis-ai-assistant 0.1.100__py3-none-any.whl → 0.1.102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

@@ -1,875 +0,0 @@
1
- import hashlib
2
- import os
3
- import numpy as np
4
- import faiss
5
- from typing import List, Tuple, Optional, Dict
6
-
7
- import yaml
8
- from jarvis.models.registry import PlatformRegistry
9
- import concurrent.futures
10
- from threading import Lock
11
- from concurrent.futures import ThreadPoolExecutor
12
- from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_file_md5, get_max_context_length, get_single_line_input, get_thread_count, load_embedding_model, load_rerank_model
13
- from jarvis.utils import init_env
14
- import argparse
15
- import pickle
16
- import lzma # 添加 lzma 导入
17
- from tqdm import tqdm
18
- import re
19
-
20
- class CodeBase:
21
- def __init__(self, root_dir: str):
22
- init_env()
23
- self.root_dir = root_dir
24
- os.chdir(self.root_dir)
25
- self.thread_count = get_thread_count()
26
- self.max_context_length = get_max_context_length()
27
- self.index = None
28
-
29
- # 初始化数据目录
30
- self.data_dir = os.path.join(self.root_dir, ".jarvis/codebase")
31
- self.cache_dir = os.path.join(self.data_dir, "cache")
32
- if not os.path.exists(self.cache_dir):
33
- os.makedirs(self.cache_dir)
34
-
35
- # 初始化嵌入模型
36
- try:
37
- self.embedding_model = load_embedding_model()
38
- test_text = """This is a test text"""
39
- self.embedding_model.encode([test_text],
40
- convert_to_tensor=True,
41
- normalize_embeddings=True)
42
- PrettyOutput.print("Model loaded successfully", output_type=OutputType.SUCCESS)
43
- except Exception as e:
44
- PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
45
- raise
46
-
47
- self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
48
- self.git_file_list = self.get_git_file_list()
49
- self.platform_registry = PlatformRegistry.get_global_platform_registry()
50
-
51
- # 初始化缓存和索引
52
- self.vector_cache = {}
53
- self.file_paths = []
54
-
55
- # 加载所有缓存文件
56
- self._load_all_cache()
57
-
58
- def get_git_file_list(self):
59
- """Get the list of files in the git repository, excluding the .jarvis-codebase directory"""
60
- files = os.popen("git ls-files").read().splitlines()
61
- # Filter out files in the .jarvis-codebase directory
62
- return [f for f in files if not f.startswith(".jarvis-")]
63
-
64
- def is_text_file(self, file_path: str):
65
- with open(file_path, "r", encoding="utf-8") as f:
66
- try:
67
- f.read()
68
- return True
69
- except UnicodeDecodeError:
70
- return False
71
-
72
- def make_description(self, file_path: str, content: str) -> str:
73
- model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
74
- if self.thread_count > 1:
75
- model.set_suppress_output(True)
76
- else:
77
- PrettyOutput.print(f"Make description for {file_path} ...", output_type=OutputType.PROGRESS)
78
- prompt = f"""Please analyze the following code file and generate a detailed description. The description should include:
79
- 1. Overall file functionality description
80
- 2. description for each global variable, function, type definition, class, method, and other code elements
81
-
82
- Please use concise and professional language, emphasizing technical functionality to facilitate subsequent code retrieval.
83
- File path: {file_path}
84
- Code content:
85
- {content}
86
- """
87
- response = model.chat_until_success(prompt)
88
- return response
89
-
90
- def export(self):
91
- """Export the current index data to standard output"""
92
- for file_path, data in self.vector_cache.items():
93
- print(f"## {file_path}")
94
- print(f"- path: {file_path}")
95
- print(f"- description: {data['description']}")
96
-
97
- def _get_cache_path(self, file_path: str) -> str:
98
- """Get cache file path for a source file
99
-
100
- Args:
101
- file_path: Source file path
102
-
103
- Returns:
104
- str: Cache file path
105
- """
106
- # 处理文件路径:
107
- # 1. 移除开头的 ./ 或 /
108
- # 2. 将 / 替换为 --
109
- # 3. 添加 .cache 后缀
110
- clean_path = file_path.lstrip('./').lstrip('/')
111
- cache_name = clean_path.replace('/', '--') + '.cache'
112
- return os.path.join(self.cache_dir, cache_name)
113
-
114
- def _load_all_cache(self):
115
- """Load all cache files"""
116
- try:
117
- # 清空现有缓存和文件路径
118
- self.vector_cache = {}
119
- self.file_paths = []
120
- vectors = []
121
-
122
- for cache_file in os.listdir(self.cache_dir):
123
- if not cache_file.endswith('.cache'):
124
- continue
125
-
126
- cache_path = os.path.join(self.cache_dir, cache_file)
127
- try:
128
- with lzma.open(cache_path, 'rb') as f:
129
- cache_data = pickle.load(f)
130
- file_path = cache_data["path"]
131
- self.vector_cache[file_path] = cache_data
132
- self.file_paths.append(file_path)
133
- vectors.append(cache_data["vector"])
134
- except Exception as e:
135
- PrettyOutput.print(f"Failed to load cache file {cache_file}: {str(e)}",
136
- output_type=OutputType.WARNING)
137
- continue
138
-
139
- if vectors:
140
- # 重建索引
141
- vectors_array = np.vstack(vectors)
142
- hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
143
- hnsw_index.hnsw.efConstruction = 40
144
- hnsw_index.hnsw.efSearch = 16
145
- self.index = faiss.IndexIDMap(hnsw_index)
146
- self.index.add_with_ids(vectors_array, np.array(range(len(vectors)))) # type: ignore
147
-
148
- PrettyOutput.print(f"Loaded {len(self.vector_cache)} vector cache and rebuilt index",
149
- output_type=OutputType.INFO)
150
- else:
151
- self.index = None
152
- PrettyOutput.print("No valid cache files found", output_type=OutputType.WARNING)
153
-
154
- except Exception as e:
155
- PrettyOutput.print(f"Failed to load cache directory: {str(e)}",
156
- output_type=OutputType.WARNING)
157
- self.vector_cache = {}
158
- self.file_paths = []
159
- self.index = None
160
-
161
- def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
162
- """Cache the vector representation of a file"""
163
- try:
164
- with open(file_path, "rb") as f:
165
- file_md5 = hashlib.md5(f.read()).hexdigest()
166
- except Exception as e:
167
- PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
168
- output_type=OutputType.ERROR)
169
- file_md5 = ""
170
-
171
- # 准备缓存数据
172
- cache_data = {
173
- "path": file_path, # 保存文件路径
174
- "md5": file_md5, # 保存文件MD5
175
- "description": description, # 保存文件描述
176
- "vector": vector # 保存向量
177
- }
178
-
179
- # 更新内存缓存
180
- self.vector_cache[file_path] = cache_data
181
-
182
- # 保存到单独的缓存文件
183
- cache_path = self._get_cache_path(file_path)
184
- try:
185
- with lzma.open(cache_path, 'wb') as f:
186
- pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
187
- except Exception as e:
188
- PrettyOutput.print(f"Failed to save cache for {file_path}: {str(e)}",
189
- output_type=OutputType.ERROR)
190
-
191
- def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
192
- """Get the vector representation of a file from the cache"""
193
- if file_path not in self.vector_cache:
194
- return None
195
-
196
- # Check if the file has been modified
197
- try:
198
- with open(file_path, "rb") as f:
199
- current_md5 = hashlib.md5(f.read()).hexdigest()
200
- except Exception as e:
201
- PrettyOutput.print(f"Failed to calculate MD5 for {file_path}: {str(e)}",
202
- output_type=OutputType.ERROR)
203
- return None
204
-
205
- cached_data = self.vector_cache[file_path]
206
- if cached_data["md5"] != current_md5:
207
- return None
208
-
209
- # Check if the description has changed
210
- if cached_data["description"] != description:
211
- return None
212
-
213
- return cached_data["vector"]
214
-
215
- def get_embedding(self, text: str) -> np.ndarray:
216
- """Use the transformers model to get the vector representation of text"""
217
- # Truncate long text
218
- max_length = 512 # Or other suitable length
219
- text = ' '.join(text.split()[:max_length])
220
-
221
- # Get the embedding vector
222
- embedding = self.embedding_model.encode(text,
223
- normalize_embeddings=True, # L2 normalization
224
- show_progress_bar=False)
225
- vector = np.array(embedding, dtype=np.float32)
226
- return vector
227
-
228
- def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
229
- """Vectorize the file content and description"""
230
- try:
231
- # Try to get the vector from the cache first
232
- cached_vector = self.get_cached_vector(file_path, description)
233
- if cached_vector is not None:
234
- return cached_vector
235
-
236
- # Read the file content and combine information
237
- content = open(file_path, "r", encoding="utf-8").read()[:self.max_context_length] # Limit the file content length
238
-
239
- # Combine file information, including file content
240
- combined_text = f"""
241
- File path: {file_path}
242
- Description: {description}
243
- Content: {content}
244
- """
245
- vector = self.get_embedding(combined_text)
246
-
247
- # Save to cache
248
- self.cache_vector(file_path, vector, description)
249
- return vector
250
- except Exception as e:
251
- PrettyOutput.print(f"Error vectorizing file {file_path}: {str(e)}",
252
- output_type=OutputType.ERROR)
253
- return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
254
-
255
- def clean_cache(self) -> bool:
256
- """Clean expired cache records"""
257
- try:
258
- files_to_delete = []
259
- for file_path in list(self.vector_cache.keys()):
260
- if not os.path.exists(file_path):
261
- files_to_delete.append(file_path)
262
- cache_path = self._get_cache_path(file_path)
263
- try:
264
- os.remove(cache_path)
265
- except Exception:
266
- pass
267
-
268
- for file_path in files_to_delete:
269
- del self.vector_cache[file_path]
270
- if file_path in self.file_paths:
271
- self.file_paths.remove(file_path)
272
-
273
- return bool(files_to_delete)
274
-
275
- except Exception as e:
276
- PrettyOutput.print(f"Failed to clean cache: {str(e)}",
277
- output_type=OutputType.ERROR)
278
- return False
279
-
280
- def process_file(self, file_path: str):
281
- """Process a single file"""
282
- try:
283
- # Skip non-existent files
284
- if not os.path.exists(file_path):
285
- return None
286
-
287
- if not self.is_text_file(file_path):
288
- return None
289
-
290
- md5 = get_file_md5(file_path)
291
-
292
- content = open(file_path, "r", encoding="utf-8").read()
293
-
294
- # Check if the file has already been processed and the content has not changed
295
- if file_path in self.vector_cache:
296
- if self.vector_cache[file_path].get("md5") == md5:
297
- return None
298
-
299
- description = self.make_description(file_path, content) # Pass the truncated content
300
- vector = self.vectorize_file(file_path, description)
301
-
302
- # Save to cache, using the actual file path as the key
303
- self.vector_cache[file_path] = {
304
- "vector": vector,
305
- "description": description,
306
- "md5": md5
307
- }
308
-
309
- return file_path
310
-
311
- except Exception as e:
312
- PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
313
- output_type=OutputType.ERROR)
314
- return None
315
-
316
- def build_index(self):
317
- """Build a faiss index from the vector cache"""
318
- try:
319
- if not self.vector_cache:
320
- self.index = None
321
- return
322
-
323
- # Create the underlying HNSW index
324
- hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
325
- hnsw_index.hnsw.efConstruction = 40
326
- hnsw_index.hnsw.efSearch = 16
327
-
328
- # Wrap the HNSW index with IndexIDMap
329
- self.index = faiss.IndexIDMap(hnsw_index)
330
-
331
- vectors = []
332
- ids = []
333
- self.file_paths = [] # Reset the file path list
334
-
335
- for i, (file_path, data) in enumerate(self.vector_cache.items()):
336
- if "vector" not in data:
337
- PrettyOutput.print(f"Invalid cache data for {file_path}: missing vector",
338
- output_type=OutputType.WARNING)
339
- continue
340
-
341
- vector = data["vector"]
342
- if not isinstance(vector, np.ndarray):
343
- PrettyOutput.print(f"Invalid vector type for {file_path}: {type(vector)}",
344
- output_type=OutputType.WARNING)
345
- continue
346
-
347
- vectors.append(vector.reshape(1, -1))
348
- ids.append(i)
349
- self.file_paths.append(file_path)
350
-
351
- if vectors:
352
- vectors = np.vstack(vectors)
353
- if len(vectors) != len(ids):
354
- PrettyOutput.print(f"Vector count mismatch: {len(vectors)} vectors vs {len(ids)} ids",
355
- output_type=OutputType.ERROR)
356
- self.index = None
357
- return
358
-
359
- try:
360
- self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
361
- PrettyOutput.print(f"Successfully built index with {len(vectors)} vectors",
362
- output_type=OutputType.SUCCESS)
363
- except Exception as e:
364
- PrettyOutput.print(f"Failed to add vectors to index: {str(e)}",
365
- output_type=OutputType.ERROR)
366
- self.index = None
367
- else:
368
- PrettyOutput.print("No valid vectors found, index not built",
369
- output_type=OutputType.WARNING)
370
- self.index = None
371
-
372
- except Exception as e:
373
- PrettyOutput.print(f"Failed to build index: {str(e)}",
374
- output_type=OutputType.ERROR)
375
- self.index = None
376
-
377
- def gen_vector_db_from_cache(self):
378
- """Generate a vector database from the cache"""
379
- self.build_index()
380
- self._load_all_cache()
381
-
382
-
383
- def generate_codebase(self, force: bool = False):
384
- """Generate the codebase index
385
- Args:
386
- force: Whether to force rebuild the index, without asking the user
387
- """
388
- try:
389
- # Update the git file list
390
- self.git_file_list = self.get_git_file_list()
391
-
392
- # Check file changes
393
- PrettyOutput.print("\nCheck file changes...", output_type=OutputType.INFO)
394
- changes_detected = False
395
- new_files = []
396
- modified_files = []
397
- deleted_files = []
398
-
399
- # Check deleted files
400
- files_to_delete = []
401
- for file_path in list(self.vector_cache.keys()):
402
- if file_path not in self.git_file_list:
403
- deleted_files.append(file_path)
404
- files_to_delete.append(file_path)
405
- changes_detected = True
406
-
407
- # Check new and modified files
408
- with tqdm(total=len(self.git_file_list), desc="Check file status") as pbar:
409
- for file_path in self.git_file_list:
410
- if not os.path.exists(file_path) or not self.is_text_file(file_path):
411
- pbar.update(1)
412
- continue
413
-
414
- try:
415
- current_md5 = get_file_md5(file_path)
416
-
417
- if file_path not in self.vector_cache:
418
- new_files.append(file_path)
419
- changes_detected = True
420
- elif self.vector_cache[file_path].get("md5") != current_md5:
421
- modified_files.append(file_path)
422
- changes_detected = True
423
- except Exception as e:
424
- PrettyOutput.print(f"Failed to check file {file_path}: {str(e)}",
425
- output_type=OutputType.ERROR)
426
- pbar.update(1)
427
-
428
- # If changes are detected, display changes and ask the user
429
- if changes_detected:
430
- PrettyOutput.print("\nDetected the following changes:", output_type=OutputType.WARNING)
431
- if new_files:
432
- PrettyOutput.print("\nNew files:", output_type=OutputType.INFO)
433
- for f in new_files:
434
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
435
- if modified_files:
436
- PrettyOutput.print("\nModified files:", output_type=OutputType.INFO)
437
- for f in modified_files:
438
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
439
- if deleted_files:
440
- PrettyOutput.print("\nDeleted files:", output_type=OutputType.INFO)
441
- for f in deleted_files:
442
- PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
443
-
444
- # If force is True, continue directly
445
- if not force:
446
- # Ask the user whether to continue
447
- while True:
448
- response = get_single_line_input("\nRebuild the index? [y/N]").lower().strip()
449
- if response in ['y', 'yes']:
450
- break
451
- elif response in ['', 'n', 'no']:
452
- PrettyOutput.print("Cancel rebuilding the index", output_type=OutputType.INFO)
453
- return
454
- else:
455
- PrettyOutput.print("Please input y or n", output_type=OutputType.WARNING)
456
-
457
- # Clean deleted files
458
- for file_path in files_to_delete:
459
- del self.vector_cache[file_path]
460
- if files_to_delete:
461
- PrettyOutput.print(f"Cleaned the cache of {len(files_to_delete)} files",
462
- output_type=OutputType.INFO)
463
-
464
- # Process new and modified files
465
- files_to_process = new_files + modified_files
466
- processed_files = []
467
-
468
- with tqdm(total=len(files_to_process), desc="Processing files") as pbar:
469
- # Use a thread pool to process files
470
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
471
- # Submit all tasks
472
- future_to_file = {
473
- executor.submit(self.process_file, file): file
474
- for file in files_to_process
475
- }
476
-
477
- # Process completed tasks
478
- for future in concurrent.futures.as_completed(future_to_file):
479
- file = future_to_file[future]
480
- try:
481
- result = future.result()
482
- if result:
483
- processed_files.append(result)
484
- except Exception as e:
485
- PrettyOutput.print(f"Failed to process file {file}: {str(e)}",
486
- output_type=OutputType.ERROR)
487
- pbar.update(1)
488
-
489
- if processed_files:
490
- PrettyOutput.print("\nRebuilding the vector database...", output_type=OutputType.INFO)
491
- self.gen_vector_db_from_cache()
492
- PrettyOutput.print(f"Successfully generated the index for {len(processed_files)} files",
493
- output_type=OutputType.SUCCESS)
494
- else:
495
- PrettyOutput.print("No file changes detected, no need to rebuild the index", output_type=OutputType.INFO)
496
-
497
- except Exception as e:
498
- # Try to save the cache when an exception occurs
499
- try:
500
- self._load_all_cache()
501
- except Exception as save_error:
502
- PrettyOutput.print(f"Failed to save cache: {str(save_error)}",
503
- output_type=OutputType.ERROR)
504
- raise e # Re-raise the original exception
505
-
506
-
507
- def _text_search_score(self, content: str, keywords: List[str]) -> float:
508
- """Calculate the matching score between the text content and the keywords
509
-
510
- Args:
511
- content: Text content
512
- keywords: List of keywords
513
-
514
- Returns:
515
- float: Matching score (0-1)
516
- """
517
- if not keywords:
518
- return 0.0
519
-
520
- content = content.lower()
521
- matched_keywords = set()
522
-
523
- for keyword in keywords:
524
- keyword = keyword.lower()
525
- if keyword in content:
526
- matched_keywords.add(keyword)
527
-
528
- # Calculate the matching score
529
- score = len(matched_keywords) / len(keywords)
530
- return score
531
-
532
- def pick_results(self, query: str, initial_results: List[str]) -> List[str]:
533
- """Use a large model to pick the search results
534
-
535
- Args:
536
- query: Search query
537
- initial_results: Initial results list of file paths
538
-
539
- Returns:
540
- List[str]: The picked results list, each item is a file path
541
- """
542
- if not initial_results:
543
- return []
544
-
545
- try:
546
- PrettyOutput.print(f"Picking results for query: {query}", output_type=OutputType.INFO)
547
-
548
- # Maximum content length per batch
549
- max_batch_length = self.max_context_length - 1000 # Reserve space for prompt
550
- max_file_length = max_batch_length // 3 # Limit individual file size
551
-
552
- # Process files in batches
553
- all_selected_files = set()
554
- current_batch = []
555
- current_length = 0
556
-
557
- for path in initial_results:
558
- try:
559
- content = open(path, "r", encoding="utf-8").read()
560
- # Truncate large files
561
- if len(content) > max_file_length:
562
- PrettyOutput.print(f"Truncating large file: {path}", OutputType.WARNING)
563
- content = content[:max_file_length] + "\n... (content truncated)"
564
-
565
- file_info = f"File: {path}\nContent: {content}\n\n"
566
- file_length = len(file_info)
567
-
568
- # If adding this file would exceed batch limit
569
- if current_length + file_length > max_batch_length:
570
- # Process current batch
571
- if current_batch:
572
- selected = self._process_batch(query, current_batch)
573
- all_selected_files.update(selected)
574
- # Start new batch
575
- current_batch = [file_info]
576
- current_length = file_length
577
- else:
578
- current_batch.append(file_info)
579
- current_length += file_length
580
-
581
- except Exception as e:
582
- PrettyOutput.print(f"Failed to read file {path}: {str(e)}", OutputType.ERROR)
583
- continue
584
-
585
- # Process final batch
586
- if current_batch:
587
- selected = self._process_batch(query, current_batch)
588
- all_selected_files.update(selected)
589
-
590
- # Convert set to list and maintain original order
591
- final_results = [path for path in initial_results if path in all_selected_files]
592
- return final_results
593
-
594
- except Exception as e:
595
- PrettyOutput.print(f"Failed to pick: {str(e)}", OutputType.ERROR)
596
- return initial_results
597
-
598
- def _process_batch(self, query: str, files_info: List[str]) -> List[str]:
599
- """Process a batch of files
600
-
601
- Args:
602
- query: Search query
603
- files_info: List of file information strings
604
-
605
- Returns:
606
- List[str]: Selected file paths from this batch
607
- """
608
- prompt = f"""Please analyze the following code files and determine which files are most relevant to the given query. Consider file paths and code content to make your judgment.
609
-
610
- Query: {query}
611
-
612
- Available files:
613
- {''.join(files_info)}
614
-
615
- Please output a YAML list of relevant file paths, ordered by relevance (most relevant first). Only include files that are truly relevant to the query.
616
- Output format:
617
- <FILES>
618
- - path/to/file1.py
619
- - path/to/file2.py
620
- </FILES>
621
-
622
- Note: Only include files that have a strong connection to the query."""
623
-
624
- # Use a large model to evaluate
625
- model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
626
- response = model.chat_until_success(prompt)
627
-
628
- # Parse the response
629
- import yaml
630
- files_match = re.search(r'<FILES>\n(.*?)</FILES>', response, re.DOTALL)
631
- if not files_match:
632
- return []
633
-
634
- # Extract the file list
635
- try:
636
- selected_files = yaml.safe_load(files_match.group(1))
637
- return selected_files if selected_files else []
638
- except Exception as e:
639
- PrettyOutput.print(f"Failed to parse response: {str(e)}", OutputType.ERROR)
640
- return []
641
-
642
- def _generate_query_variants(self, query: str) -> List[str]:
643
- """Generate different expressions of the query
644
-
645
- Args:
646
- query: Original query
647
-
648
- Returns:
649
- List[str]: The query variants list
650
- """
651
- model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
652
- prompt = f"""Please generate 3 different expressions based on the following query, each expression should fully convey the meaning of the original query. These expressions will be used for code search, maintain professionalism and accuracy.
653
- Original query: {query}
654
-
655
- Please output 3 expressions directly, separated by two line breaks, without numbering or other markers.
656
- """
657
- variants = model.chat_until_success(prompt).strip().split('\n\n')
658
- variants.append(query) # Add the original query
659
- return variants
660
-
661
- def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
662
- """Use vector search to find related files
663
-
664
- Args:
665
- query_variants: The query variants list
666
- top_k: The number of results to return
667
-
668
- Returns:
669
- Dict[str, Tuple[str, float, str]]: The mapping from file path to (file path, score, description)
670
- """
671
- results = {}
672
- for query in query_variants:
673
- query_vector = self.get_embedding(query)
674
- query_vector = query_vector.reshape(1, -1)
675
-
676
- distances, indices = self.index.search(query_vector, top_k) # type: ignore
677
-
678
- for i, distance in zip(indices[0], distances[0]):
679
- if i == -1:
680
- continue
681
-
682
- similarity = 1.0 / (1.0 + float(distance))
683
- file_path = self.file_paths[i]
684
- # Use the highest similarity score
685
- if file_path not in results:
686
- if similarity > 0.5:
687
- data = self.vector_cache[file_path]
688
- results[file_path] = (file_path, similarity, data["description"])
689
-
690
- return results
691
-
692
-
693
- def search_similar(self, query: str, top_k: int = 30) -> List[str]:
694
- """Search related files"""
695
- try:
696
- if self.index is None:
697
- return []
698
- # Generate the query variants
699
- query_variants = self._generate_query_variants(query)
700
-
701
- # Perform vector search
702
- vector_results = self._vector_search(query_variants, top_k)
703
-
704
- results = list(vector_results.values())
705
- results.sort(key=lambda x: x[1], reverse=True)
706
-
707
- # Take the top top_k results for reordering
708
- initial_results = results[:top_k]
709
-
710
- # If no results are found, return directly
711
- if not initial_results:
712
- return []
713
-
714
- # Filter low-scoring results
715
- initial_results = [(path, score, desc) for path, score, desc in initial_results if score >= 0.5]
716
-
717
- for path, score, desc in initial_results:
718
- PrettyOutput.print(f"File: {path} Similarity: {score:.3f}", output_type=OutputType.INFO)
719
-
720
- # Reorder the preliminary results
721
- return self.pick_results(query, [path for path, _, _ in initial_results])
722
-
723
- except Exception as e:
724
- PrettyOutput.print(f"Failed to search: {str(e)}", output_type=OutputType.ERROR)
725
- return []
726
-
727
- def ask_codebase(self, query: str, top_k: int=20) -> str:
728
- """Query the codebase"""
729
- results = self.search_similar(query, top_k)
730
- if not results:
731
- PrettyOutput.print("No related files found", output_type=OutputType.WARNING)
732
- return ""
733
-
734
- PrettyOutput.print(f"Found related files: ", output_type=OutputType.SUCCESS)
735
- for path in results:
736
- PrettyOutput.print(f"File: {path}",
737
- output_type=OutputType.INFO)
738
-
739
- prompt = f"""You are a code expert, please answer the user's question based on the following file information:
740
- """
741
- for path in results:
742
- try:
743
- if len(prompt) > self.max_context_length:
744
- PrettyOutput.print(f"Avoid context overflow, discard low-related file: {path}", OutputType.WARNING)
745
- continue
746
- content = open(path, "r", encoding="utf-8").read()
747
- prompt += f"""
748
- File path: {path}
749
- File content:
750
- {content}
751
- ========================================
752
- """
753
- except Exception as e:
754
- PrettyOutput.print(f"Failed to read file {path}: {str(e)}",
755
- output_type=OutputType.ERROR)
756
- continue
757
-
758
- prompt += f"""
759
- User question: {query}
760
-
761
- Please answer the user's question in Chinese using professional language. If the provided file content is insufficient to answer the user's question, please inform the user. Never make up information.
762
- """
763
- model = PlatformRegistry.get_global_platform_registry().get_codegen_platform()
764
- response = model.chat_until_success(prompt)
765
- return response
766
-
767
- def is_index_generated(self) -> bool:
768
- """Check if the index has been generated"""
769
- try:
770
- # 1. 检查基本条件
771
- if not self.vector_cache or not self.file_paths:
772
- return False
773
-
774
- if not hasattr(self, 'index') or self.index is None:
775
- return False
776
-
777
- # 2. 检查索引是否可用
778
- # 创建测试向量
779
- test_vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
780
- try:
781
- self.index.search(test_vector, 1) # type: ignore
782
- except Exception:
783
- return False
784
-
785
- # 3. 验证向量缓存和文件路径的一致性
786
- if len(self.vector_cache) != len(self.file_paths):
787
- return False
788
-
789
- # 4. 验证所有缓存文件
790
- for file_path in self.file_paths:
791
- if file_path not in self.vector_cache:
792
- return False
793
-
794
- cache_path = self._get_cache_path(file_path)
795
- if not os.path.exists(cache_path):
796
- return False
797
-
798
- cache_data = self.vector_cache[file_path]
799
- if not isinstance(cache_data.get("vector"), np.ndarray):
800
- return False
801
-
802
- return True
803
-
804
- except Exception as e:
805
- PrettyOutput.print(f"Error checking index status: {str(e)}",
806
- output_type=OutputType.ERROR)
807
- return False
808
-
809
-
810
-
811
-
812
-
813
- def main():
814
-
815
- parser = argparse.ArgumentParser(description='Codebase management and search tool')
816
- subparsers = parser.add_subparsers(dest='command', help='Available commands')
817
-
818
- # Generate command
819
- generate_parser = subparsers.add_parser('generate', help='Generate codebase index')
820
- generate_parser.add_argument('--force', action='store_true', help='Force rebuild index')
821
-
822
- # Search command
823
- search_parser = subparsers.add_parser('search', help='Search similar code files')
824
- search_parser.add_argument('query', type=str, help='Search query')
825
- search_parser.add_argument('--top-k', type=int, default=20, help='Number of results to return (default: 20)')
826
-
827
- # Ask command
828
- ask_parser = subparsers.add_parser('ask', help='Ask a question about the codebase')
829
- ask_parser.add_argument('question', type=str, help='Question to ask')
830
- ask_parser.add_argument('--top-k', type=int, default=20, help='Number of results to use (default: 20)')
831
-
832
- export_parser = subparsers.add_parser('export', help='Export current index data')
833
- args = parser.parse_args()
834
-
835
- current_dir = find_git_root()
836
- codebase = CodeBase(current_dir)
837
-
838
- if args.command == 'export':
839
- codebase.export()
840
- return
841
-
842
- # 如果没有生成索引,且不是生成命令,提示用户先生成索引
843
- if not codebase.is_index_generated() and args.command != 'generate':
844
- PrettyOutput.print("索引尚未生成,请先运行 'generate' 命令生成索引", output_type=OutputType.WARNING)
845
- return
846
-
847
- if args.command == 'generate':
848
- try:
849
- codebase.generate_codebase(force=args.force)
850
- PrettyOutput.print("\nCodebase generation completed", output_type=OutputType.SUCCESS)
851
- except Exception as e:
852
- PrettyOutput.print(f"Error during codebase generation: {str(e)}", output_type=OutputType.ERROR)
853
-
854
- elif args.command == 'search':
855
- results = codebase.search_similar(args.query, args.top_k)
856
- if not results:
857
- PrettyOutput.print("No similar files found", output_type=OutputType.WARNING)
858
- return
859
-
860
- PrettyOutput.print("\nSearch Results:", output_type=OutputType.INFO)
861
- for path in results:
862
- PrettyOutput.print("\n" + "="*50, output_type=OutputType.INFO)
863
- PrettyOutput.print(f"File: {path}", output_type=OutputType.INFO)
864
-
865
- elif args.command == 'ask':
866
- response = codebase.ask_codebase(args.question, args.top_k)
867
- PrettyOutput.print("\nAnswer:", output_type=OutputType.INFO)
868
- PrettyOutput.print(response, output_type=OutputType.INFO)
869
-
870
- else:
871
- parser.print_help()
872
-
873
-
874
- if __name__ == "__main__":
875
- exit(main())