jarvis-ai-assistant 0.1.131__py3-none-any.whl → 0.1.132__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +48 -29
  3. jarvis/jarvis_agent/patch.py +61 -43
  4. jarvis/jarvis_agent/shell_input_handler.py +1 -1
  5. jarvis/jarvis_code_agent/code_agent.py +87 -86
  6. jarvis/jarvis_dev/main.py +335 -626
  7. jarvis/jarvis_git_squash/main.py +10 -31
  8. jarvis/jarvis_multi_agent/__init__.py +19 -28
  9. jarvis/jarvis_platform/ai8.py +7 -32
  10. jarvis/jarvis_platform/base.py +2 -7
  11. jarvis/jarvis_platform/kimi.py +3 -144
  12. jarvis/jarvis_platform/ollama.py +54 -68
  13. jarvis/jarvis_platform/openai.py +0 -4
  14. jarvis/jarvis_platform/oyi.py +0 -75
  15. jarvis/jarvis_platform/yuanbao.py +264 -0
  16. jarvis/jarvis_rag/file_processors.py +138 -0
  17. jarvis/jarvis_rag/main.py +1305 -425
  18. jarvis/jarvis_tools/ask_codebase.py +205 -39
  19. jarvis/jarvis_tools/code_review.py +125 -99
  20. jarvis/jarvis_tools/execute_python_script.py +58 -0
  21. jarvis/jarvis_tools/execute_shell.py +13 -26
  22. jarvis/jarvis_tools/execute_shell_script.py +1 -1
  23. jarvis/jarvis_tools/file_analyzer.py +271 -0
  24. jarvis/jarvis_tools/file_operation.py +1 -1
  25. jarvis/jarvis_tools/find_caller.py +213 -0
  26. jarvis/jarvis_tools/find_symbol.py +211 -0
  27. jarvis/jarvis_tools/function_analyzer.py +248 -0
  28. jarvis/jarvis_tools/git_commiter.py +4 -4
  29. jarvis/jarvis_tools/methodology.py +89 -48
  30. jarvis/jarvis_tools/project_analyzer.py +220 -0
  31. jarvis/jarvis_tools/read_code.py +23 -2
  32. jarvis/jarvis_tools/read_webpage.py +195 -81
  33. jarvis/jarvis_tools/registry.py +132 -11
  34. jarvis/jarvis_tools/search_web.py +55 -10
  35. jarvis/jarvis_tools/tool_generator.py +6 -8
  36. jarvis/jarvis_utils/__init__.py +1 -0
  37. jarvis/jarvis_utils/config.py +67 -3
  38. jarvis/jarvis_utils/embedding.py +344 -45
  39. jarvis/jarvis_utils/git_utils.py +9 -1
  40. jarvis/jarvis_utils/input.py +7 -6
  41. jarvis/jarvis_utils/methodology.py +379 -7
  42. jarvis/jarvis_utils/output.py +5 -3
  43. jarvis/jarvis_utils/utils.py +59 -7
  44. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/METADATA +3 -2
  45. jarvis_ai_assistant-0.1.132.dist-info/RECORD +82 -0
  46. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/entry_points.txt +2 -0
  47. jarvis/jarvis_codebase/__init__.py +0 -0
  48. jarvis/jarvis_codebase/main.py +0 -1011
  49. jarvis/jarvis_tools/treesitter_analyzer.py +0 -331
  50. jarvis/jarvis_treesitter/README.md +0 -104
  51. jarvis/jarvis_treesitter/__init__.py +0 -20
  52. jarvis/jarvis_treesitter/database.py +0 -258
  53. jarvis/jarvis_treesitter/example.py +0 -115
  54. jarvis/jarvis_treesitter/grammar_builder.py +0 -182
  55. jarvis/jarvis_treesitter/language.py +0 -117
  56. jarvis/jarvis_treesitter/symbol.py +0 -31
  57. jarvis/jarvis_treesitter/tools_usage.md +0 -121
  58. jarvis_ai_assistant-0.1.131.dist-info/RECORD +0 -85
  59. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/LICENSE +0 -0
  60. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/WHEEL +0 -0
  61. {jarvis_ai_assistant-0.1.131.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/top_level.txt +0 -0
@@ -1,1011 +0,0 @@
1
- import hashlib
2
- import os
3
- import numpy as np
4
- import faiss
5
- from typing import List, Tuple, Optional, Dict
6
-
7
- from yaspin import yaspin
8
-
9
- from jarvis.jarvis_platform.registry import PlatformRegistry
10
- import concurrent.futures
11
- from concurrent.futures import ThreadPoolExecutor
12
- import argparse
13
- import pickle
14
- import lzma # 添加 lzma 导入
15
- from tqdm import tqdm
16
- import re
17
-
18
- from jarvis.jarvis_utils.config import get_max_token_count, get_thread_count
19
- from jarvis.jarvis_utils.embedding import get_embedding, load_embedding_model, get_context_token_count
20
- from jarvis.jarvis_utils.git_utils import find_git_root
21
- from jarvis.jarvis_utils.output import OutputType, PrettyOutput
22
- from jarvis.jarvis_utils.utils import get_file_md5, init_env, user_confirm
23
-
24
- class CodeBase:
25
- def __init__(self, root_dir: str):
26
- with yaspin(text="正在初始化环境...", color="cyan") as spinner:
27
- init_env()
28
- spinner.text = "环境初始化完成"
29
- spinner.ok("✅")
30
-
31
- self.root_dir = root_dir
32
- os.chdir(self.root_dir)
33
- self.thread_count = get_thread_count()
34
- self.max_token_count = get_max_token_count()
35
- self.index = None
36
-
37
- # 初始化数据目录
38
- with yaspin(text="正在初始化数据目录...", color="cyan") as spinner:
39
- self.data_dir = os.path.join(self.root_dir, ".jarvis/codebase")
40
- self.cache_dir = os.path.join(self.data_dir, "cache")
41
- if not os.path.exists(self.cache_dir):
42
- os.makedirs(self.cache_dir)
43
- spinner.text = "数据目录初始化完成"
44
- spinner.ok("✅")
45
-
46
- with yaspin("正在初始化嵌入模型...", color="cyan") as spinner:
47
- # 初始化嵌入模型
48
- try:
49
- self.embedding_model = load_embedding_model()
50
- test_text = """This is a test text"""
51
- self.embedding_model.encode([test_text],
52
- convert_to_tensor=True,
53
- normalize_embeddings=True)
54
- spinner.text = "嵌入模型初始化完成"
55
- spinner.ok("✅")
56
- except Exception as e:
57
- spinner.text = "嵌入模型初始化失败"
58
- spinner.fail("❌")
59
- raise
60
-
61
- self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
62
- self.git_file_list = self.get_git_file_list()
63
- self.platform_registry = PlatformRegistry.get_global_platform_registry()
64
-
65
- # 初始化缓存和索引
66
- self.vector_cache = {}
67
- self.file_paths = []
68
-
69
- # 加载所有缓存文件
70
- with spinner.hidden():
71
- self._load_all_cache()
72
-
73
- def get_git_file_list(self):
74
- """Get the list of files in the git repository, excluding the .jarvis-codebase directory"""
75
- files = os.popen("git ls-files").read().splitlines()
76
- # Filter out files in the .jarvis-codebase directory
77
- return [f for f in files if not f.startswith(".jarvis")]
78
-
79
- def is_text_file(self, file_path: str):
80
- try:
81
- open(file_path, "r", encoding="utf-8", errors="ignore").read()
82
- return True
83
- except Exception:
84
- return False
85
-
86
- def make_description(self, file_path: str, content: str) -> str:
87
- model = PlatformRegistry.get_global_platform_registry().get_cheap_platform()
88
- prompt = f"""请分析以下代码文件并生成详细描述。描述应包含:
89
- 1. 文件整体功能描述
90
- 2. 对每个全局变量、函数、类型定义、类、方法和其他代码元素的描述
91
-
92
- 请使用简洁专业的语言,强调技术功能,以便于后续代码检索。
93
- 文件路径: {file_path}
94
- 代码内容:
95
- {content}
96
- """
97
- response = model.chat_until_success(prompt)
98
- return response
99
-
100
- def export(self):
101
- """Export the current index data to standard output"""
102
- for file_path, data in self.vector_cache.items():
103
- print(f"## {file_path}")
104
- print(f"- path: {file_path}")
105
- print(f"- description: {data['description']}")
106
-
107
- def _get_cache_path(self, file_path: str) -> str:
108
- """Get cache file path for a source file
109
-
110
- Args:
111
- file_path: Source file path
112
-
113
- Returns:
114
- str: Cache file path
115
- """
116
- # 处理文件路径:
117
- # 1. 移除开头的 ./ 或 /
118
- # 2. 将 / 替换为 --
119
- # 3. 添加 .cache 后缀
120
- clean_path = file_path.lstrip('./').lstrip('/')
121
- cache_name = clean_path.replace('/', '--') + '.cache'
122
- return os.path.join(self.cache_dir, cache_name)
123
-
124
- def _load_all_cache(self):
125
- """Load all cache files"""
126
- with yaspin(text="正在加载缓存文件...", color="cyan") as spinner:
127
- try:
128
- # 清空现有缓存和文件路径
129
- self.vector_cache = {}
130
- self.file_paths = []
131
- vectors = []
132
-
133
- for cache_file in os.listdir(self.cache_dir):
134
- if not cache_file.endswith('.cache'):
135
- continue
136
-
137
- cache_path = os.path.join(self.cache_dir, cache_file)
138
- try:
139
- with lzma.open(cache_path, 'rb') as f:
140
- cache_data = pickle.load(f)
141
- file_path = cache_data["path"]
142
- self.vector_cache[file_path] = cache_data
143
- self.file_paths.append(file_path)
144
- vectors.append(cache_data["vector"])
145
- spinner.write(f"✅ 加载缓存文件成功 {file_path}")
146
- except Exception as e:
147
- spinner.write(f"❌ 加载缓存文件失败 {cache_file} {str(e)}")
148
- continue
149
-
150
- if vectors:
151
- # 重建索引
152
- vectors_array = np.vstack(vectors)
153
- hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
154
- hnsw_index.hnsw.efConstruction = 40
155
- hnsw_index.hnsw.efSearch = 16
156
- self.index = faiss.IndexIDMap(hnsw_index)
157
- self.index.add_with_ids(vectors_array, np.array(range(len(vectors)))) # type: ignore
158
-
159
- spinner.text = f"加载 {len(self.vector_cache)} 个向量缓存并重建索引"
160
- spinner.ok("✅")
161
- else:
162
- self.index = None
163
- spinner.text = "没有找到有效的缓存文件"
164
- spinner.ok("✅")
165
-
166
- except Exception as e:
167
- spinner.text = f"加载缓存目录失败: {str(e)}"
168
- spinner.fail("❌")
169
- self.vector_cache = {}
170
- self.file_paths = []
171
- self.index = None
172
-
173
- def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
174
- """Cache the vector representation of a file"""
175
- try:
176
- with open(file_path, "rb") as f:
177
- file_md5 = hashlib.md5(f.read()).hexdigest()
178
- except Exception as e:
179
- PrettyOutput.print(f"计算 {file_path} 的MD5失败: {str(e)}",
180
- output_type=OutputType.ERROR)
181
- file_md5 = ""
182
-
183
- # 准备缓存数据
184
- cache_data = {
185
- "path": file_path, # 保存文件路径
186
- "md5": file_md5, # 保存文件MD5
187
- "description": description, # 保存文件描述
188
- "vector": vector # 保存向量
189
- }
190
-
191
- # 更新内存缓存
192
- self.vector_cache[file_path] = cache_data
193
-
194
- # 保存到单独的缓存文件
195
- cache_path = self._get_cache_path(file_path)
196
- try:
197
- with lzma.open(cache_path, 'wb') as f:
198
- pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
199
- except Exception as e:
200
- PrettyOutput.print(f"保存 {file_path} 的缓存失败: {str(e)}",
201
- output_type=OutputType.ERROR)
202
-
203
- def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
204
- """Get the vector representation of a file from the cache"""
205
- if file_path not in self.vector_cache:
206
- return None
207
-
208
- # Check if the file has been modified
209
- try:
210
- with open(file_path, "rb") as f:
211
- current_md5 = hashlib.md5(f.read()).hexdigest()
212
- except Exception as e:
213
- PrettyOutput.print(f"计算 {file_path} 的MD5失败: {str(e)}",
214
- output_type=OutputType.ERROR)
215
- return None
216
-
217
- cached_data = self.vector_cache[file_path]
218
- if cached_data["md5"] != current_md5:
219
- return None
220
-
221
- # Check if the description has changed
222
- if cached_data["description"] != description:
223
- return None
224
-
225
- return cached_data["vector"]
226
-
227
- def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
228
- """Vectorize the file content and description"""
229
- try:
230
- # Try to get the vector from the cache first
231
- cached_vector = self.get_cached_vector(file_path, description)
232
- if cached_vector is not None:
233
- return cached_vector
234
-
235
- # Read the file content and combine information
236
- content = open(file_path, "r", encoding="utf-8", errors="ignore").read()[:self.max_token_count] # Limit the file content length
237
-
238
- # Combine file information, including file content
239
- combined_text = f"""
240
- File path: {file_path}
241
- Description: {description}
242
- Content: {content}
243
- """
244
- vector = get_embedding(self.embedding_model, combined_text)
245
-
246
- # Save to cache
247
- self.cache_vector(file_path, vector, description)
248
- return vector
249
- except Exception as e:
250
- PrettyOutput.print(f"向量化 {file_path} 失败: {str(e)}",
251
- output_type=OutputType.ERROR)
252
- return np.zeros(self.vector_dim, dtype=np.float32) # type: ignore
253
-
254
- def clean_cache(self) -> bool:
255
- """Clean expired cache records"""
256
- try:
257
- files_to_delete = []
258
- for file_path in list(self.vector_cache.keys()):
259
- if not os.path.exists(file_path):
260
- files_to_delete.append(file_path)
261
- cache_path = self._get_cache_path(file_path)
262
- try:
263
- os.remove(cache_path)
264
- except Exception:
265
- pass
266
-
267
- for file_path in files_to_delete:
268
- del self.vector_cache[file_path]
269
- if file_path in self.file_paths:
270
- self.file_paths.remove(file_path)
271
-
272
- return bool(files_to_delete)
273
-
274
- except Exception as e:
275
- PrettyOutput.print(f"清理缓存失败: {str(e)}",
276
- output_type=OutputType.ERROR)
277
- return False
278
-
279
- def process_file(self, file_path: str):
280
- """Process a single file"""
281
- try:
282
- # Skip non-existent files
283
- if not os.path.exists(file_path):
284
- return None
285
-
286
- if not self.is_text_file(file_path):
287
- return None
288
-
289
- md5 = get_file_md5(file_path)
290
-
291
- content = open(file_path, "r", encoding="utf-8", errors="ignore").read()
292
-
293
- # Check if the file has already been processed and the content has not changed
294
- if file_path in self.vector_cache:
295
- if self.vector_cache[file_path].get("md5") == md5:
296
- return None
297
-
298
- description = self.make_description(file_path, content) # Pass the truncated content
299
- vector = self.vectorize_file(file_path, description)
300
-
301
- # Save to cache, using the actual file path as the key
302
- self.vector_cache[file_path] = {
303
- "vector": vector,
304
- "description": description,
305
- "md5": md5
306
- }
307
-
308
- return file_path
309
-
310
- except Exception as e:
311
- PrettyOutput.print(f"处理 {file_path} 失败: {str(e)}",
312
- output_type=OutputType.ERROR)
313
- return None
314
-
315
- def build_index(self):
316
- """Build a faiss index from the vector cache"""
317
- try:
318
- if not self.vector_cache:
319
- self.index = None
320
- return
321
-
322
- # Create the underlying HNSW index
323
- hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
324
- hnsw_index.hnsw.efConstruction = 40
325
- hnsw_index.hnsw.efSearch = 16
326
-
327
- # Wrap the HNSW index with IndexIDMap
328
- self.index = faiss.IndexIDMap(hnsw_index)
329
-
330
- vectors = []
331
- ids = []
332
- self.file_paths = [] # Reset the file path list
333
-
334
- for i, ( file_path, data) in enumerate(self.vector_cache.items()):
335
- if "vector" not in data:
336
- PrettyOutput.print(f"无效的缓存数据 {file_path}: 缺少向量",
337
- output_type=OutputType.WARNING)
338
- continue
339
-
340
- vector = data["vector"]
341
- if not isinstance(vector, np.ndarray):
342
- PrettyOutput.print(f"无效的向量类型 {file_path}: {type(vector)}",
343
- output_type=OutputType.WARNING)
344
- continue
345
-
346
- vectors.append(vector.reshape(1, -1))
347
- ids.append(i)
348
- self.file_paths.append(file_path)
349
-
350
- if vectors:
351
- vectors = np.vstack(vectors)
352
- if len(vectors) != len(ids):
353
- PrettyOutput.print(f"向量数量不匹配: {len(vectors)} 个向量 vs {len(ids)} 个ID",
354
- output_type=OutputType.WARNING)
355
- self.index = None
356
- return
357
-
358
- try:
359
- self.index.add_with_ids(vectors, np.array(ids)) # type: ignore
360
- PrettyOutput.print(f"成功构建包含 {len(vectors)} 个向量的索引",
361
- output_type=OutputType.SUCCESS)
362
- except Exception as e:
363
- PrettyOutput.print(f"添加向量到索引失败: {str(e)}",
364
- output_type=OutputType.ERROR)
365
- self.index = None
366
- else:
367
- PrettyOutput.print("没有找到有效的向量, 索引未构建",
368
- output_type=OutputType.WARNING)
369
- self.index = None
370
-
371
- except Exception as e:
372
- PrettyOutput.print(f"构建索引失败: {str(e)}",
373
- output_type=OutputType.ERROR)
374
- self.index = None
375
-
376
- def gen_vector_db_from_cache(self):
377
- """Generate a vector database from the cache"""
378
- self.build_index()
379
- self._load_all_cache()
380
-
381
-
382
- def generate_codebase(self, force: bool = False):
383
- """Generate the codebase index
384
- Args:
385
- force: Whether to force rebuild the index, without asking the user
386
- """
387
- try:
388
- # Clean up cache for non-existent files
389
- files_to_delete = []
390
- for cached_file in list(self.vector_cache.keys()):
391
- if not os.path.exists(cached_file) or not self.is_text_file(cached_file):
392
- files_to_delete.append(cached_file)
393
- cache_path = self._get_cache_path(cached_file)
394
- try:
395
- os.remove(cache_path)
396
- except Exception as e:
397
- PrettyOutput.print(f"删除缓存文件 {cached_file} 失败: {str(e)}",
398
- output_type=OutputType.WARNING)
399
-
400
- if files_to_delete:
401
- for file_path in files_to_delete:
402
- del self.vector_cache[file_path]
403
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个不存在的文件的缓存",
404
- output_type=OutputType.INFO)
405
-
406
- # Update the git file list
407
- self.git_file_list = self.get_git_file_list()
408
-
409
- # Check file changes
410
- PrettyOutput.print("检查文件变化...", output_type=OutputType.INFO)
411
- changes_detected = False
412
- new_files = []
413
- modified_files = []
414
- deleted_files = []
415
-
416
- # Check deleted files
417
- files_to_delete = []
418
- for file_path in list(self.vector_cache.keys()):
419
- if file_path not in self.git_file_list:
420
- deleted_files.append(file_path)
421
- files_to_delete.append(file_path)
422
- changes_detected = True
423
- # Check new and modified files
424
- from rich.progress import Progress
425
- with Progress() as progress:
426
- task = progress.add_task("Check file status", total=len(self.git_file_list))
427
- for file_path in self.git_file_list:
428
- if not os.path.exists(file_path) or not self.is_text_file(file_path):
429
- progress.advance(task)
430
- continue
431
-
432
- try:
433
- current_md5 = get_file_md5(file_path)
434
-
435
- if file_path not in self.vector_cache:
436
- new_files.append(file_path)
437
- changes_detected = True
438
- elif self.vector_cache[file_path].get("md5") != current_md5:
439
- modified_files.append(file_path)
440
- changes_detected = True
441
- except Exception as e:
442
- PrettyOutput.print(f"检查 {file_path} 失败: {str(e)}",
443
- output_type=OutputType.ERROR)
444
- progress.advance(task)
445
-
446
- # If changes are detected, display changes and ask the user
447
- if changes_detected:
448
- output_lines = ["检测到以下变化:"]
449
- if new_files:
450
- output_lines.append("新文件:")
451
- output_lines.extend(f" {f}" for f in new_files)
452
- if modified_files:
453
- output_lines.append("修改的文件:")
454
- output_lines.extend(f" {f}" for f in modified_files)
455
- if deleted_files:
456
- output_lines.append("删除的文件:")
457
- output_lines.extend(f" {f}" for f in deleted_files)
458
-
459
- PrettyOutput.print("\n".join(output_lines), output_type=OutputType.INFO)
460
-
461
- # If force is True, continue directly
462
- if not force:
463
- if not user_confirm("重建索引?", False):
464
- return
465
-
466
- # Clean deleted files
467
- for file_path in files_to_delete:
468
- del self.vector_cache[file_path]
469
- if files_to_delete:
470
- PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
471
- output_type=OutputType.INFO)
472
-
473
- # Process new and modified files
474
- files_to_process = new_files + modified_files
475
- processed_files = []
476
-
477
- with yaspin(text="正在处理文件...", color="cyan") as spinner:
478
- # Use a thread pool to process files
479
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
480
- # Submit all tasks
481
- future_to_file = {
482
- executor.submit(self.process_file, file): file
483
- for file in files_to_process
484
- }
485
-
486
- # Process completed tasks
487
- for future in concurrent.futures.as_completed(future_to_file):
488
- file = future_to_file[future]
489
- try:
490
- result = future.result()
491
- if result:
492
- processed_files.append(result)
493
- spinner.write(f"✅ 处理文件成功 {file}")
494
- except Exception as e:
495
- spinner.write(f"❌ 处理文件失败 {file}: {str(e)}")
496
-
497
- spinner.text = f"处理完成"
498
- spinner.ok("✅")
499
-
500
- if processed_files:
501
- with yaspin(text="重建向量数据库...", color="cyan") as spinner:
502
- self.gen_vector_db_from_cache()
503
- spinner.text = f"成功生成了 {len(processed_files)} 个文件的索引"
504
- spinner.ok("✅")
505
- else:
506
- PrettyOutput.print("没有检测到文件变化, 不需要重建索引", output_type=OutputType.INFO)
507
-
508
- except Exception as e:
509
- # Try to save the cache when an exception occurs
510
- try:
511
- self._load_all_cache()
512
- except Exception as save_error:
513
- PrettyOutput.print(f"保存缓存失败: {str(save_error)}",
514
- output_type=OutputType.ERROR)
515
- raise e # Re-raise the original exception
516
-
517
-
518
- def _text_search_score(self, content: str, keywords: List[str]) -> float:
519
- """Calculate the matching score between the text content and the keywords
520
-
521
- Args:
522
- content: Text content
523
- keywords: List of keywords
524
-
525
- Returns:
526
- float: Matching score (0-1)
527
- """
528
- if not keywords:
529
- return 0.0
530
-
531
- content = content.lower()
532
- matched_keywords = set()
533
-
534
- for keyword in keywords:
535
- keyword = keyword.lower()
536
- if keyword in content:
537
- matched_keywords.add(keyword)
538
-
539
- # Calculate the matching score
540
- score = len(matched_keywords) / len(keywords)
541
- return score
542
-
543
- def pick_results(self, query: List[str], initial_results: List[str]) -> List[Dict[str,str]]:
544
- """Use a large model to pick the search results
545
-
546
- Args:
547
- query: Search query
548
- initial_results: Initial results list of file paths
549
-
550
- Returns:
551
- List[str]: The picked results list, each item is a file path
552
- """
553
- if not initial_results:
554
- return []
555
- with yaspin(text="正在筛选结果...", color="cyan") as spinner:
556
- try:
557
- # Maximum content length per batch
558
- max_batch_length = self.max_token_count - 1000 # Reserve space for prompt
559
- max_file_length = max_batch_length // 3 # Limit individual file size
560
-
561
- # Process files in batches
562
- all_selected_files = []
563
- current_batch = []
564
- current_token_count = 0
565
-
566
- for path in initial_results:
567
- try:
568
- content = open(path, "r", encoding="utf-8", errors="ignore").read()
569
- # Truncate large files
570
- if get_context_token_count(content) > max_file_length:
571
- spinner.write(f"❌ 截断大文件: {path}")
572
- content = content[:max_file_length] + "\n... (content truncated)"
573
-
574
- file_info = f"File: {path}\nContent: {content}\n\n"
575
- tokens_count = get_context_token_count(file_info)
576
-
577
- # If adding this file would exceed batch limit
578
- if current_token_count + tokens_count > max_batch_length:
579
- # Process current batch
580
- if current_batch:
581
- selected = self._process_batch('\n'.join(query), current_batch)
582
- all_selected_files.extend(selected)
583
- # Start new batch
584
- current_batch = [file_info]
585
- current_token_count = tokens_count
586
- else:
587
- current_batch.append(file_info)
588
- current_token_count += tokens_count
589
-
590
- except Exception as e:
591
- spinner.write(f"❌ 读取 {path} 失败: {str(e)}")
592
- continue
593
-
594
- # Process final batch
595
- if current_batch:
596
- selected = self._process_batch('\n'.join(query), current_batch)
597
- all_selected_files.extend(selected)
598
-
599
- spinner.write("✅ 结果筛选完成")
600
- # Convert set to list and maintain original order
601
- return all_selected_files
602
-
603
- except Exception as e:
604
- spinner.text = f"选择失败: {str(e)}"
605
- spinner.fail("❌")
606
- return [{"file": f, "reason": "" } for f in initial_results]
607
-
608
- def _process_batch(self, query: str, files_info: List[str]) -> List[Dict[str, str]]:
609
- """Process a batch of files"""
610
- prompt = f"""作为一名代码分析专家,请使用链式思维推理帮助识别与给定查询最相关的文件。
611
-
612
- 查询: {query}
613
-
614
- 可用文件:
615
- {''.join(files_info)}
616
-
617
- 请按以下步骤思考:
618
- 1. 首先,分析查询以识别关键需求和技术概念
619
- 2. 对于每个文件:
620
- - 检查其路径和内容
621
- - 评估其与查询需求的关系
622
- - 考虑直接和间接关系
623
- - 评估其相关性(高/中/低)
624
- 3. 仅选择与查询明确相关的文件
625
- 4. 按相关性排序,最相关的文件在前
626
-
627
- 请以YAML格式输出您的选择:
628
- <FILES>
629
- - file: path/to/most/relevant.py
630
- reason: xxxxxxxxxx
631
- - path/to/next/relevant.py
632
- reason: yyyyyyyyyy
633
- </FILES>
634
-
635
- 重要提示:
636
- - 仅包含真正相关的文件
637
- - 排除连接不明确或较弱的文件
638
- - 重点关注实现文件而非测试文件
639
- - 同时考虑文件路径和内容
640
- - 仅输出文件路径,不要包含其他文本
641
- """
642
-
643
- # Use a large model to evaluate
644
- model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
645
- response = model.chat_until_success(prompt)
646
-
647
- # Parse the response
648
- import yaml
649
- files_match = re.search(r'<FILES>\n(.*?)</FILES>', response, re.DOTALL)
650
- if not files_match:
651
- return []
652
-
653
- try:
654
- selected_files = yaml.safe_load(files_match.group(1))
655
- return selected_files if selected_files else []
656
- except Exception as e:
657
- PrettyOutput.print(f"解析响应失败: {str(e)}", OutputType.ERROR)
658
- return []
659
-
660
- def _generate_query_variants(self, query: str) -> List[str]:
661
- """Generate different expressions of the query optimized for vector search
662
-
663
- Args:
664
- query: Original query
665
-
666
- Returns:
667
- List[str]: The query variants list
668
- """
669
- model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
670
- prompt = f"""请基于以下查询生成10个针对向量搜索优化的不同表达。每个表达应满足:
671
- 1. 聚焦关键技术概念和术语
672
- 2. 使用清晰明确的语言
673
- 3. 包含重要的上下文术语
674
- 4. 避免使用通用或模糊的词语
675
- 5. 保持与原始查询的语义相似性
676
- 6. 适合基于嵌入的搜索
677
-
678
- 原始查询:
679
- {query}
680
-
681
- 示例转换:
682
- 查询: "如何处理用户登录?"
683
- 输出格式:
684
- <QUESTION>
685
- - 用户认证的实现与流程
686
- - 登录系统架构与组件
687
- - 凭证验证与会话管理
688
- - ...
689
- </QUESTION>
690
-
691
- 请以指定格式提供10个搜索优化的表达。
692
- """
693
- response = model.chat_until_success(prompt)
694
-
695
- # Parse the response using YAML format
696
- import yaml
697
- variants = []
698
- question_match = re.search(r'<QUESTION>\n(.*?)</QUESTION>', response, re.DOTALL)
699
- if question_match:
700
- try:
701
- variants = yaml.safe_load(question_match.group(1))
702
- if not isinstance(variants, list):
703
- variants = [str(variants)]
704
- except Exception as e:
705
- PrettyOutput.print(f"解析变体失败: {str(e)}", OutputType.ERROR)
706
-
707
- # Add original query
708
- variants.append(query)
709
- return variants if variants else [query]
710
-
711
- def _vector_search(self, query_variants: List[str], top_k: int) -> Dict[str, Tuple[str, float, str]]:
712
- """Use vector search to find related files
713
-
714
- Args:
715
- query_variants: The query variants list
716
- top_k: The number of results to return
717
-
718
- Returns:
719
- Dict[str, Tuple[str, float, str]]: The mapping from file path to (file path, score, description)
720
- """
721
- results = {}
722
- for query in query_variants:
723
- query_vector = get_embedding(self.embedding_model, query)
724
- query_vector = query_vector.reshape(1, -1)
725
-
726
- distances, indices = self.index.search(query_vector, top_k) # type: ignore
727
-
728
- for i, distance in zip(indices[0], distances[0]):
729
- if i == -1:
730
- continue
731
-
732
- similarity = 1.0 / (1.0 + float(distance))
733
- file_path = self.file_paths[i]
734
- # Use the highest similarity score
735
- if file_path not in results:
736
- if similarity > 0.5:
737
- data = self.vector_cache[file_path]
738
- results[file_path] = (file_path, similarity, data["description"])
739
-
740
- return results
741
-
742
-
743
- def search_similar(self, query: str, top_k: int = 30) -> List[Dict[str, str]]:
744
- """Search related files with optimized retrieval"""
745
- with yaspin(text="正在搜索相关文件...", color="cyan") as spinner:
746
- try:
747
- with spinner.hidden():
748
- self.generate_codebase()
749
- if self.index is None:
750
- spinner.text = "没有找到有效的缓存文件"
751
- spinner.ok("✅")
752
- return []
753
-
754
- # Generate query variants for better coverage
755
- spinner.text = "生成查询变体..."
756
- query_variants = self._generate_query_variants(query)
757
- spinner.write("✅ 查询变体生成完成")
758
-
759
- # Collect results from all variants
760
- spinner.text = "收集结果..."
761
- all_results = []
762
- seen_files = set()
763
-
764
- for variant in query_variants:
765
- # Get vector for each variant
766
- query_vector = get_embedding(self.embedding_model, variant)
767
- query_vector = query_vector.reshape(1, -1)
768
-
769
- # Search with current variant
770
- initial_k = min(top_k * 2, len(self.file_paths))
771
- distances, indices = self.index.search(query_vector, initial_k) # type: ignore
772
-
773
- # Process results
774
- for idx, dist in zip(indices[0], distances[0]):
775
- if idx != -1:
776
- file_path = self.file_paths[idx]
777
- if file_path not in seen_files:
778
- similarity = 1.0 / (1.0 + float(dist))
779
- if similarity > 0.3: # Lower threshold for better recall
780
- seen_files.add(file_path)
781
- all_results.append((file_path, similarity, self.vector_cache[file_path]["description"]))
782
- spinner.write("✅ 结果收集完成")
783
- if not all_results:
784
- spinner.text = "没有找到相关文件"
785
- spinner.ok("✅")
786
- return []
787
-
788
- spinner.text = "排序..."
789
- # Sort by similarity and take top_k
790
- all_results.sort(key=lambda x: x[1], reverse=True)
791
- results = all_results[:top_k]
792
- spinner.write("✅ 排序完成")
793
-
794
- with spinner.hidden():
795
- results = self.pick_results(query_variants, [path for path, _, _ in results])
796
-
797
- output = "Found related files:\n"
798
- for file in results:
799
- output += f'''- {file['file']} ({file['reason']})\n'''
800
-
801
- spinner.text="结果输出完成"
802
- spinner.ok("✅")
803
- return results
804
-
805
- except Exception as e:
806
- spinner.text = f"搜索失败: {str(e)}"
807
- spinner.fail("❌")
808
- return []
809
-
810
- def ask_codebase(self, query: str, top_k: int=20) -> Tuple[List[Dict[str, str]], str]:
811
- """Query the codebase with enhanced context building"""
812
- files_from_codebase = self.search_similar(query, top_k)
813
-
814
- if not files_from_codebase:
815
- PrettyOutput.print("没有找到相关文件", output_type=OutputType.WARNING)
816
- return [], ""
817
-
818
- prompt = f"""
819
- # 🤖 角色定义
820
- 您是一位代码分析专家,能够提供关于代码库的全面且准确的回答。
821
-
822
- # 🎯 核心职责
823
- - 深入分析代码文件
824
- - 清晰解释技术概念
825
- - 提供相关代码示例
826
- - 识别缺失的信息
827
- - 使用用户的语言进行回答
828
-
829
- # 📋 回答要求
830
- ## 内容质量
831
- - 关注实现细节
832
- - 保持技术准确性
833
- - 包含相关代码片段
834
- - 指出任何缺失的信息
835
- - 使用专业术语
836
-
837
- ## 回答格式
838
- - question: [重述问题]
839
- answer: |
840
- [详细的技术回答,包含:
841
- - 实现细节
842
- - 代码示例(如果相关)
843
- - 缺失的信息(如果有)
844
- - 相关技术概念]
845
-
846
- - question: [如果需要,提出后续问题]
847
- answer: |
848
- [额外的技术细节]
849
-
850
- # 🔍 分析上下文
851
- 问题: {query}
852
-
853
- 相关代码文件(按相关性排序):
854
- """
855
-
856
- with yaspin(text="正在生成回答...", color="cyan") as spinner:
857
- # 添加上下文,控制长度
858
- spinner.text = "添加上下文..."
859
- available_count = self.max_token_count - get_context_token_count(prompt) - 1000 # 为回答预留空间
860
- current_count = 0
861
-
862
- for path in files_from_codebase:
863
- try:
864
- content = open(path["file"], "r", encoding="utf-8", errors="ignore").read()
865
- file_content = f"""
866
- ## 文件: {path["file"]}
867
- ```
868
- {content}
869
- ```
870
- ---
871
- """
872
- if current_count + get_context_token_count(file_content) > available_count:
873
- spinner.write("⚠️ 由于上下文长度限制, 一些文件被省略")
874
- break
875
-
876
- prompt += file_content
877
- current_count += get_context_token_count(file_content)
878
-
879
- except Exception as e:
880
- spinner.write(f"❌ 读取 {path} 失败: {str(e)}")
881
- continue
882
-
883
- prompt += """
884
- # ❗ 重要规则
885
- 1. 始终基于提供的代码进行回答
886
- 2. 保持技术准确性
887
- 3. 在相关时包含代码示例
888
- 4. 指出任何缺失的信息
889
- 5. 保持专业语言
890
- 6. 使用用户的语言进行回答
891
- """
892
-
893
- model = PlatformRegistry.get_global_platform_registry().get_thinking_platform()
894
- spinner.text = "生成回答..."
895
- ret = files_from_codebase, model.chat_until_success(prompt)
896
- spinner.text = "回答生成完成"
897
- spinner.ok("✅")
898
- return ret
899
-
900
- def is_index_generated(self) -> bool:
901
- """Check if the index has been generated"""
902
- try:
903
- # 1. 检查基本条件
904
- if not self.vector_cache or not self.file_paths:
905
- return False
906
-
907
- if not hasattr(self, 'index') or self.index is None:
908
- return False
909
-
910
- # 2. 检查索引是否可用
911
- # 创建测试向量
912
- test_vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
913
- try:
914
- self.index.search(test_vector, 1) # type: ignore
915
- except Exception:
916
- return False
917
-
918
- # 3. 验证向量缓存和文件路径的一致性
919
- if len(self.vector_cache) != len(self.file_paths):
920
- return False
921
-
922
- # 4. 验证所有缓存文件
923
- for file_path in self.file_paths:
924
- if file_path not in self.vector_cache:
925
- return False
926
-
927
- cache_path = self._get_cache_path(file_path)
928
- if not os.path.exists(cache_path):
929
- return False
930
-
931
- cache_data = self.vector_cache[file_path]
932
- if not isinstance(cache_data.get("vector"), np.ndarray):
933
- return False
934
-
935
- return True
936
-
937
- except Exception as e:
938
- PrettyOutput.print(f"检查索引状态失败: {str(e)}",
939
- output_type=OutputType.ERROR)
940
- return False
941
-
942
-
943
-
944
-
945
-
946
- def main():
947
-
948
- parser = argparse.ArgumentParser(description='Codebase management and search tool')
949
- subparsers = parser.add_subparsers(dest='command', help='Available commands')
950
-
951
- # Generate command
952
- generate_parser = subparsers.add_parser('generate', help='Generate codebase index')
953
- generate_parser.add_argument('--force', action='store_true', help='Force rebuild index')
954
-
955
- # Search command
956
- search_parser = subparsers.add_parser('search', help='Search similar code files')
957
- search_parser.add_argument('query', type=str, help='Search query')
958
- search_parser.add_argument('--top-k', type=int, default=20, help='Number of results to return (default: 20)')
959
-
960
- # Ask command
961
- ask_parser = subparsers.add_parser('ask', help='Ask a question about the codebase')
962
- ask_parser.add_argument('question', type=str, help='Question to ask')
963
- ask_parser.add_argument('--top-k', type=int, default=20, help='Number of results to use (default: 20)')
964
-
965
- export_parser = subparsers.add_parser('export', help='Export current index data')
966
- args = parser.parse_args()
967
-
968
- current_dir = find_git_root()
969
- codebase = CodeBase(current_dir)
970
-
971
- if args.command == 'export':
972
- codebase.export()
973
- return
974
-
975
- # 如果没有生成索引,且不是生成命令,提示用户先生成索引
976
- if not codebase.is_index_generated() and args.command != 'generate':
977
- PrettyOutput.print("索引尚未生成,请先运行 'generate' 命令生成索引", output_type=OutputType.WARNING)
978
- return
979
-
980
- if args.command == 'generate':
981
- try:
982
- codebase.generate_codebase(force=args.force)
983
- PrettyOutput.print("代码库生成完成", output_type=OutputType.SUCCESS)
984
- except Exception as e:
985
- PrettyOutput.print(f"代码库生成失败: {str(e)}", output_type=OutputType.ERROR)
986
-
987
- elif args.command == 'search':
988
- results = codebase.search_similar(args.query, args.top_k)
989
- if not results:
990
- PrettyOutput.print("没有找到相似的文件", output_type=OutputType.WARNING)
991
- return
992
-
993
- output = "搜索结果:\n"
994
- for path in results:
995
- output += f"""- {path}\n"""
996
- PrettyOutput.print(output, output_type=OutputType.INFO, lang="markdown")
997
-
998
- elif args.command == 'ask':
999
- files, answer = codebase.ask_codebase(args.question, args.top_k)
1000
- output = f"# 相关文件:\n"
1001
- for file in files:
1002
- output += f"""- {file['file']} ({file['reason']})\n"""
1003
- output += f"# 回答:\n{answer}"
1004
- PrettyOutput.print(output, output_type=OutputType.SYSTEM, lang="markdown")
1005
-
1006
- else:
1007
- parser.print_help()
1008
-
1009
-
1010
- if __name__ == "__main__":
1011
- exit(main())