jarvis-ai-assistant 0.1.108__py3-none-any.whl → 0.1.109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

jarvis/jarvis_rag/main.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
3
  import faiss
4
4
  from typing import List, Tuple, Optional, Dict
5
5
  import pickle
6
- from jarvis.utils import OutputType, PrettyOutput, get_context_window, get_file_md5, get_max_context_length, get_max_paragraph_length, get_min_paragraph_length, get_thread_count, load_embedding_model, load_rerank_model
6
+ from jarvis.utils import OutputType, PrettyOutput, get_context_token_count, get_embedding, get_embedding_batch, get_file_md5, get_max_context_length, get_max_paragraph_length, get_min_paragraph_length, get_thread_count, init_gpu_config, load_embedding_model
7
7
  from jarvis.utils import init_env
8
8
  from dataclasses import dataclass
9
9
  from tqdm import tqdm
@@ -11,13 +11,9 @@ import fitz # PyMuPDF for PDF files
11
11
  from docx import Document as DocxDocument # python-docx for DOCX files
12
12
  from pathlib import Path
13
13
  from jarvis.jarvis_platform.registry import PlatformRegistry
14
- import shutil
15
- from datetime import datetime
16
14
  import lzma # 添加 lzma 导入
17
- from concurrent.futures import ThreadPoolExecutor
18
15
  from threading import Lock
19
- import concurrent.futures
20
- import re
16
+ import hashlib
21
17
 
22
18
  @dataclass
23
19
  class Document:
@@ -146,7 +142,7 @@ class RAGTool:
146
142
  # Initialize configuration
147
143
  self.min_paragraph_length = get_min_paragraph_length() # Minimum paragraph length
148
144
  self.max_paragraph_length = get_max_paragraph_length() # Maximum paragraph length
149
- self.context_window = get_context_window() # Context window size, default前后各5个片段
145
+ self.context_window = 5 # Fixed context window size
150
146
  self.max_context_length = int(get_max_context_length() * 0.8)
151
147
 
152
148
  # Initialize data directory
@@ -163,15 +159,18 @@ class RAGTool:
163
159
  PrettyOutput.print(f"Failed to load model: {str(e)}", output_type=OutputType.ERROR)
164
160
  raise
165
161
 
166
- # Initialize cache and index
167
- self.cache_path = os.path.join(self.data_dir, "cache.pkl")
162
+ # 修改缓存相关初始化
163
+ self.cache_dir = os.path.join(self.data_dir, "cache")
164
+ if not os.path.exists(self.cache_dir):
165
+ os.makedirs(self.cache_dir)
166
+
168
167
  self.documents: List[Document] = []
169
- self.index = None # IVF index for search
170
- self.flat_index = None # Store original vectors
171
- self.file_md5_cache = {} # Store file MD5 values
168
+ self.index = None
169
+ self.flat_index = None
170
+ self.file_md5_cache = {}
172
171
 
173
- # Load cache
174
- self._load_cache()
172
+ # 加载缓存索引
173
+ self._load_cache_index()
175
174
 
176
175
  # Register file processors
177
176
  self.file_processors = [
@@ -185,107 +184,99 @@ class RAGTool:
185
184
  self.vector_lock = Lock() # Protect vector list concurrency
186
185
 
187
186
  # 初始化 GPU 内存配置
188
- self.gpu_config = self._init_gpu_config()
187
+ self.gpu_config = init_gpu_config()
188
+
189
189
 
190
- def _init_gpu_config(self) -> Dict:
191
- """Initialize GPU configuration based on available hardware
190
+ def _get_cache_path(self, file_path: str) -> str:
191
+ """Get cache file path for a document
192
192
 
193
+ Args:
194
+ file_path: Original file path
195
+
193
196
  Returns:
194
- Dict: GPU configuration including memory sizes and availability
197
+ str: Cache file path
195
198
  """
196
- config = {
197
- "has_gpu": False,
198
- "shared_memory": 0,
199
- "device_memory": 0,
200
- "memory_fraction": 0.8 # 默认使用80%的可用内存
201
- }
202
-
203
- try:
204
- import torch
205
- if torch.cuda.is_available():
206
- # 获取GPU信息
207
- gpu_mem = torch.cuda.get_device_properties(0).total_memory
208
- config["has_gpu"] = True
209
- config["device_memory"] = gpu_mem
210
-
211
- # 估算共享内存 (通常是系统内存的一部分)
212
- import psutil
213
- system_memory = psutil.virtual_memory().total
214
- config["shared_memory"] = min(system_memory * 0.5, gpu_mem * 2) # 取系统内存的50%或GPU内存的2倍中的较小值
215
-
216
- # 设置CUDA内存分配
217
- torch.cuda.set_per_process_memory_fraction(config["memory_fraction"])
218
- torch.cuda.empty_cache()
219
-
220
- PrettyOutput.print(
221
- f"GPU initialized: {torch.cuda.get_device_name(0)}\n"
222
- f"Device Memory: {gpu_mem / 1024**3:.1f}GB\n"
223
- f"Shared Memory: {config['shared_memory'] / 1024**3:.1f}GB",
224
- output_type=OutputType.SUCCESS
225
- )
226
- else:
227
- PrettyOutput.print("No GPU available, using CPU mode", output_type=OutputType.WARNING)
228
- except Exception as e:
229
- PrettyOutput.print(f"GPU initialization failed: {str(e)}", output_type=OutputType.WARNING)
230
-
231
- return config
232
-
233
- def _load_cache(self):
234
- """Load cache data"""
235
- if os.path.exists(self.cache_path):
199
+ # 使用文件路径的哈希作为缓存文件名
200
+ file_hash = hashlib.md5(file_path.encode()).hexdigest()
201
+ return os.path.join(self.cache_dir, f"{file_hash}.cache")
202
+
203
+ def _load_cache_index(self):
204
+ """Load cache index"""
205
+ index_path = os.path.join(self.data_dir, "index.pkl")
206
+ if os.path.exists(index_path):
236
207
  try:
237
- with lzma.open(self.cache_path, 'rb') as f:
208
+ with lzma.open(index_path, 'rb') as f:
238
209
  cache_data = pickle.load(f)
239
- self.documents = cache_data["documents"]
240
- vectors = cache_data["vectors"]
241
- self.file_md5_cache = cache_data.get("file_md5_cache", {}) # 加载MD5缓存
210
+ self.file_md5_cache = cache_data.get("file_md5_cache", {})
242
211
 
243
- # 重建索引
244
- if vectors is not None:
245
- self._build_index(vectors)
212
+ # 从各个缓存文件加载文档
213
+ for file_path in self.file_md5_cache:
214
+ cache_path = self._get_cache_path(file_path)
215
+ if os.path.exists(cache_path):
216
+ try:
217
+ with lzma.open(cache_path, 'rb') as f:
218
+ file_cache = pickle.load(f)
219
+ self.documents.extend(file_cache["documents"])
220
+ except Exception as e:
221
+ PrettyOutput.print(f"Failed to load cache for {file_path}: {str(e)}",
222
+ output_type=OutputType.WARNING)
223
+
224
+ # 重建向量索引
225
+ if self.documents:
226
+ vectors = []
227
+ for doc in self.documents:
228
+ cache_path = self._get_cache_path(doc.metadata['file_path'])
229
+ if os.path.exists(cache_path):
230
+ with lzma.open(cache_path, 'rb') as f:
231
+ file_cache = pickle.load(f)
232
+ doc_idx = next((i for i, d in enumerate(file_cache["documents"])
233
+ if d.metadata['chunk_index'] == doc.metadata['chunk_index']), None)
234
+ if doc_idx is not None:
235
+ vectors.append(file_cache["vectors"][doc_idx])
236
+
237
+ if vectors:
238
+ vectors = np.vstack(vectors)
239
+ self._build_index(vectors)
240
+
246
241
  PrettyOutput.print(f"Loaded {len(self.documents)} document fragments",
247
242
  output_type=OutputType.INFO)
243
+
248
244
  except Exception as e:
249
- PrettyOutput.print(f"Failed to load cache: {str(e)}",
245
+ PrettyOutput.print(f"Failed to load cache index: {str(e)}",
250
246
  output_type=OutputType.WARNING)
251
247
  self.documents = []
252
248
  self.index = None
253
249
  self.flat_index = None
254
250
  self.file_md5_cache = {}
255
251
 
256
- def _save_cache(self, vectors: np.ndarray):
257
- """Optimize cache saving"""
252
+ def _save_cache(self, file_path: str, documents: List[Document], vectors: np.ndarray):
253
+ """Save cache for a single file
254
+
255
+ Args:
256
+ file_path: File path
257
+ documents: List of documents
258
+ vectors: Document vectors
259
+ """
258
260
  try:
261
+ # 保存文件缓存
262
+ cache_path = self._get_cache_path(file_path)
259
263
  cache_data = {
260
- "version": "1.0",
261
- "timestamp": datetime.now().isoformat(),
262
- "documents": self.documents,
263
- "vectors": vectors.copy() if vectors is not None else None, # Create a copy of the array
264
- "file_md5_cache": dict(self.file_md5_cache), # Create a copy of the dictionary
265
- "metadata": {
266
- "vector_dim": self.vector_dim,
267
- "total_docs": len(self.documents),
268
- "model_name": self.embedding_model.__class__.__name__
269
- }
264
+ "documents": documents,
265
+ "vectors": vectors
270
266
  }
271
-
272
- # First serialize the data to a byte stream
273
- data = pickle.dumps(cache_data, protocol=pickle.HIGHEST_PROTOCOL)
274
-
275
- # Then use LZMA to compress the byte stream
276
- with lzma.open(self.cache_path, 'wb') as f:
277
- f.write(data)
278
-
279
- # Create a backup
280
- backup_path = f"{self.cache_path}.backup"
281
- shutil.copy2(self.cache_path, backup_path)
282
-
283
- PrettyOutput.print(f"Cache saved: {len(self.documents)} document fragments",
284
- output_type=OutputType.INFO)
267
+ with lzma.open(cache_path, 'wb') as f:
268
+ pickle.dump(cache_data, f)
269
+
270
+ # 更新并保存索引
271
+ index_path = os.path.join(self.data_dir, "index.pkl")
272
+ index_data = {
273
+ "file_md5_cache": self.file_md5_cache
274
+ }
275
+ with lzma.open(index_path, 'wb') as f:
276
+ pickle.dump(index_data, f)
277
+
285
278
  except Exception as e:
286
- PrettyOutput.print(f"Failed to save cache: {str(e)}",
287
- output_type=OutputType.ERROR)
288
- raise
279
+ PrettyOutput.print(f"Failed to save cache: {str(e)}", output_type=OutputType.ERROR)
289
280
 
290
281
  def _build_index(self, vectors: np.ndarray):
291
282
  """Build FAISS index"""
@@ -364,106 +355,32 @@ class RAGTool:
364
355
 
365
356
  return paragraphs
366
357
 
367
- def _get_embedding(self, text: str) -> np.ndarray:
368
- """Get the vector representation of the text"""
369
- embedding = self.embedding_model.encode(text,
370
- normalize_embeddings=True,
371
- show_progress_bar=False)
372
- return np.array(embedding, dtype=np.float32)
373
-
374
- def _get_embedding_batch(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
375
- """Get embeddings for a batch of texts efficiently"""
376
- try:
377
- if self.gpu_config["has_gpu"]:
378
- import torch
379
- torch.cuda.empty_cache()
380
-
381
- # 使用较小的批处理大小
382
- optimal_batch_size = min(16, len(texts))
383
- all_embeddings = []
384
-
385
- with tqdm(total=len(texts), desc="Vectorizing") as pbar:
386
- for i in range(0, len(texts), optimal_batch_size):
387
- try:
388
- batch = texts[i:i + optimal_batch_size]
389
- embeddings = self.embedding_model.encode(
390
- batch,
391
- normalize_embeddings=True,
392
- show_progress_bar=False,
393
- batch_size=4, # 减小内部批处理大小
394
- convert_to_tensor=True
395
- )
396
- # 立即移动到 CPU
397
- embeddings = embeddings.cpu().numpy()
398
- all_embeddings.append(embeddings)
399
- pbar.update(len(batch))
400
-
401
- # 清理 GPU 缓存
402
- torch.cuda.empty_cache()
403
-
404
- except RuntimeError as e:
405
- if "out of memory" in str(e):
406
- # 如果内存不足,减小批次大小重试
407
- if optimal_batch_size > 4:
408
- optimal_batch_size //= 2
409
- PrettyOutput.print(
410
- f"CUDA out of memory, reducing batch size to {optimal_batch_size}",
411
- OutputType.WARNING
412
- )
413
- i -= optimal_batch_size # 重试当前批次
414
- continue
415
- raise
416
-
417
- return np.vstack(all_embeddings)
418
- else:
419
- # CPU 模式
420
- return self.embedding_model.encode(
421
- texts,
422
- normalize_embeddings=True,
423
- show_progress_bar=True,
424
- batch_size=8,
425
- convert_to_tensor=False
426
- )
427
-
428
- except Exception as e:
429
- PrettyOutput.print(f"Batch embedding failed: {str(e)}", OutputType.ERROR)
430
- return np.zeros((len(texts), self.vector_dim), dtype=np.float32) # type: ignore
431
358
 
432
359
  def _process_document_batch(self, documents: List[Document]) -> np.ndarray:
433
- """Process a batch of documents using shared memory
434
-
435
- Args:
436
- documents: List of documents to process
437
-
438
- Returns:
439
- np.ndarray: Document vectors
440
- """
360
+ """Process a batch of documents using shared memory"""
441
361
  try:
442
- import torch
443
-
444
- # 估算内存需求
445
- total_content_size = sum(len(doc.content) for doc in documents)
446
- est_memory_needed = total_content_size * 4 # 粗略估计
447
-
448
- # 如果预估内存超过共享内存限制,分批处理
449
- if est_memory_needed > self.gpu_config["shared_memory"] * 0.7:
450
- batch_size = max(1, int(len(documents) * (self.gpu_config["shared_memory"] * 0.7 / est_memory_needed)))
451
-
452
- all_vectors = []
453
- for i in range(0, len(documents), batch_size):
454
- batch = documents[i:i + batch_size]
455
- vectors = self._process_document_batch(batch)
456
- all_vectors.append(vectors)
457
- return np.vstack(all_vectors)
458
-
459
- # 正常处理单个批次
460
362
  texts = []
363
+ self.documents = [] # Reset documents to store chunks
364
+
461
365
  for doc in documents:
462
- combined_text = f"File:{doc.metadata['file_path']} Content:{doc.content}"
463
- texts.append(combined_text)
464
-
465
- return self._get_embedding_batch(texts)
366
+ # Split original document into chunks
367
+ chunks = self._split_text(doc.content)
368
+ for chunk_idx, chunk in enumerate(chunks):
369
+ # Create new Document for each chunk
370
+ new_metadata = doc.metadata.copy()
371
+ new_metadata.update({
372
+ 'chunk_index': chunk_idx,
373
+ 'total_chunks': len(chunks),
374
+ 'original_length': len(doc.content)
375
+ })
376
+ self.documents.append(Document(
377
+ content=chunk,
378
+ metadata=new_metadata,
379
+ md5=doc.md5
380
+ ))
381
+ texts.append(f"File:{doc.metadata['file_path']} Chunk:{chunk_idx} Content:{chunk}")
466
382
 
383
+ return get_embedding_batch(self.embedding_model, texts)
467
384
  except Exception as e:
468
385
  PrettyOutput.print(f"Batch processing failed: {str(e)}", OutputType.ERROR)
469
386
  return np.zeros((0, self.vector_dim), dtype=np.float32) # type: ignore
@@ -572,74 +489,64 @@ class RAGTool:
572
489
  unchanged_documents = [doc for doc in self.documents
573
490
  if doc.metadata['file_path'] in unchanged_files]
574
491
 
575
- # Process files in parallel with optimized vectorization
492
+ # Process files one by one with optimized vectorization
576
493
  if files_to_process:
577
494
  PrettyOutput.print(f"Processing {len(files_to_process)} files...", OutputType.INFO)
578
495
 
579
- # Step 1: 并行提取文本内容
580
- documents_to_process = []
581
- with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
582
- futures = {
583
- executor.submit(self._process_file, file_path): file_path
584
- for file_path in files_to_process
585
- }
586
-
587
- with tqdm(total=len(files_to_process), desc="Extracting text") as pbar:
588
- for future in concurrent.futures.as_completed(futures):
589
- try:
590
- docs = future.result()
591
- if docs:
592
- documents_to_process.extend(docs)
593
- pbar.update(1)
594
- except Exception as e:
595
- PrettyOutput.print(f"File processing failed: {str(e)}", OutputType.ERROR)
596
- pbar.update(1)
597
-
598
- # Step 2: 优化的批量向量化
599
- if documents_to_process:
600
- PrettyOutput.print(f"Vectorizing {len(documents_to_process)} documents...", OutputType.INFO)
601
-
602
- # 准备向量化的文本
603
- texts_to_vectorize = []
604
- for doc in documents_to_process:
605
- # 优化文本组合,减少内存使用
606
- combined_text = f"File:{doc.metadata['file_path']} Content:{doc.content}"
607
- texts_to_vectorize.append(combined_text)
608
-
609
- # 使用较小的初始批处理大小
610
- initial_batch_size = min(
611
- 32, # 最大批次大小
612
- max(4, len(texts_to_vectorize) // 8), # 基于文档数的批次大小
613
- len(texts_to_vectorize) # 不超过总文档数
614
- )
615
-
616
- # 批量处理向量
617
- vectors = self._get_embedding_batch(texts_to_vectorize, initial_batch_size)
496
+ new_documents = []
497
+ new_vectors = []
498
+
499
+ with tqdm(total=len(files_to_process), desc="Processing files") as pbar:
500
+ for file_path in files_to_process:
501
+ try:
502
+ # Process single file
503
+ file_docs = self._process_file(file_path)
504
+ if file_docs:
505
+ # Vectorize documents from this file
506
+ texts_to_vectorize = [
507
+ f"File:{doc.metadata['file_path']} Content:{doc.content}"
508
+ for doc in file_docs
509
+ ]
510
+ file_vectors = get_embedding_batch(self.embedding_model, texts_to_vectorize)
511
+
512
+ # Save cache for this file
513
+ self._save_cache(file_path, file_docs, file_vectors)
514
+
515
+ # Accumulate documents and vectors
516
+ new_documents.extend(file_docs)
517
+ new_vectors.append(file_vectors)
518
+
519
+ except Exception as e:
520
+ PrettyOutput.print(f"Failed to process {file_path}: {str(e)}", OutputType.ERROR)
521
+
522
+ pbar.update(1)
618
523
 
619
- # 更新文档和索引
620
- self.documents.extend(documents_to_process)
524
+ # Update documents list
525
+ self.documents.extend(new_documents)
621
526
 
622
- # 构建最终索引
527
+ # Build final index
528
+ if new_vectors:
529
+ all_new_vectors = np.vstack(new_vectors)
530
+
623
531
  if self.flat_index is not None:
624
- # 获取未更改文档的向量
532
+ # Get vectors for unchanged documents
625
533
  unchanged_vectors = self._get_unchanged_vectors(unchanged_documents)
626
534
  if unchanged_vectors is not None:
627
- final_vectors = np.vstack([unchanged_vectors, vectors])
535
+ final_vectors = np.vstack([unchanged_vectors, all_new_vectors])
628
536
  else:
629
- final_vectors = vectors
537
+ final_vectors = all_new_vectors
630
538
  else:
631
- final_vectors = vectors
539
+ final_vectors = all_new_vectors
632
540
 
633
- # 构建索引并保存缓存
541
+ # Build index
634
542
  self._build_index(final_vectors)
635
- self._save_cache(final_vectors)
636
543
 
637
- PrettyOutput.print(
638
- f"Indexed {len(self.documents)} documents "
639
- f"(New/Modified: {len(documents_to_process)}, "
640
- f"Unchanged: {len(unchanged_documents)})",
641
- OutputType.SUCCESS
642
- )
544
+ PrettyOutput.print(
545
+ f"Indexed {len(self.documents)} documents "
546
+ f"(New/Modified: {len(new_documents)}, "
547
+ f"Unchanged: {len(unchanged_documents)})",
548
+ OutputType.SUCCESS
549
+ )
643
550
 
644
551
  def _get_unchanged_vectors(self, unchanged_documents: List[Document]) -> Optional[np.ndarray]:
645
552
  """Get vectors for unchanged documents from existing index"""
@@ -663,40 +570,62 @@ class RAGTool:
663
570
  return None
664
571
 
665
572
  def search(self, query: str, top_k: int = 30) -> List[Tuple[Document, float]]:
666
- """Search documents using vector similarity
667
-
668
- Args:
669
- query: Search query
670
- top_k: Number of results to return
671
- """
573
+ """Search documents with context window"""
672
574
  if not self.index:
673
575
  PrettyOutput.print("Index not built, building...", output_type=OutputType.INFO)
674
576
  self.build_index(self.root_dir)
675
577
 
676
578
  # Get query vector
677
- query_vector = self._get_embedding(query)
579
+ query_vector = get_embedding(self.embedding_model, query)
678
580
  query_vector = query_vector.reshape(1, -1)
679
581
 
680
582
  # Search with more candidates
681
583
  initial_k = min(top_k * 4, len(self.documents))
682
584
  distances, indices = self.index.search(query_vector, initial_k) # type: ignore
683
585
 
684
- # Process results
586
+ # Process results with context window
685
587
  results = []
686
588
  seen_files = set()
589
+
687
590
  for idx, dist in zip(indices[0], distances[0]):
688
591
  if idx != -1:
689
592
  doc = self.documents[idx]
690
593
  similarity = 1.0 / (1.0 + float(dist))
691
- if similarity > 0.3: # 降低过滤阈值以获取更多结果
594
+ if similarity > 0.3:
692
595
  file_path = doc.metadata['file_path']
693
596
  if file_path not in seen_files:
694
597
  seen_files.add(file_path)
695
- results.append((doc, similarity))
696
- if len(results) >= top_k:
598
+
599
+ # Get full context from original document
600
+ original_doc = next((d for d in self.documents
601
+ if d.metadata['file_path'] == file_path), None)
602
+ if original_doc:
603
+ window_docs = [] # Add this line to initialize the list
604
+ full_content = original_doc.content
605
+ # Find all chunks from this file
606
+ file_chunks = [d for d in self.documents
607
+ if d.metadata['file_path'] == file_path]
608
+ # Add all related chunks
609
+ for chunk_doc in file_chunks:
610
+ window_docs.append((chunk_doc, similarity * 0.9))
611
+
612
+ results.extend(window_docs)
613
+ if len(results) >= top_k * (2 * self.context_window + 1):
697
614
  break
698
615
 
699
- return results
616
+ # Sort by similarity and deduplicate
617
+ results.sort(key=lambda x: x[1], reverse=True)
618
+ seen = set()
619
+ final_results = []
620
+ for doc, score in results:
621
+ key = (doc.metadata['file_path'], doc.metadata['chunk_index'])
622
+ if key not in seen:
623
+ seen.add(key)
624
+ final_results.append((doc, score))
625
+ if len(final_results) >= top_k:
626
+ break
627
+
628
+ return final_results
700
629
 
701
630
  def query(self, query: str) -> List[Document]:
702
631
  """Query related documents
@@ -718,13 +647,6 @@ class RAGTool:
718
647
  if not results:
719
648
  return None
720
649
 
721
- # 显示找到的文档
722
- for doc, score in results:
723
- output = f"""File: {doc.metadata['file_path']} (Score: {score:.3f})\n"""
724
- output += f"""Fragment {doc.metadata['chunk_index'] + 1}/{doc.metadata['total_chunks']}\n"""
725
- output += f"""Content:\n{doc.content}\n"""
726
- PrettyOutput.print(output, output_type=OutputType.INFO, lang="markdown")
727
-
728
650
  # 构建提示词
729
651
  prompt = f"""Based on the following document fragments, please answer the user's question accurately and comprehensively.
730
652
 
@@ -733,8 +655,8 @@ Question: {question}
733
655
  Relevant documents (ordered by relevance):
734
656
  """
735
657
  # 添加上下文,控制长度
736
- available_length = self.max_context_length - len(prompt) - 1000
737
- current_length = 0
658
+ available_count = self.max_context_length - get_context_token_count(prompt) - 1000
659
+ current_count = 0
738
660
 
739
661
  for doc, score in results:
740
662
  doc_content = f"""
@@ -742,7 +664,11 @@ Relevant documents (ordered by relevance):
742
664
  {doc.content}
743
665
  ---
744
666
  """
745
- if current_length + len(doc_content) > available_length:
667
+ prompt += "Answer Format:\n"
668
+ prompt += "1. Answer the question accurately and comprehensively.\n"
669
+ prompt += "2. If the documents don't fully answer the question, please indicate what information is missing.\n"
670
+ prompt += "3. Reference the documents in the answer.\n"
671
+ if current_count + get_context_token_count(doc_content) > available_count:
746
672
  PrettyOutput.print(
747
673
  "Due to context length limit, some fragments were omitted",
748
674
  output_type=OutputType.WARNING
@@ -750,7 +676,7 @@ Relevant documents (ordered by relevance):
750
676
  break
751
677
 
752
678
  prompt += doc_content
753
- current_length += len(doc_content)
679
+ current_count += get_context_token_count(doc_content)
754
680
 
755
681
  prompt += "\nIf the documents don't fully answer the question, please indicate what information is missing."
756
682
 
@@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional
5
5
 
6
6
  from jarvis.jarvis_platform.registry import PlatformRegistry
7
7
  from jarvis.jarvis_tools.base import Tool
8
- from jarvis.utils import OutputType, PrettyOutput, get_max_context_length
8
+ from jarvis.utils import OutputType, PrettyOutput, get_context_token_count, get_max_context_length
9
9
 
10
10
 
11
11
  tool_call_help = """## Tool Usage Format
@@ -137,7 +137,8 @@ class ToolRegistry:
137
137
  hasattr(item, 'name') and
138
138
  hasattr(item, 'description') and
139
139
  hasattr(item, 'parameters') and
140
- hasattr(item, 'execute')):
140
+ hasattr(item, 'execute') and
141
+ item.name == module_name):
141
142
 
142
143
  if hasattr(item, "check"):
143
144
  if not item.check():
@@ -247,16 +248,16 @@ arguments:
247
248
  PrettyOutput.section("Execution successful", OutputType.SUCCESS)
248
249
 
249
250
  # If the output exceeds 4k characters, use a large model to summarize
250
- if len(output) > self.max_context_length:
251
+ if get_context_token_count(output) > self.max_context_length:
251
252
  try:
252
253
  PrettyOutput.print("Output is too long, summarizing...", OutputType.PROGRESS)
253
254
  model = PlatformRegistry.get_global_platform_registry().get_normal_platform()
254
255
 
255
256
  # If the output exceeds the maximum context length, only take the last part
256
- max_len = self.max_context_length
257
- if len(output) > max_len:
258
- output_to_summarize = output[-max_len:]
259
- truncation_notice = f"\n(Note: Due to the length of the output, only the last {max_len} characters are summarized)"
257
+ max_count = self.max_context_length
258
+ if get_context_token_count(output) > max_count:
259
+ output_to_summarize = output[-max_count:]
260
+ truncation_notice = f"\n(Note: Due to the length of the output, only the last {max_count} characters are summarized)"
260
261
  else:
261
262
  output_to_summarize = output
262
263
  truncation_notice = ""