ragit 0.10.1__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/assistant.py +250 -39
- ragit/config.py +3 -4
- ragit/exceptions.py +2 -2
- ragit/loaders.py +1 -1
- ragit/providers/ollama.py +47 -13
- ragit/utils/__init__.py +0 -22
- ragit/version.py +1 -1
- {ragit-0.10.1.dist-info → ragit-0.11.1.dist-info}/METADATA +41 -5
- {ragit-0.10.1.dist-info → ragit-0.11.1.dist-info}/RECORD +12 -12
- {ragit-0.10.1.dist-info → ragit-0.11.1.dist-info}/WHEEL +0 -0
- {ragit-0.10.1.dist-info → ragit-0.11.1.dist-info}/licenses/LICENSE +0 -0
- {ragit-0.10.1.dist-info → ragit-0.11.1.dist-info}/top_level.txt +0 -0
ragit/assistant.py
CHANGED
|
@@ -7,10 +7,16 @@ High-level RAG Assistant for document Q&A and code generation.
|
|
|
7
7
|
|
|
8
8
|
Provides a simple interface for RAG-based tasks.
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
Thread Safety:
|
|
11
|
+
This class uses lock-free atomic operations for thread safety.
|
|
12
|
+
The IndexState is immutable, and all mutations create a new state
|
|
13
|
+
that is atomically swapped. Python's GIL ensures reference assignment
|
|
14
|
+
is atomic, making concurrent reads and writes safe.
|
|
11
15
|
"""
|
|
12
16
|
|
|
17
|
+
import json
|
|
13
18
|
from collections.abc import Callable
|
|
19
|
+
from dataclasses import dataclass
|
|
14
20
|
from pathlib import Path
|
|
15
21
|
from typing import TYPE_CHECKING
|
|
16
22
|
|
|
@@ -18,8 +24,9 @@ import numpy as np
|
|
|
18
24
|
from numpy.typing import NDArray
|
|
19
25
|
|
|
20
26
|
from ragit.core.experiment.experiment import Chunk, Document
|
|
27
|
+
from ragit.exceptions import IndexingError
|
|
21
28
|
from ragit.loaders import chunk_document, chunk_rst_sections, load_directory, load_text
|
|
22
|
-
from ragit.logging import log_operation
|
|
29
|
+
from ragit.logging import log_operation, logger
|
|
23
30
|
from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
|
|
24
31
|
from ragit.providers.function_adapter import FunctionProvider
|
|
25
32
|
|
|
@@ -27,6 +34,23 @@ if TYPE_CHECKING:
|
|
|
27
34
|
from numpy.typing import NDArray
|
|
28
35
|
|
|
29
36
|
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class IndexState:
|
|
39
|
+
"""Immutable snapshot of index state for lock-free thread safety.
|
|
40
|
+
|
|
41
|
+
This class holds all mutable index data in a single immutable structure.
|
|
42
|
+
Updates create a new IndexState instance, and the reference swap is
|
|
43
|
+
atomic under Python's GIL, ensuring thread-safe reads and writes.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
chunks: Tuple of indexed chunks (immutable).
|
|
47
|
+
embedding_matrix: Pre-normalized numpy array of embeddings, or None if empty.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
chunks: tuple[Chunk, ...]
|
|
51
|
+
embedding_matrix: NDArray[np.float64] | None
|
|
52
|
+
|
|
53
|
+
|
|
30
54
|
class RAGAssistant:
|
|
31
55
|
"""
|
|
32
56
|
High-level RAG assistant for document Q&A and generation.
|
|
@@ -63,9 +87,12 @@ class RAGAssistant:
|
|
|
63
87
|
ValueError
|
|
64
88
|
If neither embed_fn nor provider is provided.
|
|
65
89
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
This class
|
|
90
|
+
Thread Safety
|
|
91
|
+
-------------
|
|
92
|
+
This class uses lock-free atomic operations for thread safety.
|
|
93
|
+
Multiple threads can safely call retrieve() while another thread
|
|
94
|
+
calls add_documents(). The IndexState is immutable, and reference
|
|
95
|
+
swaps are atomic under Python's GIL.
|
|
69
96
|
|
|
70
97
|
Examples
|
|
71
98
|
--------
|
|
@@ -80,6 +107,10 @@ class RAGAssistant:
|
|
|
80
107
|
>>> # With Ollama provider (supports nomic-embed-text)
|
|
81
108
|
>>> from ragit.providers import OllamaProvider
|
|
82
109
|
>>> assistant = RAGAssistant(docs, provider=OllamaProvider())
|
|
110
|
+
>>>
|
|
111
|
+
>>> # Save and load index for persistence
|
|
112
|
+
>>> assistant.save_index("/path/to/index")
|
|
113
|
+
>>> loaded = RAGAssistant.load_index("/path/to/index", provider=OllamaProvider())
|
|
83
114
|
"""
|
|
84
115
|
|
|
85
116
|
def __init__(
|
|
@@ -134,9 +165,8 @@ class RAGAssistant:
|
|
|
134
165
|
# Load documents if path provided
|
|
135
166
|
self.documents = self._load_documents(documents)
|
|
136
167
|
|
|
137
|
-
#
|
|
138
|
-
self.
|
|
139
|
-
self._embedding_matrix: NDArray[np.float64] | None = None # Pre-normalized
|
|
168
|
+
# Thread-safe index state (immutable, atomic reference swap)
|
|
169
|
+
self._state: IndexState = IndexState(chunks=(), embedding_matrix=None)
|
|
140
170
|
self._build_index()
|
|
141
171
|
|
|
142
172
|
def _load_documents(self, documents: list[Document] | str | Path) -> list[Document]:
|
|
@@ -171,7 +201,11 @@ class RAGAssistant:
|
|
|
171
201
|
raise ValueError(f"Invalid documents source: {documents}")
|
|
172
202
|
|
|
173
203
|
def _build_index(self) -> None:
|
|
174
|
-
"""Build vector index from documents using batch embedding.
|
|
204
|
+
"""Build vector index from documents using batch embedding.
|
|
205
|
+
|
|
206
|
+
Raises:
|
|
207
|
+
IndexingError: If embedding count doesn't match chunk count.
|
|
208
|
+
"""
|
|
175
209
|
all_chunks: list[Chunk] = []
|
|
176
210
|
|
|
177
211
|
for doc in self.documents:
|
|
@@ -183,33 +217,54 @@ class RAGAssistant:
|
|
|
183
217
|
all_chunks.extend(chunks)
|
|
184
218
|
|
|
185
219
|
if not all_chunks:
|
|
186
|
-
|
|
187
|
-
self.
|
|
220
|
+
logger.warning("No chunks produced from documents - index will be empty")
|
|
221
|
+
self._state = IndexState(chunks=(), embedding_matrix=None)
|
|
188
222
|
return
|
|
189
223
|
|
|
190
224
|
# Batch embed all chunks at once (single API call)
|
|
191
225
|
texts = [chunk.content for chunk in all_chunks]
|
|
192
226
|
responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
|
|
193
227
|
|
|
228
|
+
# CRITICAL: Validate embedding count matches chunk count
|
|
229
|
+
if len(responses) != len(all_chunks):
|
|
230
|
+
raise IndexingError(
|
|
231
|
+
f"Embedding count mismatch: expected {len(all_chunks)} embeddings, "
|
|
232
|
+
f"got {len(responses)}. Index may be corrupted."
|
|
233
|
+
)
|
|
234
|
+
|
|
194
235
|
# Build embedding matrix directly (skip storing in chunks to avoid duplication)
|
|
195
236
|
embedding_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
|
|
196
237
|
|
|
238
|
+
# Additional validation: matrix shape
|
|
239
|
+
if embedding_matrix.shape[0] != len(all_chunks):
|
|
240
|
+
raise IndexingError(
|
|
241
|
+
f"Matrix row count {embedding_matrix.shape[0]} doesn't match chunk count {len(all_chunks)}"
|
|
242
|
+
)
|
|
243
|
+
|
|
197
244
|
# Pre-normalize for fast cosine similarity (normalize once, use many times)
|
|
198
245
|
norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
|
|
199
246
|
norms[norms == 0] = 1 # Avoid division by zero
|
|
200
247
|
|
|
201
|
-
#
|
|
202
|
-
self.
|
|
203
|
-
|
|
248
|
+
# Atomic state update (thread-safe under GIL)
|
|
249
|
+
self._state = IndexState(
|
|
250
|
+
chunks=tuple(all_chunks),
|
|
251
|
+
embedding_matrix=embedding_matrix / norms,
|
|
252
|
+
)
|
|
204
253
|
|
|
205
254
|
def add_documents(self, documents: list[Document] | str | Path) -> int:
|
|
206
255
|
"""Add documents to the existing index incrementally.
|
|
207
256
|
|
|
257
|
+
This method is thread-safe. It creates a new IndexState and atomically
|
|
258
|
+
swaps the reference, ensuring readers always see a consistent state.
|
|
259
|
+
|
|
208
260
|
Args:
|
|
209
261
|
documents: Documents to add.
|
|
210
262
|
|
|
211
263
|
Returns:
|
|
212
264
|
Number of chunks added.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
IndexingError: If embedding count doesn't match chunk count.
|
|
213
268
|
"""
|
|
214
269
|
new_docs = self._load_documents(documents)
|
|
215
270
|
if not new_docs:
|
|
@@ -233,6 +288,13 @@ class RAGAssistant:
|
|
|
233
288
|
texts = [chunk.content for chunk in new_chunks]
|
|
234
289
|
responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
|
|
235
290
|
|
|
291
|
+
# Validate embedding count
|
|
292
|
+
if len(responses) != len(new_chunks):
|
|
293
|
+
raise IndexingError(
|
|
294
|
+
f"Embedding count mismatch: expected {len(new_chunks)} embeddings, "
|
|
295
|
+
f"got {len(responses)}. Index update aborted."
|
|
296
|
+
)
|
|
297
|
+
|
|
236
298
|
new_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
|
|
237
299
|
|
|
238
300
|
# Normalize
|
|
@@ -240,21 +302,27 @@ class RAGAssistant:
|
|
|
240
302
|
norms[norms == 0] = 1
|
|
241
303
|
new_matrix_norm = new_matrix / norms
|
|
242
304
|
|
|
243
|
-
#
|
|
244
|
-
|
|
245
|
-
current_chunks.extend(new_chunks)
|
|
246
|
-
self._chunks = tuple(current_chunks)
|
|
305
|
+
# Read current state (atomic read)
|
|
306
|
+
current_state = self._state
|
|
247
307
|
|
|
248
|
-
|
|
249
|
-
|
|
308
|
+
# Build new state
|
|
309
|
+
combined_chunks = current_state.chunks + tuple(new_chunks)
|
|
310
|
+
if current_state.embedding_matrix is None:
|
|
311
|
+
combined_matrix = new_matrix_norm
|
|
250
312
|
else:
|
|
251
|
-
|
|
313
|
+
combined_matrix = np.vstack((current_state.embedding_matrix, new_matrix_norm))
|
|
314
|
+
|
|
315
|
+
# Atomic state swap (thread-safe under GIL)
|
|
316
|
+
self._state = IndexState(chunks=combined_chunks, embedding_matrix=combined_matrix)
|
|
252
317
|
|
|
253
318
|
return len(new_chunks)
|
|
254
319
|
|
|
255
320
|
def remove_documents(self, source_path_pattern: str) -> int:
|
|
256
321
|
"""Remove documents matching a source path pattern.
|
|
257
322
|
|
|
323
|
+
This method is thread-safe. It creates a new IndexState and atomically
|
|
324
|
+
swaps the reference.
|
|
325
|
+
|
|
258
326
|
Args:
|
|
259
327
|
source_path_pattern: Glob pattern to match 'source' metadata.
|
|
260
328
|
|
|
@@ -263,14 +331,17 @@ class RAGAssistant:
|
|
|
263
331
|
"""
|
|
264
332
|
import fnmatch
|
|
265
333
|
|
|
266
|
-
|
|
334
|
+
# Read current state (atomic read)
|
|
335
|
+
current_state = self._state
|
|
336
|
+
|
|
337
|
+
if not current_state.chunks:
|
|
267
338
|
return 0
|
|
268
339
|
|
|
269
340
|
indices_to_keep = []
|
|
270
341
|
kept_chunks = []
|
|
271
342
|
removed_count = 0
|
|
272
343
|
|
|
273
|
-
for i, chunk in enumerate(
|
|
344
|
+
for i, chunk in enumerate(current_state.chunks):
|
|
274
345
|
source = chunk.metadata.get("source", "")
|
|
275
346
|
if not source or not fnmatch.fnmatch(source, source_path_pattern):
|
|
276
347
|
indices_to_keep.append(i)
|
|
@@ -281,13 +352,14 @@ class RAGAssistant:
|
|
|
281
352
|
if removed_count == 0:
|
|
282
353
|
return 0
|
|
283
354
|
|
|
284
|
-
|
|
355
|
+
# Build new embedding matrix
|
|
356
|
+
if current_state.embedding_matrix is not None:
|
|
357
|
+
new_matrix = None if not kept_chunks else current_state.embedding_matrix[indices_to_keep]
|
|
358
|
+
else:
|
|
359
|
+
new_matrix = None
|
|
285
360
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
self._embedding_matrix = None
|
|
289
|
-
else:
|
|
290
|
-
self._embedding_matrix = self._embedding_matrix[indices_to_keep]
|
|
361
|
+
# Atomic state swap (thread-safe under GIL)
|
|
362
|
+
self._state = IndexState(chunks=tuple(kept_chunks), embedding_matrix=new_matrix)
|
|
291
363
|
|
|
292
364
|
# Also remove from self.documents
|
|
293
365
|
self.documents = [
|
|
@@ -330,6 +402,7 @@ class RAGAssistant:
|
|
|
330
402
|
Retrieve relevant chunks for a query.
|
|
331
403
|
|
|
332
404
|
Uses vectorized cosine similarity for fast search over all chunks.
|
|
405
|
+
This method is thread-safe - it reads a consistent snapshot of the index.
|
|
333
406
|
|
|
334
407
|
Parameters
|
|
335
408
|
----------
|
|
@@ -349,7 +422,10 @@ class RAGAssistant:
|
|
|
349
422
|
>>> for chunk, score in results:
|
|
350
423
|
... print(f"{score:.2f}: {chunk.content[:100]}...")
|
|
351
424
|
"""
|
|
352
|
-
|
|
425
|
+
# Atomic state read - get consistent snapshot
|
|
426
|
+
state = self._state
|
|
427
|
+
|
|
428
|
+
if not state.chunks or state.embedding_matrix is None:
|
|
353
429
|
return []
|
|
354
430
|
|
|
355
431
|
# Get query embedding and normalize
|
|
@@ -361,7 +437,7 @@ class RAGAssistant:
|
|
|
361
437
|
query_normalized = query_vec / query_norm
|
|
362
438
|
|
|
363
439
|
# Fast cosine similarity: matrix is pre-normalized, just dot product
|
|
364
|
-
similarities =
|
|
440
|
+
similarities = state.embedding_matrix @ query_normalized
|
|
365
441
|
|
|
366
442
|
# Get top_k indices using argpartition (faster than full sort for large arrays)
|
|
367
443
|
if len(similarities) <= top_k:
|
|
@@ -372,7 +448,7 @@ class RAGAssistant:
|
|
|
372
448
|
# Sort the top_k by score
|
|
373
449
|
top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
|
|
374
450
|
|
|
375
|
-
return [(
|
|
451
|
+
return [(state.chunks[i], float(similarities[i])) for i in top_indices]
|
|
376
452
|
|
|
377
453
|
def retrieve_with_context(
|
|
378
454
|
self,
|
|
@@ -415,6 +491,9 @@ class RAGAssistant:
|
|
|
415
491
|
>>> for chunk, score in results:
|
|
416
492
|
... print(f"{score:.2f}: {chunk.content[:50]}...")
|
|
417
493
|
"""
|
|
494
|
+
# Get consistent state snapshot
|
|
495
|
+
state = self._state
|
|
496
|
+
|
|
418
497
|
with log_operation("retrieve_with_context", query_len=len(query), top_k=top_k, window_size=window_size) as ctx:
|
|
419
498
|
# Get initial results (more than top_k to account for filtering)
|
|
420
499
|
results = self.retrieve(query, top_k * 2)
|
|
@@ -428,7 +507,7 @@ class RAGAssistant:
|
|
|
428
507
|
return results[:top_k]
|
|
429
508
|
|
|
430
509
|
# Build chunk index for fast lookup
|
|
431
|
-
chunk_to_idx = {id(chunk): i for i, chunk in enumerate(
|
|
510
|
+
chunk_to_idx = {id(chunk): i for i, chunk in enumerate(state.chunks)}
|
|
432
511
|
|
|
433
512
|
expanded_results: list[tuple[Chunk, float]] = []
|
|
434
513
|
seen_indices: set[int] = set()
|
|
@@ -441,13 +520,13 @@ class RAGAssistant:
|
|
|
441
520
|
|
|
442
521
|
# Get window of adjacent chunks from same document
|
|
443
522
|
start_idx = max(0, chunk_idx - window_size)
|
|
444
|
-
end_idx = min(len(
|
|
523
|
+
end_idx = min(len(state.chunks), chunk_idx + window_size + 1)
|
|
445
524
|
|
|
446
525
|
for idx in range(start_idx, end_idx):
|
|
447
526
|
if idx in seen_indices:
|
|
448
527
|
continue
|
|
449
528
|
|
|
450
|
-
adjacent_chunk =
|
|
529
|
+
adjacent_chunk = state.chunks[idx]
|
|
451
530
|
# Only include adjacent chunks from same document
|
|
452
531
|
if adjacent_chunk.doc_id == chunk.doc_id:
|
|
453
532
|
seen_indices.add(idx)
|
|
@@ -456,7 +535,7 @@ class RAGAssistant:
|
|
|
456
535
|
expanded_results.append((adjacent_chunk, adj_score))
|
|
457
536
|
|
|
458
537
|
# Sort by score (highest first)
|
|
459
|
-
expanded_results.sort(key=lambda x: (-x[1],
|
|
538
|
+
expanded_results.sort(key=lambda x: (-x[1], state.chunks.index(x[0]) if x[0] in state.chunks else 0))
|
|
460
539
|
ctx["expanded_chunks"] = len(expanded_results)
|
|
461
540
|
|
|
462
541
|
return expanded_results
|
|
@@ -489,6 +568,9 @@ class RAGAssistant:
|
|
|
489
568
|
str
|
|
490
569
|
Formatted context string with merged chunks.
|
|
491
570
|
"""
|
|
571
|
+
# Get consistent state snapshot
|
|
572
|
+
state = self._state
|
|
573
|
+
|
|
492
574
|
results = self.retrieve_with_context(query, top_k, window_size, min_score)
|
|
493
575
|
|
|
494
576
|
if not results:
|
|
@@ -506,10 +588,10 @@ class RAGAssistant:
|
|
|
506
588
|
|
|
507
589
|
for _doc_id, chunks in doc_chunks.items():
|
|
508
590
|
# Sort chunks by their position in the original list
|
|
509
|
-
chunks.sort(key=lambda x:
|
|
591
|
+
chunks.sort(key=lambda x: state.chunks.index(x[0]) if x[0] in state.chunks else 0)
|
|
510
592
|
|
|
511
593
|
# Merge overlapping text
|
|
512
|
-
merged_content = []
|
|
594
|
+
merged_content: list[str] = []
|
|
513
595
|
for chunk, _ in chunks:
|
|
514
596
|
if merged_content:
|
|
515
597
|
# Check for overlap with previous chunk
|
|
@@ -744,7 +826,17 @@ Generate the {language} code:"""
|
|
|
744
826
|
@property
|
|
745
827
|
def num_chunks(self) -> int:
|
|
746
828
|
"""Return number of indexed chunks."""
|
|
747
|
-
return len(self.
|
|
829
|
+
return len(self._state.chunks)
|
|
830
|
+
|
|
831
|
+
@property
|
|
832
|
+
def chunk_count(self) -> int:
|
|
833
|
+
"""Number of chunks in index (alias for num_chunks)."""
|
|
834
|
+
return len(self._state.chunks)
|
|
835
|
+
|
|
836
|
+
@property
|
|
837
|
+
def is_indexed(self) -> bool:
|
|
838
|
+
"""Check if index has any documents."""
|
|
839
|
+
return len(self._state.chunks) > 0
|
|
748
840
|
|
|
749
841
|
@property
|
|
750
842
|
def num_documents(self) -> int:
|
|
@@ -755,3 +847,122 @@ Generate the {language} code:"""
|
|
|
755
847
|
def has_llm(self) -> bool:
|
|
756
848
|
"""Check if LLM is configured."""
|
|
757
849
|
return self._llm_provider is not None
|
|
850
|
+
|
|
851
|
+
def save_index(self, path: str | Path) -> None:
|
|
852
|
+
"""Save index to disk for later restoration.
|
|
853
|
+
|
|
854
|
+
Saves the index in an efficient format:
|
|
855
|
+
- chunks.json: Chunk metadata and content
|
|
856
|
+
- embeddings.npy: Numpy array of embeddings (binary format)
|
|
857
|
+
- metadata.json: Index configuration
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
path: Directory path to save index files.
|
|
861
|
+
|
|
862
|
+
Example:
|
|
863
|
+
>>> assistant.save_index("/path/to/index")
|
|
864
|
+
>>> # Later...
|
|
865
|
+
>>> loaded = RAGAssistant.load_index("/path/to/index", provider=provider)
|
|
866
|
+
"""
|
|
867
|
+
path = Path(path)
|
|
868
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
869
|
+
|
|
870
|
+
state = self._state
|
|
871
|
+
|
|
872
|
+
# Save chunks as JSON
|
|
873
|
+
chunks_data = [
|
|
874
|
+
{
|
|
875
|
+
"content": chunk.content,
|
|
876
|
+
"doc_id": chunk.doc_id,
|
|
877
|
+
"chunk_index": chunk.chunk_index,
|
|
878
|
+
"metadata": chunk.metadata,
|
|
879
|
+
}
|
|
880
|
+
for chunk in state.chunks
|
|
881
|
+
]
|
|
882
|
+
(path / "chunks.json").write_text(json.dumps(chunks_data, indent=2))
|
|
883
|
+
|
|
884
|
+
# Save embeddings as numpy binary (efficient for large arrays)
|
|
885
|
+
if state.embedding_matrix is not None:
|
|
886
|
+
np.save(path / "embeddings.npy", state.embedding_matrix)
|
|
887
|
+
|
|
888
|
+
# Save metadata for validation and configuration restoration
|
|
889
|
+
metadata = {
|
|
890
|
+
"chunk_count": len(state.chunks),
|
|
891
|
+
"embedding_model": self.embedding_model,
|
|
892
|
+
"chunk_size": self.chunk_size,
|
|
893
|
+
"chunk_overlap": self.chunk_overlap,
|
|
894
|
+
"version": "1.0",
|
|
895
|
+
}
|
|
896
|
+
(path / "metadata.json").write_text(json.dumps(metadata, indent=2))
|
|
897
|
+
|
|
898
|
+
logger.info(f"Index saved to {path} ({len(state.chunks)} chunks)")
|
|
899
|
+
|
|
900
|
+
@classmethod
|
|
901
|
+
def load_index(
|
|
902
|
+
cls,
|
|
903
|
+
path: str | Path,
|
|
904
|
+
provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
|
|
905
|
+
) -> "RAGAssistant":
|
|
906
|
+
"""Load a previously saved index.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
path: Directory path containing saved index files.
|
|
910
|
+
provider: Provider for embeddings/LLM (required for new queries).
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
RAGAssistant instance with loaded index.
|
|
914
|
+
|
|
915
|
+
Raises:
|
|
916
|
+
IndexingError: If loaded index is corrupted (count mismatch).
|
|
917
|
+
FileNotFoundError: If index files don't exist.
|
|
918
|
+
|
|
919
|
+
Example:
|
|
920
|
+
>>> loaded = RAGAssistant.load_index("/path/to/index", provider=OllamaProvider())
|
|
921
|
+
>>> results = loaded.retrieve("query")
|
|
922
|
+
"""
|
|
923
|
+
path = Path(path)
|
|
924
|
+
|
|
925
|
+
# Load metadata
|
|
926
|
+
metadata = json.loads((path / "metadata.json").read_text())
|
|
927
|
+
|
|
928
|
+
# Load chunks
|
|
929
|
+
chunks_data = json.loads((path / "chunks.json").read_text())
|
|
930
|
+
chunks = tuple(
|
|
931
|
+
Chunk(
|
|
932
|
+
content=c["content"],
|
|
933
|
+
doc_id=c.get("doc_id", ""),
|
|
934
|
+
chunk_index=c.get("chunk_index", 0),
|
|
935
|
+
metadata=c.get("metadata", {}),
|
|
936
|
+
)
|
|
937
|
+
for c in chunks_data
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Load embeddings
|
|
941
|
+
embeddings_path = path / "embeddings.npy"
|
|
942
|
+
embedding_matrix: NDArray[np.float64] | None = None
|
|
943
|
+
if embeddings_path.exists():
|
|
944
|
+
embedding_matrix = np.load(embeddings_path)
|
|
945
|
+
|
|
946
|
+
# Validate consistency
|
|
947
|
+
if embedding_matrix is not None and embedding_matrix.shape[0] != len(chunks):
|
|
948
|
+
raise IndexingError(
|
|
949
|
+
f"Loaded index corrupted: {embedding_matrix.shape[0]} embeddings but {len(chunks)} chunks"
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
# Create instance without calling __init__ (skip indexing)
|
|
953
|
+
instance = object.__new__(cls)
|
|
954
|
+
|
|
955
|
+
# Initialize required attributes
|
|
956
|
+
instance._state = IndexState(chunks=chunks, embedding_matrix=embedding_matrix)
|
|
957
|
+
instance.embedding_model = metadata.get("embedding_model", "default")
|
|
958
|
+
instance.llm_model = metadata.get("llm_model", "default")
|
|
959
|
+
instance.chunk_size = metadata.get("chunk_size", 512)
|
|
960
|
+
instance.chunk_overlap = metadata.get("chunk_overlap", 50)
|
|
961
|
+
instance.documents = [] # Original docs not saved
|
|
962
|
+
|
|
963
|
+
# Set up providers
|
|
964
|
+
instance._embedding_provider = provider if isinstance(provider, BaseEmbeddingProvider) else None # type: ignore
|
|
965
|
+
instance._llm_provider = provider if isinstance(provider, BaseLLMProvider) else None
|
|
966
|
+
|
|
967
|
+
logger.info(f"Index loaded from {path} ({len(chunks)} chunks)")
|
|
968
|
+
return instance
|
ragit/config.py
CHANGED
|
@@ -153,16 +153,15 @@ def _safe_get_env(key: str, default: str | None = None) -> str | None:
|
|
|
153
153
|
return value
|
|
154
154
|
|
|
155
155
|
|
|
156
|
-
def _safe_get_int_env(key: str, default: int) -> int
|
|
157
|
-
"""Get environment variable as int,
|
|
156
|
+
def _safe_get_int_env(key: str, default: int) -> int:
|
|
157
|
+
"""Get environment variable as int, raising on invalid values."""
|
|
158
158
|
value = os.getenv(key)
|
|
159
159
|
if value is None:
|
|
160
160
|
return default
|
|
161
161
|
try:
|
|
162
162
|
return int(value)
|
|
163
163
|
except ValueError:
|
|
164
|
-
|
|
165
|
-
return value
|
|
164
|
+
raise ConfigValidationError(f"Invalid integer value for {key}: {value!r}") from None
|
|
166
165
|
|
|
167
166
|
|
|
168
167
|
def load_config() -> RagitConfig:
|
ragit/exceptions.py
CHANGED
|
@@ -24,7 +24,7 @@ class RagitError(Exception):
|
|
|
24
24
|
----------
|
|
25
25
|
message : str
|
|
26
26
|
Human-readable error message.
|
|
27
|
-
original_exception :
|
|
27
|
+
original_exception : BaseException, optional
|
|
28
28
|
The underlying exception that caused this error.
|
|
29
29
|
|
|
30
30
|
Examples
|
|
@@ -37,7 +37,7 @@ class RagitError(Exception):
|
|
|
37
37
|
... print(f"Caused by: {e.original_exception}")
|
|
38
38
|
"""
|
|
39
39
|
|
|
40
|
-
def __init__(self, message: str, original_exception:
|
|
40
|
+
def __init__(self, message: str, original_exception: BaseException | None = None):
|
|
41
41
|
self.message = message
|
|
42
42
|
self.original_exception = original_exception
|
|
43
43
|
super().__init__(self._format_message())
|
ragit/loaders.py
CHANGED
ragit/providers/ollama.py
CHANGED
|
@@ -10,7 +10,7 @@ Configuration is loaded from environment variables.
|
|
|
10
10
|
|
|
11
11
|
Performance optimizations:
|
|
12
12
|
- Connection pooling via requests.Session()
|
|
13
|
-
- Async parallel embedding via
|
|
13
|
+
- Async parallel embedding via httpx
|
|
14
14
|
- LRU cache for repeated embedding queries
|
|
15
15
|
|
|
16
16
|
Resilience features (via resilient-circuit):
|
|
@@ -216,22 +216,42 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
216
216
|
|
|
217
217
|
@property
|
|
218
218
|
def session(self) -> requests.Session:
|
|
219
|
-
"""Lazy-initialized session for connection pooling.
|
|
219
|
+
"""Lazy-initialized session for connection pooling.
|
|
220
|
+
|
|
221
|
+
Note: API key is NOT stored in session headers to prevent
|
|
222
|
+
potential exposure in logs or error messages. Authentication
|
|
223
|
+
is handled per-request via _get_headers().
|
|
224
|
+
"""
|
|
220
225
|
if self._session is None:
|
|
221
226
|
self._session = requests.Session()
|
|
222
227
|
self._session.headers.update({"Content-Type": "application/json"})
|
|
223
|
-
|
|
224
|
-
|
|
228
|
+
# Security: API key is injected per-request via _get_headers()
|
|
229
|
+
# rather than stored in session headers to prevent log exposure
|
|
225
230
|
return self._session
|
|
226
231
|
|
|
227
232
|
def close(self) -> None:
|
|
228
233
|
"""Close the session and release resources."""
|
|
229
|
-
|
|
230
|
-
|
|
234
|
+
session = getattr(self, "_session", None)
|
|
235
|
+
if session is not None:
|
|
236
|
+
session.close()
|
|
231
237
|
self._session = None
|
|
232
238
|
|
|
239
|
+
def __enter__(self) -> "OllamaProvider":
|
|
240
|
+
"""Context manager entry - returns self for use in 'with' statements.
|
|
241
|
+
|
|
242
|
+
Example:
|
|
243
|
+
with OllamaProvider() as provider:
|
|
244
|
+
result = provider.generate("Hello", model="llama3")
|
|
245
|
+
# Session automatically closed here
|
|
246
|
+
"""
|
|
247
|
+
return self
|
|
248
|
+
|
|
249
|
+
def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: object) -> None:
|
|
250
|
+
"""Context manager exit - ensures cleanup regardless of exceptions."""
|
|
251
|
+
self.close()
|
|
252
|
+
|
|
233
253
|
def __del__(self) -> None:
|
|
234
|
-
"""Cleanup on garbage collection."""
|
|
254
|
+
"""Cleanup on garbage collection (fallback, prefer context manager)."""
|
|
235
255
|
self.close()
|
|
236
256
|
|
|
237
257
|
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
@@ -254,6 +274,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
254
274
|
try:
|
|
255
275
|
response = self.session.get(
|
|
256
276
|
f"{self.base_url}/api/tags",
|
|
277
|
+
headers=self._get_headers(),
|
|
257
278
|
timeout=self._timeouts["health"],
|
|
258
279
|
)
|
|
259
280
|
return bool(response.status_code == 200)
|
|
@@ -265,6 +286,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
265
286
|
try:
|
|
266
287
|
response = self.session.get(
|
|
267
288
|
f"{self.base_url}/api/tags",
|
|
289
|
+
headers=self._get_headers(),
|
|
268
290
|
timeout=self._timeouts["list_models"],
|
|
269
291
|
)
|
|
270
292
|
response.raise_for_status()
|
|
@@ -326,6 +348,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
326
348
|
try:
|
|
327
349
|
response = self.session.post(
|
|
328
350
|
f"{self.base_url}/api/generate",
|
|
351
|
+
headers=self._get_headers(),
|
|
329
352
|
json=payload,
|
|
330
353
|
timeout=self._timeouts["generate"],
|
|
331
354
|
)
|
|
@@ -385,6 +408,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
385
408
|
# Direct call without cache
|
|
386
409
|
response = self.session.post(
|
|
387
410
|
f"{self.embedding_url}/api/embed",
|
|
411
|
+
headers=self._get_headers(),
|
|
388
412
|
json={"model": model, "input": truncated_text},
|
|
389
413
|
timeout=self._timeouts["embed"],
|
|
390
414
|
)
|
|
@@ -446,6 +470,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
446
470
|
try:
|
|
447
471
|
response = self.session.post(
|
|
448
472
|
f"{self.embedding_url}/api/embed",
|
|
473
|
+
headers=self._get_headers(),
|
|
449
474
|
json={"model": model, "input": truncated_texts},
|
|
450
475
|
timeout=self._timeouts["embed_batch"],
|
|
451
476
|
)
|
|
@@ -504,8 +529,8 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
504
529
|
|
|
505
530
|
Examples
|
|
506
531
|
--------
|
|
507
|
-
>>> import
|
|
508
|
-
>>> embeddings =
|
|
532
|
+
>>> import asyncio
|
|
533
|
+
>>> embeddings = asyncio.run(provider.embed_batch_async(texts, "mxbai-embed-large"))
|
|
509
534
|
"""
|
|
510
535
|
self._current_embed_model = model
|
|
511
536
|
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
@@ -611,6 +636,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
611
636
|
try:
|
|
612
637
|
response = self.session.post(
|
|
613
638
|
f"{self.base_url}/api/chat",
|
|
639
|
+
headers=self._get_headers(),
|
|
614
640
|
json=payload,
|
|
615
641
|
timeout=self._timeouts["chat"],
|
|
616
642
|
)
|
|
@@ -638,16 +664,24 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
638
664
|
if not self.use_resilience or self._generate_policy is None:
|
|
639
665
|
return "disabled"
|
|
640
666
|
# Access the circuit protector (second policy in SafetyNet)
|
|
641
|
-
|
|
642
|
-
|
|
667
|
+
policies = getattr(self._generate_policy, "policies", None)
|
|
668
|
+
if policies is None or len(policies) < 2:
|
|
669
|
+
return "unknown"
|
|
670
|
+
circuit = policies[1]
|
|
671
|
+
status = getattr(circuit, "status", None)
|
|
672
|
+
return str(getattr(status, "name", "unknown"))
|
|
643
673
|
|
|
644
674
|
@property
|
|
645
675
|
def embed_circuit_status(self) -> str:
|
|
646
676
|
"""Get embed circuit breaker status (CLOSED, OPEN, HALF_OPEN, or 'disabled')."""
|
|
647
677
|
if not self.use_resilience or self._embed_policy is None:
|
|
648
678
|
return "disabled"
|
|
649
|
-
|
|
650
|
-
|
|
679
|
+
policies = getattr(self._embed_policy, "policies", None)
|
|
680
|
+
if policies is None or len(policies) < 2:
|
|
681
|
+
return "unknown"
|
|
682
|
+
circuit = policies[1]
|
|
683
|
+
status = getattr(circuit, "status", None)
|
|
684
|
+
return str(getattr(status, "name", "unknown"))
|
|
651
685
|
|
|
652
686
|
@staticmethod
|
|
653
687
|
def clear_embedding_cache() -> None:
|
ragit/utils/__init__.py
CHANGED
|
@@ -12,8 +12,6 @@ from datetime import datetime
|
|
|
12
12
|
from math import floor
|
|
13
13
|
from typing import Any
|
|
14
14
|
|
|
15
|
-
import pandas as pd
|
|
16
|
-
|
|
17
15
|
|
|
18
16
|
def get_hashable_repr(dct: dict[str, object]) -> tuple[tuple[str, object, float, int | None], ...]:
|
|
19
17
|
"""
|
|
@@ -62,26 +60,6 @@ def remove_duplicates(items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
|
62
60
|
return deduplicated_items
|
|
63
61
|
|
|
64
62
|
|
|
65
|
-
def handle_missing_values_in_combinations(df: pd.DataFrame) -> pd.DataFrame:
|
|
66
|
-
"""
|
|
67
|
-
Handle missing values in experiment data combinations.
|
|
68
|
-
|
|
69
|
-
Parameters
|
|
70
|
-
----------
|
|
71
|
-
df : pd.DataFrame
|
|
72
|
-
Experiment data with combinations being explored.
|
|
73
|
-
|
|
74
|
-
Returns
|
|
75
|
-
-------
|
|
76
|
-
pd.DataFrame
|
|
77
|
-
Data with NaN values properly replaced.
|
|
78
|
-
"""
|
|
79
|
-
if "chunk_overlap" in df.columns:
|
|
80
|
-
df["chunk_overlap"] = df["chunk_overlap"].map(lambda el: 0 if pd.isna(el) else el)
|
|
81
|
-
|
|
82
|
-
return df
|
|
83
|
-
|
|
84
|
-
|
|
85
63
|
def datetime_str_to_epoch_time(timestamp: str | int) -> str | int:
|
|
86
64
|
"""
|
|
87
65
|
Convert datetime string to epoch time.
|
ragit/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ragit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.11.1
|
|
4
4
|
Summary: Automatic RAG Pattern Optimization Engine
|
|
5
5
|
Author: RODMENA LIMITED
|
|
6
6
|
Maintainer-email: RODMENA LIMITED <info@rodmena.co.uk>
|
|
@@ -20,13 +20,10 @@ Requires-Python: >=3.12
|
|
|
20
20
|
Description-Content-Type: text/markdown
|
|
21
21
|
License-File: LICENSE
|
|
22
22
|
Requires-Dist: requests>=2.31.0
|
|
23
|
-
Requires-Dist: numpy
|
|
24
|
-
Requires-Dist: pandas>=2.2.0
|
|
23
|
+
Requires-Dist: numpy==2.4.1
|
|
25
24
|
Requires-Dist: pydantic>=2.0.0
|
|
26
25
|
Requires-Dist: python-dotenv>=1.0.0
|
|
27
|
-
Requires-Dist: scikit-learn>=1.5.0
|
|
28
26
|
Requires-Dist: tqdm>=4.66.0
|
|
29
|
-
Requires-Dist: trio>=0.24.0
|
|
30
27
|
Requires-Dist: httpx>=0.27.0
|
|
31
28
|
Requires-Dist: resilient-circuit>=0.4.7
|
|
32
29
|
Provides-Extra: dev
|
|
@@ -115,6 +112,45 @@ answer = assistant.ask(question, top_k=3) # Requires generate_fn/LLM
|
|
|
115
112
|
code = assistant.generate_code(request) # Requires generate_fn/LLM
|
|
116
113
|
```
|
|
117
114
|
|
|
115
|
+
## Index Persistence
|
|
116
|
+
|
|
117
|
+
Save and load indexes to avoid re-computing embeddings:
|
|
118
|
+
|
|
119
|
+
```python
|
|
120
|
+
# Save index to disk
|
|
121
|
+
assistant.save_index("./my_index")
|
|
122
|
+
|
|
123
|
+
# Load index later (much faster than re-indexing)
|
|
124
|
+
loaded = RAGAssistant.load_index("./my_index", provider=OllamaProvider())
|
|
125
|
+
results = loaded.retrieve("query")
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
## Thread Safety
|
|
129
|
+
|
|
130
|
+
RAGAssistant is thread-safe. Multiple threads can safely read while another writes:
|
|
131
|
+
|
|
132
|
+
```python
|
|
133
|
+
import threading
|
|
134
|
+
|
|
135
|
+
assistant = RAGAssistant("docs/", provider=OllamaProvider())
|
|
136
|
+
|
|
137
|
+
# Safe: concurrent reads and writes
|
|
138
|
+
threading.Thread(target=lambda: assistant.retrieve("query")).start()
|
|
139
|
+
threading.Thread(target=lambda: assistant.add_documents([new_doc])).start()
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
## Resource Management
|
|
143
|
+
|
|
144
|
+
Use context managers for automatic cleanup:
|
|
145
|
+
|
|
146
|
+
```python
|
|
147
|
+
from ragit.providers import OllamaProvider
|
|
148
|
+
|
|
149
|
+
with OllamaProvider() as provider:
|
|
150
|
+
response = provider.generate("Hello", model="llama3")
|
|
151
|
+
# Session automatically closed
|
|
152
|
+
```
|
|
153
|
+
|
|
118
154
|
## Document Loading
|
|
119
155
|
|
|
120
156
|
```python
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
ragit/__init__.py,sha256=54z3-xCkEa4_P4eonrweSu3Lbig1BWLIGOGT3QUJ4N8,3263
|
|
2
|
-
ragit/assistant.py,sha256=
|
|
3
|
-
ragit/config.py,sha256=
|
|
4
|
-
ragit/exceptions.py,sha256=
|
|
5
|
-
ragit/loaders.py,sha256=
|
|
2
|
+
ragit/assistant.py,sha256=pjB58KyHGD7PwpwLE-lDyXxMhaehDe3IFiO9j7yewxk,33252
|
|
3
|
+
ragit/config.py,sha256=M3YCyogalJ-_cNbY3vAnKIknNsBmqeUFH6lhknuPKV4,6399
|
|
4
|
+
ragit/exceptions.py,sha256=2nBdAWbeLxTkykmwJBTn6BFBNib2dgPfr_Z58p1IwlY,7215
|
|
5
|
+
ragit/loaders.py,sha256=r9hDPTpnVHs9-nMeL2IhEfjIda-TCwYmG3RvnpDcs70,11042
|
|
6
6
|
ragit/logging.py,sha256=YnvhOfnOE3nTd-fR9LKPUHrWdh8fcSHIBEBS5iWDMs8,5739
|
|
7
7
|
ragit/monitor.py,sha256=ajYTdQKM4QlYhlzjiKbSiks4kQj94v0pOhW4q16vJWY,10272
|
|
8
|
-
ragit/version.py,sha256=
|
|
8
|
+
ragit/version.py,sha256=T4eF_UU9MtokI6Rs4IpazY60vfP-oX4SYll80WXMEA0,98
|
|
9
9
|
ragit/core/__init__.py,sha256=j53PFfoSMXwSbK1rRHpMbo8mX2i4R1LJ5kvTxBd7-0w,100
|
|
10
10
|
ragit/core/experiment/__init__.py,sha256=4vAPOOYlY5Dcr2gOolyhBSPGIUxZKwEkgQffxS9BodA,452
|
|
11
11
|
ragit/core/experiment/experiment.py,sha256=Ydf3jz5AXbttc2xcvIMecfc3lh4MKgCtCtyNCsFsn9c,19573
|
|
@@ -13,10 +13,10 @@ ragit/core/experiment/results.py,sha256=KHpN3YSLJ83_JUfIMccRPS-q7LEt0S9p8ehDRawk
|
|
|
13
13
|
ragit/providers/__init__.py,sha256=DSdv2-N9kJwrF6PymKYiktKbjc7g22J_7MD1Rm2ep4g,919
|
|
14
14
|
ragit/providers/base.py,sha256=MJ8mVeXuGWhkX2XGTbkWIY3cVoTOPr4h5XBXw8rAX2Q,3434
|
|
15
15
|
ragit/providers/function_adapter.py,sha256=A-TQhBgBWbuO_w1sy795Dxep1FOCBpAlWpXCKVQD8rc,7778
|
|
16
|
-
ragit/providers/ollama.py,sha256=
|
|
17
|
-
ragit/utils/__init__.py,sha256
|
|
18
|
-
ragit-0.
|
|
19
|
-
ragit-0.
|
|
20
|
-
ragit-0.
|
|
21
|
-
ragit-0.
|
|
22
|
-
ragit-0.
|
|
16
|
+
ragit/providers/ollama.py,sha256=oV6_FojbMrxYyh-g5x77EM1vhzFT4aF98aj2TybWrlw,27600
|
|
17
|
+
ragit/utils/__init__.py,sha256=6oQm2KwXFWIMtAE-0TgcDB6WwKyMy736UPnhG3bFFK4,2531
|
|
18
|
+
ragit-0.11.1.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
|
|
19
|
+
ragit-0.11.1.dist-info/METADATA,sha256=ashTbdQR3yyr9pMv4CdVmMn4tCl2NdSAlzWrJTXuwDM,5299
|
|
20
|
+
ragit-0.11.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
21
|
+
ragit-0.11.1.dist-info/top_level.txt,sha256=pkPbG7yrw61wt9_y_xcLE2vq2a55fzockASD0yq0g4s,6
|
|
22
|
+
ragit-0.11.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|