ragit 0.8.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragit/assistant.py CHANGED
@@ -7,10 +7,16 @@ High-level RAG Assistant for document Q&A and code generation.
7
7
 
8
8
  Provides a simple interface for RAG-based tasks.
9
9
 
10
- Note: This class is NOT thread-safe. Do not share instances across threads.
10
+ Thread Safety:
11
+ This class uses lock-free atomic operations for thread safety.
12
+ The IndexState is immutable, and all mutations create a new state
13
+ that is atomically swapped. Python's GIL ensures reference assignment
14
+ is atomic, making concurrent reads and writes safe.
11
15
  """
12
16
 
17
+ import json
13
18
  from collections.abc import Callable
19
+ from dataclasses import dataclass
14
20
  from pathlib import Path
15
21
  from typing import TYPE_CHECKING
16
22
 
@@ -18,7 +24,9 @@ import numpy as np
18
24
  from numpy.typing import NDArray
19
25
 
20
26
  from ragit.core.experiment.experiment import Chunk, Document
27
+ from ragit.exceptions import IndexingError
21
28
  from ragit.loaders import chunk_document, chunk_rst_sections, load_directory, load_text
29
+ from ragit.logging import log_operation, logger
22
30
  from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
23
31
  from ragit.providers.function_adapter import FunctionProvider
24
32
 
@@ -26,6 +34,23 @@ if TYPE_CHECKING:
26
34
  from numpy.typing import NDArray
27
35
 
28
36
 
37
+ @dataclass(frozen=True)
38
+ class IndexState:
39
+ """Immutable snapshot of index state for lock-free thread safety.
40
+
41
+ This class holds all mutable index data in a single immutable structure.
42
+ Updates create a new IndexState instance, and the reference swap is
43
+ atomic under Python's GIL, ensuring thread-safe reads and writes.
44
+
45
+ Attributes:
46
+ chunks: Tuple of indexed chunks (immutable).
47
+ embedding_matrix: Pre-normalized numpy array of embeddings, or None if empty.
48
+ """
49
+
50
+ chunks: tuple[Chunk, ...]
51
+ embedding_matrix: NDArray[np.float64] | None
52
+
53
+
29
54
  class RAGAssistant:
30
55
  """
31
56
  High-level RAG assistant for document Q&A and generation.
@@ -62,9 +87,12 @@ class RAGAssistant:
62
87
  ValueError
63
88
  If neither embed_fn nor provider is provided.
64
89
 
65
- Note
66
- ----
67
- This class is NOT thread-safe. Each thread should have its own instance.
90
+ Thread Safety
91
+ -------------
92
+ This class uses lock-free atomic operations for thread safety.
93
+ Multiple threads can safely call retrieve() while another thread
94
+ calls add_documents(). The IndexState is immutable, and reference
95
+ swaps are atomic under Python's GIL.
68
96
 
69
97
  Examples
70
98
  --------
@@ -76,13 +104,13 @@ class RAGAssistant:
76
104
  >>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
77
105
  >>> answer = assistant.ask("What is X?")
78
106
  >>>
79
- >>> # With explicit provider
107
+ >>> # With Ollama provider (supports nomic-embed-text)
80
108
  >>> from ragit.providers import OllamaProvider
81
109
  >>> assistant = RAGAssistant(docs, provider=OllamaProvider())
82
110
  >>>
83
- >>> # With SentenceTransformers (offline)
84
- >>> from ragit.providers import SentenceTransformersProvider
85
- >>> assistant = RAGAssistant(docs, provider=SentenceTransformersProvider())
111
+ >>> # Save and load index for persistence
112
+ >>> assistant.save_index("/path/to/index")
113
+ >>> loaded = RAGAssistant.load_index("/path/to/index", provider=OllamaProvider())
86
114
  """
87
115
 
88
116
  def __init__(
@@ -126,8 +154,7 @@ class RAGAssistant:
126
154
  "Must provide embed_fn or provider for embeddings. "
127
155
  "Examples:\n"
128
156
  " RAGAssistant(docs, embed_fn=my_embed_function)\n"
129
- " RAGAssistant(docs, provider=OllamaProvider())\n"
130
- " RAGAssistant(docs, provider=SentenceTransformersProvider())"
157
+ " RAGAssistant(docs, provider=OllamaProvider())"
131
158
  )
132
159
 
133
160
  self.embedding_model = embedding_model or "default"
@@ -138,9 +165,8 @@ class RAGAssistant:
138
165
  # Load documents if path provided
139
166
  self.documents = self._load_documents(documents)
140
167
 
141
- # Index chunks - embeddings stored as pre-normalized numpy matrix for fast search
142
- self._chunks: tuple[Chunk, ...] = ()
143
- self._embedding_matrix: NDArray[np.float64] | None = None # Pre-normalized
168
+ # Thread-safe index state (immutable, atomic reference swap)
169
+ self._state: IndexState = IndexState(chunks=(), embedding_matrix=None)
144
170
  self._build_index()
145
171
 
146
172
  def _load_documents(self, documents: list[Document] | str | Path) -> list[Document]:
@@ -175,45 +201,70 @@ class RAGAssistant:
175
201
  raise ValueError(f"Invalid documents source: {documents}")
176
202
 
177
203
  def _build_index(self) -> None:
178
- """Build vector index from documents using batch embedding."""
204
+ """Build vector index from documents using batch embedding.
205
+
206
+ Raises:
207
+ IndexingError: If embedding count doesn't match chunk count.
208
+ """
179
209
  all_chunks: list[Chunk] = []
180
210
 
181
211
  for doc in self.documents:
182
212
  # Use RST section chunking for .rst files, otherwise regular chunking
183
213
  if doc.metadata.get("filename", "").endswith(".rst"):
184
- chunks = chunk_rst_sections(doc.content, doc.id, metadata=doc.metadata)
214
+ chunks = chunk_rst_sections(doc.content, doc.id)
185
215
  else:
186
216
  chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
187
217
  all_chunks.extend(chunks)
188
218
 
189
219
  if not all_chunks:
190
- self._chunks = ()
191
- self._embedding_matrix = None
220
+ logger.warning("No chunks produced from documents - index will be empty")
221
+ self._state = IndexState(chunks=(), embedding_matrix=None)
192
222
  return
193
223
 
194
224
  # Batch embed all chunks at once (single API call)
195
225
  texts = [chunk.content for chunk in all_chunks]
196
226
  responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
197
227
 
228
+ # CRITICAL: Validate embedding count matches chunk count
229
+ if len(responses) != len(all_chunks):
230
+ raise IndexingError(
231
+ f"Embedding count mismatch: expected {len(all_chunks)} embeddings, "
232
+ f"got {len(responses)}. Index may be corrupted."
233
+ )
234
+
198
235
  # Build embedding matrix directly (skip storing in chunks to avoid duplication)
199
236
  embedding_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
200
237
 
238
+ # Additional validation: matrix shape
239
+ if embedding_matrix.shape[0] != len(all_chunks):
240
+ raise IndexingError(
241
+ f"Matrix row count {embedding_matrix.shape[0]} doesn't match chunk count {len(all_chunks)}"
242
+ )
243
+
201
244
  # Pre-normalize for fast cosine similarity (normalize once, use many times)
202
245
  norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
203
246
  norms[norms == 0] = 1 # Avoid division by zero
204
247
 
205
- # Store as immutable tuple and pre-normalized numpy matrix
206
- self._chunks = tuple(all_chunks)
207
- self._embedding_matrix = embedding_matrix / norms
248
+ # Atomic state update (thread-safe under GIL)
249
+ self._state = IndexState(
250
+ chunks=tuple(all_chunks),
251
+ embedding_matrix=embedding_matrix / norms,
252
+ )
208
253
 
209
254
  def add_documents(self, documents: list[Document] | str | Path) -> int:
210
255
  """Add documents to the existing index incrementally.
211
256
 
257
+ This method is thread-safe. It creates a new IndexState and atomically
258
+ swaps the reference, ensuring readers always see a consistent state.
259
+
212
260
  Args:
213
261
  documents: Documents to add.
214
262
 
215
263
  Returns:
216
264
  Number of chunks added.
265
+
266
+ Raises:
267
+ IndexingError: If embedding count doesn't match chunk count.
217
268
  """
218
269
  new_docs = self._load_documents(documents)
219
270
  if not new_docs:
@@ -225,7 +276,7 @@ class RAGAssistant:
225
276
  new_chunks: list[Chunk] = []
226
277
  for doc in new_docs:
227
278
  if doc.metadata.get("filename", "").endswith(".rst"):
228
- chunks = chunk_rst_sections(doc.content, doc.id, metadata=doc.metadata)
279
+ chunks = chunk_rst_sections(doc.content, doc.id)
229
280
  else:
230
281
  chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
231
282
  new_chunks.extend(chunks)
@@ -237,6 +288,13 @@ class RAGAssistant:
237
288
  texts = [chunk.content for chunk in new_chunks]
238
289
  responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
239
290
 
291
+ # Validate embedding count
292
+ if len(responses) != len(new_chunks):
293
+ raise IndexingError(
294
+ f"Embedding count mismatch: expected {len(new_chunks)} embeddings, "
295
+ f"got {len(responses)}. Index update aborted."
296
+ )
297
+
240
298
  new_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
241
299
 
242
300
  # Normalize
@@ -244,21 +302,27 @@ class RAGAssistant:
244
302
  norms[norms == 0] = 1
245
303
  new_matrix_norm = new_matrix / norms
246
304
 
247
- # Update state
248
- current_chunks = list(self._chunks)
249
- current_chunks.extend(new_chunks)
250
- self._chunks = tuple(current_chunks)
305
+ # Read current state (atomic read)
306
+ current_state = self._state
251
307
 
252
- if self._embedding_matrix is None:
253
- self._embedding_matrix = new_matrix_norm
308
+ # Build new state
309
+ combined_chunks = current_state.chunks + tuple(new_chunks)
310
+ if current_state.embedding_matrix is None:
311
+ combined_matrix = new_matrix_norm
254
312
  else:
255
- self._embedding_matrix = np.vstack((self._embedding_matrix, new_matrix_norm))
313
+ combined_matrix = np.vstack((current_state.embedding_matrix, new_matrix_norm))
314
+
315
+ # Atomic state swap (thread-safe under GIL)
316
+ self._state = IndexState(chunks=combined_chunks, embedding_matrix=combined_matrix)
256
317
 
257
318
  return len(new_chunks)
258
319
 
259
320
  def remove_documents(self, source_path_pattern: str) -> int:
260
321
  """Remove documents matching a source path pattern.
261
322
 
323
+ This method is thread-safe. It creates a new IndexState and atomically
324
+ swaps the reference.
325
+
262
326
  Args:
263
327
  source_path_pattern: Glob pattern to match 'source' metadata.
264
328
 
@@ -267,14 +331,17 @@ class RAGAssistant:
267
331
  """
268
332
  import fnmatch
269
333
 
270
- if not self._chunks:
334
+ # Read current state (atomic read)
335
+ current_state = self._state
336
+
337
+ if not current_state.chunks:
271
338
  return 0
272
339
 
273
340
  indices_to_keep = []
274
341
  kept_chunks = []
275
342
  removed_count = 0
276
343
 
277
- for i, chunk in enumerate(self._chunks):
344
+ for i, chunk in enumerate(current_state.chunks):
278
345
  source = chunk.metadata.get("source", "")
279
346
  if not source or not fnmatch.fnmatch(source, source_path_pattern):
280
347
  indices_to_keep.append(i)
@@ -285,13 +352,14 @@ class RAGAssistant:
285
352
  if removed_count == 0:
286
353
  return 0
287
354
 
288
- self._chunks = tuple(kept_chunks)
355
+ # Build new embedding matrix
356
+ if current_state.embedding_matrix is not None:
357
+ new_matrix = None if not kept_chunks else current_state.embedding_matrix[indices_to_keep]
358
+ else:
359
+ new_matrix = None
289
360
 
290
- if self._embedding_matrix is not None:
291
- if not kept_chunks:
292
- self._embedding_matrix = None
293
- else:
294
- self._embedding_matrix = self._embedding_matrix[indices_to_keep]
361
+ # Atomic state swap (thread-safe under GIL)
362
+ self._state = IndexState(chunks=tuple(kept_chunks), embedding_matrix=new_matrix)
295
363
 
296
364
  # Also remove from self.documents
297
365
  self.documents = [
@@ -334,6 +402,7 @@ class RAGAssistant:
334
402
  Retrieve relevant chunks for a query.
335
403
 
336
404
  Uses vectorized cosine similarity for fast search over all chunks.
405
+ This method is thread-safe - it reads a consistent snapshot of the index.
337
406
 
338
407
  Parameters
339
408
  ----------
@@ -353,7 +422,10 @@ class RAGAssistant:
353
422
  >>> for chunk, score in results:
354
423
  ... print(f"{score:.2f}: {chunk.content[:100]}...")
355
424
  """
356
- if not self._chunks or self._embedding_matrix is None:
425
+ # Atomic state read - get consistent snapshot
426
+ state = self._state
427
+
428
+ if not state.chunks or state.embedding_matrix is None:
357
429
  return []
358
430
 
359
431
  # Get query embedding and normalize
@@ -365,7 +437,7 @@ class RAGAssistant:
365
437
  query_normalized = query_vec / query_norm
366
438
 
367
439
  # Fast cosine similarity: matrix is pre-normalized, just dot product
368
- similarities = self._embedding_matrix @ query_normalized
440
+ similarities = state.embedding_matrix @ query_normalized
369
441
 
370
442
  # Get top_k indices using argpartition (faster than full sort for large arrays)
371
443
  if len(similarities) <= top_k:
@@ -376,7 +448,197 @@ class RAGAssistant:
376
448
  # Sort the top_k by score
377
449
  top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
378
450
 
379
- return [(self._chunks[i], float(similarities[i])) for i in top_indices]
451
+ return [(state.chunks[i], float(similarities[i])) for i in top_indices]
452
+
453
+ def retrieve_with_context(
454
+ self,
455
+ query: str,
456
+ top_k: int = 3,
457
+ window_size: int = 1,
458
+ min_score: float = 0.0,
459
+ ) -> list[tuple[Chunk, float]]:
460
+ """
461
+ Retrieve chunks with adjacent context expansion (window search).
462
+
463
+ For each retrieved chunk, also includes adjacent chunks from the
464
+ same document to provide more context. This is useful when relevant
465
+ information spans multiple chunks.
466
+
467
+ Pattern inspired by ai4rag window_search.
468
+
469
+ Parameters
470
+ ----------
471
+ query : str
472
+ Search query.
473
+ top_k : int
474
+ Number of initial chunks to retrieve (default: 3).
475
+ window_size : int
476
+ Number of adjacent chunks to include on each side (default: 1).
477
+ Set to 0 to disable window expansion.
478
+ min_score : float
479
+ Minimum similarity score threshold (default: 0.0).
480
+
481
+ Returns
482
+ -------
483
+ list[tuple[Chunk, float]]
484
+ List of (chunk, similarity_score) tuples, sorted by relevance.
485
+ Adjacent chunks have slightly lower scores.
486
+
487
+ Examples
488
+ --------
489
+ >>> # Get chunks with 1 adjacent chunk on each side
490
+ >>> results = assistant.retrieve_with_context("query", window_size=1)
491
+ >>> for chunk, score in results:
492
+ ... print(f"{score:.2f}: {chunk.content[:50]}...")
493
+ """
494
+ # Get consistent state snapshot
495
+ state = self._state
496
+
497
+ with log_operation("retrieve_with_context", query_len=len(query), top_k=top_k, window_size=window_size) as ctx:
498
+ # Get initial results (more than top_k to account for filtering)
499
+ results = self.retrieve(query, top_k * 2)
500
+
501
+ # Apply minimum score threshold
502
+ if min_score > 0:
503
+ results = [(chunk, score) for chunk, score in results if score >= min_score]
504
+
505
+ if window_size == 0 or not results:
506
+ ctx["expanded_chunks"] = len(results)
507
+ return results[:top_k]
508
+
509
+ # Build chunk index for fast lookup
510
+ chunk_to_idx = {id(chunk): i for i, chunk in enumerate(state.chunks)}
511
+
512
+ expanded_results: list[tuple[Chunk, float]] = []
513
+ seen_indices: set[int] = set()
514
+
515
+ for chunk, score in results[:top_k]:
516
+ chunk_idx = chunk_to_idx.get(id(chunk))
517
+ if chunk_idx is None:
518
+ expanded_results.append((chunk, score))
519
+ continue
520
+
521
+ # Get window of adjacent chunks from same document
522
+ start_idx = max(0, chunk_idx - window_size)
523
+ end_idx = min(len(state.chunks), chunk_idx + window_size + 1)
524
+
525
+ for idx in range(start_idx, end_idx):
526
+ if idx in seen_indices:
527
+ continue
528
+
529
+ adjacent_chunk = state.chunks[idx]
530
+ # Only include adjacent chunks from same document
531
+ if adjacent_chunk.doc_id == chunk.doc_id:
532
+ seen_indices.add(idx)
533
+ # Original chunk keeps full score, adjacent get 80%
534
+ adj_score = score if idx == chunk_idx else score * 0.8
535
+ expanded_results.append((adjacent_chunk, adj_score))
536
+
537
+ # Sort by score (highest first)
538
+ expanded_results.sort(key=lambda x: (-x[1], state.chunks.index(x[0]) if x[0] in state.chunks else 0))
539
+ ctx["expanded_chunks"] = len(expanded_results)
540
+
541
+ return expanded_results
542
+
543
+ def get_context_with_window(
544
+ self,
545
+ query: str,
546
+ top_k: int = 3,
547
+ window_size: int = 1,
548
+ min_score: float = 0.0,
549
+ ) -> str:
550
+ """
551
+ Get formatted context with adjacent chunk expansion.
552
+
553
+ Merges overlapping text from adjacent chunks intelligently.
554
+
555
+ Parameters
556
+ ----------
557
+ query : str
558
+ Search query.
559
+ top_k : int
560
+ Number of initial chunks to retrieve.
561
+ window_size : int
562
+ Number of adjacent chunks on each side.
563
+ min_score : float
564
+ Minimum similarity score threshold.
565
+
566
+ Returns
567
+ -------
568
+ str
569
+ Formatted context string with merged chunks.
570
+ """
571
+ # Get consistent state snapshot
572
+ state = self._state
573
+
574
+ results = self.retrieve_with_context(query, top_k, window_size, min_score)
575
+
576
+ if not results:
577
+ return ""
578
+
579
+ # Group chunks by document to merge properly
580
+ doc_chunks: dict[str, list[tuple[Chunk, float]]] = {}
581
+ for chunk, score in results:
582
+ doc_id = chunk.doc_id or "unknown"
583
+ if doc_id not in doc_chunks:
584
+ doc_chunks[doc_id] = []
585
+ doc_chunks[doc_id].append((chunk, score))
586
+
587
+ merged_sections: list[str] = []
588
+
589
+ for _doc_id, chunks in doc_chunks.items():
590
+ # Sort chunks by their position in the original list
591
+ chunks.sort(key=lambda x: state.chunks.index(x[0]) if x[0] in state.chunks else 0)
592
+
593
+ # Merge overlapping text
594
+ merged_content: list[str] = []
595
+ for chunk, _ in chunks:
596
+ if merged_content:
597
+ # Check for overlap with previous chunk
598
+ prev_content = merged_content[-1]
599
+ non_overlapping = self._get_non_overlapping_text(prev_content, chunk.content)
600
+ if non_overlapping != chunk.content:
601
+ # Found overlap, extend previous chunk
602
+ merged_content[-1] = prev_content + non_overlapping
603
+ else:
604
+ # No overlap, add as new section
605
+ merged_content.append(chunk.content)
606
+ else:
607
+ merged_content.append(chunk.content)
608
+
609
+ merged_sections.append("\n".join(merged_content))
610
+
611
+ return "\n\n---\n\n".join(merged_sections)
612
+
613
+ def _get_non_overlapping_text(self, str1: str, str2: str) -> str:
614
+ """
615
+ Find non-overlapping portion of str2 when appending after str1.
616
+
617
+ Detects overlap where the end of str1 matches the beginning of str2,
618
+ and returns only the non-overlapping portion of str2.
619
+
620
+ Pattern from ai4rag vector_store/utils.py.
621
+
622
+ Parameters
623
+ ----------
624
+ str1 : str
625
+ First string (previous content).
626
+ str2 : str
627
+ Second string (content to potentially append).
628
+
629
+ Returns
630
+ -------
631
+ str
632
+ Non-overlapping portion of str2, or full str2 if no overlap.
633
+ """
634
+ # Limit overlap search to avoid O(n^2) for large strings
635
+ max_overlap = min(len(str1), len(str2), 200)
636
+
637
+ for i in range(max_overlap, 0, -1):
638
+ if str1[-i:] == str2[:i]:
639
+ return str2[i:]
640
+
641
+ return str2
380
642
 
381
643
  def get_context(self, query: str, top_k: int = 3) -> str:
382
644
  """
@@ -564,7 +826,17 @@ Generate the {language} code:"""
564
826
  @property
565
827
  def num_chunks(self) -> int:
566
828
  """Return number of indexed chunks."""
567
- return len(self._chunks)
829
+ return len(self._state.chunks)
830
+
831
+ @property
832
+ def chunk_count(self) -> int:
833
+ """Number of chunks in index (alias for num_chunks)."""
834
+ return len(self._state.chunks)
835
+
836
+ @property
837
+ def is_indexed(self) -> bool:
838
+ """Check if index has any documents."""
839
+ return len(self._state.chunks) > 0
568
840
 
569
841
  @property
570
842
  def num_documents(self) -> int:
@@ -575,3 +847,122 @@ Generate the {language} code:"""
575
847
  def has_llm(self) -> bool:
576
848
  """Check if LLM is configured."""
577
849
  return self._llm_provider is not None
850
+
851
+ def save_index(self, path: str | Path) -> None:
852
+ """Save index to disk for later restoration.
853
+
854
+ Saves the index in an efficient format:
855
+ - chunks.json: Chunk metadata and content
856
+ - embeddings.npy: Numpy array of embeddings (binary format)
857
+ - metadata.json: Index configuration
858
+
859
+ Args:
860
+ path: Directory path to save index files.
861
+
862
+ Example:
863
+ >>> assistant.save_index("/path/to/index")
864
+ >>> # Later...
865
+ >>> loaded = RAGAssistant.load_index("/path/to/index", provider=provider)
866
+ """
867
+ path = Path(path)
868
+ path.mkdir(parents=True, exist_ok=True)
869
+
870
+ state = self._state
871
+
872
+ # Save chunks as JSON
873
+ chunks_data = [
874
+ {
875
+ "content": chunk.content,
876
+ "doc_id": chunk.doc_id,
877
+ "chunk_index": chunk.chunk_index,
878
+ "metadata": chunk.metadata,
879
+ }
880
+ for chunk in state.chunks
881
+ ]
882
+ (path / "chunks.json").write_text(json.dumps(chunks_data, indent=2))
883
+
884
+ # Save embeddings as numpy binary (efficient for large arrays)
885
+ if state.embedding_matrix is not None:
886
+ np.save(path / "embeddings.npy", state.embedding_matrix)
887
+
888
+ # Save metadata for validation and configuration restoration
889
+ metadata = {
890
+ "chunk_count": len(state.chunks),
891
+ "embedding_model": self.embedding_model,
892
+ "chunk_size": self.chunk_size,
893
+ "chunk_overlap": self.chunk_overlap,
894
+ "version": "1.0",
895
+ }
896
+ (path / "metadata.json").write_text(json.dumps(metadata, indent=2))
897
+
898
+ logger.info(f"Index saved to {path} ({len(state.chunks)} chunks)")
899
+
900
+ @classmethod
901
+ def load_index(
902
+ cls,
903
+ path: str | Path,
904
+ provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
905
+ ) -> "RAGAssistant":
906
+ """Load a previously saved index.
907
+
908
+ Args:
909
+ path: Directory path containing saved index files.
910
+ provider: Provider for embeddings/LLM (required for new queries).
911
+
912
+ Returns:
913
+ RAGAssistant instance with loaded index.
914
+
915
+ Raises:
916
+ IndexingError: If loaded index is corrupted (count mismatch).
917
+ FileNotFoundError: If index files don't exist.
918
+
919
+ Example:
920
+ >>> loaded = RAGAssistant.load_index("/path/to/index", provider=OllamaProvider())
921
+ >>> results = loaded.retrieve("query")
922
+ """
923
+ path = Path(path)
924
+
925
+ # Load metadata
926
+ metadata = json.loads((path / "metadata.json").read_text())
927
+
928
+ # Load chunks
929
+ chunks_data = json.loads((path / "chunks.json").read_text())
930
+ chunks = tuple(
931
+ Chunk(
932
+ content=c["content"],
933
+ doc_id=c.get("doc_id", ""),
934
+ chunk_index=c.get("chunk_index", 0),
935
+ metadata=c.get("metadata", {}),
936
+ )
937
+ for c in chunks_data
938
+ )
939
+
940
+ # Load embeddings
941
+ embeddings_path = path / "embeddings.npy"
942
+ embedding_matrix: NDArray[np.float64] | None = None
943
+ if embeddings_path.exists():
944
+ embedding_matrix = np.load(embeddings_path)
945
+
946
+ # Validate consistency
947
+ if embedding_matrix is not None and embedding_matrix.shape[0] != len(chunks):
948
+ raise IndexingError(
949
+ f"Loaded index corrupted: {embedding_matrix.shape[0]} embeddings but {len(chunks)} chunks"
950
+ )
951
+
952
+ # Create instance without calling __init__ (skip indexing)
953
+ instance = object.__new__(cls)
954
+
955
+ # Initialize required attributes
956
+ instance._state = IndexState(chunks=chunks, embedding_matrix=embedding_matrix)
957
+ instance.embedding_model = metadata.get("embedding_model", "default")
958
+ instance.llm_model = metadata.get("llm_model", "default")
959
+ instance.chunk_size = metadata.get("chunk_size", 512)
960
+ instance.chunk_overlap = metadata.get("chunk_overlap", 50)
961
+ instance.documents = [] # Original docs not saved
962
+
963
+ # Set up providers
964
+ instance._embedding_provider = provider if isinstance(provider, BaseEmbeddingProvider) else None # type: ignore
965
+ instance._llm_provider = provider if isinstance(provider, BaseLLMProvider) else None
966
+
967
+ logger.info(f"Index loaded from {path} ({len(chunks)} chunks)")
968
+ return instance