jarvis-ai-assistant 0.1.102__py3-none-any.whl → 0.1.104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (55) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/agent.py +138 -117
  3. jarvis/jarvis_code_agent/code_agent.py +234 -0
  4. jarvis/{jarvis_coder → jarvis_code_agent}/file_select.py +19 -22
  5. jarvis/jarvis_code_agent/patch.py +120 -0
  6. jarvis/jarvis_code_agent/relevant_files.py +97 -0
  7. jarvis/jarvis_codebase/main.py +871 -0
  8. jarvis/jarvis_platform/main.py +5 -3
  9. jarvis/jarvis_rag/main.py +818 -0
  10. jarvis/jarvis_smart_shell/main.py +2 -2
  11. jarvis/models/ai8.py +3 -1
  12. jarvis/models/kimi.py +36 -30
  13. jarvis/models/ollama.py +17 -11
  14. jarvis/models/openai.py +15 -12
  15. jarvis/models/oyi.py +24 -7
  16. jarvis/models/registry.py +1 -25
  17. jarvis/tools/__init__.py +0 -6
  18. jarvis/tools/ask_codebase.py +96 -0
  19. jarvis/tools/ask_user.py +1 -9
  20. jarvis/tools/chdir.py +2 -37
  21. jarvis/tools/code_review.py +210 -0
  22. jarvis/tools/create_code_test_agent.py +115 -0
  23. jarvis/tools/create_ctags_agent.py +164 -0
  24. jarvis/tools/create_sub_agent.py +2 -2
  25. jarvis/tools/execute_shell.py +2 -2
  26. jarvis/tools/file_operation.py +2 -2
  27. jarvis/tools/find_in_codebase.py +78 -0
  28. jarvis/tools/git_commiter.py +68 -0
  29. jarvis/tools/methodology.py +3 -3
  30. jarvis/tools/rag.py +141 -0
  31. jarvis/tools/read_code.py +116 -0
  32. jarvis/tools/read_webpage.py +1 -1
  33. jarvis/tools/registry.py +47 -31
  34. jarvis/tools/search.py +8 -6
  35. jarvis/tools/select_code_files.py +4 -4
  36. jarvis/utils.py +375 -85
  37. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/METADATA +107 -32
  38. jarvis_ai_assistant-0.1.104.dist-info/RECORD +50 -0
  39. jarvis_ai_assistant-0.1.104.dist-info/entry_points.txt +11 -0
  40. jarvis/jarvis_code_agent/main.py +0 -200
  41. jarvis/jarvis_coder/git_utils.py +0 -123
  42. jarvis/jarvis_coder/patch_handler.py +0 -340
  43. jarvis/jarvis_github/main.py +0 -232
  44. jarvis/tools/create_code_sub_agent.py +0 -56
  45. jarvis/tools/execute_code_modification.py +0 -70
  46. jarvis/tools/find_files.py +0 -119
  47. jarvis/tools/generate_tool.py +0 -174
  48. jarvis/tools/thinker.py +0 -151
  49. jarvis_ai_assistant-0.1.102.dist-info/RECORD +0 -46
  50. jarvis_ai_assistant-0.1.102.dist-info/entry_points.txt +0 -6
  51. /jarvis/{jarvis_coder → jarvis_codebase}/__init__.py +0 -0
  52. /jarvis/{jarvis_github → jarvis_rag}/__init__.py +0 -0
  53. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/LICENSE +0 -0
  54. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/WHEEL +0 -0
  55. {jarvis_ai_assistant-0.1.102.dist-info → jarvis_ai_assistant-0.1.104.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,818 @@
1
+ import os
2
+ import numpy as np
3
+ import faiss
4
+ from typing import List, Tuple, Optional, Dict
5
+ import pickle
6
+ from jarvis.utils import OutputType, PrettyOutput, get_file_md5, get_max_context_length, load_embedding_model, load_rerank_model
7
+ from jarvis.utils import init_env
8
+ from dataclasses import dataclass
9
+ from tqdm import tqdm
10
+ import fitz # PyMuPDF for PDF files
11
+ from docx import Document as DocxDocument # python-docx for DOCX files
12
+ from pathlib import Path
13
+ from jarvis.models.registry import PlatformRegistry
14
+ import shutil
15
+ from datetime import datetime
16
+ import lzma # 添加 lzma 导入
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from threading import Lock
19
+
20
+ @dataclass
21
+ class Document:
22
+ """Document class, for storing document content and metadata"""
23
+ content: str # Document content
24
+ metadata: Dict # Metadata (file path, position, etc.)
25
+ md5: str = "" # File MD5 value, for incremental update detection
26
+
27
+ class FileProcessor:
28
+ """Base class for file processor"""
29
+ @staticmethod
30
+ def can_handle(file_path: str) -> bool:
31
+ """Determine if the file can be processed"""
32
+ raise NotImplementedError
33
+
34
+ @staticmethod
35
+ def extract_text(file_path: str) -> str:
36
+ """Extract file text content"""
37
+ raise NotImplementedError
38
+
39
+ class TextFileProcessor(FileProcessor):
40
+ """Text file processor"""
41
+ ENCODINGS = ['utf-8', 'gbk', 'gb2312', 'latin1']
42
+ SAMPLE_SIZE = 8192 # Read the first 8KB to detect encoding
43
+
44
+ @staticmethod
45
+ def can_handle(file_path: str) -> bool:
46
+ """Determine if the file is a text file by trying to decode it"""
47
+ try:
48
+ # Read the first part of the file to detect encoding
49
+ with open(file_path, 'rb') as f:
50
+ sample = f.read(TextFileProcessor.SAMPLE_SIZE)
51
+
52
+ # Check if it contains null bytes (usually represents a binary file)
53
+ if b'\x00' in sample:
54
+ return False
55
+
56
+ # Check if it contains too many non-printable characters (usually represents a binary file)
57
+ non_printable = sum(1 for byte in sample if byte < 32 and byte not in (9, 10, 13)) # tab, newline, carriage return
58
+ if non_printable / len(sample) > 0.3: # If non-printable characters exceed 30%, it is considered a binary file
59
+ return False
60
+
61
+ # Try to decode with different encodings
62
+ for encoding in TextFileProcessor.ENCODINGS:
63
+ try:
64
+ sample.decode(encoding)
65
+ return True
66
+ except UnicodeDecodeError:
67
+ continue
68
+
69
+ return False
70
+
71
+ except Exception:
72
+ return False
73
+
74
+ @staticmethod
75
+ def extract_text(file_path: str) -> str:
76
+ """Extract text content, using the detected correct encoding"""
77
+ detected_encoding = None
78
+ try:
79
+ # First try to detect encoding
80
+ with open(file_path, 'rb') as f:
81
+ raw_data = f.read()
82
+
83
+ # Try different encodings
84
+ for encoding in TextFileProcessor.ENCODINGS:
85
+ try:
86
+ raw_data.decode(encoding)
87
+ detected_encoding = encoding
88
+ break
89
+ except UnicodeDecodeError:
90
+ continue
91
+
92
+ if not detected_encoding:
93
+ raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
94
+
95
+ # Use the detected encoding to read the file
96
+ with open(file_path, 'r', encoding=detected_encoding, errors='replace') as f:
97
+ content = f.read()
98
+
99
+ # Normalize Unicode characters
100
+ import unicodedata
101
+ content = unicodedata.normalize('NFKC', content)
102
+
103
+ return content
104
+
105
+ except Exception as e:
106
+ raise Exception(f"Failed to read file: {str(e)}")
107
+
108
+ class PDFProcessor(FileProcessor):
109
+ """PDF file processor"""
110
+ @staticmethod
111
+ def can_handle(file_path: str) -> bool:
112
+ return Path(file_path).suffix.lower() == '.pdf'
113
+
114
+ @staticmethod
115
+ def extract_text(file_path: str) -> str:
116
+ text_parts = []
117
+ with fitz.open(file_path) as doc: # type: ignore
118
+ for page in doc:
119
+ text_parts.append(page.get_text()) # type: ignore
120
+ return "\n".join(text_parts)
121
+
122
+ class DocxProcessor(FileProcessor):
123
+ """DOCX file processor"""
124
+ @staticmethod
125
+ def can_handle(file_path: str) -> bool:
126
+ return Path(file_path).suffix.lower() == '.docx'
127
+
128
+ @staticmethod
129
+ def extract_text(file_path: str) -> str:
130
+ doc = DocxDocument(file_path)
131
+ return "\n".join([paragraph.text for paragraph in doc.paragraphs])
132
+
133
+ class RAGTool:
134
+ def __init__(self, root_dir: str):
135
+ """Initialize RAG tool
136
+
137
+ Args:
138
+ root_dir: Project root directory
139
+ """
140
+ init_env()
141
+ self.root_dir = root_dir
142
+ os.chdir(self.root_dir)
143
+
144
+ # Initialize configuration
145
+ self.min_paragraph_length = int(os.environ.get("JARVIS_MIN_PARAGRAPH_LENGTH", "50")) # Minimum paragraph length
146
+ self.max_paragraph_length = int(os.environ.get("JARVIS_MAX_PARAGRAPH_LENGTH", "1000")) # Maximum paragraph length
147
+ self.context_window = int(os.environ.get("JARVIS_CONTEXT_WINDOW", "5")) # Context window size, default前后各5个片段
148
+ self.max_context_length = int(get_max_context_length() * 0.8)
149
+
150
+ # Initialize data directory
151
+ self.data_dir = os.path.join(self.root_dir, ".jarvis-rag")
152
+ if not os.path.exists(self.data_dir):
153
+ os.makedirs(self.data_dir)
154
+
155
+ # Initialize embedding model
156
+ try:
157
+ self.embedding_model = load_embedding_model()
158
+ self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
159
+ PrettyOutput.print("Model loaded", output_type=OutputType.SUCCESS)
160
+ except Exception as e:
161
+ PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
162
+ raise
163
+
164
+ # Initialize cache and index
165
+ self.cache_path = os.path.join(self.data_dir, "cache.pkl")
166
+ self.documents: List[Document] = []
167
+ self.index = None # IVF index for search
168
+ self.flat_index = None # Store original vectors
169
+ self.file_md5_cache = {} # Store file MD5 values
170
+
171
+ # Load cache
172
+ self._load_cache()
173
+
174
+ # Register file processors
175
+ self.file_processors = [
176
+ TextFileProcessor(),
177
+ PDFProcessor(),
178
+ DocxProcessor()
179
+ ]
180
+
181
+ # Add thread related configuration
182
+ self.thread_count = int(os.environ.get("JARVIS_THREAD_COUNT", os.cpu_count() or 4))
183
+ self.vector_lock = Lock() # Protect vector list concurrency
184
+
185
+ def _load_cache(self):
186
+ """Load cache data"""
187
+ if os.path.exists(self.cache_path):
188
+ try:
189
+ with lzma.open(self.cache_path, 'rb') as f:
190
+ cache_data = pickle.load(f)
191
+ self.documents = cache_data["documents"]
192
+ vectors = cache_data["vectors"]
193
+ self.file_md5_cache = cache_data.get("file_md5_cache", {}) # 加载MD5缓存
194
+
195
+ # 重建索引
196
+ if vectors is not None:
197
+ self._build_index(vectors)
198
+ PrettyOutput.print(f"Loaded {len(self.documents)} document fragments",
199
+ output_type=OutputType.INFO)
200
+ except Exception as e:
201
+ PrettyOutput.print(f"Failed to load cache: {str(e)}",
202
+ output_type=OutputType.WARNING)
203
+ self.documents = []
204
+ self.index = None
205
+ self.flat_index = None
206
+ self.file_md5_cache = {}
207
+
208
+ def _save_cache(self, vectors: np.ndarray):
209
+ """Optimize cache saving"""
210
+ try:
211
+ cache_data = {
212
+ "version": "1.0",
213
+ "timestamp": datetime.now().isoformat(),
214
+ "documents": self.documents,
215
+ "vectors": vectors.copy() if vectors is not None else None, # Create a copy of the array
216
+ "file_md5_cache": dict(self.file_md5_cache), # Create a copy of the dictionary
217
+ "metadata": {
218
+ "vector_dim": self.vector_dim,
219
+ "total_docs": len(self.documents),
220
+ "model_name": self.embedding_model.__class__.__name__
221
+ }
222
+ }
223
+
224
+ # First serialize the data to a byte stream
225
+ data = pickle.dumps(cache_data, protocol=pickle.HIGHEST_PROTOCOL)
226
+
227
+ # Then use LZMA to compress the byte stream
228
+ with lzma.open(self.cache_path, 'wb') as f:
229
+ f.write(data)
230
+
231
+ # Create a backup
232
+ backup_path = f"{self.cache_path}.backup"
233
+ shutil.copy2(self.cache_path, backup_path)
234
+
235
+ PrettyOutput.print(f"Cache saved: {len(self.documents)} document fragments",
236
+ output_type=OutputType.INFO)
237
+ except Exception as e:
238
+ PrettyOutput.print(f"Failed to save cache: {str(e)}",
239
+ output_type=OutputType.ERROR)
240
+ raise
241
+
242
+ def _build_index(self, vectors: np.ndarray):
243
+ """Build FAISS index"""
244
+ if vectors.shape[0] == 0:
245
+ self.index = None
246
+ self.flat_index = None
247
+ return
248
+
249
+ # Create a flat index to store original vectors, for reconstruction
250
+ self.flat_index = faiss.IndexFlatIP(self.vector_dim)
251
+ self.flat_index.add(vectors) # type: ignore
252
+
253
+ # Create an IVF index for fast search
254
+ nlist = max(4, int(vectors.shape[0] / 1000)) # 每1000个向量一个聚类中心
255
+ quantizer = faiss.IndexFlatIP(self.vector_dim)
256
+ self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
257
+
258
+ # Train and add vectors
259
+ self.index.train(vectors) # type: ignore
260
+ self.index.add(vectors) # type: ignore
261
+ # Set the number of clusters to probe during search
262
+ self.index.nprobe = min(nlist, 10)
263
+
264
+ def _split_text(self, text: str) -> List[str]:
265
+ """Use a more intelligent splitting strategy"""
266
+ # Add overlapping blocks to maintain context consistency
267
+ overlap_size = min(200, self.max_paragraph_length // 4)
268
+
269
+ paragraphs = []
270
+ current_chunk = []
271
+ current_length = 0
272
+
273
+ # First split by sentence
274
+ sentences = []
275
+ current_sentence = []
276
+ sentence_ends = {'。', '!', '?', '…', '.', '!', '?'}
277
+
278
+ for char in text:
279
+ current_sentence.append(char)
280
+ if char in sentence_ends:
281
+ sentence = ''.join(current_sentence)
282
+ if sentence.strip():
283
+ sentences.append(sentence)
284
+ current_sentence = []
285
+
286
+ if current_sentence:
287
+ sentence = ''.join(current_sentence)
288
+ if sentence.strip():
289
+ sentences.append(sentence)
290
+
291
+ # Build overlapping blocks based on sentences
292
+ for sentence in sentences:
293
+ if current_length + len(sentence) > self.max_paragraph_length:
294
+ if current_chunk:
295
+ chunk_text = ' '.join(current_chunk)
296
+ if len(chunk_text) >= self.min_paragraph_length:
297
+ paragraphs.append(chunk_text)
298
+
299
+ # Keep some content as overlap
300
+ overlap_text = ' '.join(current_chunk[-2:]) # Keep the last two sentences
301
+ current_chunk = []
302
+ if overlap_text:
303
+ current_chunk.append(overlap_text)
304
+ current_length = len(overlap_text)
305
+ else:
306
+ current_length = 0
307
+
308
+ current_chunk.append(sentence)
309
+ current_length += len(sentence)
310
+
311
+ # Process the last chunk
312
+ if current_chunk:
313
+ chunk_text = ' '.join(current_chunk)
314
+ if len(chunk_text) >= self.min_paragraph_length:
315
+ paragraphs.append(chunk_text)
316
+
317
+ return paragraphs
318
+
319
+ def _get_embedding(self, text: str) -> np.ndarray:
320
+ """Get the vector representation of the text"""
321
+ embedding = self.embedding_model.encode(text,
322
+ normalize_embeddings=True,
323
+ show_progress_bar=False)
324
+ return np.array(embedding, dtype=np.float32)
325
+
326
+ def _get_embedding_batch(self, texts: List[str]) -> np.ndarray:
327
+ """Get the vector representation of the text batch
328
+
329
+ Args:
330
+ texts: Text list
331
+
332
+ Returns:
333
+ np.ndarray: Vector representation array
334
+ """
335
+ try:
336
+ embeddings = self.embedding_model.encode(texts,
337
+ normalize_embeddings=True,
338
+ show_progress_bar=False,
339
+ batch_size=32) # Use batch processing to improve efficiency
340
+ return np.array(embeddings, dtype=np.float32)
341
+ except Exception as e:
342
+ PrettyOutput.print(f"Failed to get vector representation: {str(e)}",
343
+ output_type=OutputType.ERROR)
344
+ return np.zeros((len(texts), self.vector_dim), dtype=np.float32) # type: ignore
345
+
346
+ def _process_document_batch(self, documents: List[Document]) -> List[np.ndarray]:
347
+ """Process a batch of documents vectorization
348
+
349
+ Args:
350
+ documents: Document list
351
+
352
+ Returns:
353
+ List[np.ndarray]: Vector list
354
+ """
355
+ texts = []
356
+ for doc in documents:
357
+ # Combine document information
358
+ combined_text = f"""
359
+ File: {doc.metadata['file_path']}
360
+ Content: {doc.content}
361
+ """
362
+ texts.append(combined_text)
363
+
364
+ return self._get_embedding_batch(texts) # type: ignore
365
+
366
+ def _process_file(self, file_path: str) -> List[Document]:
367
+ """Process a single file"""
368
+ try:
369
+ # Calculate file MD5
370
+ current_md5 = get_file_md5(file_path)
371
+ if not current_md5:
372
+ return []
373
+
374
+ # Check if the file needs to be reprocessed
375
+ if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
376
+ return []
377
+
378
+ # Find the appropriate processor
379
+ processor = None
380
+ for p in self.file_processors:
381
+ if p.can_handle(file_path):
382
+ processor = p
383
+ break
384
+
385
+ if not processor:
386
+ # If no appropriate processor is found, return an empty document
387
+ return []
388
+
389
+ # Extract text content
390
+ content = processor.extract_text(file_path)
391
+ if not content.strip():
392
+ return []
393
+
394
+ # Split text
395
+ chunks = self._split_text(content)
396
+
397
+ # Create document objects
398
+ documents = []
399
+ for i, chunk in enumerate(chunks):
400
+ doc = Document(
401
+ content=chunk,
402
+ metadata={
403
+ "file_path": file_path,
404
+ "file_type": Path(file_path).suffix.lower(),
405
+ "chunk_index": i,
406
+ "total_chunks": len(chunks)
407
+ },
408
+ md5=current_md5
409
+ )
410
+ documents.append(doc)
411
+
412
+ # Update MD5 cache
413
+ self.file_md5_cache[file_path] = current_md5
414
+ return documents
415
+
416
+ except Exception as e:
417
+ PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
418
+ output_type=OutputType.ERROR)
419
+ return []
420
+
421
+ def build_index(self, dir: str):
422
+ """Build document index"""
423
+ # Get all files
424
+ all_files = []
425
+ for root, _, files in os.walk(dir):
426
+ if any(ignored in root for ignored in ['.git', '__pycache__', 'node_modules']) or \
427
+ any(part.startswith('.jarvis-') for part in root.split(os.sep)):
428
+ continue
429
+ for file in files:
430
+ if file.startswith('.jarvis-'):
431
+ continue
432
+
433
+ file_path = os.path.join(root, file)
434
+ if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
435
+ PrettyOutput.print(f"Skip large file: {file_path}",
436
+ output_type=OutputType.WARNING)
437
+ continue
438
+ all_files.append(file_path)
439
+
440
+ # Clean up cache for deleted files
441
+ deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
442
+ for file_path in deleted_files:
443
+ del self.file_md5_cache[file_path]
444
+ # Remove related documents
445
+ self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
446
+
447
+ # Check file changes
448
+ files_to_process = []
449
+ unchanged_files = []
450
+
451
+ with tqdm(total=len(all_files), desc="Check file status") as pbar:
452
+ for file_path in all_files:
453
+ current_md5 = get_file_md5(file_path)
454
+ if current_md5: # Only process files that can successfully calculate MD5
455
+ if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
456
+ # File未变化,记录但不重新处理
457
+ unchanged_files.append(file_path)
458
+ else:
459
+ # New file or modified file
460
+ files_to_process.append(file_path)
461
+ pbar.update(1)
462
+
463
+ # Keep documents for unchanged files
464
+ unchanged_documents = [doc for doc in self.documents
465
+ if doc.metadata['file_path'] in unchanged_files]
466
+
467
+ # Process new files and modified files
468
+ new_documents = []
469
+ if files_to_process:
470
+ with tqdm(total=len(files_to_process), desc="Process files") as pbar:
471
+ for file_path in files_to_process:
472
+ try:
473
+ docs = self._process_file(file_path)
474
+ if len(docs) > 0:
475
+ new_documents.extend(docs)
476
+ except Exception as e:
477
+ PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
478
+ output_type=OutputType.ERROR)
479
+ pbar.update(1)
480
+
481
+ # Update document list
482
+ self.documents = unchanged_documents + new_documents
483
+
484
+ if not self.documents:
485
+ PrettyOutput.print("No documents to process", output_type=OutputType.WARNING)
486
+ return
487
+
488
+ # Only vectorize new documents
489
+ if new_documents:
490
+ PrettyOutput.print(f"Start processing {len(new_documents)} new documents",
491
+ output_type=OutputType.INFO)
492
+
493
+ # Use thread pool to process vectorization
494
+ batch_size = 32
495
+ new_vectors = []
496
+
497
+ with tqdm(total=len(new_documents), desc="Generating vectors") as pbar:
498
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
499
+ for i in range(0, len(new_documents), batch_size):
500
+ batch = new_documents[i:i + batch_size]
501
+ future = executor.submit(self._process_document_batch, batch)
502
+ batch_vectors = future.result()
503
+
504
+ with self.vector_lock:
505
+ new_vectors.extend(batch_vectors)
506
+
507
+ pbar.update(len(batch))
508
+
509
+ # Merge new and old vectors
510
+ if self.flat_index is not None:
511
+ # Get vectors for unchanged documents
512
+ unchanged_vectors = []
513
+ for doc in unchanged_documents:
514
+ # Get vectors from existing index
515
+ doc_idx = next((i for i, d in enumerate(self.documents)
516
+ if d.metadata['file_path'] == doc.metadata['file_path']), None)
517
+ if doc_idx is not None:
518
+ # Reconstruct vectors from flat index
519
+ vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
520
+ self.flat_index.reconstruct(doc_idx, vector.ravel())
521
+ unchanged_vectors.append(vector)
522
+
523
+ if unchanged_vectors:
524
+ unchanged_vectors = np.vstack(unchanged_vectors)
525
+ vectors = np.vstack([unchanged_vectors, np.vstack(new_vectors)])
526
+ else:
527
+ vectors = np.vstack(new_vectors)
528
+ else:
529
+ vectors = np.vstack(new_vectors)
530
+
531
+ # Build index
532
+ self._build_index(vectors)
533
+ # Save cache
534
+ self._save_cache(vectors)
535
+
536
+ PrettyOutput.print(f"Successfully indexed {len(self.documents)} document fragments (Added/Modified: {len(new_documents)}, Unchanged: {len(unchanged_documents)})",
537
+ output_type=OutputType.SUCCESS)
538
+
539
+ def search(self, query: str, top_k: int = 30) -> List[Tuple[Document, float]]:
540
+ """Optimize search strategy"""
541
+ if not self.index:
542
+ PrettyOutput.print("Index not built, building...", output_type=OutputType.INFO)
543
+ self.build_index(self.root_dir)
544
+
545
+ # Implement MMR (Maximal Marginal Relevance) to increase result diversity
546
+ def mmr(query_vec, doc_vecs, doc_ids, lambda_param=0.5, n_docs=top_k):
547
+ selected = []
548
+ selected_ids = []
549
+
550
+ while len(selected) < n_docs and len(doc_ids) > 0:
551
+ best_score = -1
552
+ best_idx = -1
553
+
554
+ for i, (doc_vec, doc_id) in enumerate(zip(doc_vecs, doc_ids)):
555
+ # Calculate similarity with query
556
+ query_sim = float(np.dot(query_vec, doc_vec))
557
+
558
+ # Calculate maximum similarity with selected documents
559
+ if selected:
560
+ doc_sims = [float(np.dot(doc_vec, selected_doc)) for selected_doc in selected]
561
+ max_doc_sim = max(doc_sims)
562
+ else:
563
+ max_doc_sim = 0
564
+
565
+ # MMR score
566
+ score = lambda_param * query_sim - (1 - lambda_param) * max_doc_sim
567
+
568
+ if score > best_score:
569
+ best_score = score
570
+ best_idx = i
571
+
572
+ if best_idx == -1:
573
+ break
574
+
575
+ selected.append(doc_vecs[best_idx])
576
+ selected_ids.append(doc_ids[best_idx])
577
+ doc_vecs = np.delete(doc_vecs, best_idx, axis=0)
578
+ doc_ids = np.delete(doc_ids, best_idx)
579
+
580
+ return selected_ids
581
+
582
+ # Get query vector
583
+ query_vector = self._get_embedding(query)
584
+ query_vector = query_vector.reshape(1, -1)
585
+
586
+ # Initial search more results for MMR
587
+ initial_k = min(top_k * 2, len(self.documents))
588
+ distances, indices = self.index.search(query_vector, initial_k) # type: ignore
589
+
590
+ # Get valid results
591
+ valid_indices = indices[0][indices[0] != -1]
592
+ valid_vectors = np.vstack([self._get_embedding(self.documents[idx].content) for idx in valid_indices])
593
+
594
+ # Apply MMR
595
+ final_indices = mmr(query_vector[0], valid_vectors, valid_indices, n_docs=top_k)
596
+
597
+ # Build results
598
+ results = []
599
+ for idx in final_indices:
600
+ doc = self.documents[idx]
601
+ similarity = 1.0 / (1.0 + float(distances[0][np.where(indices[0] == idx)[0][0]]))
602
+ results.append((doc, similarity))
603
+
604
+ return results
605
+
606
+ def _rerank_results(self, query: str, initial_results: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
607
+ """Use rerank model to rerank search results"""
608
+ try:
609
+ import torch
610
+ model, tokenizer = load_rerank_model()
611
+
612
+ # Prepare data
613
+ pairs = []
614
+ for doc, _ in initial_results:
615
+ # Combine document information
616
+ doc_content = f"""
617
+ File: {doc.metadata['file_path']}
618
+ Content: {doc.content}
619
+ """
620
+ pairs.append([query, doc_content])
621
+
622
+ # Score each document pair
623
+ scores = []
624
+ batch_size = 8
625
+
626
+ with torch.no_grad():
627
+ for i in range(0, len(pairs), batch_size):
628
+ batch_pairs = pairs[i:i + batch_size]
629
+ encoded = tokenizer(
630
+ batch_pairs,
631
+ padding=True,
632
+ truncation=True,
633
+ max_length=512,
634
+ return_tensors='pt'
635
+ )
636
+
637
+ if torch.cuda.is_available():
638
+ encoded = {k: v.cuda() for k, v in encoded.items()}
639
+
640
+ outputs = model(**encoded)
641
+ batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
642
+ scores.extend(batch_scores.tolist())
643
+
644
+ # Normalize scores to 0-1 range
645
+ if scores:
646
+ min_score = min(scores)
647
+ max_score = max(scores)
648
+ if max_score > min_score:
649
+ scores = [(s - min_score) / (max_score - min_score) for s in scores]
650
+
651
+ # Combine scores with documents and sort
652
+ scored_results = []
653
+ for (doc, _), score in zip(initial_results, scores):
654
+ if score >= 0.5: # Only keep results with a score greater than 0.5
655
+ scored_results.append((doc, float(score)))
656
+
657
+ # Sort by score in descending order
658
+ scored_results.sort(key=lambda x: x[1], reverse=True)
659
+
660
+ return scored_results
661
+
662
+ except Exception as e:
663
+ PrettyOutput.print(f"Failed to rerank, using original sorting: {str(e)}", output_type=OutputType.WARNING)
664
+ return initial_results
665
+
666
+ def is_index_built(self):
667
+ """Check if index is built"""
668
+ return self.index is not None
669
+
670
+ def query(self, query: str) -> List[Document]:
671
+ """Query related documents
672
+
673
+ Args:
674
+ query: Query text
675
+
676
+ Returns:
677
+ List[Document]: Related documents, including context
678
+ """
679
+ results = self.search(query)
680
+ return [doc for doc, _ in results]
681
+
682
+ def ask(self, question: str) -> Optional[str]:
683
+ """Ask about documents
684
+
685
+ Args:
686
+ question: User question
687
+
688
+ Returns:
689
+ Model answer, return None if failed
690
+ """
691
+ try:
692
+ # Search related document fragments
693
+ results = self.query(question)
694
+ if not results:
695
+ return None
696
+
697
+ # Display found document fragments
698
+ for doc in results:
699
+ output = f"""File: {doc.metadata['file_path']}\n"""
700
+ output += f"""Fragment {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}\n"""
701
+ output += f"""Content:\n{doc.content}\n"""
702
+ PrettyOutput.print(output, output_type=OutputType.INFO, lang="markdown")
703
+
704
+ # Build base prompt
705
+ base_prompt = f"""Please answer the user's question based on the following document fragments. If the document content is not sufficient to answer the question completely, please clearly indicate.
706
+
707
+ User question: {question}
708
+
709
+ Related document fragments:
710
+ """
711
+ end_prompt = "\nPlease provide an accurate and concise answer. If the document content is not sufficient to answer the question completely, please clearly indicate."
712
+
713
+ # Calculate the maximum length that can be used for document content
714
+ # Leave some space for the model's answer
715
+ available_length = self.max_context_length - len(base_prompt) - len(end_prompt) - 500
716
+
717
+ # Build context, while controlling the total length
718
+ context = []
719
+ current_length = 0
720
+
721
+ for doc in results:
722
+ # Calculate the length of this document fragment's content
723
+ doc_content = f"""
724
+ Source file: {doc.metadata['file_path']}
725
+ Content:
726
+ {doc.content}
727
+ ---
728
+ """
729
+ content_length = len(doc_content)
730
+
731
+ # If adding this fragment would exceed the limit, stop adding
732
+ if current_length + content_length > available_length:
733
+ PrettyOutput.print("Due to context length limit, some related document fragments were omitted",
734
+ output_type=OutputType.WARNING)
735
+ break
736
+
737
+ context.append(doc_content)
738
+ current_length += content_length
739
+
740
+ # Build complete prompt
741
+ prompt = base_prompt + ''.join(context) + end_prompt
742
+
743
+ # Get model instance and generate answer
744
+ model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
745
+ response = model.chat_until_success(prompt)
746
+
747
+ return response
748
+
749
+ except Exception as e:
750
+ PrettyOutput.print(f"Failed to answer: {str(e)}", output_type=OutputType.ERROR)
751
+ return None
752
+
753
+ def main():
754
+ """Main function"""
755
+ import argparse
756
+ import sys
757
+
758
+ # Set standard output encoding to UTF-8
759
+ if sys.stdout.encoding != 'utf-8':
760
+ import codecs
761
+ sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
762
+ sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')
763
+
764
+ parser = argparse.ArgumentParser(description='Document retrieval and analysis tool')
765
+ parser.add_argument('--dir', type=str, help='Directory to process')
766
+ parser.add_argument('--build', action='store_true', help='Build document index')
767
+ parser.add_argument('--search', type=str, help='Search document content')
768
+ parser.add_argument('--ask', type=str, help='Ask about documents')
769
+ args = parser.parse_args()
770
+
771
+ try:
772
+ current_dir = os.getcwd()
773
+ rag = RAGTool(current_dir)
774
+
775
+ if not args.dir:
776
+ args.dir = current_dir
777
+
778
+ if args.dir and args.build:
779
+ PrettyOutput.print(f"Processing directory: {args.dir}", output_type=OutputType.INFO)
780
+ rag.build_index(args.dir)
781
+ return 0
782
+
783
+ if args.search or args.ask:
784
+
785
+ if args.search:
786
+ results = rag.query(args.search)
787
+ if not results:
788
+ PrettyOutput.print("No related content found", output_type=OutputType.WARNING)
789
+ return 1
790
+
791
+ for doc in results:
792
+ output = f"""File: {doc.metadata['file_path']}\n"""
793
+ output += f"""Fragment {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}\n"""
794
+ output += f"""Content:\n{doc.content}\n"""
795
+ PrettyOutput.print(output, output_type=OutputType.INFO, lang="markdown")
796
+ return 0
797
+
798
+ if args.ask:
799
+ # Call ask method
800
+ response = rag.ask(args.ask)
801
+ if not response:
802
+ PrettyOutput.print("Failed to get answer", output_type=OutputType.WARNING)
803
+ return 1
804
+
805
+ # Display answer
806
+ output = f"""Answer:\n{response}"""
807
+ PrettyOutput.print(output, output_type=OutputType.INFO)
808
+ return 0
809
+
810
+ PrettyOutput.print("Please specify operation parameters. Use -h to view help.", output_type=OutputType.WARNING)
811
+ return 1
812
+
813
+ except Exception as e:
814
+ PrettyOutput.print(f"Failed to execute: {str(e)}", output_type=OutputType.ERROR)
815
+ return 1
816
+
817
+ if __name__ == "__main__":
818
+ main()