jarvis-ai-assistant 0.1.58__py3-none-any.whl → 0.1.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

@@ -0,0 +1,636 @@
1
+ import hashlib
2
+ import os
3
+ import numpy as np
4
+ import faiss
5
+ from typing import List, Tuple, Optional
6
+ from jarvis.models.registry import PlatformRegistry
7
+ import concurrent.futures
8
+ from threading import Lock
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from jarvis.utils import OutputType, PrettyOutput, find_git_root
11
+ from jarvis.utils import load_env_from_file
12
+ import argparse
13
+ from sentence_transformers import SentenceTransformer
14
+ import pickle
15
+
16
+ class CodeBase:
17
+ def __init__(self, root_dir: str):
18
+ load_env_from_file()
19
+ self.root_dir = root_dir
20
+ os.chdir(self.root_dir)
21
+ self.thread_count = int(os.environ.get("JARVIS_THREAD_COUNT") or 10)
22
+ self.cheap_platform = os.environ.get("JARVIS_CHEAP_PLATFORM") or os.environ.get("JARVIS_PLATFORM") or "kimi"
23
+ self.cheap_model = os.environ.get("JARVIS_CHEAP_MODEL") or os.environ.get("JARVIS_MODEL") or "kimi"
24
+ self.normal_platform = os.environ.get("JARVIS_PLATFORM") or "kimi"
25
+ self.codegen_platform = os.environ.get("JARVIS_CODEGEN_PLATFORM") or os.environ.get("JARVIS_PLATFORM") or "kimi"
26
+ self.codegen_model = os.environ.get("JARVIS_CODEGEN_MODEL") or os.environ.get("JARVIS_MODEL") or "kimi"
27
+ self.normal_model = os.environ.get("JARVIS_MODEL") or "kimi"
28
+ self.embedding_model_name = os.environ.get("JARVIS_EMBEDDING_MODEL") or "BAAI/bge-large-zh-v1.5"
29
+ if not self.cheap_platform or not self.cheap_model or not self.codegen_platform or not self.codegen_model or not self.embedding_model_name or not self.normal_platform or not self.normal_model:
30
+ raise ValueError("JARVIS_CHEAP_PLATFORM or JARVIS_CHEAP_MODEL or JARVIS_CODEGEN_PLATFORM or JARVIS_CODEGEN_MODEL or JARVIS_EMBEDDING_MODEL or JARVIS_PLATFORM or JARVIS_MODEL is not set")
31
+
32
+ PrettyOutput.print(f"廉价模型使用平台: {self.cheap_platform} 模型: {self.cheap_model}", output_type=OutputType.INFO)
33
+ PrettyOutput.print(f"代码生成模型使用平台: {self.codegen_platform} 模型: {self.codegen_model}", output_type=OutputType.INFO)
34
+ PrettyOutput.print(f"分析模型使用平台: {self.normal_platform} 模型: {self.normal_model}", output_type=OutputType.INFO)
35
+ PrettyOutput.print(f"嵌入模型: {self.embedding_model_name}", output_type=OutputType.INFO)
36
+ PrettyOutput.print(f"索引建立线程数: {self.thread_count}", output_type=OutputType.INFO)
37
+ PrettyOutput.print(f"检索算法:分层导航小世界算法", output_type=OutputType.INFO)
38
+
39
+ # 初始化数据目录
40
+ self.data_dir = os.path.join(self.root_dir, ".jarvis-codebase")
41
+ if not os.path.exists(self.data_dir):
42
+ os.makedirs(self.data_dir)
43
+
44
+ # 初始化嵌入模型,使用系统默认缓存目录
45
+ try:
46
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
+ PrettyOutput.print("正在加载/下载模型,请稍候...", output_type=OutputType.INFO)
48
+ self.embedding_model = SentenceTransformer(self.embedding_model_name)
49
+
50
+ # 强制完全加载所有模型组件
51
+ test_text = """
52
+ 这是一段测试文本,用于确保模型完全加载。
53
+ 包含多行内容,以模拟实际使用场景。
54
+ """
55
+ # 预热模型,确保所有组件都被加载
56
+ self.embedding_model.encode([test_text],
57
+ convert_to_tensor=True,
58
+ normalize_embeddings=True)
59
+ PrettyOutput.print("模型加载完成", output_type=OutputType.SUCCESS)
60
+ except Exception as e:
61
+ PrettyOutput.print(f"加载模型失败: {str(e)}", output_type=OutputType.ERROR)
62
+ raise
63
+
64
+ self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
65
+
66
+ self.git_file_list = self.get_git_file_list()
67
+ self.platform_registry = PlatformRegistry().get_global_platform_registry()
68
+
69
+ # 初始化缓存和索引
70
+ self.cache_path = os.path.join(self.data_dir, "cache.pkl")
71
+ self.vector_cache = {}
72
+ self.file_paths = []
73
+
74
+ # 加载缓存
75
+ if os.path.exists(self.cache_path):
76
+ try:
77
+ with open(self.cache_path, 'rb') as f:
78
+ cache_data = pickle.load(f)
79
+ self.vector_cache = cache_data["vectors"]
80
+ self.file_paths = cache_data["file_paths"]
81
+ PrettyOutput.print(f"加载了 {len(self.vector_cache)} 个向量缓存",
82
+ output_type=OutputType.INFO)
83
+ # 从缓存重建索引
84
+ self.build_index()
85
+ except Exception as e:
86
+ PrettyOutput.print(f"加载缓存失败: {str(e)}",
87
+ output_type=OutputType.WARNING)
88
+ self.vector_cache = {}
89
+ self.file_paths = []
90
+ self.index = None
91
+
92
+ def get_git_file_list(self):
93
+ """获取 git 仓库中的文件列表,排除 .jarvis-codebase 目录"""
94
+ files = os.popen("git ls-files").read().splitlines()
95
+ # 过滤掉 .jarvis-codebase 目录下的文件
96
+ return [f for f in files if not f.startswith(".jarvis-codebase/")]
97
+
98
+ def is_text_file(self, file_path: str):
99
+ with open(file_path, "r", encoding="utf-8") as f:
100
+ try:
101
+ f.read()
102
+ return True
103
+ except UnicodeDecodeError:
104
+ return False
105
+
106
+ def make_description(self, file_path: str) -> str:
107
+ model = self.platform_registry.create_platform(self.cheap_platform)
108
+ model.set_model_name(self.cheap_model)
109
+ model.set_suppress_output(True)
110
+ content = open(file_path, "r", encoding="utf-8").read()
111
+ prompt = f"""请分析以下代码文件,并生成一个详细的描述。描述应该包含以下要点:
112
+
113
+ 1. 主要功能和用途
114
+ 2. 关键类和方法的作用
115
+ 3. 重要的依赖和技术特征(如使用了什么框架、算法、设计模式等)
116
+ 4. 代码处理的主要数据类型和数据结构
117
+ 5. 关键业务逻辑和处理流程
118
+ 6. 特殊功能点和亮点特性
119
+
120
+ 请用简洁专业的语言描述,突出代码的技术特征和功能特点,以便后续进行相似代码检索。
121
+
122
+ 文件路径:{file_path}
123
+ 代码内容:
124
+ {content}
125
+ """
126
+ response = model.chat(prompt)
127
+ return response
128
+
129
+ def save_cache(self):
130
+ """保存缓存数据"""
131
+ try:
132
+ cache_data = {
133
+ "vectors": self.vector_cache,
134
+ "file_paths": self.file_paths
135
+ }
136
+ with open(self.cache_path, 'wb') as f:
137
+ pickle.dump(cache_data, f)
138
+ PrettyOutput.print(f"保存了 {len(self.vector_cache)} 个向量缓存",
139
+ output_type=OutputType.INFO)
140
+ except Exception as e:
141
+ PrettyOutput.print(f"保存缓存失败: {str(e)}",
142
+ output_type=OutputType.ERROR)
143
+
144
+ def get_cached_vector(self, file_path: str, description: str) -> Optional[np.ndarray]:
145
+ """从缓存获取文件的向量表示"""
146
+ if file_path not in self.vector_cache:
147
+ return None
148
+
149
+ # 检查文件是否被修改
150
+ try:
151
+ with open(file_path, "rb") as f:
152
+ current_md5 = hashlib.md5(f.read()).hexdigest()
153
+ except Exception as e:
154
+ PrettyOutput.print(f"计算文件MD5失败 {file_path}: {str(e)}",
155
+ output_type=OutputType.ERROR)
156
+ return None
157
+
158
+ cached_data = self.vector_cache[file_path]
159
+ if cached_data["md5"] != current_md5:
160
+ return None
161
+
162
+ # 检查描述是否变化
163
+ if cached_data["description"] != description:
164
+ return None
165
+
166
+ return cached_data["vector"]
167
+
168
+ def cache_vector(self, file_path: str, vector: np.ndarray, description: str):
169
+ """缓存文件的向量表示"""
170
+ try:
171
+ with open(file_path, "rb") as f:
172
+ file_md5 = hashlib.md5(f.read()).hexdigest()
173
+ except Exception as e:
174
+ PrettyOutput.print(f"计算文件MD5失败 {file_path}: {str(e)}",
175
+ output_type=OutputType.ERROR)
176
+ file_md5 = ""
177
+
178
+ self.vector_cache[file_path] = {
179
+ "path": file_path, # 保存文件路径
180
+ "md5": file_md5, # 保存文件MD5
181
+ "description": description, # 保存文件描述
182
+ "vector": vector # 保存向量
183
+ }
184
+
185
+ # 保存缓存到文件
186
+ try:
187
+ with open(self.cache_path, 'wb') as f:
188
+ cache_data = {
189
+ "vectors": self.vector_cache,
190
+ "file_paths": self.file_paths
191
+ }
192
+ pickle.dump(cache_data, f)
193
+ except Exception as e:
194
+ PrettyOutput.print(f"保存向量缓存失败: {str(e)}",
195
+ output_type=OutputType.ERROR)
196
+
197
+ def get_embedding(self, text: str) -> np.ndarray:
198
+ """使用 transformers 模型获取文本的向量表示"""
199
+ # 对长文本进行截断
200
+ max_length = 512 # 或其他合适的长度
201
+ text = ' '.join(text.split()[:max_length])
202
+
203
+ # 获取嵌入向量
204
+ embedding = self.embedding_model.encode(text,
205
+ normalize_embeddings=True, # L2归一化
206
+ show_progress_bar=False)
207
+ vector = np.array(embedding, dtype=np.float32)
208
+ return vector
209
+
210
+ def vectorize_file(self, file_path: str, description: str) -> np.ndarray:
211
+ """将文件内容和描述向量化"""
212
+ try:
213
+ # 先尝试从缓存获取
214
+ cached_vector = self.get_cached_vector(file_path, description)
215
+ if cached_vector is not None:
216
+ return cached_vector
217
+
218
+ # 组合文件信息
219
+ combined_text = f"""
220
+ 文件路径: {file_path}
221
+ 文件描述: {description}
222
+ """
223
+ vector = self.get_embedding(combined_text)
224
+
225
+ # 保存到缓存
226
+ self.cache_vector(file_path, vector, description)
227
+ return vector
228
+ except Exception as e:
229
+ PrettyOutput.print(f"Error vectorizing file {file_path}: {str(e)}",
230
+ output_type=OutputType.ERROR)
231
+ return np.zeros(self.vector_dim, dtype=np.float32)
232
+
233
+ def clean_cache(self) -> bool:
234
+ """清理过期的缓存记录,返回是否有文件被删除"""
235
+ files_to_delete = []
236
+ for file_path in list(self.vector_cache.keys()):
237
+ if file_path not in self.git_file_list:
238
+ del self.vector_cache[file_path]
239
+ files_to_delete.append(file_path)
240
+
241
+ if files_to_delete:
242
+ self.save_cache()
243
+ PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
244
+ output_type=OutputType.INFO)
245
+ return True
246
+ return False
247
+
248
+ def process_file(self, file_path: str):
249
+ """处理单个文件"""
250
+ try:
251
+ # 跳过不存在的文件
252
+ if not os.path.exists(file_path):
253
+ return None
254
+
255
+ if not self.is_text_file(file_path):
256
+ return None
257
+
258
+ md5 = hashlib.md5(open(file_path, "rb").read()).hexdigest()
259
+
260
+ # 检查文件是否已经处理过且内容未变
261
+ if file_path in self.vector_cache:
262
+ if self.vector_cache[file_path].get("md5") == md5:
263
+ return None
264
+
265
+ description = self.make_description(file_path)
266
+ vector = self.vectorize_file(file_path, description)
267
+
268
+ # 保存到缓存,使用实际文件路径作为键
269
+ self.vector_cache[file_path] = {
270
+ "vector": vector,
271
+ "description": description,
272
+ "md5": md5
273
+ }
274
+
275
+ return file_path
276
+
277
+ except Exception as e:
278
+ PrettyOutput.print(f"处理文件失败 {file_path}: {str(e)}",
279
+ output_type=OutputType.ERROR,
280
+ traceback=True)
281
+ return None
282
+
283
+ def build_index(self):
284
+ """从向量缓存构建 faiss 索引"""
285
+ # 创建底层 HNSW 索引
286
+ hnsw_index = faiss.IndexHNSWFlat(self.vector_dim, 16)
287
+ hnsw_index.hnsw.efConstruction = 40
288
+ hnsw_index.hnsw.efSearch = 16
289
+
290
+ # 用 IndexIDMap 包装 HNSW 索引
291
+ self.index = faiss.IndexIDMap(hnsw_index)
292
+
293
+ vectors = []
294
+ ids = []
295
+ self.file_paths = [] # 重置文件路径列表
296
+
297
+ for i, (file_path, data) in enumerate(self.vector_cache.items()):
298
+ vectors.append(data["vector"].reshape(1, -1))
299
+ ids.append(i)
300
+ self.file_paths.append(file_path)
301
+
302
+ if vectors:
303
+ vectors = np.vstack(vectors)
304
+ self.index.add_with_ids(vectors, np.array(ids))
305
+ else:
306
+ self.index = None
307
+
308
+ def gen_vector_db_from_cache(self):
309
+ """从缓存生成向量数据库"""
310
+ self.build_index()
311
+ self.save_cache()
312
+
313
+
314
+ def generate_codebase(self, force: bool = False):
315
+ """生成代码库索引
316
+ Args:
317
+ force: 是否强制重建索引,不询问用户
318
+ """
319
+ # 更新 git 文件列表
320
+ self.git_file_list = self.get_git_file_list()
321
+
322
+ # 检查文件变化
323
+ changes_detected = False
324
+ new_files = []
325
+ modified_files = []
326
+ deleted_files = []
327
+
328
+ # 检查删除的文件
329
+ files_to_delete = []
330
+ for file_path in list(self.vector_cache.keys()):
331
+ if file_path not in self.git_file_list:
332
+ deleted_files.append(file_path)
333
+ files_to_delete.append(file_path)
334
+ changes_detected = True
335
+
336
+ # 检查新增和修改的文件
337
+ for file_path in self.git_file_list:
338
+ if not os.path.exists(file_path) or not self.is_text_file(file_path):
339
+ continue
340
+
341
+ try:
342
+ current_md5 = hashlib.md5(open(file_path, "rb").read()).hexdigest()
343
+
344
+ if file_path not in self.vector_cache:
345
+ new_files.append(file_path)
346
+ changes_detected = True
347
+ elif self.vector_cache[file_path].get("md5") != current_md5:
348
+ modified_files.append(file_path)
349
+ changes_detected = True
350
+ except Exception as e:
351
+ PrettyOutput.print(f"检查文件失败 {file_path}: {str(e)}",
352
+ output_type=OutputType.ERROR)
353
+ continue
354
+
355
+ # 如果检测到变化,显示变化并询问用户
356
+ if changes_detected:
357
+ PrettyOutput.print("\n检测到以下变化:", output_type=OutputType.WARNING)
358
+ if new_files:
359
+ PrettyOutput.print("\n新增文件:", output_type=OutputType.INFO)
360
+ for f in new_files:
361
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
362
+ if modified_files:
363
+ PrettyOutput.print("\n修改的文件:", output_type=OutputType.INFO)
364
+ for f in modified_files:
365
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
366
+ if deleted_files:
367
+ PrettyOutput.print("\n删除的文件:", output_type=OutputType.INFO)
368
+ for f in deleted_files:
369
+ PrettyOutput.print(f" {f}", output_type=OutputType.INFO)
370
+
371
+
372
+ # 如果force为True,直接继续
373
+ if not force:
374
+ # 询问用户是否继续
375
+ while True:
376
+ response = input("\n是否重建索引?[y/N] ").lower().strip()
377
+ if response in ['y', 'yes']:
378
+ break
379
+ elif response in ['', 'n', 'no']:
380
+ PrettyOutput.print("取消重建索引", output_type=OutputType.INFO)
381
+ return
382
+ else:
383
+ PrettyOutput.print("请输入 y 或 n", output_type=OutputType.WARNING)
384
+
385
+ # 清理已删除的文件
386
+ for file_path in files_to_delete:
387
+ del self.vector_cache[file_path]
388
+ if files_to_delete:
389
+ PrettyOutput.print(f"清理了 {len(files_to_delete)} 个文件的缓存",
390
+ output_type=OutputType.INFO)
391
+
392
+ # 处理新文件和修改的文件
393
+ processed_files = []
394
+ files_to_process = new_files + modified_files
395
+
396
+ # 使用线程池处理文件
397
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
398
+ futures = [executor.submit(self.process_file, file) for file in files_to_process]
399
+ for future in concurrent.futures.as_completed(futures):
400
+ result = future.result()
401
+ if result:
402
+ processed_files.append(result)
403
+ PrettyOutput.print(f"索引文件: {result}", output_type=OutputType.INFO)
404
+
405
+ PrettyOutput.print("重新生成向量数据库", output_type=OutputType.INFO)
406
+ self.gen_vector_db_from_cache()
407
+ PrettyOutput.print(f"成功为 {len(processed_files)} 个文件生成索引", output_type=OutputType.INFO)
408
+ else:
409
+ PrettyOutput.print("没有检测到文件变更,无需重建索引", output_type=OutputType.INFO)
410
+
411
+ def rerank_results(self, query: str, initial_results: List[Tuple[str, float, str]]) -> List[Tuple[str, float, str]]:
412
+ """使用大模型对搜索结果重新排序"""
413
+ if not initial_results:
414
+ return []
415
+
416
+ model = self.platform_registry.create_platform(self.normal_platform)
417
+ model.set_model_name(self.normal_model)
418
+ model.set_suppress_output(True)
419
+
420
+ try:
421
+ # 构建重排序的prompt
422
+ prompt = f"""请根据用户的查询,对以下代码文件进行相关性排序。对每个文件给出0-100的相关性分数,分数越高表示越相关。
423
+ 只需要输出每个文件的分数,格式为:
424
+ <RERANK_START>
425
+ 文件路径: 分数
426
+ 文件路径: 分数
427
+ <RERANK_END>
428
+
429
+ 用户查询: {query}
430
+
431
+ 待评估文件:
432
+ """
433
+ for path, _, desc in initial_results:
434
+ prompt += f"""
435
+ 文件: {path}
436
+ 描述: {desc}
437
+ ---
438
+ """
439
+
440
+ response = model.chat(prompt)
441
+
442
+ # 提取<RERANK_START>和<RERANK_END>之间的内容
443
+ start_tag = "<RERANK_START>"
444
+ end_tag = "<RERANK_END>"
445
+ if start_tag in response and end_tag in response:
446
+ response = response[response.find(start_tag) + len(start_tag):response.find(end_tag)]
447
+
448
+ # 解析响应,提取文件路径和分数
449
+ scored_results = []
450
+ for line in response.split('\n'):
451
+ if ':' not in line:
452
+ continue
453
+ try:
454
+ file_path, score_str = line.split(':', 1)
455
+ file_path = file_path.strip()
456
+ score = float(score_str.strip()) / 100.0 # 转换为0-1范围
457
+ # 只保留相关度大于等于0.7的结果
458
+ if score >= 0.7:
459
+ # 找到对应的原始描述
460
+ desc = next((desc for p, _, desc in initial_results if p == file_path), "")
461
+ scored_results.append((file_path, score, desc))
462
+ except:
463
+ continue
464
+
465
+ # 按分数降序排序
466
+ return sorted(scored_results, key=lambda x: x[1], reverse=True)
467
+
468
+ finally:
469
+ model.delete_chat()
470
+
471
+ return initial_results
472
+
473
+ def search_similar(self, query: str, top_k: int = 20) -> List[Tuple[str, float, str]]:
474
+ """搜索相似文件"""
475
+ model = self.platform_registry.create_platform(self.normal_platform)
476
+ model.set_model_name(self.normal_model)
477
+ model.set_suppress_output(True)
478
+
479
+ try:
480
+ prompt = f"""请根据以下查询,生成意思完全相同的另一个表述。这个表述将用于代码搜索,所以要保持专业性和准确性。
481
+ 原始查询: {query}
482
+
483
+ 请直接输出新表述,不要有编号或其他标记。
484
+ """
485
+
486
+ query = model.chat(prompt)
487
+
488
+ finally:
489
+ model.delete_chat()
490
+
491
+ PrettyOutput.print(f"查询: {query}", output_type=OutputType.INFO)
492
+
493
+ # 为每个查询获取相似文件
494
+ q_vector = self.get_embedding(query)
495
+ q_vector = q_vector.reshape(1, -1)
496
+
497
+ distances, indices = self.index.search(q_vector, top_k)
498
+
499
+ PrettyOutput.print(f"查询 {query} 的结果: ", output_type=OutputType.INFO)
500
+
501
+ initial_results = []
502
+
503
+ for i, distance in zip(indices[0], distances[0]):
504
+ if i == -1: # faiss返回-1表示无效结果
505
+ continue
506
+
507
+ similarity = 1.0 / (1.0 + float(distance))
508
+ # 只保留相似度大于等于0.5的结果
509
+ if similarity >= 0.5:
510
+ PrettyOutput.print(f" {self.file_paths[i]} : 距离 {distance:.3f}, 相似度 {similarity:.3f}",
511
+ output_type=OutputType.INFO)
512
+
513
+ file_path = self.file_paths[i]
514
+ data = self.vector_cache[file_path]
515
+ initial_results.append((file_path, similarity, data["description"]))
516
+
517
+ if not initial_results:
518
+ PrettyOutput.print("没有找到相似度大于0.5的文件", output_type=OutputType.WARNING)
519
+ return []
520
+
521
+ # 使用大模型重新排序
522
+ PrettyOutput.print("使用大模型重新排序...", output_type=OutputType.INFO)
523
+ reranked_results = self.rerank_results(query, initial_results)
524
+
525
+ return reranked_results
526
+
527
+ def ask_codebase(self, query: str, top_k: int=20) -> str:
528
+ """查询代码库"""
529
+ results = self.search_similar(query, top_k)
530
+ PrettyOutput.print(f"找到的关联文件: ", output_type=OutputType.SUCCESS)
531
+ for path, score, _ in results:
532
+ PrettyOutput.print(f"文件: {path} 关联度: {score:.3f}",
533
+ output_type=OutputType.INFO)
534
+
535
+ prompt = f"""你是一个代码专家,请根据以下文件信息回答用户的问题:
536
+ """
537
+ for path, _, _ in results:
538
+ try:
539
+ if len(prompt) > 30 * 1024:
540
+ PrettyOutput.print(f"避免上下文超限,丢弃低相关度文件:{path}", OutputType.WARNING)
541
+ continue
542
+ content = open(path, "r", encoding="utf-8").read()
543
+ prompt += f"""
544
+ 文件路径: {path}prompt
545
+ 文件内容:
546
+ {content}
547
+ ========================================
548
+ """
549
+ except Exception as e:
550
+ PrettyOutput.print(f"读取文件失败 {path}: {str(e)}",
551
+ output_type=OutputType.ERROR)
552
+ continue
553
+
554
+ prompt += f"""
555
+ 用户问题: {query}
556
+
557
+ 请用专业的语言回答用户的问题,如果给出的文件内容不足以回答用户的问题,请告诉用户,绝对不要胡编乱造。
558
+ """
559
+ model = self.platform_registry.create_platform(self.codegen_platform)
560
+ model.set_model_name(self.codegen_model)
561
+ try:
562
+ response = model.chat(prompt)
563
+ return response
564
+ finally:
565
+ model.delete_chat()
566
+
567
+ def is_index_generated(self) -> bool:
568
+ """检查索引是否已经生成"""
569
+ # 检查缓存文件是否存在
570
+ if not os.path.exists(self.cache_path):
571
+ return False
572
+
573
+ # 检查缓存是否有效
574
+ try:
575
+ with open(self.cache_path, 'rb') as f:
576
+ cache_data = pickle.load(f)
577
+ if not cache_data.get("vectors") or not cache_data.get("file_paths"):
578
+ return False
579
+ except Exception:
580
+ return False
581
+
582
+ # 检查索引是否已构建
583
+ if not hasattr(self, 'index') or self.index is None:
584
+ return False
585
+
586
+ # 检查向量缓存和文件路径列表是否非空
587
+ if not self.vector_cache or not self.file_paths:
588
+ return False
589
+
590
+ return True
591
+
592
+
593
+
594
+ def main():
595
+ parser = argparse.ArgumentParser(description='Codebase management and search tool')
596
+ parser.add_argument('--search', type=str, help='Search query to find similar code files')
597
+ parser.add_argument('--top-k', type=int, default=20, help='Number of results to return (default: 20)')
598
+ parser.add_argument('--ask', type=str, help='Ask a question about the codebase')
599
+ parser.add_argument('--generate', action='store_true', help='Generate codebase index')
600
+ args = parser.parse_args()
601
+
602
+ current_dir = find_git_root()
603
+ codebase = CodeBase(current_dir)
604
+
605
+ # 如果没有生成索引,且不是生成命令,提示用户先生成索引
606
+ if not codebase.is_index_generated() and not args.generate:
607
+ PrettyOutput.print("索引尚未生成,请先运行 --generate 生成索引", output_type=OutputType.WARNING)
608
+ return
609
+
610
+
611
+ if args.generate:
612
+ try:
613
+ codebase.generate_codebase(force=True)
614
+ PrettyOutput.print("\nCodebase generation completed", output_type=OutputType.SUCCESS)
615
+ except Exception as e:
616
+ PrettyOutput.print(f"Error during codebase generation: {str(e)}", output_type=OutputType.ERROR)
617
+
618
+ if args.search:
619
+ results = codebase.search_similar(args.search, args.top_k)
620
+ if not results:
621
+ PrettyOutput.print("No similar files found", output_type=OutputType.WARNING)
622
+ return
623
+
624
+ PrettyOutput.print("\nSearch Results:", output_type=OutputType.INFO)
625
+ for path, score, desc in results:
626
+ PrettyOutput.print("\n" + "="*50, output_type=OutputType.INFO)
627
+ PrettyOutput.print(f"File: {path}", output_type=OutputType.INFO)
628
+ PrettyOutput.print(f"Similarity: {score:.3f}", output_type=OutputType.INFO)
629
+ PrettyOutput.print(f"Description: {desc[100:]}", output_type=OutputType.INFO)
630
+
631
+ if args.ask:
632
+ codebase.ask_codebase(args.ask, args.top_k)
633
+
634
+
635
+ if __name__ == "__main__":
636
+ exit(main())
File without changes