jarvis-ai-assistant 0.1.101__py3-none-any.whl → 0.1.102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +3 -24
- jarvis/jarvis_code_agent/main.py +1 -3
- jarvis/jarvis_github/main.py +232 -0
- jarvis/models/ai8.py +2 -3
- jarvis/models/oyi.py +1 -3
- jarvis/tools/registry.py +48 -40
- jarvis/utils.py +9 -124
- {jarvis_ai_assistant-0.1.101.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/METADATA +1 -47
- {jarvis_ai_assistant-0.1.101.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/RECORD +15 -20
- {jarvis_ai_assistant-0.1.101.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/entry_points.txt +0 -3
- jarvis/jarvis_codebase/main.py +0 -875
- jarvis/jarvis_coder/main.py +0 -241
- jarvis/jarvis_coder/plan_generator.py +0 -145
- jarvis/jarvis_rag/__init__.py +0 -0
- jarvis/jarvis_rag/main.py +0 -822
- jarvis/tools/rag.py +0 -138
- /jarvis/{jarvis_codebase → jarvis_github}/__init__.py +0 -0
- {jarvis_ai_assistant-0.1.101.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.101.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.101.dist-info → jarvis_ai_assistant-0.1.102.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/main.py
DELETED
|
@@ -1,822 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import numpy as np
|
|
3
|
-
import faiss
|
|
4
|
-
from typing import List, Tuple, Optional, Dict
|
|
5
|
-
import pickle
|
|
6
|
-
from jarvis.utils import OutputType, PrettyOutput, get_file_md5, get_max_context_length, load_embedding_model, load_rerank_model
|
|
7
|
-
from jarvis.utils import init_env
|
|
8
|
-
from dataclasses import dataclass
|
|
9
|
-
from tqdm import tqdm
|
|
10
|
-
import fitz # PyMuPDF for PDF files
|
|
11
|
-
from docx import Document as DocxDocument # python-docx for DOCX files
|
|
12
|
-
from pathlib import Path
|
|
13
|
-
from jarvis.models.registry import PlatformRegistry
|
|
14
|
-
import shutil
|
|
15
|
-
from datetime import datetime
|
|
16
|
-
import lzma # 添加 lzma 导入
|
|
17
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
18
|
-
from threading import Lock
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class Document:
|
|
22
|
-
"""Document class, for storing document content and metadata"""
|
|
23
|
-
content: str # Document content
|
|
24
|
-
metadata: Dict # Metadata (file path, position, etc.)
|
|
25
|
-
md5: str = "" # File MD5 value, for incremental update detection
|
|
26
|
-
|
|
27
|
-
class FileProcessor:
|
|
28
|
-
"""Base class for file processor"""
|
|
29
|
-
@staticmethod
|
|
30
|
-
def can_handle(file_path: str) -> bool:
|
|
31
|
-
"""Determine if the file can be processed"""
|
|
32
|
-
raise NotImplementedError
|
|
33
|
-
|
|
34
|
-
@staticmethod
|
|
35
|
-
def extract_text(file_path: str) -> str:
|
|
36
|
-
"""Extract file text content"""
|
|
37
|
-
raise NotImplementedError
|
|
38
|
-
|
|
39
|
-
class TextFileProcessor(FileProcessor):
|
|
40
|
-
"""Text file processor"""
|
|
41
|
-
ENCODINGS = ['utf-8', 'gbk', 'gb2312', 'latin1']
|
|
42
|
-
SAMPLE_SIZE = 8192 # Read the first 8KB to detect encoding
|
|
43
|
-
|
|
44
|
-
@staticmethod
|
|
45
|
-
def can_handle(file_path: str) -> bool:
|
|
46
|
-
"""Determine if the file is a text file by trying to decode it"""
|
|
47
|
-
try:
|
|
48
|
-
# Read the first part of the file to detect encoding
|
|
49
|
-
with open(file_path, 'rb') as f:
|
|
50
|
-
sample = f.read(TextFileProcessor.SAMPLE_SIZE)
|
|
51
|
-
|
|
52
|
-
# Check if it contains null bytes (usually represents a binary file)
|
|
53
|
-
if b'\x00' in sample:
|
|
54
|
-
return False
|
|
55
|
-
|
|
56
|
-
# Check if it contains too many non-printable characters (usually represents a binary file)
|
|
57
|
-
non_printable = sum(1 for byte in sample if byte < 32 and byte not in (9, 10, 13)) # tab, newline, carriage return
|
|
58
|
-
if non_printable / len(sample) > 0.3: # If non-printable characters exceed 30%, it is considered a binary file
|
|
59
|
-
return False
|
|
60
|
-
|
|
61
|
-
# Try to decode with different encodings
|
|
62
|
-
for encoding in TextFileProcessor.ENCODINGS:
|
|
63
|
-
try:
|
|
64
|
-
sample.decode(encoding)
|
|
65
|
-
return True
|
|
66
|
-
except UnicodeDecodeError:
|
|
67
|
-
continue
|
|
68
|
-
|
|
69
|
-
return False
|
|
70
|
-
|
|
71
|
-
except Exception:
|
|
72
|
-
return False
|
|
73
|
-
|
|
74
|
-
@staticmethod
|
|
75
|
-
def extract_text(file_path: str) -> str:
|
|
76
|
-
"""Extract text content, using the detected correct encoding"""
|
|
77
|
-
detected_encoding = None
|
|
78
|
-
try:
|
|
79
|
-
# First try to detect encoding
|
|
80
|
-
with open(file_path, 'rb') as f:
|
|
81
|
-
raw_data = f.read()
|
|
82
|
-
|
|
83
|
-
# Try different encodings
|
|
84
|
-
for encoding in TextFileProcessor.ENCODINGS:
|
|
85
|
-
try:
|
|
86
|
-
raw_data.decode(encoding)
|
|
87
|
-
detected_encoding = encoding
|
|
88
|
-
break
|
|
89
|
-
except UnicodeDecodeError:
|
|
90
|
-
continue
|
|
91
|
-
|
|
92
|
-
if not detected_encoding:
|
|
93
|
-
raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
|
|
94
|
-
|
|
95
|
-
# Use the detected encoding to read the file
|
|
96
|
-
with open(file_path, 'r', encoding=detected_encoding, errors='replace') as f:
|
|
97
|
-
content = f.read()
|
|
98
|
-
|
|
99
|
-
# Normalize Unicode characters
|
|
100
|
-
import unicodedata
|
|
101
|
-
content = unicodedata.normalize('NFKC', content)
|
|
102
|
-
|
|
103
|
-
return content
|
|
104
|
-
|
|
105
|
-
except Exception as e:
|
|
106
|
-
raise Exception(f"Failed to read file: {str(e)}")
|
|
107
|
-
|
|
108
|
-
class PDFProcessor(FileProcessor):
|
|
109
|
-
"""PDF file processor"""
|
|
110
|
-
@staticmethod
|
|
111
|
-
def can_handle(file_path: str) -> bool:
|
|
112
|
-
return Path(file_path).suffix.lower() == '.pdf'
|
|
113
|
-
|
|
114
|
-
@staticmethod
|
|
115
|
-
def extract_text(file_path: str) -> str:
|
|
116
|
-
text_parts = []
|
|
117
|
-
with fitz.open(file_path) as doc: # type: ignore
|
|
118
|
-
for page in doc:
|
|
119
|
-
text_parts.append(page.get_text())
|
|
120
|
-
return "\n".join(text_parts)
|
|
121
|
-
|
|
122
|
-
class DocxProcessor(FileProcessor):
|
|
123
|
-
"""DOCX file processor"""
|
|
124
|
-
@staticmethod
|
|
125
|
-
def can_handle(file_path: str) -> bool:
|
|
126
|
-
return Path(file_path).suffix.lower() == '.docx'
|
|
127
|
-
|
|
128
|
-
@staticmethod
|
|
129
|
-
def extract_text(file_path: str) -> str:
|
|
130
|
-
doc = DocxDocument(file_path)
|
|
131
|
-
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
|
132
|
-
|
|
133
|
-
class RAGTool:
|
|
134
|
-
def __init__(self, root_dir: str):
|
|
135
|
-
"""Initialize RAG tool
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
root_dir: Project root directory
|
|
139
|
-
"""
|
|
140
|
-
init_env()
|
|
141
|
-
self.root_dir = root_dir
|
|
142
|
-
os.chdir(self.root_dir)
|
|
143
|
-
|
|
144
|
-
# Initialize configuration
|
|
145
|
-
self.min_paragraph_length = int(os.environ.get("JARVIS_MIN_PARAGRAPH_LENGTH", "50")) # Minimum paragraph length
|
|
146
|
-
self.max_paragraph_length = int(os.environ.get("JARVIS_MAX_PARAGRAPH_LENGTH", "1000")) # Maximum paragraph length
|
|
147
|
-
self.context_window = int(os.environ.get("JARVIS_CONTEXT_WINDOW", "5")) # Context window size, default前后各5个片段
|
|
148
|
-
self.max_context_length = int(get_max_context_length() * 0.8)
|
|
149
|
-
|
|
150
|
-
# Initialize data directory
|
|
151
|
-
self.data_dir = os.path.join(self.root_dir, ".jarvis-rag")
|
|
152
|
-
if not os.path.exists(self.data_dir):
|
|
153
|
-
os.makedirs(self.data_dir)
|
|
154
|
-
|
|
155
|
-
# Initialize embedding model
|
|
156
|
-
try:
|
|
157
|
-
self.embedding_model = load_embedding_model()
|
|
158
|
-
self.vector_dim = self.embedding_model.get_sentence_embedding_dimension()
|
|
159
|
-
PrettyOutput.print("Model loaded", output_type=OutputType.SUCCESS)
|
|
160
|
-
except Exception as e:
|
|
161
|
-
PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
|
|
162
|
-
raise
|
|
163
|
-
|
|
164
|
-
# Initialize cache and index
|
|
165
|
-
self.cache_path = os.path.join(self.data_dir, "cache.pkl")
|
|
166
|
-
self.documents: List[Document] = []
|
|
167
|
-
self.index = None # IVF index for search
|
|
168
|
-
self.flat_index = None # Store original vectors
|
|
169
|
-
self.file_md5_cache = {} # Store file MD5 values
|
|
170
|
-
|
|
171
|
-
# Load cache
|
|
172
|
-
self._load_cache()
|
|
173
|
-
|
|
174
|
-
# Register file processors
|
|
175
|
-
self.file_processors = [
|
|
176
|
-
TextFileProcessor(),
|
|
177
|
-
PDFProcessor(),
|
|
178
|
-
DocxProcessor()
|
|
179
|
-
]
|
|
180
|
-
|
|
181
|
-
# Add thread related configuration
|
|
182
|
-
self.thread_count = int(os.environ.get("JARVIS_THREAD_COUNT", os.cpu_count() or 4))
|
|
183
|
-
self.vector_lock = Lock() # Protect vector list concurrency
|
|
184
|
-
|
|
185
|
-
def _load_cache(self):
|
|
186
|
-
"""Load cache data"""
|
|
187
|
-
if os.path.exists(self.cache_path):
|
|
188
|
-
try:
|
|
189
|
-
with lzma.open(self.cache_path, 'rb') as f:
|
|
190
|
-
cache_data = pickle.load(f)
|
|
191
|
-
self.documents = cache_data["documents"]
|
|
192
|
-
vectors = cache_data["vectors"]
|
|
193
|
-
self.file_md5_cache = cache_data.get("file_md5_cache", {}) # 加载MD5缓存
|
|
194
|
-
|
|
195
|
-
# 重建索引
|
|
196
|
-
if vectors is not None:
|
|
197
|
-
self._build_index(vectors)
|
|
198
|
-
PrettyOutput.print(f"Loaded {len(self.documents)} document fragments",
|
|
199
|
-
output_type=OutputType.INFO)
|
|
200
|
-
except Exception as e:
|
|
201
|
-
PrettyOutput.print(f"Failed to load cache: {str(e)}",
|
|
202
|
-
output_type=OutputType.WARNING)
|
|
203
|
-
self.documents = []
|
|
204
|
-
self.index = None
|
|
205
|
-
self.flat_index = None
|
|
206
|
-
self.file_md5_cache = {}
|
|
207
|
-
|
|
208
|
-
def _save_cache(self, vectors: np.ndarray):
|
|
209
|
-
"""Optimize cache saving"""
|
|
210
|
-
try:
|
|
211
|
-
cache_data = {
|
|
212
|
-
"version": "1.0",
|
|
213
|
-
"timestamp": datetime.now().isoformat(),
|
|
214
|
-
"documents": self.documents,
|
|
215
|
-
"vectors": vectors.copy() if vectors is not None else None, # Create a copy of the array
|
|
216
|
-
"file_md5_cache": dict(self.file_md5_cache), # Create a copy of the dictionary
|
|
217
|
-
"metadata": {
|
|
218
|
-
"vector_dim": self.vector_dim,
|
|
219
|
-
"total_docs": len(self.documents),
|
|
220
|
-
"model_name": self.embedding_model.__class__.__name__
|
|
221
|
-
}
|
|
222
|
-
}
|
|
223
|
-
|
|
224
|
-
# First serialize the data to a byte stream
|
|
225
|
-
data = pickle.dumps(cache_data, protocol=pickle.HIGHEST_PROTOCOL)
|
|
226
|
-
|
|
227
|
-
# Then use LZMA to compress the byte stream
|
|
228
|
-
with lzma.open(self.cache_path, 'wb') as f:
|
|
229
|
-
f.write(data)
|
|
230
|
-
|
|
231
|
-
# Create a backup
|
|
232
|
-
backup_path = f"{self.cache_path}.backup"
|
|
233
|
-
shutil.copy2(self.cache_path, backup_path)
|
|
234
|
-
|
|
235
|
-
PrettyOutput.print(f"Cache saved: {len(self.documents)} document fragments",
|
|
236
|
-
output_type=OutputType.INFO)
|
|
237
|
-
except Exception as e:
|
|
238
|
-
PrettyOutput.print(f"Failed to save cache: {str(e)}",
|
|
239
|
-
output_type=OutputType.ERROR)
|
|
240
|
-
raise
|
|
241
|
-
|
|
242
|
-
def _build_index(self, vectors: np.ndarray):
|
|
243
|
-
"""Build FAISS index"""
|
|
244
|
-
if vectors.shape[0] == 0:
|
|
245
|
-
self.index = None
|
|
246
|
-
self.flat_index = None
|
|
247
|
-
return
|
|
248
|
-
|
|
249
|
-
# Create a flat index to store original vectors, for reconstruction
|
|
250
|
-
self.flat_index = faiss.IndexFlatIP(self.vector_dim)
|
|
251
|
-
self.flat_index.add(vectors) # type: ignore
|
|
252
|
-
|
|
253
|
-
# Create an IVF index for fast search
|
|
254
|
-
nlist = max(4, int(vectors.shape[0] / 1000)) # 每1000个向量一个聚类中心
|
|
255
|
-
quantizer = faiss.IndexFlatIP(self.vector_dim)
|
|
256
|
-
self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
|
|
257
|
-
|
|
258
|
-
# Train and add vectors
|
|
259
|
-
self.index.train(vectors) # type: ignore
|
|
260
|
-
self.index.add(vectors) # type: ignore
|
|
261
|
-
# Set the number of clusters to probe during search
|
|
262
|
-
self.index.nprobe = min(nlist, 10)
|
|
263
|
-
|
|
264
|
-
def _split_text(self, text: str) -> List[str]:
|
|
265
|
-
"""Use a more intelligent splitting strategy"""
|
|
266
|
-
# Add overlapping blocks to maintain context consistency
|
|
267
|
-
overlap_size = min(200, self.max_paragraph_length // 4)
|
|
268
|
-
|
|
269
|
-
paragraphs = []
|
|
270
|
-
current_chunk = []
|
|
271
|
-
current_length = 0
|
|
272
|
-
|
|
273
|
-
# First split by sentence
|
|
274
|
-
sentences = []
|
|
275
|
-
current_sentence = []
|
|
276
|
-
sentence_ends = {'。', '!', '?', '…', '.', '!', '?'}
|
|
277
|
-
|
|
278
|
-
for char in text:
|
|
279
|
-
current_sentence.append(char)
|
|
280
|
-
if char in sentence_ends:
|
|
281
|
-
sentence = ''.join(current_sentence)
|
|
282
|
-
if sentence.strip():
|
|
283
|
-
sentences.append(sentence)
|
|
284
|
-
current_sentence = []
|
|
285
|
-
|
|
286
|
-
if current_sentence:
|
|
287
|
-
sentence = ''.join(current_sentence)
|
|
288
|
-
if sentence.strip():
|
|
289
|
-
sentences.append(sentence)
|
|
290
|
-
|
|
291
|
-
# Build overlapping blocks based on sentences
|
|
292
|
-
for sentence in sentences:
|
|
293
|
-
if current_length + len(sentence) > self.max_paragraph_length:
|
|
294
|
-
if current_chunk:
|
|
295
|
-
chunk_text = ' '.join(current_chunk)
|
|
296
|
-
if len(chunk_text) >= self.min_paragraph_length:
|
|
297
|
-
paragraphs.append(chunk_text)
|
|
298
|
-
|
|
299
|
-
# Keep some content as overlap
|
|
300
|
-
overlap_text = ' '.join(current_chunk[-2:]) # Keep the last two sentences
|
|
301
|
-
current_chunk = []
|
|
302
|
-
if overlap_text:
|
|
303
|
-
current_chunk.append(overlap_text)
|
|
304
|
-
current_length = len(overlap_text)
|
|
305
|
-
else:
|
|
306
|
-
current_length = 0
|
|
307
|
-
|
|
308
|
-
current_chunk.append(sentence)
|
|
309
|
-
current_length += len(sentence)
|
|
310
|
-
|
|
311
|
-
# Process the last chunk
|
|
312
|
-
if current_chunk:
|
|
313
|
-
chunk_text = ' '.join(current_chunk)
|
|
314
|
-
if len(chunk_text) >= self.min_paragraph_length:
|
|
315
|
-
paragraphs.append(chunk_text)
|
|
316
|
-
|
|
317
|
-
return paragraphs
|
|
318
|
-
|
|
319
|
-
def _get_embedding(self, text: str) -> np.ndarray:
|
|
320
|
-
"""Get the vector representation of the text"""
|
|
321
|
-
embedding = self.embedding_model.encode(text,
|
|
322
|
-
normalize_embeddings=True,
|
|
323
|
-
show_progress_bar=False)
|
|
324
|
-
return np.array(embedding, dtype=np.float32)
|
|
325
|
-
|
|
326
|
-
def _get_embedding_batch(self, texts: List[str]) -> np.ndarray:
|
|
327
|
-
"""Get the vector representation of the text batch
|
|
328
|
-
|
|
329
|
-
Args:
|
|
330
|
-
texts: Text list
|
|
331
|
-
|
|
332
|
-
Returns:
|
|
333
|
-
np.ndarray: Vector representation array
|
|
334
|
-
"""
|
|
335
|
-
try:
|
|
336
|
-
embeddings = self.embedding_model.encode(texts,
|
|
337
|
-
normalize_embeddings=True,
|
|
338
|
-
show_progress_bar=False,
|
|
339
|
-
batch_size=32) # Use batch processing to improve efficiency
|
|
340
|
-
return np.array(embeddings, dtype=np.float32)
|
|
341
|
-
except Exception as e:
|
|
342
|
-
PrettyOutput.print(f"Failed to get vector representation: {str(e)}",
|
|
343
|
-
output_type=OutputType.ERROR)
|
|
344
|
-
return np.zeros((len(texts), self.vector_dim), dtype=np.float32) # type: ignore
|
|
345
|
-
|
|
346
|
-
def _process_document_batch(self, documents: List[Document]) -> List[np.ndarray]:
|
|
347
|
-
"""Process a batch of documents vectorization
|
|
348
|
-
|
|
349
|
-
Args:
|
|
350
|
-
documents: Document list
|
|
351
|
-
|
|
352
|
-
Returns:
|
|
353
|
-
List[np.ndarray]: Vector list
|
|
354
|
-
"""
|
|
355
|
-
texts = []
|
|
356
|
-
for doc in documents:
|
|
357
|
-
# Combine document information
|
|
358
|
-
combined_text = f"""
|
|
359
|
-
File: {doc.metadata['file_path']}
|
|
360
|
-
Content: {doc.content}
|
|
361
|
-
"""
|
|
362
|
-
texts.append(combined_text)
|
|
363
|
-
|
|
364
|
-
return self._get_embedding_batch(texts) # type: ignore
|
|
365
|
-
|
|
366
|
-
def _process_file(self, file_path: str) -> List[Document]:
|
|
367
|
-
"""Process a single file"""
|
|
368
|
-
try:
|
|
369
|
-
# Calculate file MD5
|
|
370
|
-
current_md5 = get_file_md5(file_path)
|
|
371
|
-
if not current_md5:
|
|
372
|
-
return []
|
|
373
|
-
|
|
374
|
-
# Check if the file needs to be reprocessed
|
|
375
|
-
if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
|
|
376
|
-
return []
|
|
377
|
-
|
|
378
|
-
# Find the appropriate processor
|
|
379
|
-
processor = None
|
|
380
|
-
for p in self.file_processors:
|
|
381
|
-
if p.can_handle(file_path):
|
|
382
|
-
processor = p
|
|
383
|
-
break
|
|
384
|
-
|
|
385
|
-
if not processor:
|
|
386
|
-
# If no appropriate processor is found, return an empty document
|
|
387
|
-
return []
|
|
388
|
-
|
|
389
|
-
# Extract text content
|
|
390
|
-
content = processor.extract_text(file_path)
|
|
391
|
-
if not content.strip():
|
|
392
|
-
return []
|
|
393
|
-
|
|
394
|
-
# Split text
|
|
395
|
-
chunks = self._split_text(content)
|
|
396
|
-
|
|
397
|
-
# Create document objects
|
|
398
|
-
documents = []
|
|
399
|
-
for i, chunk in enumerate(chunks):
|
|
400
|
-
doc = Document(
|
|
401
|
-
content=chunk,
|
|
402
|
-
metadata={
|
|
403
|
-
"file_path": file_path,
|
|
404
|
-
"file_type": Path(file_path).suffix.lower(),
|
|
405
|
-
"chunk_index": i,
|
|
406
|
-
"total_chunks": len(chunks)
|
|
407
|
-
},
|
|
408
|
-
md5=current_md5
|
|
409
|
-
)
|
|
410
|
-
documents.append(doc)
|
|
411
|
-
|
|
412
|
-
# Update MD5 cache
|
|
413
|
-
self.file_md5_cache[file_path] = current_md5
|
|
414
|
-
return documents
|
|
415
|
-
|
|
416
|
-
except Exception as e:
|
|
417
|
-
PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
|
|
418
|
-
output_type=OutputType.ERROR)
|
|
419
|
-
return []
|
|
420
|
-
|
|
421
|
-
def build_index(self, dir: str):
|
|
422
|
-
"""Build document index"""
|
|
423
|
-
# Get all files
|
|
424
|
-
all_files = []
|
|
425
|
-
for root, _, files in os.walk(dir):
|
|
426
|
-
if any(ignored in root for ignored in ['.git', '__pycache__', 'node_modules']) or \
|
|
427
|
-
any(part.startswith('.jarvis-') for part in root.split(os.sep)):
|
|
428
|
-
continue
|
|
429
|
-
for file in files:
|
|
430
|
-
if file.startswith('.jarvis-'):
|
|
431
|
-
continue
|
|
432
|
-
|
|
433
|
-
file_path = os.path.join(root, file)
|
|
434
|
-
if os.path.getsize(file_path) > 100 * 1024 * 1024: # 100MB
|
|
435
|
-
PrettyOutput.print(f"Skip large file: {file_path}",
|
|
436
|
-
output_type=OutputType.WARNING)
|
|
437
|
-
continue
|
|
438
|
-
all_files.append(file_path)
|
|
439
|
-
|
|
440
|
-
# Clean up cache for deleted files
|
|
441
|
-
deleted_files = set(self.file_md5_cache.keys()) - set(all_files)
|
|
442
|
-
for file_path in deleted_files:
|
|
443
|
-
del self.file_md5_cache[file_path]
|
|
444
|
-
# Remove related documents
|
|
445
|
-
self.documents = [doc for doc in self.documents if doc.metadata['file_path'] != file_path]
|
|
446
|
-
|
|
447
|
-
# Check file changes
|
|
448
|
-
files_to_process = []
|
|
449
|
-
unchanged_files = []
|
|
450
|
-
|
|
451
|
-
with tqdm(total=len(all_files), desc="Check file status") as pbar:
|
|
452
|
-
for file_path in all_files:
|
|
453
|
-
current_md5 = get_file_md5(file_path)
|
|
454
|
-
if current_md5: # Only process files that can successfully calculate MD5
|
|
455
|
-
if file_path in self.file_md5_cache and self.file_md5_cache[file_path] == current_md5:
|
|
456
|
-
# File未变化,记录但不重新处理
|
|
457
|
-
unchanged_files.append(file_path)
|
|
458
|
-
else:
|
|
459
|
-
# New file or modified file
|
|
460
|
-
files_to_process.append(file_path)
|
|
461
|
-
pbar.update(1)
|
|
462
|
-
|
|
463
|
-
# Keep documents for unchanged files
|
|
464
|
-
unchanged_documents = [doc for doc in self.documents
|
|
465
|
-
if doc.metadata['file_path'] in unchanged_files]
|
|
466
|
-
|
|
467
|
-
# Process new files and modified files
|
|
468
|
-
new_documents = []
|
|
469
|
-
if files_to_process:
|
|
470
|
-
with tqdm(total=len(files_to_process), desc="Process files") as pbar:
|
|
471
|
-
for file_path in files_to_process:
|
|
472
|
-
try:
|
|
473
|
-
docs = self._process_file(file_path)
|
|
474
|
-
if len(docs) > 0:
|
|
475
|
-
new_documents.extend(docs)
|
|
476
|
-
except Exception as e:
|
|
477
|
-
PrettyOutput.print(f"Failed to process file {file_path}: {str(e)}",
|
|
478
|
-
output_type=OutputType.ERROR)
|
|
479
|
-
pbar.update(1)
|
|
480
|
-
|
|
481
|
-
# Update document list
|
|
482
|
-
self.documents = unchanged_documents + new_documents
|
|
483
|
-
|
|
484
|
-
if not self.documents:
|
|
485
|
-
PrettyOutput.print("No documents to process", output_type=OutputType.WARNING)
|
|
486
|
-
return
|
|
487
|
-
|
|
488
|
-
# Only vectorize new documents
|
|
489
|
-
if new_documents:
|
|
490
|
-
PrettyOutput.print(f"Start processing {len(new_documents)} new documents",
|
|
491
|
-
output_type=OutputType.INFO)
|
|
492
|
-
|
|
493
|
-
# Use thread pool to process vectorization
|
|
494
|
-
batch_size = 32
|
|
495
|
-
new_vectors = []
|
|
496
|
-
|
|
497
|
-
with tqdm(total=len(new_documents), desc="Generating vectors") as pbar:
|
|
498
|
-
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
|
|
499
|
-
for i in range(0, len(new_documents), batch_size):
|
|
500
|
-
batch = new_documents[i:i + batch_size]
|
|
501
|
-
future = executor.submit(self._process_document_batch, batch)
|
|
502
|
-
batch_vectors = future.result()
|
|
503
|
-
|
|
504
|
-
with self.vector_lock:
|
|
505
|
-
new_vectors.extend(batch_vectors)
|
|
506
|
-
|
|
507
|
-
pbar.update(len(batch))
|
|
508
|
-
|
|
509
|
-
# Merge new and old vectors
|
|
510
|
-
if self.flat_index is not None:
|
|
511
|
-
# Get vectors for unchanged documents
|
|
512
|
-
unchanged_vectors = []
|
|
513
|
-
for doc in unchanged_documents:
|
|
514
|
-
# Get vectors from existing index
|
|
515
|
-
doc_idx = next((i for i, d in enumerate(self.documents)
|
|
516
|
-
if d.metadata['file_path'] == doc.metadata['file_path']), None)
|
|
517
|
-
if doc_idx is not None:
|
|
518
|
-
# Reconstruct vectors from flat index
|
|
519
|
-
vector = np.zeros((1, self.vector_dim), dtype=np.float32) # type: ignore
|
|
520
|
-
self.flat_index.reconstruct(doc_idx, vector.ravel())
|
|
521
|
-
unchanged_vectors.append(vector)
|
|
522
|
-
|
|
523
|
-
if unchanged_vectors:
|
|
524
|
-
unchanged_vectors = np.vstack(unchanged_vectors)
|
|
525
|
-
vectors = np.vstack([unchanged_vectors, np.vstack(new_vectors)])
|
|
526
|
-
else:
|
|
527
|
-
vectors = np.vstack(new_vectors)
|
|
528
|
-
else:
|
|
529
|
-
vectors = np.vstack(new_vectors)
|
|
530
|
-
|
|
531
|
-
# Build index
|
|
532
|
-
self._build_index(vectors)
|
|
533
|
-
# Save cache
|
|
534
|
-
self._save_cache(vectors)
|
|
535
|
-
|
|
536
|
-
PrettyOutput.print(f"Successfully indexed {len(self.documents)} document fragments (Added/Modified: {len(new_documents)}, Unchanged: {len(unchanged_documents)})",
|
|
537
|
-
output_type=OutputType.SUCCESS)
|
|
538
|
-
|
|
539
|
-
def search(self, query: str, top_k: int = 30) -> List[Tuple[Document, float]]:
|
|
540
|
-
"""Optimize search strategy"""
|
|
541
|
-
if not self.index:
|
|
542
|
-
PrettyOutput.print("Index not built, building...", output_type=OutputType.INFO)
|
|
543
|
-
self.build_index(self.root_dir)
|
|
544
|
-
|
|
545
|
-
# Implement MMR (Maximal Marginal Relevance) to increase result diversity
|
|
546
|
-
def mmr(query_vec, doc_vecs, doc_ids, lambda_param=0.5, n_docs=top_k):
|
|
547
|
-
selected = []
|
|
548
|
-
selected_ids = []
|
|
549
|
-
|
|
550
|
-
while len(selected) < n_docs and len(doc_ids) > 0:
|
|
551
|
-
best_score = -1
|
|
552
|
-
best_idx = -1
|
|
553
|
-
|
|
554
|
-
for i, (doc_vec, doc_id) in enumerate(zip(doc_vecs, doc_ids)):
|
|
555
|
-
# Calculate similarity with query
|
|
556
|
-
query_sim = float(np.dot(query_vec, doc_vec))
|
|
557
|
-
|
|
558
|
-
# Calculate maximum similarity with selected documents
|
|
559
|
-
if selected:
|
|
560
|
-
doc_sims = [float(np.dot(doc_vec, selected_doc)) for selected_doc in selected]
|
|
561
|
-
max_doc_sim = max(doc_sims)
|
|
562
|
-
else:
|
|
563
|
-
max_doc_sim = 0
|
|
564
|
-
|
|
565
|
-
# MMR score
|
|
566
|
-
score = lambda_param * query_sim - (1 - lambda_param) * max_doc_sim
|
|
567
|
-
|
|
568
|
-
if score > best_score:
|
|
569
|
-
best_score = score
|
|
570
|
-
best_idx = i
|
|
571
|
-
|
|
572
|
-
if best_idx == -1:
|
|
573
|
-
break
|
|
574
|
-
|
|
575
|
-
selected.append(doc_vecs[best_idx])
|
|
576
|
-
selected_ids.append(doc_ids[best_idx])
|
|
577
|
-
doc_vecs = np.delete(doc_vecs, best_idx, axis=0)
|
|
578
|
-
doc_ids = np.delete(doc_ids, best_idx)
|
|
579
|
-
|
|
580
|
-
return selected_ids
|
|
581
|
-
|
|
582
|
-
# Get query vector
|
|
583
|
-
query_vector = self._get_embedding(query)
|
|
584
|
-
query_vector = query_vector.reshape(1, -1)
|
|
585
|
-
|
|
586
|
-
# Initial search more results for MMR
|
|
587
|
-
initial_k = min(top_k * 2, len(self.documents))
|
|
588
|
-
distances, indices = self.index.search(query_vector, initial_k) # type: ignore
|
|
589
|
-
|
|
590
|
-
# Get valid results
|
|
591
|
-
valid_indices = indices[0][indices[0] != -1]
|
|
592
|
-
valid_vectors = np.vstack([self._get_embedding(self.documents[idx].content) for idx in valid_indices])
|
|
593
|
-
|
|
594
|
-
# Apply MMR
|
|
595
|
-
final_indices = mmr(query_vector[0], valid_vectors, valid_indices, n_docs=top_k)
|
|
596
|
-
|
|
597
|
-
# Build results
|
|
598
|
-
results = []
|
|
599
|
-
for idx in final_indices:
|
|
600
|
-
doc = self.documents[idx]
|
|
601
|
-
similarity = 1.0 / (1.0 + float(distances[0][np.where(indices[0] == idx)[0][0]]))
|
|
602
|
-
results.append((doc, similarity))
|
|
603
|
-
|
|
604
|
-
return results
|
|
605
|
-
|
|
606
|
-
def _rerank_results(self, query: str, initial_results: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
|
|
607
|
-
"""Use rerank model to rerank search results"""
|
|
608
|
-
try:
|
|
609
|
-
import torch
|
|
610
|
-
model, tokenizer = load_rerank_model()
|
|
611
|
-
|
|
612
|
-
# Prepare data
|
|
613
|
-
pairs = []
|
|
614
|
-
for doc, _ in initial_results:
|
|
615
|
-
# Combine document information
|
|
616
|
-
doc_content = f"""
|
|
617
|
-
File: {doc.metadata['file_path']}
|
|
618
|
-
Content: {doc.content}
|
|
619
|
-
"""
|
|
620
|
-
pairs.append([query, doc_content])
|
|
621
|
-
|
|
622
|
-
# Score each document pair
|
|
623
|
-
scores = []
|
|
624
|
-
batch_size = 8
|
|
625
|
-
|
|
626
|
-
with torch.no_grad():
|
|
627
|
-
for i in range(0, len(pairs), batch_size):
|
|
628
|
-
batch_pairs = pairs[i:i + batch_size]
|
|
629
|
-
encoded = tokenizer(
|
|
630
|
-
batch_pairs,
|
|
631
|
-
padding=True,
|
|
632
|
-
truncation=True,
|
|
633
|
-
max_length=512,
|
|
634
|
-
return_tensors='pt'
|
|
635
|
-
)
|
|
636
|
-
|
|
637
|
-
if torch.cuda.is_available():
|
|
638
|
-
encoded = {k: v.cuda() for k, v in encoded.items()}
|
|
639
|
-
|
|
640
|
-
outputs = model(**encoded)
|
|
641
|
-
batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
|
|
642
|
-
scores.extend(batch_scores.tolist())
|
|
643
|
-
|
|
644
|
-
# Normalize scores to 0-1 range
|
|
645
|
-
if scores:
|
|
646
|
-
min_score = min(scores)
|
|
647
|
-
max_score = max(scores)
|
|
648
|
-
if max_score > min_score:
|
|
649
|
-
scores = [(s - min_score) / (max_score - min_score) for s in scores]
|
|
650
|
-
|
|
651
|
-
# Combine scores with documents and sort
|
|
652
|
-
scored_results = []
|
|
653
|
-
for (doc, _), score in zip(initial_results, scores):
|
|
654
|
-
if score >= 0.5: # Only keep results with a score greater than 0.5
|
|
655
|
-
scored_results.append((doc, float(score)))
|
|
656
|
-
|
|
657
|
-
# Sort by score in descending order
|
|
658
|
-
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
659
|
-
|
|
660
|
-
return scored_results
|
|
661
|
-
|
|
662
|
-
except Exception as e:
|
|
663
|
-
PrettyOutput.print(f"Failed to rerank, using original sorting: {str(e)}", output_type=OutputType.WARNING)
|
|
664
|
-
return initial_results
|
|
665
|
-
|
|
666
|
-
def is_index_built(self):
|
|
667
|
-
"""Check if index is built"""
|
|
668
|
-
return self.index is not None
|
|
669
|
-
|
|
670
|
-
def query(self, query: str) -> List[Document]:
|
|
671
|
-
"""Query related documents
|
|
672
|
-
|
|
673
|
-
Args:
|
|
674
|
-
query: Query text
|
|
675
|
-
|
|
676
|
-
Returns:
|
|
677
|
-
List[Document]: Related documents, including context
|
|
678
|
-
"""
|
|
679
|
-
results = self.search(query)
|
|
680
|
-
return [doc for doc, _ in results]
|
|
681
|
-
|
|
682
|
-
def ask(self, question: str) -> Optional[str]:
|
|
683
|
-
"""Ask about documents
|
|
684
|
-
|
|
685
|
-
Args:
|
|
686
|
-
question: User question
|
|
687
|
-
|
|
688
|
-
Returns:
|
|
689
|
-
Model answer, return None if failed
|
|
690
|
-
"""
|
|
691
|
-
try:
|
|
692
|
-
# Search related document fragments
|
|
693
|
-
results = self.query(question)
|
|
694
|
-
if not results:
|
|
695
|
-
return None
|
|
696
|
-
|
|
697
|
-
# Display found document fragments
|
|
698
|
-
for doc in results:
|
|
699
|
-
PrettyOutput.print(f"File: {doc.metadata['file_path']}", output_type=OutputType.INFO)
|
|
700
|
-
PrettyOutput.print(f"Fragment {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}",
|
|
701
|
-
output_type=OutputType.INFO)
|
|
702
|
-
PrettyOutput.print("\nContent:", output_type=OutputType.INFO)
|
|
703
|
-
content = doc.content.encode('utf-8', errors='replace').decode('utf-8')
|
|
704
|
-
PrettyOutput.print(content, output_type=OutputType.INFO)
|
|
705
|
-
|
|
706
|
-
# Build base prompt
|
|
707
|
-
base_prompt = f"""Please answer the user's question based on the following document fragments. If the document content is not sufficient to answer the question completely, please clearly indicate.
|
|
708
|
-
|
|
709
|
-
User question: {question}
|
|
710
|
-
|
|
711
|
-
Related document fragments:
|
|
712
|
-
"""
|
|
713
|
-
end_prompt = "\nPlease provide an accurate and concise answer. If the document content is not sufficient to answer the question completely, please clearly indicate."
|
|
714
|
-
|
|
715
|
-
# Calculate the maximum length that can be used for document content
|
|
716
|
-
# Leave some space for the model's answer
|
|
717
|
-
available_length = self.max_context_length - len(base_prompt) - len(end_prompt) - 500
|
|
718
|
-
|
|
719
|
-
# Build context, while controlling the total length
|
|
720
|
-
context = []
|
|
721
|
-
current_length = 0
|
|
722
|
-
|
|
723
|
-
for doc in results:
|
|
724
|
-
# Calculate the length of this document fragment's content
|
|
725
|
-
doc_content = f"""
|
|
726
|
-
Source file: {doc.metadata['file_path']}
|
|
727
|
-
Content:
|
|
728
|
-
{doc.content}
|
|
729
|
-
---
|
|
730
|
-
"""
|
|
731
|
-
content_length = len(doc_content)
|
|
732
|
-
|
|
733
|
-
# If adding this fragment would exceed the limit, stop adding
|
|
734
|
-
if current_length + content_length > available_length:
|
|
735
|
-
PrettyOutput.print("Due to context length limit, some related document fragments were omitted",
|
|
736
|
-
output_type=OutputType.WARNING)
|
|
737
|
-
break
|
|
738
|
-
|
|
739
|
-
context.append(doc_content)
|
|
740
|
-
current_length += content_length
|
|
741
|
-
|
|
742
|
-
# Build complete prompt
|
|
743
|
-
prompt = base_prompt + ''.join(context) + end_prompt
|
|
744
|
-
|
|
745
|
-
# Get model instance and generate answer
|
|
746
|
-
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
747
|
-
response = model.chat_until_success(prompt)
|
|
748
|
-
|
|
749
|
-
return response
|
|
750
|
-
|
|
751
|
-
except Exception as e:
|
|
752
|
-
PrettyOutput.print(f"Failed to answer: {str(e)}", output_type=OutputType.ERROR)
|
|
753
|
-
return None
|
|
754
|
-
|
|
755
|
-
def main():
|
|
756
|
-
"""Main function"""
|
|
757
|
-
import argparse
|
|
758
|
-
import sys
|
|
759
|
-
|
|
760
|
-
# Set standard output encoding to UTF-8
|
|
761
|
-
if sys.stdout.encoding != 'utf-8':
|
|
762
|
-
import codecs
|
|
763
|
-
sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
|
|
764
|
-
sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')
|
|
765
|
-
|
|
766
|
-
parser = argparse.ArgumentParser(description='Document retrieval and analysis tool')
|
|
767
|
-
parser.add_argument('--dir', type=str, help='Directory to process')
|
|
768
|
-
parser.add_argument('--build', action='store_true', help='Build document index')
|
|
769
|
-
parser.add_argument('--search', type=str, help='Search document content')
|
|
770
|
-
parser.add_argument('--ask', type=str, help='Ask about documents')
|
|
771
|
-
args = parser.parse_args()
|
|
772
|
-
|
|
773
|
-
try:
|
|
774
|
-
current_dir = os.getcwd()
|
|
775
|
-
rag = RAGTool(current_dir)
|
|
776
|
-
|
|
777
|
-
if not args.dir:
|
|
778
|
-
args.dir = current_dir
|
|
779
|
-
|
|
780
|
-
if args.dir and args.build:
|
|
781
|
-
PrettyOutput.print(f"Processing directory: {args.dir}", output_type=OutputType.INFO)
|
|
782
|
-
rag.build_index(args.dir)
|
|
783
|
-
return 0
|
|
784
|
-
|
|
785
|
-
if args.search or args.ask:
|
|
786
|
-
|
|
787
|
-
if args.search:
|
|
788
|
-
results = rag.query(args.search)
|
|
789
|
-
if not results:
|
|
790
|
-
PrettyOutput.print("No related content found", output_type=OutputType.WARNING)
|
|
791
|
-
return 1
|
|
792
|
-
|
|
793
|
-
for doc in results:
|
|
794
|
-
PrettyOutput.print(f"\nFile: {doc.metadata['file_path']}", output_type=OutputType.INFO)
|
|
795
|
-
PrettyOutput.print(f"Fragment {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}",
|
|
796
|
-
output_type=OutputType.INFO)
|
|
797
|
-
PrettyOutput.print("\nContent:", output_type=OutputType.INFO)
|
|
798
|
-
content = doc.content.encode('utf-8', errors='replace').decode('utf-8')
|
|
799
|
-
PrettyOutput.print(content, output_type=OutputType.INFO)
|
|
800
|
-
return 0
|
|
801
|
-
|
|
802
|
-
if args.ask:
|
|
803
|
-
# Call ask method
|
|
804
|
-
response = rag.ask(args.ask)
|
|
805
|
-
if not response:
|
|
806
|
-
PrettyOutput.print("Failed to get answer", output_type=OutputType.WARNING)
|
|
807
|
-
return 1
|
|
808
|
-
|
|
809
|
-
# Display answer
|
|
810
|
-
PrettyOutput.print("\nAnswer:", output_type=OutputType.INFO)
|
|
811
|
-
PrettyOutput.print(response, output_type=OutputType.INFO)
|
|
812
|
-
return 0
|
|
813
|
-
|
|
814
|
-
PrettyOutput.print("Please specify operation parameters. Use -h to view help.", output_type=OutputType.WARNING)
|
|
815
|
-
return 1
|
|
816
|
-
|
|
817
|
-
except Exception as e:
|
|
818
|
-
PrettyOutput.print(f"Failed to execute: {str(e)}", output_type=OutputType.ERROR)
|
|
819
|
-
return 1
|
|
820
|
-
|
|
821
|
-
if __name__ == "__main__":
|
|
822
|
-
main()
|