quantalogic 0.59.3__py3-none-any.whl → 0.60.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/agent.py +268 -24
- quantalogic/create_custom_agent.py +26 -78
- quantalogic/prompts/chat_system_prompt.j2 +10 -7
- quantalogic/prompts/code_2_system_prompt.j2 +190 -0
- quantalogic/prompts/code_system_prompt.j2 +142 -0
- quantalogic/prompts/doc_system_prompt.j2 +178 -0
- quantalogic/prompts/legal_2_system_prompt.j2 +218 -0
- quantalogic/prompts/legal_system_prompt.j2 +140 -0
- quantalogic/prompts/system_prompt.j2 +6 -2
- quantalogic/prompts/tools_prompt.j2 +2 -4
- quantalogic/prompts.py +23 -4
- quantalogic/server/agent_server.py +1 -1
- quantalogic/tools/__init__.py +2 -0
- quantalogic/tools/duckduckgo_search_tool.py +1 -0
- quantalogic/tools/execute_bash_command_tool.py +114 -57
- quantalogic/tools/file_tracker_tool.py +49 -0
- quantalogic/tools/google_packages/google_news_tool.py +3 -0
- quantalogic/tools/image_generation/dalle_e.py +89 -137
- quantalogic/tools/rag_tool/__init__.py +2 -9
- quantalogic/tools/rag_tool/document_rag_sources_.py +728 -0
- quantalogic/tools/rag_tool/ocr_pdf_markdown.py +144 -0
- quantalogic/tools/replace_in_file_tool.py +1 -1
- quantalogic/tools/terminal_capture_tool.py +293 -0
- quantalogic/tools/tool.py +4 -0
- quantalogic/tools/utilities/__init__.py +2 -0
- quantalogic/tools/utilities/download_file_tool.py +3 -5
- quantalogic/tools/utilities/llm_tool.py +283 -0
- quantalogic/tools/utilities/selenium_tool.py +296 -0
- quantalogic/tools/utilities/vscode_tool.py +1 -1
- quantalogic/tools/web_navigation/__init__.py +5 -0
- quantalogic/tools/web_navigation/web_tool.py +145 -0
- quantalogic/tools/write_file_tool.py +72 -36
- {quantalogic-0.59.3.dist-info → quantalogic-0.60.0.dist-info}/METADATA +1 -1
- {quantalogic-0.59.3.dist-info → quantalogic-0.60.0.dist-info}/RECORD +37 -28
- quantalogic/tools/rag_tool/document_metadata.py +0 -15
- quantalogic/tools/rag_tool/query_response.py +0 -20
- quantalogic/tools/rag_tool/rag_tool.py +0 -566
- quantalogic/tools/rag_tool/rag_tool_beta.py +0 -264
- {quantalogic-0.59.3.dist-info → quantalogic-0.60.0.dist-info}/LICENSE +0 -0
- {quantalogic-0.59.3.dist-info → quantalogic-0.60.0.dist-info}/WHEEL +0 -0
- {quantalogic-0.59.3.dist-info → quantalogic-0.60.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,728 @@
|
|
1
|
+
"""Multilingual RAG Tool optimized for French and Arabic using HuggingFace models.
|
2
|
+
|
3
|
+
This tool provides enhanced RAG capabilities with:
|
4
|
+
- Multilingual support (French/Arabic) using specialized embedding models
|
5
|
+
- Improved query processing with source attribution
|
6
|
+
- Persistent ChromaDB storage
|
7
|
+
- Enhanced response formatting
|
8
|
+
"""
|
9
|
+
|
10
|
+
import os
|
11
|
+
from typing import List, Optional, Dict, Any
|
12
|
+
from dataclasses import dataclass
|
13
|
+
import asyncio
|
14
|
+
import shutil
|
15
|
+
import json
|
16
|
+
from datetime import datetime
|
17
|
+
|
18
|
+
import chromadb
|
19
|
+
from sentence_transformers import SentenceTransformer
|
20
|
+
from llama_index.core import (
|
21
|
+
SimpleDirectoryReader,
|
22
|
+
StorageContext,
|
23
|
+
VectorStoreIndex,
|
24
|
+
load_index_from_storage,
|
25
|
+
Response,
|
26
|
+
QueryBundle,
|
27
|
+
Settings,
|
28
|
+
Document,
|
29
|
+
)
|
30
|
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
31
|
+
from llama_index.vector_stores.chroma import ChromaVectorStore
|
32
|
+
from llama_index.core.node_parser import SentenceSplitter
|
33
|
+
from llama_index.core.postprocessor import SimilarityPostprocessor
|
34
|
+
from llama_index.readers.file.docs import PDFReader
|
35
|
+
from loguru import logger
|
36
|
+
from quantalogic.tools.tool import Tool, ToolArgument
|
37
|
+
from quantalogic.tools.rag_tool.ocr_pdf_markdown import PDFToMarkdownConverter
|
38
|
+
from rank_bm25 import BM25Okapi
|
39
|
+
from sklearn.preprocessing import MinMaxScaler
|
40
|
+
import numpy as np
|
41
|
+
|
42
|
+
# Configure tool-specific logging
|
43
|
+
logger.remove()
|
44
|
+
logger.add(
|
45
|
+
sink=lambda msg: print(msg, end=""),
|
46
|
+
level="INFO",
|
47
|
+
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan> - <level>{message}</level>"
|
48
|
+
)
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class LawSource:
|
52
|
+
"""Structured representation of a law source."""
|
53
|
+
content: str
|
54
|
+
file_name: str
|
55
|
+
page_number: str
|
56
|
+
reference_number: Optional[str] = None
|
57
|
+
score: Optional[float] = None
|
58
|
+
|
59
|
+
@dataclass
|
60
|
+
class SearchResult:
|
61
|
+
"""Represents a single search result with combined scores."""
|
62
|
+
content: str
|
63
|
+
file_name: str
|
64
|
+
page_number: str
|
65
|
+
reference_number: Optional[str] = None
|
66
|
+
bm25_score: float = 0.0
|
67
|
+
embedding_score: float = 0.0
|
68
|
+
combined_score: float = 0.0
|
69
|
+
metadata: Dict[str, Any] = None
|
70
|
+
|
71
|
+
class RagToolHf_(Tool):
|
72
|
+
"""Enhanced RAG tool specialized for law source retrieval."""
|
73
|
+
|
74
|
+
name: str = "rag_tool_hf"
|
75
|
+
description: str = (
|
76
|
+
"Specialized RAG tool for retrieving and analyzing sources "
|
77
|
+
"from documents with detailed source attribution."
|
78
|
+
)
|
79
|
+
arguments: List[ToolArgument] = [
|
80
|
+
ToolArgument(
|
81
|
+
name="query",
|
82
|
+
arg_type="string",
|
83
|
+
description="Query to search for specific sources",
|
84
|
+
required=True,
|
85
|
+
example="Find articles related to environmental protection",
|
86
|
+
),
|
87
|
+
ToolArgument(
|
88
|
+
name="max_sources",
|
89
|
+
arg_type="int",
|
90
|
+
description="Maximum number of sources to return",
|
91
|
+
required=False,
|
92
|
+
example="5",
|
93
|
+
),
|
94
|
+
]
|
95
|
+
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
name: str = "rag_tool_hf",
|
99
|
+
persist_dir: str = "./storage/multilingual_rag",
|
100
|
+
document_paths: Optional[List[str]] = None,
|
101
|
+
chunk_size: int = 512,
|
102
|
+
chunk_overlap: int = 50,
|
103
|
+
use_ocr_for_pdfs: bool = False,
|
104
|
+
ocr_model: str = "openai/gpt-4o-mini",
|
105
|
+
embed_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
106
|
+
force_reindex: bool = False
|
107
|
+
):
|
108
|
+
"""Initialize the multilingual RAG tool.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
force_reindex: If True, forces reindexing even if embeddings exist
|
112
|
+
"""
|
113
|
+
super().__init__()
|
114
|
+
self.name = name
|
115
|
+
self.persist_dir = os.path.abspath(persist_dir)
|
116
|
+
self.use_ocr_for_pdfs = use_ocr_for_pdfs
|
117
|
+
self.ocr_model = ocr_model
|
118
|
+
|
119
|
+
# Check if we need to reindex
|
120
|
+
chroma_persist_dir = os.path.join(self.persist_dir, "chroma")
|
121
|
+
embedding_config_path = os.path.join(self.persist_dir, "embedding_config.json")
|
122
|
+
needs_reindex = False
|
123
|
+
|
124
|
+
if os.path.exists(embedding_config_path):
|
125
|
+
try:
|
126
|
+
with open(embedding_config_path, 'r') as f:
|
127
|
+
config = json.load(f)
|
128
|
+
if config.get('embed_model') != embed_model:
|
129
|
+
logger.info(f"Embedding model changed from {config.get('embed_model')} to {embed_model}")
|
130
|
+
needs_reindex = True
|
131
|
+
except Exception as e:
|
132
|
+
logger.warning(f"Failed to read embedding config: {e}")
|
133
|
+
needs_reindex = True
|
134
|
+
else:
|
135
|
+
needs_reindex = True
|
136
|
+
|
137
|
+
# Clean up only if needed
|
138
|
+
if (needs_reindex or force_reindex) and os.path.exists(chroma_persist_dir):
|
139
|
+
logger.info("Cleaning up existing index due to model change or forced reindex")
|
140
|
+
shutil.rmtree(chroma_persist_dir)
|
141
|
+
|
142
|
+
# Save new embedding configuration
|
143
|
+
os.makedirs(os.path.dirname(embedding_config_path), exist_ok=True)
|
144
|
+
with open(embedding_config_path, 'w') as f:
|
145
|
+
json.dump({'embed_model': embed_model}, f)
|
146
|
+
|
147
|
+
# Initialize embedding model
|
148
|
+
self.embed_model = HuggingFaceEmbedding(
|
149
|
+
model_name=embed_model,
|
150
|
+
embed_batch_size=8
|
151
|
+
)
|
152
|
+
|
153
|
+
# Configure ChromaDB
|
154
|
+
os.makedirs(chroma_persist_dir, exist_ok=True)
|
155
|
+
chroma_client = chromadb.PersistentClient(path=chroma_persist_dir)
|
156
|
+
collection = chroma_client.create_collection(
|
157
|
+
name="multilingual_collection",
|
158
|
+
get_or_create=True
|
159
|
+
)
|
160
|
+
|
161
|
+
self.vector_store = ChromaVectorStore(chroma_collection=collection)
|
162
|
+
|
163
|
+
# Configure llama-index settings
|
164
|
+
Settings.embed_model = self.embed_model
|
165
|
+
Settings.chunk_size = chunk_size
|
166
|
+
Settings.chunk_overlap = chunk_overlap
|
167
|
+
Settings.num_output = 1024
|
168
|
+
|
169
|
+
self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
|
170
|
+
|
171
|
+
# Initialize text splitter
|
172
|
+
self.text_splitter = SentenceSplitter(
|
173
|
+
chunk_size=chunk_size,
|
174
|
+
chunk_overlap=chunk_overlap,
|
175
|
+
paragraph_separator="\n\n",
|
176
|
+
tokenizer=lambda x: x.replace("\n", " ").split(" ")
|
177
|
+
)
|
178
|
+
|
179
|
+
# Initialize or load index
|
180
|
+
self.index = self._initialize_index(document_paths)
|
181
|
+
|
182
|
+
async def _process_pdf_with_ocr(self, path: str) -> List[Document]:
|
183
|
+
"""Process a PDF file using OCR and convert to Documents."""
|
184
|
+
try:
|
185
|
+
converter = PDFToMarkdownConverter(
|
186
|
+
model=self.ocr_model,
|
187
|
+
custom_system_prompt=(
|
188
|
+
"Convert the PDF page to clean, well-formatted text. "
|
189
|
+
"Preserve all content including tables, lists, and mathematical notation. "
|
190
|
+
"For images and charts, provide detailed descriptions. "
|
191
|
+
"Maintain the original document structure and hierarchy."
|
192
|
+
)
|
193
|
+
)
|
194
|
+
|
195
|
+
markdown_content = await converter.convert_pdf(path)
|
196
|
+
if not markdown_content:
|
197
|
+
logger.warning(f"OCR produced no content for {path}")
|
198
|
+
return []
|
199
|
+
|
200
|
+
# Create a single document with the full content
|
201
|
+
doc = Document(
|
202
|
+
text=markdown_content,
|
203
|
+
metadata={
|
204
|
+
"file_name": os.path.basename(path),
|
205
|
+
"file_path": path,
|
206
|
+
"processing_method": "ocr"
|
207
|
+
}
|
208
|
+
)
|
209
|
+
return [doc]
|
210
|
+
|
211
|
+
except Exception as e:
|
212
|
+
logger.error(f"Error processing PDF with OCR {path}: {e}")
|
213
|
+
return []
|
214
|
+
|
215
|
+
def _load_documents(self, document_paths: List[str]) -> List[Document]:
|
216
|
+
"""Load documents with special handling for PDFs."""
|
217
|
+
all_documents = []
|
218
|
+
pdf_reader = PDFReader()
|
219
|
+
|
220
|
+
for path in document_paths:
|
221
|
+
if not os.path.exists(path):
|
222
|
+
logger.warning(f"Document path does not exist: {path}")
|
223
|
+
continue
|
224
|
+
|
225
|
+
try:
|
226
|
+
if path.lower().endswith('.pdf'):
|
227
|
+
if self.use_ocr_for_pdfs:
|
228
|
+
# Use asyncio to run the async OCR function
|
229
|
+
docs = asyncio.run(self._process_pdf_with_ocr(path))
|
230
|
+
else:
|
231
|
+
# Use standard PDF reader
|
232
|
+
docs = pdf_reader.load_data(
|
233
|
+
path,
|
234
|
+
extra_info={
|
235
|
+
"file_name": os.path.basename(path),
|
236
|
+
"file_path": path,
|
237
|
+
"processing_method": "standard"
|
238
|
+
}
|
239
|
+
)
|
240
|
+
|
241
|
+
if not self.use_ocr_for_pdfs:
|
242
|
+
# Process each page to improve text quality (only for standard PDF reader)
|
243
|
+
processed_docs = []
|
244
|
+
for doc in docs:
|
245
|
+
# Clean up text
|
246
|
+
text = doc.text
|
247
|
+
text = text.replace('\n\n', '[PAGE_BREAK]')
|
248
|
+
text = text.replace('\n', ' ')
|
249
|
+
text = text.replace('[PAGE_BREAK]', '\n\n')
|
250
|
+
text = ' '.join(text.split())
|
251
|
+
|
252
|
+
processed_doc = Document(
|
253
|
+
text=text,
|
254
|
+
metadata={
|
255
|
+
**doc.metadata,
|
256
|
+
"file_name": os.path.basename(path),
|
257
|
+
"file_path": path,
|
258
|
+
"page_number": doc.metadata.get("page_number", "unknown"),
|
259
|
+
"processing_method": "standard"
|
260
|
+
}
|
261
|
+
)
|
262
|
+
processed_docs.append(processed_doc)
|
263
|
+
docs = processed_docs
|
264
|
+
else:
|
265
|
+
docs = SimpleDirectoryReader(
|
266
|
+
input_files=[path],
|
267
|
+
filename_as_id=True,
|
268
|
+
file_metadata=lambda x: {"file_name": os.path.basename(x), "file_path": x}
|
269
|
+
).load_data()
|
270
|
+
|
271
|
+
all_documents.extend(docs)
|
272
|
+
|
273
|
+
# Log document details
|
274
|
+
for doc in docs:
|
275
|
+
logger.debug(f"Document content length: {len(doc.text)} characters")
|
276
|
+
logger.debug(f"Document metadata: {doc.metadata}")
|
277
|
+
preview = doc.text[:200].replace('\n', ' ').strip()
|
278
|
+
logger.debug(f"Content preview: {preview}...")
|
279
|
+
|
280
|
+
except Exception as e:
|
281
|
+
logger.error(f"Error loading document {path}: {str(e)}")
|
282
|
+
continue
|
283
|
+
|
284
|
+
return all_documents
|
285
|
+
|
286
|
+
def _initialize_index(self, document_paths: Optional[List[str]]) -> Optional[VectorStoreIndex]:
|
287
|
+
"""Initialize or load the vector index."""
|
288
|
+
logger.info("Initializing index...")
|
289
|
+
|
290
|
+
if document_paths:
|
291
|
+
return self._create_index(document_paths)
|
292
|
+
|
293
|
+
# Try loading existing index
|
294
|
+
index_path = os.path.join(self.persist_dir, "docstore.json")
|
295
|
+
if os.path.exists(index_path):
|
296
|
+
try:
|
297
|
+
return load_index_from_storage(storage_context=self.storage_context)
|
298
|
+
except Exception as e:
|
299
|
+
logger.error(f"Failed to load existing index: {str(e)}")
|
300
|
+
else:
|
301
|
+
logger.warning("No existing index found and no documents provided")
|
302
|
+
|
303
|
+
return None
|
304
|
+
|
305
|
+
def _create_index(self, document_paths: List[str]) -> Optional[VectorStoreIndex]:
|
306
|
+
"""Create a new index from documents."""
|
307
|
+
try:
|
308
|
+
all_documents = self._load_documents(document_paths)
|
309
|
+
|
310
|
+
if not all_documents:
|
311
|
+
logger.warning("No valid documents found")
|
312
|
+
return None
|
313
|
+
|
314
|
+
total_chunks = 0
|
315
|
+
for doc in all_documents:
|
316
|
+
chunks = self.text_splitter.split_text(doc.text)
|
317
|
+
total_chunks += len(chunks)
|
318
|
+
logger.debug(f"Created {len(chunks)} chunks from document {doc.metadata.get('file_name', 'unknown')}")
|
319
|
+
for i, chunk in enumerate(chunks[:2]): # Log only first 2 chunks as preview
|
320
|
+
logger.debug(f"Chunk {i+1} preview ({len(chunk)} chars): {chunk[:100]}...")
|
321
|
+
|
322
|
+
logger.info(f"Total chunks created: {total_chunks}")
|
323
|
+
logger.info("Creating vector index...")
|
324
|
+
|
325
|
+
index = VectorStoreIndex.from_documents(
|
326
|
+
all_documents,
|
327
|
+
storage_context=self.storage_context,
|
328
|
+
transformations=[self.text_splitter],
|
329
|
+
show_progress=True
|
330
|
+
)
|
331
|
+
|
332
|
+
self.storage_context.persist(persist_dir=self.persist_dir)
|
333
|
+
logger.info(f"Created and persisted index with {len(all_documents)} documents")
|
334
|
+
|
335
|
+
return index
|
336
|
+
|
337
|
+
except Exception as e:
|
338
|
+
logger.error(f"Error creating index: {str(e)}")
|
339
|
+
return None
|
340
|
+
|
341
|
+
def _extract_law_reference(self, text: str) -> Optional[str]:
|
342
|
+
"""Extract law reference numbers from text."""
|
343
|
+
import re
|
344
|
+
|
345
|
+
# Common patterns for law references
|
346
|
+
patterns = [
|
347
|
+
r'(?:loi|décret|arrêté)\s+n[°o]?\s*(\d+[-./]\d+)', # French
|
348
|
+
r'(?:قانون|مرسوم|قرار)\s+(?:رقم\s+)?(\d+[-./]\d+)', # Arabic
|
349
|
+
r'(?:law|decree)\s+(?:no\.\s+)?(\d+[-./]\d+)', # English
|
350
|
+
]
|
351
|
+
|
352
|
+
for pattern in patterns:
|
353
|
+
match = re.search(pattern, text.lower())
|
354
|
+
if match:
|
355
|
+
return match.group(1)
|
356
|
+
return None
|
357
|
+
|
358
|
+
def execute(self, query: str, max_sources: int = 5) -> str:
|
359
|
+
"""
|
360
|
+
Execute a search for sources and return a JSON string of law sources.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
query: Search query for finding relevant law sources
|
364
|
+
max_sources: Maximum number of sources to return
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
JSON string containing an array of law sources with their content and metadata
|
368
|
+
"""
|
369
|
+
try:
|
370
|
+
if not self.index:
|
371
|
+
raise ValueError("No index available. Please add documents first.")
|
372
|
+
|
373
|
+
logger.info(f"Searching for sources with query: {query}")
|
374
|
+
|
375
|
+
query_engine = self.index.as_query_engine(
|
376
|
+
similarity_top_k=max_sources,
|
377
|
+
node_postprocessors=[
|
378
|
+
SimilarityPostprocessor(similarity_cutoff=0.1)
|
379
|
+
],
|
380
|
+
response_mode="no_text",
|
381
|
+
streaming=False,
|
382
|
+
verbose=True
|
383
|
+
)
|
384
|
+
|
385
|
+
response = query_engine.query(query)
|
386
|
+
|
387
|
+
# Process sources
|
388
|
+
processed_sources = []
|
389
|
+
for node in response.source_nodes:
|
390
|
+
if node.score < 0.1:
|
391
|
+
continue
|
392
|
+
|
393
|
+
# Extract reference number once to avoid duplicate processing
|
394
|
+
ref_number = self._extract_law_reference(node.node.text)
|
395
|
+
|
396
|
+
# Create a dictionary with source information
|
397
|
+
source_data = {
|
398
|
+
'content': node.node.text.strip(),
|
399
|
+
'file_path': node.node.metadata.get('file_path', ''),
|
400
|
+
'file_name': node.node.metadata.get('file_name', 'Unknown'),
|
401
|
+
'page_number': str(node.node.metadata.get('page_number', 'N/A')),
|
402
|
+
'reference_number': ref_number,
|
403
|
+
'score': float(node.score) if node.score else 0.0,
|
404
|
+
'metadata': {
|
405
|
+
'source_type': 'law_document',
|
406
|
+
'processing_method': node.node.metadata.get('processing_method', 'standard'),
|
407
|
+
'query': query,
|
408
|
+
'timestamp': str(datetime.now().isoformat())
|
409
|
+
}
|
410
|
+
}
|
411
|
+
processed_sources.append(source_data)
|
412
|
+
|
413
|
+
# Sort sources by score
|
414
|
+
processed_sources.sort(key=lambda x: x['score'], reverse=True)
|
415
|
+
|
416
|
+
logger.info(f"Found {len(processed_sources)} relevant law sources for query: {query}")
|
417
|
+
return json.dumps(processed_sources, indent=4, ensure_ascii=False)
|
418
|
+
|
419
|
+
except Exception as e:
|
420
|
+
error_msg = str(e)
|
421
|
+
logger.error(f"Source search failed: {error_msg}")
|
422
|
+
error_response = {
|
423
|
+
'error': error_msg,
|
424
|
+
'query': query,
|
425
|
+
'timestamp': str(datetime.now().isoformat()),
|
426
|
+
'sources': []
|
427
|
+
}
|
428
|
+
return json.dumps(error_response, indent=4, ensure_ascii=False)
|
429
|
+
|
430
|
+
def format_sources(self, sources: List[LawSource]) -> str:
|
431
|
+
"""Format a list of LawSource objects into a readable string."""
|
432
|
+
if not sources:
|
433
|
+
return "No relevant sources found in the documents."
|
434
|
+
|
435
|
+
output = ["# Sources Found\n"]
|
436
|
+
current_file = None
|
437
|
+
|
438
|
+
for source in sources:
|
439
|
+
if current_file != source.file_name:
|
440
|
+
current_file = source.file_name
|
441
|
+
output.append(f"\n## Document: {source.file_name}\n")
|
442
|
+
|
443
|
+
# Format source information
|
444
|
+
if source.reference_number:
|
445
|
+
output.append(f"**Reference Number:** {source.reference_number}\n")
|
446
|
+
output.append(f"**Page:** {source.page_number}\n")
|
447
|
+
if source.score:
|
448
|
+
output.append(f"**Relevance Score:** {round(source.score * 100, 2)}%\n")
|
449
|
+
output.append(f"\n{source.content}\n")
|
450
|
+
output.append("\n---\n")
|
451
|
+
|
452
|
+
return "\n".join(output)
|
453
|
+
|
454
|
+
def add_documents(self, document_paths: List[str]) -> bool:
|
455
|
+
"""Add new documents to the index."""
|
456
|
+
try:
|
457
|
+
new_index = self._create_index(document_paths)
|
458
|
+
if new_index:
|
459
|
+
self.index = new_index
|
460
|
+
return True
|
461
|
+
return False
|
462
|
+
except Exception as e:
|
463
|
+
logger.error(f"Error adding documents: {str(e)}")
|
464
|
+
return False
|
465
|
+
|
466
|
+
|
467
|
+
class RagToolHf(RagToolHf_):
|
468
|
+
"""Enhanced RAG tool with hybrid BM25 + Embeddings search."""
|
469
|
+
|
470
|
+
def __init__(
|
471
|
+
self,
|
472
|
+
name: str = "hybrid_rag_tool_hf",
|
473
|
+
persist_dir: str = "./storage/hybrid_multilingual_rag",
|
474
|
+
document_paths: Optional[List[str]] = None,
|
475
|
+
chunk_size: int = 512,
|
476
|
+
chunk_overlap: int = 50,
|
477
|
+
use_ocr_for_pdfs: bool = False,
|
478
|
+
ocr_model: str = "openai/gpt-4o-mini",
|
479
|
+
embed_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
480
|
+
bm25_weight: float = 0.3, # Weight for BM25 scores in hybrid ranking
|
481
|
+
embedding_weight: float = 0.7, # Weight for embedding scores in hybrid ranking
|
482
|
+
force_reindex: bool = False
|
483
|
+
):
|
484
|
+
"""Initialize the hybrid RAG tool with both BM25 and embeddings capabilities.
|
485
|
+
|
486
|
+
Args:
|
487
|
+
bm25_weight: Weight for BM25 scores in hybrid ranking (0.0-1.0)
|
488
|
+
embedding_weight: Weight for embedding scores in hybrid ranking (0.0-1.0)
|
489
|
+
force_reindex: If True, forces reindexing even if embeddings exist
|
490
|
+
"""
|
491
|
+
super().__init__(
|
492
|
+
name=name,
|
493
|
+
persist_dir=persist_dir,
|
494
|
+
document_paths=document_paths,
|
495
|
+
chunk_size=chunk_size,
|
496
|
+
chunk_overlap=chunk_overlap,
|
497
|
+
use_ocr_for_pdfs=use_ocr_for_pdfs,
|
498
|
+
ocr_model=ocr_model,
|
499
|
+
embed_model=embed_model,
|
500
|
+
force_reindex=force_reindex
|
501
|
+
)
|
502
|
+
|
503
|
+
self.bm25_weight = bm25_weight
|
504
|
+
self.embedding_weight = embedding_weight
|
505
|
+
|
506
|
+
# Initialize BM25 index and document store
|
507
|
+
self.bm25_index = None
|
508
|
+
self.document_store = []
|
509
|
+
|
510
|
+
# Build BM25 index if we have documents
|
511
|
+
if document_paths:
|
512
|
+
self._build_hybrid_index(document_paths)
|
513
|
+
|
514
|
+
def _build_hybrid_index(self, document_paths: List[str]):
|
515
|
+
"""Build BM25 index and optionally rebuild embedding index."""
|
516
|
+
# Load documents if needed
|
517
|
+
if not self.document_store:
|
518
|
+
documents = self._load_documents(document_paths)
|
519
|
+
|
520
|
+
# Store documents and their text for BM25
|
521
|
+
tokenized_corpus = []
|
522
|
+
|
523
|
+
for doc in documents:
|
524
|
+
# Process text for BM25
|
525
|
+
text = doc.text.lower()
|
526
|
+
tokens = text.split()
|
527
|
+
|
528
|
+
# Store document info
|
529
|
+
self.document_store.append({
|
530
|
+
'text': doc.text,
|
531
|
+
'metadata': doc.metadata,
|
532
|
+
'tokens': tokens
|
533
|
+
})
|
534
|
+
|
535
|
+
tokenized_corpus.append(tokens)
|
536
|
+
|
537
|
+
# Create BM25 index
|
538
|
+
self.bm25_index = BM25Okapi(tokenized_corpus)
|
539
|
+
|
540
|
+
# Rebuild embedding index if needed
|
541
|
+
if self.force_reindex or not self.index:
|
542
|
+
self._create_index(document_paths)
|
543
|
+
|
544
|
+
def _normalize_scores(self, scores: List[float]) -> List[float]:
|
545
|
+
"""Normalize scores to range [0, 1] using min-max scaling."""
|
546
|
+
if not scores:
|
547
|
+
return scores
|
548
|
+
scaler = MinMaxScaler()
|
549
|
+
normalized = scaler.fit_transform(np.array(scores).reshape(-1, 1))
|
550
|
+
return normalized.flatten().tolist()
|
551
|
+
|
552
|
+
def execute(self, query: str, max_sources: int = 5) -> str:
|
553
|
+
"""Execute hybrid search combining BM25 and embedding-based retrieval."""
|
554
|
+
try:
|
555
|
+
if not self.index or not self.bm25_index:
|
556
|
+
raise ValueError("Indices not initialized. Please add documents first.")
|
557
|
+
|
558
|
+
logger.info(f"Executing hybrid search for query: {query}")
|
559
|
+
|
560
|
+
# 1. Get embedding-based results
|
561
|
+
query_engine = self.index.as_query_engine(
|
562
|
+
similarity_top_k=max_sources * 2, # Get more results for reranking
|
563
|
+
node_postprocessors=[
|
564
|
+
SimilarityPostprocessor(similarity_cutoff=0.1)
|
565
|
+
],
|
566
|
+
response_mode="no_text",
|
567
|
+
streaming=False,
|
568
|
+
verbose=True
|
569
|
+
)
|
570
|
+
|
571
|
+
embedding_response = query_engine.query(query)
|
572
|
+
|
573
|
+
# 2. Get BM25 results
|
574
|
+
tokenized_query = query.lower().split()
|
575
|
+
bm25_scores = self.bm25_index.get_scores(tokenized_query)
|
576
|
+
|
577
|
+
# 3. Combine and rank results
|
578
|
+
combined_results = []
|
579
|
+
seen_texts = set()
|
580
|
+
|
581
|
+
# Process embedding results
|
582
|
+
for node in embedding_response.source_nodes:
|
583
|
+
if node.score < 0.1:
|
584
|
+
continue
|
585
|
+
|
586
|
+
text = node.node.text.strip()
|
587
|
+
if text in seen_texts:
|
588
|
+
continue
|
589
|
+
seen_texts.add(text)
|
590
|
+
|
591
|
+
# Find corresponding BM25 score
|
592
|
+
doc_idx = next(
|
593
|
+
(i for i, doc in enumerate(self.document_store)
|
594
|
+
if doc['text'].strip() == text),
|
595
|
+
None
|
596
|
+
)
|
597
|
+
|
598
|
+
bm25_score = bm25_scores[doc_idx] if doc_idx is not None else 0.0
|
599
|
+
|
600
|
+
result = SearchResult(
|
601
|
+
content=text,
|
602
|
+
file_name=node.node.metadata.get('file_name', 'Unknown'),
|
603
|
+
page_number=str(node.node.metadata.get('page_number', 'N/A')),
|
604
|
+
reference_number=self._extract_law_reference(text),
|
605
|
+
bm25_score=bm25_score,
|
606
|
+
embedding_score=float(node.score) if node.score else 0.0,
|
607
|
+
metadata={
|
608
|
+
'source_type': 'law_document',
|
609
|
+
'processing_method': node.node.metadata.get('processing_method', 'standard'),
|
610
|
+
'query': query,
|
611
|
+
'timestamp': str(datetime.now().isoformat())
|
612
|
+
}
|
613
|
+
)
|
614
|
+
combined_results.append(result)
|
615
|
+
|
616
|
+
# Normalize scores
|
617
|
+
if combined_results:
|
618
|
+
bm25_scores = [r.bm25_score for r in combined_results]
|
619
|
+
embedding_scores = [r.embedding_score for r in combined_results]
|
620
|
+
|
621
|
+
normalized_bm25 = self._normalize_scores(bm25_scores)
|
622
|
+
normalized_embedding = self._normalize_scores(embedding_scores)
|
623
|
+
|
624
|
+
# Calculate combined scores
|
625
|
+
for i, result in enumerate(combined_results):
|
626
|
+
result.bm25_score = normalized_bm25[i]
|
627
|
+
result.embedding_score = normalized_embedding[i]
|
628
|
+
result.combined_score = (
|
629
|
+
self.bm25_weight * result.bm25_score +
|
630
|
+
self.embedding_weight * result.embedding_score
|
631
|
+
)
|
632
|
+
|
633
|
+
# Sort by combined score and limit results
|
634
|
+
combined_results.sort(key=lambda x: x.combined_score, reverse=True)
|
635
|
+
combined_results = combined_results[:max_sources]
|
636
|
+
|
637
|
+
# Format results for output
|
638
|
+
output_results = []
|
639
|
+
for result in combined_results:
|
640
|
+
output_results.append({
|
641
|
+
'content': result.content,
|
642
|
+
'file_name': result.file_name,
|
643
|
+
'page_number': result.page_number,
|
644
|
+
'reference_number': result.reference_number,
|
645
|
+
'scores': {
|
646
|
+
'bm25_score': round(result.bm25_score, 4),
|
647
|
+
'embedding_score': round(result.embedding_score, 4),
|
648
|
+
'combined_score': round(result.combined_score, 4)
|
649
|
+
},
|
650
|
+
'metadata': result.metadata
|
651
|
+
})
|
652
|
+
|
653
|
+
logger.info(f"Found {len(output_results)} relevant sources using hybrid search")
|
654
|
+
return json.dumps(output_results, indent=4, ensure_ascii=False)
|
655
|
+
|
656
|
+
except Exception as e:
|
657
|
+
error_msg = str(e)
|
658
|
+
logger.error(f"Hybrid search failed: {error_msg}")
|
659
|
+
error_response = {
|
660
|
+
'error': error_msg,
|
661
|
+
'query': query,
|
662
|
+
'timestamp': str(datetime.now().isoformat()),
|
663
|
+
'sources': []
|
664
|
+
}
|
665
|
+
return json.dumps(error_response, indent=4, ensure_ascii=False)
|
666
|
+
|
667
|
+
if __name__ == "__main__":
|
668
|
+
# Example usage
|
669
|
+
if os.path.exists("./storage/multilingual_rag"):
|
670
|
+
shutil.rmtree("./storage/multilingual_rag")
|
671
|
+
|
672
|
+
tool = RagToolHf_(
|
673
|
+
persist_dir="./storage/multilingual_rag",
|
674
|
+
document_paths=[
|
675
|
+
"./docs/test/F2015054.pdf",
|
676
|
+
"./docs/test/F2015055.pdf"
|
677
|
+
],
|
678
|
+
chunk_size=512,
|
679
|
+
chunk_overlap=50,
|
680
|
+
use_ocr_for_pdfs=False
|
681
|
+
)
|
682
|
+
|
683
|
+
# Test queries
|
684
|
+
test_queries = [
|
685
|
+
"Find articles related to environmental protection",
|
686
|
+
"Search for traffic regulations",
|
687
|
+
"Look for workplace safety laws"
|
688
|
+
]
|
689
|
+
|
690
|
+
for query in test_queries:
|
691
|
+
print(f"\nQuery: {query}")
|
692
|
+
try:
|
693
|
+
result = tool.execute(query, max_sources=2)
|
694
|
+
print(result)
|
695
|
+
except Exception as e:
|
696
|
+
print(f"Error: {str(e)}")
|
697
|
+
|
698
|
+
# Example usage of hybrid tool
|
699
|
+
if os.path.exists("./storage/hybrid_multilingual_rag"):
|
700
|
+
shutil.rmtree("./storage/hybrid_multilingual_rag")
|
701
|
+
|
702
|
+
hybrid_tool = RagToolHf(
|
703
|
+
persist_dir="./storage/hybrid_multilingual_rag",
|
704
|
+
document_paths=[
|
705
|
+
"./docs/test/code_civile.md",
|
706
|
+
"./docs/test/code_procedure.md"
|
707
|
+
],
|
708
|
+
chunk_size=512,
|
709
|
+
chunk_overlap=50,
|
710
|
+
use_ocr_for_pdfs=False,
|
711
|
+
bm25_weight=0.3,
|
712
|
+
embedding_weight=0.7
|
713
|
+
)
|
714
|
+
|
715
|
+
# Test queries
|
716
|
+
test_queries = [
|
717
|
+
"Find articles related to environmental protection",
|
718
|
+
"Search for traffic regulations",
|
719
|
+
"Look for workplace safety laws"
|
720
|
+
]
|
721
|
+
|
722
|
+
for query in test_queries:
|
723
|
+
print(f"\nQuery: {query}")
|
724
|
+
try:
|
725
|
+
result = hybrid_tool.execute(query, max_sources=2)
|
726
|
+
print(result)
|
727
|
+
except Exception as e:
|
728
|
+
print(f"Error: {str(e)}")
|