jarvis-ai-assistant 0.1.107__py3-none-any.whl → 0.1.109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +3 -3
- jarvis/jarvis_code_agent/code_agent.py +69 -217
- jarvis/jarvis_code_agent/file_select.py +11 -10
- jarvis/jarvis_code_agent/patch.py +19 -9
- jarvis/jarvis_code_agent/relevant_files.py +1 -162
- jarvis/jarvis_codebase/main.py +52 -57
- jarvis/jarvis_rag/main.py +193 -267
- jarvis/jarvis_tools/registry.py +8 -7
- jarvis/utils.py +151 -12
- {jarvis_ai_assistant-0.1.107.dist-info → jarvis_ai_assistant-0.1.109.dist-info}/METADATA +12 -3
- {jarvis_ai_assistant-0.1.107.dist-info → jarvis_ai_assistant-0.1.109.dist-info}/RECORD +16 -16
- {jarvis_ai_assistant-0.1.107.dist-info → jarvis_ai_assistant-0.1.109.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.107.dist-info → jarvis_ai_assistant-0.1.109.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.107.dist-info → jarvis_ai_assistant-0.1.109.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.107.dist-info → jarvis_ai_assistant-0.1.109.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/main.py
CHANGED
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
|
3
3
|
import faiss
|
|
4
4
|
from typing import List, Tuple, Optional, Dict
|
|
5
5
|
import pickle
|
|
6
|
-
from jarvis.utils import OutputType, PrettyOutput,
|
|
6
|
+
from jarvis.utils import OutputType, PrettyOutput, get_context_token_count, get_embedding, get_embedding_batch, get_file_md5, get_max_context_length, get_max_paragraph_length, get_min_paragraph_length, get_thread_count, init_gpu_config, load_embedding_model
|
|
7
7
|
from jarvis.utils import init_env
|
|
8
8
|
from dataclasses import dataclass
|
|
9
9
|
from tqdm import tqdm
|
|
@@ -11,13 +11,9 @@ import fitz # PyMuPDF for PDF files
|
|
|
11
11
|
from docx import Document as DocxDocument # python-docx for DOCX files
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
14
|
-
import shutil
|
|
15
|
-
from datetime import datetime
|
|
16
14
|
import lzma # 添加 lzma 导入
|
|
17
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
18
15
|
from threading import Lock
|
|
19
|
-
import
|
|
20
|
-
import re
|
|
16
|
+
import hashlib
|
|
21
17
|
|
|
22
18
|
@dataclass
|
|
23
19
|
class Document:
|
|
@@ -146,7 +142,7 @@ class RAGTool:
|
|
|
146
142
|
# Initialize configuration
|
|
147
143
|
self.min_paragraph_length = get_min_paragraph_length() # Minimum paragraph length
|
|
148
144
|
self.max_paragraph_length = get_max_paragraph_length() # Maximum paragraph length
|
|
149
|
-
self.context_window =
|
|
145
|
+
self.context_window = 5 # Fixed context window size
|
|
150
146
|
self.max_context_length = int(get_max_context_length() * 0.8)
|
|
151
147
|
|
|
152
148
|
# Initialize data directory
|
|
@@ -163,15 +159,18 @@ class RAGTool:
|
|
|
163
159
|
PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
|
|
164
160
|
raise
|
|
165
161
|
|
|
166
|
-
#
|
|
167
|
-
self.
|
|
162
|
+
# 修改缓存相关初始化
|
|
163
|
+
self.cache_dir = os.path.join(self.data_dir, "cache")
|
|
164
|
+
if not os.path.exists(self.cache_dir):
|
|
165
|
+
os.makedirs(self.cache_dir)
|
|
166
|
+
|
|
168
167
|
self.documents: List[Document] = []
|
|
169
|
-
self.index = None
|
|
170
|
-
self.flat_index = None
|
|
171
|
-
self.file_md5_cache = {}
|
|
168
|
+
self.index = None
|
|
169
|
+
self.flat_index = None
|
|
170
|
+
self.file_md5_cache = {}
|
|
172
171
|
|
|
173
|
-
#
|
|
174
|
-
self.
|
|
172
|
+
# 加载缓存索引
|
|
173
|
+
self._load_cache_index()
|
|
175
174
|
|
|
176
175
|
# Register file processors
|
|
177
176
|
self.file_processors = [
|
|
@@ -185,107 +184,99 @@ class RAGTool:
|
|
|
185
184
|
self.vector_lock = Lock() # Protect vector list concurrency
|
|
186
185
|
|
|
187
186
|
# 初始化 GPU 内存配置
|
|
188
|
-
self.gpu_config =
|
|
187
|
+
self.gpu_config = init_gpu_config()
|
|
188
|
+
|
|
189
189
|
|
|
190
|
-
def
|
|
191
|
-
"""
|
|
190
|
+
def _get_cache_path(self, file_path: str) -> str:
|
|
191
|
+
"""Get cache file path for a document
|
|
192
192
|
|
|
193
|
+
Args:
|
|
194
|
+
file_path: Original file path
|
|
195
|
+
|
|
193
196
|
Returns:
|
|
194
|
-
|
|
197
|
+
str: Cache file path
|
|
195
198
|
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
import torch
|
|
205
|
-
if torch.cuda.is_available():
|
|
206
|
-
# 获取GPU信息
|
|
207
|
-
gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
|
208
|
-
config["has_gpu"] = True
|
|
209
|
-
config["device_memory"] = gpu_mem
|
|
210
|
-
|
|
211
|
-
# 估算共享内存 (通常是系统内存的一部分)
|
|
212
|
-
import psutil
|
|
213
|
-
system_memory = psutil.virtual_memory().total
|
|
214
|
-
config["shared_memory"] = min(system_memory * 0.5, gpu_mem * 2) # 取系统内存的50%或GPU内存的2倍中的较小值
|
|
215
|
-
|
|
216
|
-
# 设置CUDA内存分配
|
|
217
|
-
torch.cuda.set_per_process_memory_fraction(config["memory_fraction"])
|
|
218
|
-
torch.cuda.empty_cache()
|
|
219
|
-
|
|
220
|
-
PrettyOutput.print(
|
|
221
|
-
f"GPU initialized: {torch.cuda.get_device_name(0)}\n"
|
|
222
|
-
f"Device Memory: {gpu_mem / 1024**3:.1f}GB\n"
|
|
223
|
-
f"Shared Memory: {config['shared_memory'] / 1024**3:.1f}GB",
|
|
224
|
-
output_type=OutputType.SUCCESS
|
|
225
|
-
)
|
|
226
|
-
else:
|
|
227
|
-
PrettyOutput.print("No GPU available, using CPU mode", output_type=OutputType.WARNING)
|
|
228
|
-
except Exception as e:
|
|
229
|
-
PrettyOutput.print(f"GPU initialization failed: {str(e)}", output_type=OutputType.WARNING)
|
|
230
|
-
|
|
231
|
-
return config
|
|
232
|
-
|
|
233
|
-
def _load_cache(self):
|
|
234
|
-
"""Load cache data"""
|
|
235
|
-
if os.path.exists(self.cache_path):
|
|
199
|
+
# 使用文件路径的哈希作为缓存文件名
|
|
200
|
+
file_hash = hashlib.md5(file_path.encode()).hexdigest()
|
|
201
|
+
return os.path.join(self.cache_dir, f"{file_hash}.cache")
|
|
202
|
+
|
|
203
|
+
def _load_cache_index(self):
|
|
204
|
+
"""Load cache index"""
|
|
205
|
+
index_path = os.path.join(self.data_dir, "index.pkl")
|
|
206
|
+
if os.path.exists(index_path):
|
|
236
207
|
try:
|
|
237
|
-
with lzma.open(
|
|
208
|
+
with lzma.open(index_path, 'rb') as f:
|
|
238
209
|
cache_data = pickle.load(f)
|
|
239
|
-
self.
|
|
240
|
-
vectors = cache_data["vectors"]
|
|
241
|
-
self.file_md5_cache = cache_data.get("file_md5_cache", {}) # 加载MD5缓存
|
|
210
|
+
self.file_md5_cache = cache_data.get("file_md5_cache", {})
|
|
242
211
|
|
|
243
|
-
#
|
|
244
|
-
|
|
245
|
-
self.
|
|
212
|
+
# 从各个缓存文件加载文档
|
|
213
|
+
for file_path in self.file_md5_cache:
|
|
214
|
+
cache_path = self._get_cache_path(file_path)
|
|
215
|
+
if os.path.exists(cache_path):
|
|
216
|
+
try:
|
|
217
|
+
with lzma.open(cache_path, 'rb') as f:
|
|
218
|
+
file_cache = pickle.load(f)
|
|
219
|
+
self.documents.extend(file_cache["documents"])
|
|
220
|
+
except Exception as e:
|
|
221
|
+
PrettyOutput.print(f"Failed to load cache for {file_path}: {str(e)}",
|
|
222
|
+
output_type=OutputType.WARNING)
|
|
223
|
+
|
|
224
|
+
# 重建向量索引
|
|
225
|
+
if self.documents:
|
|
226
|
+
vectors = []
|
|
227
|
+
for doc in self.documents:
|
|
228
|
+
cache_path = self._get_cache_path(doc.metadata['file_path'])
|
|
229
|
+
if os.path.exists(cache_path):
|
|
230
|
+
with lzma.open(cache_path, 'rb') as f:
|
|
231
|
+
file_cache = pickle.load(f)
|
|
232
|
+
doc_idx = next((i for i, d in enumerate(file_cache["documents"])
|
|
233
|
+
if d.metadata['chunk_index'] == doc.metadata['chunk_index']), None)
|
|
234
|
+
if doc_idx is not None:
|
|
235
|
+
vectors.append(file_cache["vectors"][doc_idx])
|
|
236
|
+
|
|
237
|
+
if vectors:
|
|
238
|
+
vectors = np.vstack(vectors)
|
|
239
|
+
self._build_index(vectors)
|
|
240
|
+
|
|
246
241
|
PrettyOutput.print(f"Loaded {len(self.documents)} document fragments",
|
|
247
242
|
output_type=OutputType.INFO)
|
|
243
|
+
|
|
248
244
|
except Exception as e:
|
|
249
|
-
PrettyOutput.print(f"Failed to load cache: {str(e)}",
|
|
245
|
+
PrettyOutput.print(f"Failed to load cache index: {str(e)}",
|
|
250
246
|
output_type=OutputType.WARNING)
|
|
251
247
|
self.documents = []
|
|
252
248
|
self.index = None
|
|
253
249
|
self.flat_index = None
|
|
254
250
|
self.file_md5_cache = {}
|
|
255
251
|
|
|
256
|
-
def _save_cache(self, vectors: np.ndarray):
|
|
257
|
-
"""
|
|
252
|
+
def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray):
|
|
253
|
+
"""Save cache for a single file
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
file_path: File path
|
|
257
|
+
documents: List of documents
|
|
258
|
+
vectors: Document vectors
|
|
259
|
+
"""
|
|
258
260
|
try:
|
|
261
|
+
# 保存文件缓存
|
|
262
|
+
cache_path = self._get_cache_path(file_path)
|
|
259
263
|
cache_data = {
|
|
260
|
-
"
|
|
261
|
-
"
|
|
262
|
-
"documents": self.documents,
|
|
263
|
-
"vectors": vectors.copy() if vectors is not None else None, # Create a copy of the array
|
|
264
|
-
"file_md5_cache": dict(self.file_md5_cache), # Create a copy of the dictionary
|
|
265
|
-
"metadata": {
|
|
266
|
-
"vector_dim": self.vector_dim,
|
|
267
|
-
"total_docs": len(self.documents),
|
|
268
|
-
"model_name": self.embedding_model.__class__.__name__
|
|
269
|
-
}
|
|
264
|
+
"documents": documents,
|
|
265
|
+
"vectors": vectors
|
|
270
266
|
}
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
PrettyOutput.print(f"Cache saved: {len(self.documents)} document fragments",
|
|
284
|
-
output_type=OutputType.INFO)
|
|
267
|
+
with lzma.open(cache_path, 'wb') as f:
|
|
268
|
+
pickle.dump(cache_data, f)
|
|
269
|
+
|
|
270
|
+
# 更新并保存索引
|
|
271
|
+
index_path = os.path.join(self.data_dir, "index.pkl")
|
|
272
|
+
index_data = {
|
|
273
|
+
"file_md5_cache": self.file_md5_cache
|
|
274
|
+
}
|
|
275
|
+
with lzma.open(index_path, 'wb') as f:
|
|
276
|
+
pickle.dump(index_data, f)
|
|
277
|
+
|
|
285
278
|
except Exception as e:
|
|
286
|
-
PrettyOutput.print(f"Failed to save cache: {str(e)}",
|
|
287
|
-
output_type=OutputType.ERROR)
|
|
288
|
-
raise
|
|
279
|
+
PrettyOutput.print(f"Failed to save cache: {str(e)}", output_type=OutputType.ERROR)
|
|
289
280
|
|
|
290
281
|
def _build_index(self, vectors: np.ndarray):
|
|
291
282
|
"""Build FAISS index"""
|
|
@@ -364,106 +355,32 @@ class RAGTool:
|
|
|
364
355
|
|
|
365
356
|
return paragraphs
|
|
366
357
|
|
|
367
|
-
def _get_embedding(self, text: str) -> np.ndarray:
|
|
368
|
-
"""Get the vector representation of the text"""
|
|
369
|
-
embedding = self.embedding_model.encode(text,
|
|
370
|
-
normalize_embeddings=True,
|
|
371
|
-
show_progress_bar=False)
|
|
372
|
-
return np.array(embedding, dtype=np.float32)
|
|
373
|
-
|
|
374
|
-
def _get_embedding_batch(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
|
375
|
-
"""Get embeddings for a batch of texts efficiently"""
|
|
376
|
-
try:
|
|
377
|
-
if self.gpu_config["has_gpu"]:
|
|
378
|
-
import torch
|
|
379
|
-
torch.cuda.empty_cache()
|
|
380
|
-
|
|
381
|
-
# 使用较小的批处理大小
|
|
382
|
-
optimal_batch_size = min(16, len(texts))
|
|
383
|
-
all_embeddings = []
|
|
384
|
-
|
|
385
|
-
with tqdm(total=len(texts), desc="Vectorizing") as pbar:
|
|
386
|
-
for i in range(0, len(texts), optimal_batch_size):
|
|
387
|
-
try:
|
|
388
|
-
batch = texts[i:i + optimal_batch_size]
|
|
389
|
-
embeddings = self.embedding_model.encode(
|
|
390
|
-
batch,
|
|
391
|
-
normalize_embeddings=True,
|
|
392
|
-
show_progress_bar=False,
|
|
393
|
-
batch_size=4, # 减小内部批处理大小
|
|
394
|
-
convert_to_tensor=True
|
|
395
|
-
)
|
|
396
|
-
# 立即移动到 CPU
|
|
397
|
-
embeddings = embeddings.cpu().numpy()
|
|
398
|
-
all_embeddings.append(embeddings)
|
|
399
|
-
pbar.update(len(batch))
|
|
400
|
-
|
|
401
|
-
# 清理 GPU 缓存
|
|
402
|
-
torch.cuda.empty_cache()
|
|
403
|
-
|
|
404
|
-
except RuntimeError as e:
|
|
405
|
-
if "out of memory" in str(e):
|
|
406
|
-
# 如果内存不足,减小批次大小重试
|
|
407
|
-
if optimal_batch_size > 4:
|
|
408
|
-
optimal_batch_size //= 2
|
|
409
|
-
PrettyOutput.print(
|
|
410
|
-
f"CUDA out of memory, reducing batch size to {optimal_batch_size}",
|
|
411
|
-
OutputType.WARNING
|
|
412
|
-
)
|
|
413
|
-
i -= optimal_batch_size # 重试当前批次
|
|
414
|
-
continue
|
|
415
|
-
raise
|
|
416
|
-
|
|
417
|
-
return np.vstack(all_embeddings)
|
|
418
|
-
else:
|
|
419
|
-
# CPU 模式
|
|
420
|
-
return self.embedding_model.encode(
|
|
421
|
-
texts,
|
|
422
|
-
normalize_embeddings=True,
|
|
423
|
-
show_progress_bar=True,
|
|
424
|
-
batch_size=8,
|
|
425
|
-
convert_to_tensor=False
|
|
426
|
-
)
|
|
427
|
-
|
|
428
|
-
except Exception as e:
|
|
429
|
-
PrettyOutput.print(f"Batch embedding failed: {str(e)}", OutputType.ERROR)
|
|
430
|
-
return np.zeros((len(texts), self.vector_dim), dtype=np.float32) # type: ignore
|
|
431
358
|
|
|
432
359
|
def _process_document_batch(self, documents: List[Document]) -> np.ndarray:
|
|
433
|
-
"""Process a batch of documents using shared memory
|
|
434
|
-
|
|
435
|
-
Args:
|
|
436
|
-
documents: List of documents to process
|
|
437
|
-
|
|
438
|
-
Returns:
|
|
439
|
-
np.ndarray: Document vectors
|
|
440
|
-
"""
|
|
360
|
+
"""Process a batch of documents using shared memory"""
|
|
441
361
|
try:
|
|
442
|
-
import torch
|
|
443
|
-
|
|
444
|
-
# 估算内存需求
|
|
445
|
-
total_content_size = sum(len(doc.content) for doc in documents)
|
|
446
|
-
est_memory_needed = total_content_size * 4 # 粗略估计
|
|
447
|
-
|
|
448
|
-
# 如果预估内存超过共享内存限制,分批处理
|
|
449
|
-
if est_memory_needed > self.gpu_config["shared_memory"] * 0.7:
|
|
450
|
-
batch_size = max(1, int(len(documents) * (self.gpu_config["shared_memory"] * 0.7 / est_memory_needed)))
|
|
451
|
-
|
|
452
|
-
all_vectors = []
|
|
453
|
-
for i in range(0, len(documents), batch_size):
|
|
454
|
-
batch = documents[i:i + batch_size]
|
|
455
|
-
vectors = self._process_document_batch(batch)
|
|
456
|
-
all_vectors.append(vectors)
|
|
457
|
-
return np.vstack(all_vectors)
|
|
458
|
-
|
|
459
|
-
# 正常处理单个批次
|
|
460
362
|
texts = []
|
|
363
|
+
self.documents = [] # Reset documents to store chunks
|
|
364
|
+
|
|
461
365
|
for doc in documents:
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
366
|
+
# Split original document into chunks
|
|
367
|
+
chunks = self._split_text(doc.content)
|
|
368
|
+
for chunk_idx, chunk in enumerate(chunks):
|
|
369
|
+
# Create new Document for each chunk
|
|
370
|
+
new_metadata = doc.metadata.copy()
|
|
371
|
+
new_metadata.update({
|
|
372
|
+
'chunk_index': chunk_idx,
|
|
373
|
+
'total_chunks': len(chunks),
|
|
374
|
+
'original_length': len(doc.content)
|
|
375
|
+
})
|
|
376
|
+
self.documents.append(Document(
|
|
377
|
+
content=chunk,
|
|
378
|
+
metadata=new_metadata,
|
|
379
|
+
md5=doc.md5
|
|
380
|
+
))
|
|
381
|
+
texts.append(f"File:{doc.metadata['file_path']} Chunk:{chunk_idx} Content:{chunk}")
|
|
466
382
|
|
|
383
|
+
return get_embedding_batch(self.embedding_model, texts)
|
|
467
384
|
except Exception as e:
|
|
468
385
|
PrettyOutput.print(f"Batch processing failed: {str(e)}", OutputType.ERROR)
|
|
469
386
|
return np.zeros((0, self.vector_dim), dtype=np.float32) # type: ignore
|
|
@@ -572,74 +489,64 @@ class RAGTool:
|
|
|
572
489
|
unchanged_documents = [doc for doc in self.documents
|
|
573
490
|
if doc.metadata['file_path'] in unchanged_files]
|
|
574
491
|
|
|
575
|
-
# Process files
|
|
492
|
+
# Process files one by one with optimized vectorization
|
|
576
493
|
if files_to_process:
|
|
577
494
|
PrettyOutput.print(f"Processing {len(files_to_process)} files...", OutputType.INFO)
|
|
578
495
|
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
combined_text = f"File:{doc.metadata['file_path']} Content:{doc.content}"
|
|
607
|
-
texts_to_vectorize.append(combined_text)
|
|
608
|
-
|
|
609
|
-
# 使用较小的初始批处理大小
|
|
610
|
-
initial_batch_size = min(
|
|
611
|
-
32, # 最大批次大小
|
|
612
|
-
max(4, len(texts_to_vectorize) // 8), # 基于文档数的批次大小
|
|
613
|
-
len(texts_to_vectorize) # 不超过总文档数
|
|
614
|
-
)
|
|
615
|
-
|
|
616
|
-
# 批量处理向量
|
|
617
|
-
vectors = self._get_embedding_batch(texts_to_vectorize, initial_batch_size)
|
|
496
|
+
new_documents = []
|
|
497
|
+
new_vectors = []
|
|
498
|
+
|
|
499
|
+
with tqdm(total=len(files_to_process), desc="Processing files") as pbar:
|
|
500
|
+
for file_path in files_to_process:
|
|
501
|
+
try:
|
|
502
|
+
# Process single file
|
|
503
|
+
file_docs = self._process_file(file_path)
|
|
504
|
+
if file_docs:
|
|
505
|
+
# Vectorize documents from this file
|
|
506
|
+
texts_to_vectorize = [
|
|
507
|
+
f"File:{doc.metadata['file_path']} Content:{doc.content}"
|
|
508
|
+
for doc in file_docs
|
|
509
|
+
]
|
|
510
|
+
file_vectors = get_embedding_batch(self.embedding_model, texts_to_vectorize)
|
|
511
|
+
|
|
512
|
+
# Save cache for this file
|
|
513
|
+
self._save_cache(file_path, file_docs, file_vectors)
|
|
514
|
+
|
|
515
|
+
# Accumulate documents and vectors
|
|
516
|
+
new_documents.extend(file_docs)
|
|
517
|
+
new_vectors.append(file_vectors)
|
|
518
|
+
|
|
519
|
+
except Exception as e:
|
|
520
|
+
PrettyOutput.print(f"Failed to process {file_path}: {str(e)}", OutputType.ERROR)
|
|
521
|
+
|
|
522
|
+
pbar.update(1)
|
|
618
523
|
|
|
619
|
-
|
|
620
|
-
|
|
524
|
+
# Update documents list
|
|
525
|
+
self.documents.extend(new_documents)
|
|
621
526
|
|
|
622
|
-
|
|
527
|
+
# Build final index
|
|
528
|
+
if new_vectors:
|
|
529
|
+
all_new_vectors = np.vstack(new_vectors)
|
|
530
|
+
|
|
623
531
|
if self.flat_index is not None:
|
|
624
|
-
#
|
|
532
|
+
# Get vectors for unchanged documents
|
|
625
533
|
unchanged_vectors = self._get_unchanged_vectors(unchanged_documents)
|
|
626
534
|
if unchanged_vectors is not None:
|
|
627
|
-
final_vectors = np.vstack([unchanged_vectors,
|
|
535
|
+
final_vectors = np.vstack([unchanged_vectors, all_new_vectors])
|
|
628
536
|
else:
|
|
629
|
-
final_vectors =
|
|
537
|
+
final_vectors = all_new_vectors
|
|
630
538
|
else:
|
|
631
|
-
final_vectors =
|
|
539
|
+
final_vectors = all_new_vectors
|
|
632
540
|
|
|
633
|
-
#
|
|
541
|
+
# Build index
|
|
634
542
|
self._build_index(final_vectors)
|
|
635
|
-
self._save_cache(final_vectors)
|
|
636
543
|
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
544
|
+
PrettyOutput.print(
|
|
545
|
+
f"Indexed {len(self.documents)} documents "
|
|
546
|
+
f"(New/Modified: {len(new_documents)}, "
|
|
547
|
+
f"Unchanged: {len(unchanged_documents)})",
|
|
548
|
+
OutputType.SUCCESS
|
|
549
|
+
)
|
|
643
550
|
|
|
644
551
|
def _get_unchanged_vectors(self, unchanged_documents: List[Document]) -> Optional[np.ndarray]:
|
|
645
552
|
"""Get vectors for unchanged documents from existing index"""
|
|
@@ -663,40 +570,62 @@ class RAGTool:
|
|
|
663
570
|
return None
|
|
664
571
|
|
|
665
572
|
def search(self, query: str, top_k: int = 30) -> List[Tuple[Document, float]]:
|
|
666
|
-
"""Search documents
|
|
667
|
-
|
|
668
|
-
Args:
|
|
669
|
-
query: Search query
|
|
670
|
-
top_k: Number of results to return
|
|
671
|
-
"""
|
|
573
|
+
"""Search documents with context window"""
|
|
672
574
|
if not self.index:
|
|
673
575
|
PrettyOutput.print("Index not built, building...", output_type=OutputType.INFO)
|
|
674
576
|
self.build_index(self.root_dir)
|
|
675
577
|
|
|
676
578
|
# Get query vector
|
|
677
|
-
query_vector = self.
|
|
579
|
+
query_vector = get_embedding(self.embedding_model, query)
|
|
678
580
|
query_vector = query_vector.reshape(1, -1)
|
|
679
581
|
|
|
680
582
|
# Search with more candidates
|
|
681
583
|
initial_k = min(top_k * 4, len(self.documents))
|
|
682
584
|
distances, indices = self.index.search(query_vector, initial_k) # type: ignore
|
|
683
585
|
|
|
684
|
-
# Process results
|
|
586
|
+
# Process results with context window
|
|
685
587
|
results = []
|
|
686
588
|
seen_files = set()
|
|
589
|
+
|
|
687
590
|
for idx, dist in zip(indices[0], distances[0]):
|
|
688
591
|
if idx != -1:
|
|
689
592
|
doc = self.documents[idx]
|
|
690
593
|
similarity = 1.0 / (1.0 + float(dist))
|
|
691
|
-
if similarity > 0.3:
|
|
594
|
+
if similarity > 0.3:
|
|
692
595
|
file_path = doc.metadata['file_path']
|
|
693
596
|
if file_path not in seen_files:
|
|
694
597
|
seen_files.add(file_path)
|
|
695
|
-
|
|
696
|
-
|
|
598
|
+
|
|
599
|
+
# Get full context from original document
|
|
600
|
+
original_doc = next((d for d in self.documents
|
|
601
|
+
if d.metadata['file_path'] == file_path), None)
|
|
602
|
+
if original_doc:
|
|
603
|
+
window_docs = [] # Add this line to initialize the list
|
|
604
|
+
full_content = original_doc.content
|
|
605
|
+
# Find all chunks from this file
|
|
606
|
+
file_chunks = [d for d in self.documents
|
|
607
|
+
if d.metadata['file_path'] == file_path]
|
|
608
|
+
# Add all related chunks
|
|
609
|
+
for chunk_doc in file_chunks:
|
|
610
|
+
window_docs.append((chunk_doc, similarity * 0.9))
|
|
611
|
+
|
|
612
|
+
results.extend(window_docs)
|
|
613
|
+
if len(results) >= top_k * (2 * self.context_window + 1):
|
|
697
614
|
break
|
|
698
615
|
|
|
699
|
-
|
|
616
|
+
# Sort by similarity and deduplicate
|
|
617
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
618
|
+
seen = set()
|
|
619
|
+
final_results = []
|
|
620
|
+
for doc, score in results:
|
|
621
|
+
key = (doc.metadata['file_path'], doc.metadata['chunk_index'])
|
|
622
|
+
if key not in seen:
|
|
623
|
+
seen.add(key)
|
|
624
|
+
final_results.append((doc, score))
|
|
625
|
+
if len(final_results) >= top_k:
|
|
626
|
+
break
|
|
627
|
+
|
|
628
|
+
return final_results
|
|
700
629
|
|
|
701
630
|
def query(self, query: str) -> List[Document]:
|
|
702
631
|
"""Query related documents
|
|
@@ -718,13 +647,6 @@ class RAGTool:
|
|
|
718
647
|
if not results:
|
|
719
648
|
return None
|
|
720
649
|
|
|
721
|
-
# 显示找到的文档
|
|
722
|
-
for doc, score in results:
|
|
723
|
-
output = f"""File: {doc.metadata['file_path']} (Score: {score:.3f})\n"""
|
|
724
|
-
output += f"""Fragment {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}\n"""
|
|
725
|
-
output += f"""Content:\n{doc.content}\n"""
|
|
726
|
-
PrettyOutput.print(output, output_type=OutputType.INFO, lang="markdown")
|
|
727
|
-
|
|
728
650
|
# 构建提示词
|
|
729
651
|
prompt = f"""Based on the following document fragments, please answer the user's question accurately and comprehensively.
|
|
730
652
|
|
|
@@ -733,8 +655,8 @@ Question: {question}
|
|
|
733
655
|
Relevant documents (ordered by relevance):
|
|
734
656
|
"""
|
|
735
657
|
# 添加上下文,控制长度
|
|
736
|
-
|
|
737
|
-
|
|
658
|
+
available_count = self.max_context_length - get_context_token_count(prompt) - 1000
|
|
659
|
+
current_count = 0
|
|
738
660
|
|
|
739
661
|
for doc, score in results:
|
|
740
662
|
doc_content = f"""
|
|
@@ -742,7 +664,11 @@ Relevant documents (ordered by relevance):
|
|
|
742
664
|
{doc.content}
|
|
743
665
|
---
|
|
744
666
|
"""
|
|
745
|
-
|
|
667
|
+
prompt += "Answer Format:\n"
|
|
668
|
+
prompt += "1. Answer the question accurately and comprehensively.\n"
|
|
669
|
+
prompt += "2. If the documents don't fully answer the question, please indicate what information is missing.\n"
|
|
670
|
+
prompt += "3. Reference the documents in the answer.\n"
|
|
671
|
+
if current_count + get_context_token_count(doc_content) > available_count:
|
|
746
672
|
PrettyOutput.print(
|
|
747
673
|
"Due to context length limit, some fragments were omitted",
|
|
748
674
|
output_type=OutputType.WARNING
|
|
@@ -750,7 +676,7 @@ Relevant documents (ordered by relevance):
|
|
|
750
676
|
break
|
|
751
677
|
|
|
752
678
|
prompt += doc_content
|
|
753
|
-
|
|
679
|
+
current_count += get_context_token_count(doc_content)
|
|
754
680
|
|
|
755
681
|
prompt += "\nIf the documents don't fully answer the question, please indicate what information is missing."
|
|
756
682
|
|
jarvis/jarvis_tools/registry.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional
|
|
|
5
5
|
|
|
6
6
|
from jarvis.jarvis_platform.registry import PlatformRegistry
|
|
7
7
|
from jarvis.jarvis_tools.base import Tool
|
|
8
|
-
from jarvis.utils import OutputType, PrettyOutput, get_max_context_length
|
|
8
|
+
from jarvis.utils import OutputType, PrettyOutput, get_context_token_count, get_max_context_length
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
tool_call_help = """## Tool Usage Format
|
|
@@ -137,7 +137,8 @@ class ToolRegistry:
|
|
|
137
137
|
hasattr(item, 'name') and
|
|
138
138
|
hasattr(item, 'description') and
|
|
139
139
|
hasattr(item, 'parameters') and
|
|
140
|
-
hasattr(item, 'execute')
|
|
140
|
+
hasattr(item, 'execute') and
|
|
141
|
+
item.name == module_name):
|
|
141
142
|
|
|
142
143
|
if hasattr(item, "check"):
|
|
143
144
|
if not item.check():
|
|
@@ -247,16 +248,16 @@ arguments:
|
|
|
247
248
|
PrettyOutput.section("Execution successful", OutputType.SUCCESS)
|
|
248
249
|
|
|
249
250
|
# If the output exceeds 4k characters, use a large model to summarize
|
|
250
|
-
if
|
|
251
|
+
if get_context_token_count(output) > self.max_context_length:
|
|
251
252
|
try:
|
|
252
253
|
PrettyOutput.print("Output is too long, summarizing...", OutputType.PROGRESS)
|
|
253
254
|
model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
|
|
254
255
|
|
|
255
256
|
# If the output exceeds the maximum context length, only take the last part
|
|
256
|
-
|
|
257
|
-
if
|
|
258
|
-
output_to_summarize = output[-
|
|
259
|
-
truncation_notice = f"\n(Note: Due to the length of the output, only the last {
|
|
257
|
+
max_count = self.max_context_length
|
|
258
|
+
if get_context_token_count(output) > max_count:
|
|
259
|
+
output_to_summarize = output[-max_count:]
|
|
260
|
+
truncation_notice = f"\n(Note: Due to the length of the output, only the last {max_count} characters are summarized)"
|
|
260
261
|
else:
|
|
261
262
|
output_to_summarize = output
|
|
262
263
|
truncation_notice = ""
|