ragit 0.8.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +27 -15
- ragit/assistant.py +431 -40
- ragit/config.py +165 -22
- ragit/core/experiment/experiment.py +7 -1
- ragit/exceptions.py +271 -0
- ragit/loaders.py +200 -44
- ragit/logging.py +194 -0
- ragit/monitor.py +307 -0
- ragit/providers/__init__.py +1 -13
- ragit/providers/ollama.py +379 -121
- ragit/utils/__init__.py +0 -22
- ragit/version.py +1 -1
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/METADATA +48 -25
- ragit-0.11.0.dist-info/RECORD +22 -0
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/WHEEL +1 -1
- ragit/providers/sentence_transformers.py +0 -225
- ragit-0.8.2.dist-info/RECORD +0 -20
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/licenses/LICENSE +0 -0
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/top_level.txt +0 -0
ragit/assistant.py
CHANGED
|
@@ -7,10 +7,16 @@ High-level RAG Assistant for document Q&A and code generation.
|
|
|
7
7
|
|
|
8
8
|
Provides a simple interface for RAG-based tasks.
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
Thread Safety:
|
|
11
|
+
This class uses lock-free atomic operations for thread safety.
|
|
12
|
+
The IndexState is immutable, and all mutations create a new state
|
|
13
|
+
that is atomically swapped. Python's GIL ensures reference assignment
|
|
14
|
+
is atomic, making concurrent reads and writes safe.
|
|
11
15
|
"""
|
|
12
16
|
|
|
17
|
+
import json
|
|
13
18
|
from collections.abc import Callable
|
|
19
|
+
from dataclasses import dataclass
|
|
14
20
|
from pathlib import Path
|
|
15
21
|
from typing import TYPE_CHECKING
|
|
16
22
|
|
|
@@ -18,7 +24,9 @@ import numpy as np
|
|
|
18
24
|
from numpy.typing import NDArray
|
|
19
25
|
|
|
20
26
|
from ragit.core.experiment.experiment import Chunk, Document
|
|
27
|
+
from ragit.exceptions import IndexingError
|
|
21
28
|
from ragit.loaders import chunk_document, chunk_rst_sections, load_directory, load_text
|
|
29
|
+
from ragit.logging import log_operation, logger
|
|
22
30
|
from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
|
|
23
31
|
from ragit.providers.function_adapter import FunctionProvider
|
|
24
32
|
|
|
@@ -26,6 +34,23 @@ if TYPE_CHECKING:
|
|
|
26
34
|
from numpy.typing import NDArray
|
|
27
35
|
|
|
28
36
|
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class IndexState:
|
|
39
|
+
"""Immutable snapshot of index state for lock-free thread safety.
|
|
40
|
+
|
|
41
|
+
This class holds all mutable index data in a single immutable structure.
|
|
42
|
+
Updates create a new IndexState instance, and the reference swap is
|
|
43
|
+
atomic under Python's GIL, ensuring thread-safe reads and writes.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
chunks: Tuple of indexed chunks (immutable).
|
|
47
|
+
embedding_matrix: Pre-normalized numpy array of embeddings, or None if empty.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
chunks: tuple[Chunk, ...]
|
|
51
|
+
embedding_matrix: NDArray[np.float64] | None
|
|
52
|
+
|
|
53
|
+
|
|
29
54
|
class RAGAssistant:
|
|
30
55
|
"""
|
|
31
56
|
High-level RAG assistant for document Q&A and generation.
|
|
@@ -62,9 +87,12 @@ class RAGAssistant:
|
|
|
62
87
|
ValueError
|
|
63
88
|
If neither embed_fn nor provider is provided.
|
|
64
89
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
This class
|
|
90
|
+
Thread Safety
|
|
91
|
+
-------------
|
|
92
|
+
This class uses lock-free atomic operations for thread safety.
|
|
93
|
+
Multiple threads can safely call retrieve() while another thread
|
|
94
|
+
calls add_documents(). The IndexState is immutable, and reference
|
|
95
|
+
swaps are atomic under Python's GIL.
|
|
68
96
|
|
|
69
97
|
Examples
|
|
70
98
|
--------
|
|
@@ -76,13 +104,13 @@ class RAGAssistant:
|
|
|
76
104
|
>>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
|
|
77
105
|
>>> answer = assistant.ask("What is X?")
|
|
78
106
|
>>>
|
|
79
|
-
>>> # With
|
|
107
|
+
>>> # With Ollama provider (supports nomic-embed-text)
|
|
80
108
|
>>> from ragit.providers import OllamaProvider
|
|
81
109
|
>>> assistant = RAGAssistant(docs, provider=OllamaProvider())
|
|
82
110
|
>>>
|
|
83
|
-
>>> #
|
|
84
|
-
>>>
|
|
85
|
-
>>>
|
|
111
|
+
>>> # Save and load index for persistence
|
|
112
|
+
>>> assistant.save_index("/path/to/index")
|
|
113
|
+
>>> loaded = RAGAssistant.load_index("/path/to/index", provider=OllamaProvider())
|
|
86
114
|
"""
|
|
87
115
|
|
|
88
116
|
def __init__(
|
|
@@ -126,8 +154,7 @@ class RAGAssistant:
|
|
|
126
154
|
"Must provide embed_fn or provider for embeddings. "
|
|
127
155
|
"Examples:\n"
|
|
128
156
|
" RAGAssistant(docs, embed_fn=my_embed_function)\n"
|
|
129
|
-
" RAGAssistant(docs, provider=OllamaProvider())
|
|
130
|
-
" RAGAssistant(docs, provider=SentenceTransformersProvider())"
|
|
157
|
+
" RAGAssistant(docs, provider=OllamaProvider())"
|
|
131
158
|
)
|
|
132
159
|
|
|
133
160
|
self.embedding_model = embedding_model or "default"
|
|
@@ -138,9 +165,8 @@ class RAGAssistant:
|
|
|
138
165
|
# Load documents if path provided
|
|
139
166
|
self.documents = self._load_documents(documents)
|
|
140
167
|
|
|
141
|
-
#
|
|
142
|
-
self.
|
|
143
|
-
self._embedding_matrix: NDArray[np.float64] | None = None # Pre-normalized
|
|
168
|
+
# Thread-safe index state (immutable, atomic reference swap)
|
|
169
|
+
self._state: IndexState = IndexState(chunks=(), embedding_matrix=None)
|
|
144
170
|
self._build_index()
|
|
145
171
|
|
|
146
172
|
def _load_documents(self, documents: list[Document] | str | Path) -> list[Document]:
|
|
@@ -175,45 +201,70 @@ class RAGAssistant:
|
|
|
175
201
|
raise ValueError(f"Invalid documents source: {documents}")
|
|
176
202
|
|
|
177
203
|
def _build_index(self) -> None:
|
|
178
|
-
"""Build vector index from documents using batch embedding.
|
|
204
|
+
"""Build vector index from documents using batch embedding.
|
|
205
|
+
|
|
206
|
+
Raises:
|
|
207
|
+
IndexingError: If embedding count doesn't match chunk count.
|
|
208
|
+
"""
|
|
179
209
|
all_chunks: list[Chunk] = []
|
|
180
210
|
|
|
181
211
|
for doc in self.documents:
|
|
182
212
|
# Use RST section chunking for .rst files, otherwise regular chunking
|
|
183
213
|
if doc.metadata.get("filename", "").endswith(".rst"):
|
|
184
|
-
chunks = chunk_rst_sections(doc.content, doc.id
|
|
214
|
+
chunks = chunk_rst_sections(doc.content, doc.id)
|
|
185
215
|
else:
|
|
186
216
|
chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
|
|
187
217
|
all_chunks.extend(chunks)
|
|
188
218
|
|
|
189
219
|
if not all_chunks:
|
|
190
|
-
|
|
191
|
-
self.
|
|
220
|
+
logger.warning("No chunks produced from documents - index will be empty")
|
|
221
|
+
self._state = IndexState(chunks=(), embedding_matrix=None)
|
|
192
222
|
return
|
|
193
223
|
|
|
194
224
|
# Batch embed all chunks at once (single API call)
|
|
195
225
|
texts = [chunk.content for chunk in all_chunks]
|
|
196
226
|
responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
|
|
197
227
|
|
|
228
|
+
# CRITICAL: Validate embedding count matches chunk count
|
|
229
|
+
if len(responses) != len(all_chunks):
|
|
230
|
+
raise IndexingError(
|
|
231
|
+
f"Embedding count mismatch: expected {len(all_chunks)} embeddings, "
|
|
232
|
+
f"got {len(responses)}. Index may be corrupted."
|
|
233
|
+
)
|
|
234
|
+
|
|
198
235
|
# Build embedding matrix directly (skip storing in chunks to avoid duplication)
|
|
199
236
|
embedding_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
|
|
200
237
|
|
|
238
|
+
# Additional validation: matrix shape
|
|
239
|
+
if embedding_matrix.shape[0] != len(all_chunks):
|
|
240
|
+
raise IndexingError(
|
|
241
|
+
f"Matrix row count {embedding_matrix.shape[0]} doesn't match chunk count {len(all_chunks)}"
|
|
242
|
+
)
|
|
243
|
+
|
|
201
244
|
# Pre-normalize for fast cosine similarity (normalize once, use many times)
|
|
202
245
|
norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
|
|
203
246
|
norms[norms == 0] = 1 # Avoid division by zero
|
|
204
247
|
|
|
205
|
-
#
|
|
206
|
-
self.
|
|
207
|
-
|
|
248
|
+
# Atomic state update (thread-safe under GIL)
|
|
249
|
+
self._state = IndexState(
|
|
250
|
+
chunks=tuple(all_chunks),
|
|
251
|
+
embedding_matrix=embedding_matrix / norms,
|
|
252
|
+
)
|
|
208
253
|
|
|
209
254
|
def add_documents(self, documents: list[Document] | str | Path) -> int:
|
|
210
255
|
"""Add documents to the existing index incrementally.
|
|
211
256
|
|
|
257
|
+
This method is thread-safe. It creates a new IndexState and atomically
|
|
258
|
+
swaps the reference, ensuring readers always see a consistent state.
|
|
259
|
+
|
|
212
260
|
Args:
|
|
213
261
|
documents: Documents to add.
|
|
214
262
|
|
|
215
263
|
Returns:
|
|
216
264
|
Number of chunks added.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
IndexingError: If embedding count doesn't match chunk count.
|
|
217
268
|
"""
|
|
218
269
|
new_docs = self._load_documents(documents)
|
|
219
270
|
if not new_docs:
|
|
@@ -225,7 +276,7 @@ class RAGAssistant:
|
|
|
225
276
|
new_chunks: list[Chunk] = []
|
|
226
277
|
for doc in new_docs:
|
|
227
278
|
if doc.metadata.get("filename", "").endswith(".rst"):
|
|
228
|
-
chunks = chunk_rst_sections(doc.content, doc.id
|
|
279
|
+
chunks = chunk_rst_sections(doc.content, doc.id)
|
|
229
280
|
else:
|
|
230
281
|
chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
|
|
231
282
|
new_chunks.extend(chunks)
|
|
@@ -237,6 +288,13 @@ class RAGAssistant:
|
|
|
237
288
|
texts = [chunk.content for chunk in new_chunks]
|
|
238
289
|
responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
|
|
239
290
|
|
|
291
|
+
# Validate embedding count
|
|
292
|
+
if len(responses) != len(new_chunks):
|
|
293
|
+
raise IndexingError(
|
|
294
|
+
f"Embedding count mismatch: expected {len(new_chunks)} embeddings, "
|
|
295
|
+
f"got {len(responses)}. Index update aborted."
|
|
296
|
+
)
|
|
297
|
+
|
|
240
298
|
new_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
|
|
241
299
|
|
|
242
300
|
# Normalize
|
|
@@ -244,21 +302,27 @@ class RAGAssistant:
|
|
|
244
302
|
norms[norms == 0] = 1
|
|
245
303
|
new_matrix_norm = new_matrix / norms
|
|
246
304
|
|
|
247
|
-
#
|
|
248
|
-
|
|
249
|
-
current_chunks.extend(new_chunks)
|
|
250
|
-
self._chunks = tuple(current_chunks)
|
|
305
|
+
# Read current state (atomic read)
|
|
306
|
+
current_state = self._state
|
|
251
307
|
|
|
252
|
-
|
|
253
|
-
|
|
308
|
+
# Build new state
|
|
309
|
+
combined_chunks = current_state.chunks + tuple(new_chunks)
|
|
310
|
+
if current_state.embedding_matrix is None:
|
|
311
|
+
combined_matrix = new_matrix_norm
|
|
254
312
|
else:
|
|
255
|
-
|
|
313
|
+
combined_matrix = np.vstack((current_state.embedding_matrix, new_matrix_norm))
|
|
314
|
+
|
|
315
|
+
# Atomic state swap (thread-safe under GIL)
|
|
316
|
+
self._state = IndexState(chunks=combined_chunks, embedding_matrix=combined_matrix)
|
|
256
317
|
|
|
257
318
|
return len(new_chunks)
|
|
258
319
|
|
|
259
320
|
def remove_documents(self, source_path_pattern: str) -> int:
|
|
260
321
|
"""Remove documents matching a source path pattern.
|
|
261
322
|
|
|
323
|
+
This method is thread-safe. It creates a new IndexState and atomically
|
|
324
|
+
swaps the reference.
|
|
325
|
+
|
|
262
326
|
Args:
|
|
263
327
|
source_path_pattern: Glob pattern to match 'source' metadata.
|
|
264
328
|
|
|
@@ -267,14 +331,17 @@ class RAGAssistant:
|
|
|
267
331
|
"""
|
|
268
332
|
import fnmatch
|
|
269
333
|
|
|
270
|
-
|
|
334
|
+
# Read current state (atomic read)
|
|
335
|
+
current_state = self._state
|
|
336
|
+
|
|
337
|
+
if not current_state.chunks:
|
|
271
338
|
return 0
|
|
272
339
|
|
|
273
340
|
indices_to_keep = []
|
|
274
341
|
kept_chunks = []
|
|
275
342
|
removed_count = 0
|
|
276
343
|
|
|
277
|
-
for i, chunk in enumerate(
|
|
344
|
+
for i, chunk in enumerate(current_state.chunks):
|
|
278
345
|
source = chunk.metadata.get("source", "")
|
|
279
346
|
if not source or not fnmatch.fnmatch(source, source_path_pattern):
|
|
280
347
|
indices_to_keep.append(i)
|
|
@@ -285,13 +352,14 @@ class RAGAssistant:
|
|
|
285
352
|
if removed_count == 0:
|
|
286
353
|
return 0
|
|
287
354
|
|
|
288
|
-
|
|
355
|
+
# Build new embedding matrix
|
|
356
|
+
if current_state.embedding_matrix is not None:
|
|
357
|
+
new_matrix = None if not kept_chunks else current_state.embedding_matrix[indices_to_keep]
|
|
358
|
+
else:
|
|
359
|
+
new_matrix = None
|
|
289
360
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
self._embedding_matrix = None
|
|
293
|
-
else:
|
|
294
|
-
self._embedding_matrix = self._embedding_matrix[indices_to_keep]
|
|
361
|
+
# Atomic state swap (thread-safe under GIL)
|
|
362
|
+
self._state = IndexState(chunks=tuple(kept_chunks), embedding_matrix=new_matrix)
|
|
295
363
|
|
|
296
364
|
# Also remove from self.documents
|
|
297
365
|
self.documents = [
|
|
@@ -334,6 +402,7 @@ class RAGAssistant:
|
|
|
334
402
|
Retrieve relevant chunks for a query.
|
|
335
403
|
|
|
336
404
|
Uses vectorized cosine similarity for fast search over all chunks.
|
|
405
|
+
This method is thread-safe - it reads a consistent snapshot of the index.
|
|
337
406
|
|
|
338
407
|
Parameters
|
|
339
408
|
----------
|
|
@@ -353,7 +422,10 @@ class RAGAssistant:
|
|
|
353
422
|
>>> for chunk, score in results:
|
|
354
423
|
... print(f"{score:.2f}: {chunk.content[:100]}...")
|
|
355
424
|
"""
|
|
356
|
-
|
|
425
|
+
# Atomic state read - get consistent snapshot
|
|
426
|
+
state = self._state
|
|
427
|
+
|
|
428
|
+
if not state.chunks or state.embedding_matrix is None:
|
|
357
429
|
return []
|
|
358
430
|
|
|
359
431
|
# Get query embedding and normalize
|
|
@@ -365,7 +437,7 @@ class RAGAssistant:
|
|
|
365
437
|
query_normalized = query_vec / query_norm
|
|
366
438
|
|
|
367
439
|
# Fast cosine similarity: matrix is pre-normalized, just dot product
|
|
368
|
-
similarities =
|
|
440
|
+
similarities = state.embedding_matrix @ query_normalized
|
|
369
441
|
|
|
370
442
|
# Get top_k indices using argpartition (faster than full sort for large arrays)
|
|
371
443
|
if len(similarities) <= top_k:
|
|
@@ -376,7 +448,197 @@ class RAGAssistant:
|
|
|
376
448
|
# Sort the top_k by score
|
|
377
449
|
top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
|
|
378
450
|
|
|
379
|
-
return [(
|
|
451
|
+
return [(state.chunks[i], float(similarities[i])) for i in top_indices]
|
|
452
|
+
|
|
453
|
+
def retrieve_with_context(
|
|
454
|
+
self,
|
|
455
|
+
query: str,
|
|
456
|
+
top_k: int = 3,
|
|
457
|
+
window_size: int = 1,
|
|
458
|
+
min_score: float = 0.0,
|
|
459
|
+
) -> list[tuple[Chunk, float]]:
|
|
460
|
+
"""
|
|
461
|
+
Retrieve chunks with adjacent context expansion (window search).
|
|
462
|
+
|
|
463
|
+
For each retrieved chunk, also includes adjacent chunks from the
|
|
464
|
+
same document to provide more context. This is useful when relevant
|
|
465
|
+
information spans multiple chunks.
|
|
466
|
+
|
|
467
|
+
Pattern inspired by ai4rag window_search.
|
|
468
|
+
|
|
469
|
+
Parameters
|
|
470
|
+
----------
|
|
471
|
+
query : str
|
|
472
|
+
Search query.
|
|
473
|
+
top_k : int
|
|
474
|
+
Number of initial chunks to retrieve (default: 3).
|
|
475
|
+
window_size : int
|
|
476
|
+
Number of adjacent chunks to include on each side (default: 1).
|
|
477
|
+
Set to 0 to disable window expansion.
|
|
478
|
+
min_score : float
|
|
479
|
+
Minimum similarity score threshold (default: 0.0).
|
|
480
|
+
|
|
481
|
+
Returns
|
|
482
|
+
-------
|
|
483
|
+
list[tuple[Chunk, float]]
|
|
484
|
+
List of (chunk, similarity_score) tuples, sorted by relevance.
|
|
485
|
+
Adjacent chunks have slightly lower scores.
|
|
486
|
+
|
|
487
|
+
Examples
|
|
488
|
+
--------
|
|
489
|
+
>>> # Get chunks with 1 adjacent chunk on each side
|
|
490
|
+
>>> results = assistant.retrieve_with_context("query", window_size=1)
|
|
491
|
+
>>> for chunk, score in results:
|
|
492
|
+
... print(f"{score:.2f}: {chunk.content[:50]}...")
|
|
493
|
+
"""
|
|
494
|
+
# Get consistent state snapshot
|
|
495
|
+
state = self._state
|
|
496
|
+
|
|
497
|
+
with log_operation("retrieve_with_context", query_len=len(query), top_k=top_k, window_size=window_size) as ctx:
|
|
498
|
+
# Get initial results (more than top_k to account for filtering)
|
|
499
|
+
results = self.retrieve(query, top_k * 2)
|
|
500
|
+
|
|
501
|
+
# Apply minimum score threshold
|
|
502
|
+
if min_score > 0:
|
|
503
|
+
results = [(chunk, score) for chunk, score in results if score >= min_score]
|
|
504
|
+
|
|
505
|
+
if window_size == 0 or not results:
|
|
506
|
+
ctx["expanded_chunks"] = len(results)
|
|
507
|
+
return results[:top_k]
|
|
508
|
+
|
|
509
|
+
# Build chunk index for fast lookup
|
|
510
|
+
chunk_to_idx = {id(chunk): i for i, chunk in enumerate(state.chunks)}
|
|
511
|
+
|
|
512
|
+
expanded_results: list[tuple[Chunk, float]] = []
|
|
513
|
+
seen_indices: set[int] = set()
|
|
514
|
+
|
|
515
|
+
for chunk, score in results[:top_k]:
|
|
516
|
+
chunk_idx = chunk_to_idx.get(id(chunk))
|
|
517
|
+
if chunk_idx is None:
|
|
518
|
+
expanded_results.append((chunk, score))
|
|
519
|
+
continue
|
|
520
|
+
|
|
521
|
+
# Get window of adjacent chunks from same document
|
|
522
|
+
start_idx = max(0, chunk_idx - window_size)
|
|
523
|
+
end_idx = min(len(state.chunks), chunk_idx + window_size + 1)
|
|
524
|
+
|
|
525
|
+
for idx in range(start_idx, end_idx):
|
|
526
|
+
if idx in seen_indices:
|
|
527
|
+
continue
|
|
528
|
+
|
|
529
|
+
adjacent_chunk = state.chunks[idx]
|
|
530
|
+
# Only include adjacent chunks from same document
|
|
531
|
+
if adjacent_chunk.doc_id == chunk.doc_id:
|
|
532
|
+
seen_indices.add(idx)
|
|
533
|
+
# Original chunk keeps full score, adjacent get 80%
|
|
534
|
+
adj_score = score if idx == chunk_idx else score * 0.8
|
|
535
|
+
expanded_results.append((adjacent_chunk, adj_score))
|
|
536
|
+
|
|
537
|
+
# Sort by score (highest first)
|
|
538
|
+
expanded_results.sort(key=lambda x: (-x[1], state.chunks.index(x[0]) if x[0] in state.chunks else 0))
|
|
539
|
+
ctx["expanded_chunks"] = len(expanded_results)
|
|
540
|
+
|
|
541
|
+
return expanded_results
|
|
542
|
+
|
|
543
|
+
def get_context_with_window(
|
|
544
|
+
self,
|
|
545
|
+
query: str,
|
|
546
|
+
top_k: int = 3,
|
|
547
|
+
window_size: int = 1,
|
|
548
|
+
min_score: float = 0.0,
|
|
549
|
+
) -> str:
|
|
550
|
+
"""
|
|
551
|
+
Get formatted context with adjacent chunk expansion.
|
|
552
|
+
|
|
553
|
+
Merges overlapping text from adjacent chunks intelligently.
|
|
554
|
+
|
|
555
|
+
Parameters
|
|
556
|
+
----------
|
|
557
|
+
query : str
|
|
558
|
+
Search query.
|
|
559
|
+
top_k : int
|
|
560
|
+
Number of initial chunks to retrieve.
|
|
561
|
+
window_size : int
|
|
562
|
+
Number of adjacent chunks on each side.
|
|
563
|
+
min_score : float
|
|
564
|
+
Minimum similarity score threshold.
|
|
565
|
+
|
|
566
|
+
Returns
|
|
567
|
+
-------
|
|
568
|
+
str
|
|
569
|
+
Formatted context string with merged chunks.
|
|
570
|
+
"""
|
|
571
|
+
# Get consistent state snapshot
|
|
572
|
+
state = self._state
|
|
573
|
+
|
|
574
|
+
results = self.retrieve_with_context(query, top_k, window_size, min_score)
|
|
575
|
+
|
|
576
|
+
if not results:
|
|
577
|
+
return ""
|
|
578
|
+
|
|
579
|
+
# Group chunks by document to merge properly
|
|
580
|
+
doc_chunks: dict[str, list[tuple[Chunk, float]]] = {}
|
|
581
|
+
for chunk, score in results:
|
|
582
|
+
doc_id = chunk.doc_id or "unknown"
|
|
583
|
+
if doc_id not in doc_chunks:
|
|
584
|
+
doc_chunks[doc_id] = []
|
|
585
|
+
doc_chunks[doc_id].append((chunk, score))
|
|
586
|
+
|
|
587
|
+
merged_sections: list[str] = []
|
|
588
|
+
|
|
589
|
+
for _doc_id, chunks in doc_chunks.items():
|
|
590
|
+
# Sort chunks by their position in the original list
|
|
591
|
+
chunks.sort(key=lambda x: state.chunks.index(x[0]) if x[0] in state.chunks else 0)
|
|
592
|
+
|
|
593
|
+
# Merge overlapping text
|
|
594
|
+
merged_content: list[str] = []
|
|
595
|
+
for chunk, _ in chunks:
|
|
596
|
+
if merged_content:
|
|
597
|
+
# Check for overlap with previous chunk
|
|
598
|
+
prev_content = merged_content[-1]
|
|
599
|
+
non_overlapping = self._get_non_overlapping_text(prev_content, chunk.content)
|
|
600
|
+
if non_overlapping != chunk.content:
|
|
601
|
+
# Found overlap, extend previous chunk
|
|
602
|
+
merged_content[-1] = prev_content + non_overlapping
|
|
603
|
+
else:
|
|
604
|
+
# No overlap, add as new section
|
|
605
|
+
merged_content.append(chunk.content)
|
|
606
|
+
else:
|
|
607
|
+
merged_content.append(chunk.content)
|
|
608
|
+
|
|
609
|
+
merged_sections.append("\n".join(merged_content))
|
|
610
|
+
|
|
611
|
+
return "\n\n---\n\n".join(merged_sections)
|
|
612
|
+
|
|
613
|
+
def _get_non_overlapping_text(self, str1: str, str2: str) -> str:
|
|
614
|
+
"""
|
|
615
|
+
Find non-overlapping portion of str2 when appending after str1.
|
|
616
|
+
|
|
617
|
+
Detects overlap where the end of str1 matches the beginning of str2,
|
|
618
|
+
and returns only the non-overlapping portion of str2.
|
|
619
|
+
|
|
620
|
+
Pattern from ai4rag vector_store/utils.py.
|
|
621
|
+
|
|
622
|
+
Parameters
|
|
623
|
+
----------
|
|
624
|
+
str1 : str
|
|
625
|
+
First string (previous content).
|
|
626
|
+
str2 : str
|
|
627
|
+
Second string (content to potentially append).
|
|
628
|
+
|
|
629
|
+
Returns
|
|
630
|
+
-------
|
|
631
|
+
str
|
|
632
|
+
Non-overlapping portion of str2, or full str2 if no overlap.
|
|
633
|
+
"""
|
|
634
|
+
# Limit overlap search to avoid O(n^2) for large strings
|
|
635
|
+
max_overlap = min(len(str1), len(str2), 200)
|
|
636
|
+
|
|
637
|
+
for i in range(max_overlap, 0, -1):
|
|
638
|
+
if str1[-i:] == str2[:i]:
|
|
639
|
+
return str2[i:]
|
|
640
|
+
|
|
641
|
+
return str2
|
|
380
642
|
|
|
381
643
|
def get_context(self, query: str, top_k: int = 3) -> str:
|
|
382
644
|
"""
|
|
@@ -564,7 +826,17 @@ Generate the {language} code:"""
|
|
|
564
826
|
@property
|
|
565
827
|
def num_chunks(self) -> int:
|
|
566
828
|
"""Return number of indexed chunks."""
|
|
567
|
-
return len(self.
|
|
829
|
+
return len(self._state.chunks)
|
|
830
|
+
|
|
831
|
+
@property
|
|
832
|
+
def chunk_count(self) -> int:
|
|
833
|
+
"""Number of chunks in index (alias for num_chunks)."""
|
|
834
|
+
return len(self._state.chunks)
|
|
835
|
+
|
|
836
|
+
@property
|
|
837
|
+
def is_indexed(self) -> bool:
|
|
838
|
+
"""Check if index has any documents."""
|
|
839
|
+
return len(self._state.chunks) > 0
|
|
568
840
|
|
|
569
841
|
@property
|
|
570
842
|
def num_documents(self) -> int:
|
|
@@ -575,3 +847,122 @@ Generate the {language} code:"""
|
|
|
575
847
|
def has_llm(self) -> bool:
|
|
576
848
|
"""Check if LLM is configured."""
|
|
577
849
|
return self._llm_provider is not None
|
|
850
|
+
|
|
851
|
+
def save_index(self, path: str | Path) -> None:
|
|
852
|
+
"""Save index to disk for later restoration.
|
|
853
|
+
|
|
854
|
+
Saves the index in an efficient format:
|
|
855
|
+
- chunks.json: Chunk metadata and content
|
|
856
|
+
- embeddings.npy: Numpy array of embeddings (binary format)
|
|
857
|
+
- metadata.json: Index configuration
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
path: Directory path to save index files.
|
|
861
|
+
|
|
862
|
+
Example:
|
|
863
|
+
>>> assistant.save_index("/path/to/index")
|
|
864
|
+
>>> # Later...
|
|
865
|
+
>>> loaded = RAGAssistant.load_index("/path/to/index", provider=provider)
|
|
866
|
+
"""
|
|
867
|
+
path = Path(path)
|
|
868
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
869
|
+
|
|
870
|
+
state = self._state
|
|
871
|
+
|
|
872
|
+
# Save chunks as JSON
|
|
873
|
+
chunks_data = [
|
|
874
|
+
{
|
|
875
|
+
"content": chunk.content,
|
|
876
|
+
"doc_id": chunk.doc_id,
|
|
877
|
+
"chunk_index": chunk.chunk_index,
|
|
878
|
+
"metadata": chunk.metadata,
|
|
879
|
+
}
|
|
880
|
+
for chunk in state.chunks
|
|
881
|
+
]
|
|
882
|
+
(path / "chunks.json").write_text(json.dumps(chunks_data, indent=2))
|
|
883
|
+
|
|
884
|
+
# Save embeddings as numpy binary (efficient for large arrays)
|
|
885
|
+
if state.embedding_matrix is not None:
|
|
886
|
+
np.save(path / "embeddings.npy", state.embedding_matrix)
|
|
887
|
+
|
|
888
|
+
# Save metadata for validation and configuration restoration
|
|
889
|
+
metadata = {
|
|
890
|
+
"chunk_count": len(state.chunks),
|
|
891
|
+
"embedding_model": self.embedding_model,
|
|
892
|
+
"chunk_size": self.chunk_size,
|
|
893
|
+
"chunk_overlap": self.chunk_overlap,
|
|
894
|
+
"version": "1.0",
|
|
895
|
+
}
|
|
896
|
+
(path / "metadata.json").write_text(json.dumps(metadata, indent=2))
|
|
897
|
+
|
|
898
|
+
logger.info(f"Index saved to {path} ({len(state.chunks)} chunks)")
|
|
899
|
+
|
|
900
|
+
@classmethod
|
|
901
|
+
def load_index(
|
|
902
|
+
cls,
|
|
903
|
+
path: str | Path,
|
|
904
|
+
provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
|
|
905
|
+
) -> "RAGAssistant":
|
|
906
|
+
"""Load a previously saved index.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
path: Directory path containing saved index files.
|
|
910
|
+
provider: Provider for embeddings/LLM (required for new queries).
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
RAGAssistant instance with loaded index.
|
|
914
|
+
|
|
915
|
+
Raises:
|
|
916
|
+
IndexingError: If loaded index is corrupted (count mismatch).
|
|
917
|
+
FileNotFoundError: If index files don't exist.
|
|
918
|
+
|
|
919
|
+
Example:
|
|
920
|
+
>>> loaded = RAGAssistant.load_index("/path/to/index", provider=OllamaProvider())
|
|
921
|
+
>>> results = loaded.retrieve("query")
|
|
922
|
+
"""
|
|
923
|
+
path = Path(path)
|
|
924
|
+
|
|
925
|
+
# Load metadata
|
|
926
|
+
metadata = json.loads((path / "metadata.json").read_text())
|
|
927
|
+
|
|
928
|
+
# Load chunks
|
|
929
|
+
chunks_data = json.loads((path / "chunks.json").read_text())
|
|
930
|
+
chunks = tuple(
|
|
931
|
+
Chunk(
|
|
932
|
+
content=c["content"],
|
|
933
|
+
doc_id=c.get("doc_id", ""),
|
|
934
|
+
chunk_index=c.get("chunk_index", 0),
|
|
935
|
+
metadata=c.get("metadata", {}),
|
|
936
|
+
)
|
|
937
|
+
for c in chunks_data
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Load embeddings
|
|
941
|
+
embeddings_path = path / "embeddings.npy"
|
|
942
|
+
embedding_matrix: NDArray[np.float64] | None = None
|
|
943
|
+
if embeddings_path.exists():
|
|
944
|
+
embedding_matrix = np.load(embeddings_path)
|
|
945
|
+
|
|
946
|
+
# Validate consistency
|
|
947
|
+
if embedding_matrix is not None and embedding_matrix.shape[0] != len(chunks):
|
|
948
|
+
raise IndexingError(
|
|
949
|
+
f"Loaded index corrupted: {embedding_matrix.shape[0]} embeddings but {len(chunks)} chunks"
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
# Create instance without calling __init__ (skip indexing)
|
|
953
|
+
instance = object.__new__(cls)
|
|
954
|
+
|
|
955
|
+
# Initialize required attributes
|
|
956
|
+
instance._state = IndexState(chunks=chunks, embedding_matrix=embedding_matrix)
|
|
957
|
+
instance.embedding_model = metadata.get("embedding_model", "default")
|
|
958
|
+
instance.llm_model = metadata.get("llm_model", "default")
|
|
959
|
+
instance.chunk_size = metadata.get("chunk_size", 512)
|
|
960
|
+
instance.chunk_overlap = metadata.get("chunk_overlap", 50)
|
|
961
|
+
instance.documents = [] # Original docs not saved
|
|
962
|
+
|
|
963
|
+
# Set up providers
|
|
964
|
+
instance._embedding_provider = provider if isinstance(provider, BaseEmbeddingProvider) else None # type: ignore
|
|
965
|
+
instance._llm_provider = provider if isinstance(provider, BaseLLMProvider) else None
|
|
966
|
+
|
|
967
|
+
logger.info(f"Index loaded from {path} ({len(chunks)} chunks)")
|
|
968
|
+
return instance
|