ragmint 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ragmint might be problematic. Click here for more details.

Files changed (46) hide show
  1. ragmint/__init__.py +0 -0
  2. ragmint/__main__.py +28 -0
  3. ragmint/autotuner.py +138 -0
  4. ragmint/core/__init__.py +0 -0
  5. ragmint/core/chunking.py +86 -0
  6. ragmint/core/embeddings.py +55 -0
  7. ragmint/core/evaluation.py +38 -0
  8. ragmint/core/pipeline.py +62 -0
  9. ragmint/core/reranker.py +62 -0
  10. ragmint/core/retriever.py +165 -0
  11. ragmint/experiments/__init__.py +0 -0
  12. ragmint/experiments/validation_qa.json +14 -0
  13. ragmint/explainer.py +63 -0
  14. ragmint/integrations/__init__.py +0 -0
  15. ragmint/integrations/config_adapter.py +96 -0
  16. ragmint/integrations/langchain_prebuilder.py +99 -0
  17. ragmint/leaderboard.py +45 -0
  18. ragmint/optimization/__init__.py +0 -0
  19. ragmint/optimization/search.py +48 -0
  20. ragmint/tests/__init__.py +0 -0
  21. ragmint/tests/conftest.py +16 -0
  22. ragmint/tests/test_autotuner.py +51 -0
  23. ragmint/tests/test_config_adapter.py +39 -0
  24. ragmint/tests/test_embeddings.py +46 -0
  25. ragmint/tests/test_explainer.py +20 -0
  26. ragmint/tests/test_explainer_integration.py +18 -0
  27. ragmint/tests/test_integration_autotuner_ragmint.py +47 -0
  28. ragmint/tests/test_langchain_prebuilder.py +82 -0
  29. ragmint/tests/test_leaderboard.py +39 -0
  30. ragmint/tests/test_pipeline.py +20 -0
  31. ragmint/tests/test_retriever.py +15 -0
  32. ragmint/tests/test_search.py +17 -0
  33. ragmint/tests/test_tuner.py +71 -0
  34. ragmint/tuner.py +189 -0
  35. ragmint/utils/__init__.py +0 -0
  36. ragmint/utils/caching.py +37 -0
  37. ragmint/utils/data_loader.py +65 -0
  38. ragmint/utils/logger.py +36 -0
  39. ragmint/utils/metrics.py +27 -0
  40. ragmint-0.3.1.data/data/LICENSE +19 -0
  41. ragmint-0.3.1.data/data/README.md +397 -0
  42. ragmint-0.3.1.dist-info/METADATA +441 -0
  43. ragmint-0.3.1.dist-info/RECORD +46 -0
  44. ragmint-0.3.1.dist-info/WHEEL +5 -0
  45. ragmint-0.3.1.dist-info/licenses/LICENSE +19 -0
  46. ragmint-0.3.1.dist-info/top_level.txt +1 -0
ragmint/__init__.py ADDED
File without changes
ragmint/__main__.py ADDED
@@ -0,0 +1,28 @@
1
+ from pathlib import Path
2
+ from ragmint.tuner import RAGMint
3
+
4
+ def main():
5
+ # Dynamically resolve the path to the installed ragmint package
6
+ base_dir = Path(__file__).resolve().parent
7
+
8
+ docs_path = base_dir / "experiments" / "corpus"
9
+ validation_file = base_dir / "experiments" / "validation_qa.json"
10
+
11
+ rag = RAGMint(
12
+ docs_path=str(docs_path),
13
+ retrievers=["faiss"],
14
+ embeddings=["openai/text-embedding-3-small"],
15
+ rerankers=["mmr"],
16
+ )
17
+
18
+ best, results = rag.optimize(
19
+ validation_set=str(validation_file),
20
+ metric="faithfulness",
21
+ search_type="bayesian",
22
+ trials=10,
23
+ )
24
+
25
+ print("Best config found:\n", best)
26
+
27
+ if __name__ == "__main__":
28
+ main()
ragmint/autotuner.py ADDED
@@ -0,0 +1,138 @@
1
+ """
2
+ Auto-RAG Tuner
3
+ --------------
4
+ Automatically recommends and optimizes RAG configurations based on corpus statistics.
5
+ Integrates with RAGMint to perform full end-to-end tuning.
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from statistics import mean
11
+ from typing import Dict, Any, Tuple, List
12
+
13
+ from .tuner import RAGMint
14
+ from .core.evaluation import evaluate_config
15
+
16
+ logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
17
+
18
+
19
+ class AutoRAGTuner:
20
+ def __init__(self, docs_path: str):
21
+ """
22
+ AutoRAGTuner automatically analyzes a corpus and runs an optimized RAG tuning pipeline.
23
+
24
+ Args:
25
+ docs_path (str): Path to the directory containing documents (.txt, .md, .rst)
26
+ """
27
+ self.docs_path = docs_path
28
+ self.corpus_stats = self._analyze_corpus()
29
+
30
+ # -----------------------------
31
+ # Corpus Analysis
32
+ # -----------------------------
33
+ def _analyze_corpus(self) -> Dict[str, Any]:
34
+ """Compute corpus size, average length, and number of documents."""
35
+ docs = []
36
+ total_chars = 0
37
+ num_docs = 0
38
+
39
+ if not os.path.exists(self.docs_path):
40
+ logging.warning(f"⚠️ Corpus path not found: {self.docs_path}")
41
+ return {"size": 0, "avg_len": 0, "num_docs": 0}
42
+
43
+ for file in os.listdir(self.docs_path):
44
+ if file.endswith((".txt", ".md", ".rst")):
45
+ with open(os.path.join(self.docs_path, file), "r", encoding="utf-8") as f:
46
+ content = f.read()
47
+ docs.append(content)
48
+ total_chars += len(content)
49
+ num_docs += 1
50
+
51
+ avg_len = int(mean([len(d) for d in docs])) if docs else 0
52
+ stats = {"size": total_chars, "avg_len": avg_len, "num_docs": num_docs}
53
+ logging.info(f"📊 Corpus stats: {stats}")
54
+ return stats
55
+
56
+ # -----------------------------
57
+ # Recommendation Logic
58
+ # -----------------------------
59
+ def recommend(self) -> Dict[str, Any]:
60
+ """Recommend retriever, embedding, and chunking based on corpus stats."""
61
+ size = self.corpus_stats.get("size", 0)
62
+ avg_len = self.corpus_stats.get("avg_len", 0)
63
+ num_docs = self.corpus_stats.get("num_docs", 0)
64
+
65
+ # Heuristic-based tuning
66
+ # Determine chunking heuristics first
67
+ if avg_len < 200:
68
+ chunk_size, overlap = 300, 50
69
+ elif avg_len < 500:
70
+ chunk_size, overlap = 500, 100
71
+ else:
72
+ chunk_size, overlap = 800, 150
73
+
74
+ # Determine retriever–embedding based on corpus size
75
+ if size <= 2000:
76
+ retriever = "BM25"
77
+ embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
78
+ elif size <= 10000:
79
+ retriever = "Chroma"
80
+ embedding_model = "sentence-transformers/paraphrase-MiniLM-L6-v2"
81
+ else:
82
+ retriever = "FAISS"
83
+ embedding_model = "sentence-transformers/all-mpnet-base-v2"
84
+
85
+ strategy = "fixed" if avg_len < 400 else "sentence"
86
+
87
+ recommendation = {
88
+ "retriever": retriever,
89
+ "embedding_model": embedding_model,
90
+ "chunk_size": chunk_size,
91
+ "overlap": overlap,
92
+ "strategy": strategy,
93
+ }
94
+
95
+ logging.info(f"🔮 AutoRAG Recommendation: {recommendation}")
96
+ return recommendation
97
+
98
+ # -----------------------------
99
+ # Full Auto-Tuning
100
+ # -----------------------------
101
+ def auto_tune(
102
+ self,
103
+ validation_set: str = None,
104
+ metric: str = "faithfulness",
105
+ trials: int = 5,
106
+ search_type: str = "random",
107
+ ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
108
+ """
109
+ Run a full automatic optimization using RAGMint.
110
+
111
+ Automatically:
112
+ - Recommends initial config (retriever, embedding, chunking)
113
+ - Launches RAGMint optimization trials
114
+ - Returns best configuration and results
115
+ """
116
+ rec = self.recommend()
117
+
118
+ logging.info("🚀 Launching full AutoRAG optimization with RAGMint")
119
+
120
+ tuner = RAGMint(
121
+ docs_path=self.docs_path,
122
+ retrievers=[rec["retriever"]],
123
+ embeddings=[rec["embedding_model"]],
124
+ rerankers=["mmr"],
125
+ chunk_sizes=[rec["chunk_size"]],
126
+ overlaps=[rec["overlap"]],
127
+ strategies=[rec["strategy"]],
128
+ )
129
+
130
+ best, results = tuner.optimize(
131
+ validation_set=validation_set,
132
+ metric=metric,
133
+ trials=trials,
134
+ search_type=search_type,
135
+ )
136
+
137
+ logging.info(f"🏁 AutoRAG tuning complete. Best: {best}")
138
+ return best, results
File without changes
@@ -0,0 +1,86 @@
1
+ from typing import List
2
+ import re
3
+
4
+ try:
5
+ import tiktoken
6
+ except ImportError:
7
+ tiktoken = None
8
+
9
+ try:
10
+ import nltk
11
+ nltk.download("punkt", quiet=True)
12
+ from nltk.tokenize import sent_tokenize
13
+ except ImportError:
14
+ sent_tokenize = None
15
+
16
+
17
+ class Chunker:
18
+ """
19
+ Handles text chunking strategies:
20
+ - fixed: character-based
21
+ - token: token-based (requires tiktoken)
22
+ - sentence: splits by full sentences (requires nltk)
23
+ """
24
+
25
+ def __init__(self, chunk_size: int = 500, overlap: int = 100, strategy: str = "fixed"):
26
+ self.chunk_size = chunk_size
27
+ self.overlap = overlap
28
+ self.strategy = strategy
29
+
30
+ def chunk_text(self, text: str) -> List[str]:
31
+ """Dispatches to the correct chunking strategy."""
32
+ if self.strategy == "token" and tiktoken:
33
+ return self._chunk_by_tokens(text)
34
+ elif self.strategy == "sentence" and sent_tokenize:
35
+ return self._chunk_by_sentences(text)
36
+ else:
37
+ return self._chunk_fixed(text)
38
+
39
+ # -------------------------------
40
+ # Fixed-length (default)
41
+ # -------------------------------
42
+ def _chunk_fixed(self, text: str) -> List[str]:
43
+ chunks = []
44
+ start = 0
45
+ while start < len(text):
46
+ end = start + self.chunk_size
47
+ chunks.append(text[start:end])
48
+ start += self.chunk_size - self.overlap
49
+ return chunks
50
+
51
+ # -------------------------------
52
+ # Token-based (for LLM embedding)
53
+ # -------------------------------
54
+ def _chunk_by_tokens(self, text: str) -> List[str]:
55
+ if not tiktoken:
56
+ raise ImportError("tiktoken is required for token-based chunking.")
57
+ enc = tiktoken.get_encoding("cl100k_base")
58
+ tokens = enc.encode(text)
59
+
60
+ chunks = []
61
+ for i in range(0, len(tokens), self.chunk_size - self.overlap):
62
+ chunk_tokens = tokens[i:i + self.chunk_size]
63
+ chunks.append(enc.decode(chunk_tokens))
64
+ return chunks
65
+
66
+ # -------------------------------
67
+ # Sentence-based
68
+ # -------------------------------
69
+ def _chunk_by_sentences(self, text: str) -> List[str]:
70
+ if not sent_tokenize:
71
+ raise ImportError("nltk is required for sentence-based chunking.")
72
+ sentences = sent_tokenize(text)
73
+ chunks = []
74
+ current_chunk = ""
75
+
76
+ for sentence in sentences:
77
+ if len(current_chunk) + len(sentence) <= self.chunk_size:
78
+ current_chunk += " " + sentence
79
+ else:
80
+ chunks.append(current_chunk.strip())
81
+ current_chunk = sentence
82
+
83
+ if current_chunk:
84
+ chunks.append(current_chunk.strip())
85
+
86
+ return chunks
@@ -0,0 +1,55 @@
1
+ import numpy as np
2
+ from dotenv import load_dotenv
3
+
4
+ try:
5
+ from sentence_transformers import SentenceTransformer
6
+ except ImportError:
7
+ SentenceTransformer = None
8
+
9
+
10
+ class Embeddings:
11
+ """
12
+ Wrapper for embedding backends: HuggingFace (SentenceTransformers) or Dummy.
13
+
14
+ Example:
15
+ model = Embeddings("huggingface", model_name="all-MiniLM-L6-v2")
16
+ embeddings = model.encode(["example text"])
17
+ """
18
+
19
+ def __init__(self, backend: str = "huggingface", model_name: str = None):
20
+ load_dotenv()
21
+ self.backend = backend.lower()
22
+ self.model_name = model_name or "all-MiniLM-L6-v2"
23
+
24
+ if self.backend == "huggingface":
25
+ if SentenceTransformer is None:
26
+ raise ImportError("Please install `sentence-transformers` to use HuggingFace embeddings.")
27
+ self.model = SentenceTransformer(self.model_name)
28
+ self.dim = self.model.get_sentence_embedding_dimension()
29
+
30
+ elif self.backend == "dummy":
31
+ self.model = None
32
+ self.dim = 768 # Default embedding dimension for dummy backend
33
+
34
+ else:
35
+ raise ValueError(f"Unsupported embedding backend: {backend}")
36
+
37
+ def encode(self, texts):
38
+ if isinstance(texts, str):
39
+ texts = [texts]
40
+
41
+ if self.backend == "huggingface":
42
+ embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
43
+
44
+ elif self.backend == "dummy":
45
+ # Return a NumPy array of shape (len(texts), dim)
46
+ embeddings = np.random.rand(len(texts), self.dim).astype(np.float32)
47
+
48
+ else:
49
+ raise ValueError(f"Unknown embedding backend: {self.backend}")
50
+
51
+ # ✅ Always ensure NumPy array output
52
+ if not isinstance(embeddings, np.ndarray):
53
+ embeddings = np.array(embeddings, dtype=np.float32)
54
+
55
+ return embeddings
@@ -0,0 +1,38 @@
1
+ import time
2
+ from typing import Dict, Any
3
+ from difflib import SequenceMatcher
4
+
5
+
6
+ class Evaluator:
7
+ """
8
+ Simple evaluation of generated answers:
9
+ - Faithfulness (similarity between answer and context)
10
+ - Latency
11
+ """
12
+
13
+ def __init__(self):
14
+ pass
15
+
16
+ def evaluate(self, query: str, answer: str, context: str) -> Dict[str, Any]:
17
+ start = time.time()
18
+ faithfulness = self._similarity(answer, context)
19
+ latency = time.time() - start
20
+
21
+ return {
22
+ "faithfulness": faithfulness,
23
+ "latency": latency,
24
+ }
25
+
26
+ def _similarity(self, a: str, b: str) -> float:
27
+ return SequenceMatcher(None, a, b).ratio()
28
+
29
+ def evaluate_config(config, validation_data):
30
+ evaluator = Evaluator()
31
+ results = []
32
+ for sample in validation_data:
33
+ query = sample.get("query", "")
34
+ answer = sample.get("answer", "")
35
+ context = sample.get("context", "")
36
+ results.append(evaluator.evaluate(query, answer, context))
37
+ return results
38
+
@@ -0,0 +1,62 @@
1
+ from typing import Any, Dict, Optional
2
+ from .retriever import Retriever
3
+ from .reranker import Reranker
4
+ from .evaluation import Evaluator
5
+ from .chunking import Chunker
6
+
7
+
8
+ class RAGPipeline:
9
+ """
10
+ Core Retrieval-Augmented Generation pipeline.
11
+ Retrieves, reranks, and evaluates a query given the configured backends.
12
+ Supports text chunking for optimal retrieval performance.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ retriever: Retriever,
18
+ reranker: Reranker,
19
+ evaluator: Evaluator,
20
+ chunk_size: int = 500,
21
+ overlap: int = 100,
22
+ chunking_strategy: str = "fixed"
23
+ ):
24
+ self.retriever = retriever
25
+ self.reranker = reranker
26
+ self.evaluator = evaluator
27
+
28
+ # Initialize chunker for preprocessing
29
+ self.chunker = Chunker(chunk_size=chunk_size, overlap=overlap, strategy=chunking_strategy)
30
+
31
+ def preprocess_docs(self, documents):
32
+ """Applies the selected chunking strategy to the document set."""
33
+ all_chunks = []
34
+ for doc in documents:
35
+ chunks = self.chunker.chunk_text(doc)
36
+ all_chunks.extend(chunks)
37
+ return all_chunks
38
+
39
+ def run(self, query: str, top_k: int = 5, use_chunking: bool = True) -> Dict[str, Any]:
40
+ # Optional preprocessing step
41
+ if use_chunking and hasattr(self.retriever, "documents") and self.retriever.documents:
42
+ self.retriever.documents = self.preprocess_docs(self.retriever.documents)
43
+
44
+ # Retrieve documents
45
+ retrieved_docs = self.retriever.retrieve(query, top_k=top_k)
46
+
47
+ # Rerank
48
+ reranked_docs = self.reranker.rerank(query, retrieved_docs)
49
+
50
+ # Construct pseudo-answer
51
+ answer = reranked_docs[0]["text"] if reranked_docs else ""
52
+ context = "\n".join([d["text"] for d in reranked_docs])
53
+
54
+ # Evaluate
55
+ metrics = self.evaluator.evaluate(query, answer, context)
56
+
57
+ return {
58
+ "query": query,
59
+ "answer": answer,
60
+ "docs": reranked_docs,
61
+ "metrics": metrics,
62
+ }
@@ -0,0 +1,62 @@
1
+ from typing import List, Dict, Any
2
+ import numpy as np
3
+
4
+
5
+ class Reranker:
6
+ """
7
+ Supports:
8
+ - MMR (Maximal Marginal Relevance)
9
+ - Dummy CrossEncoder (for demonstration)
10
+ """
11
+
12
+ def __init__(self, mode: str = "mmr", lambda_param: float = 0.5, seed: int = 42):
13
+ self.mode = mode
14
+ self.lambda_param = lambda_param
15
+ np.random.seed(seed)
16
+
17
+ def rerank(self, query: str, docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
18
+ if not docs:
19
+ return []
20
+
21
+ if self.mode == "crossencoder":
22
+ return self._crossencoder_rerank(query, docs)
23
+ return self._mmr_rerank(query, docs)
24
+
25
+ def _mmr_rerank(self, query: str, docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
26
+ """Perform MMR reranking using dummy similarity scores."""
27
+ selected = []
28
+ remaining = docs.copy()
29
+
30
+ while remaining and len(selected) < len(docs):
31
+ if not selected:
32
+ # pick doc with highest base score
33
+ best = max(remaining, key=lambda d: d["score"])
34
+ else:
35
+ # MMR balancing between relevance and diversity
36
+ mmr_scores = []
37
+ for d in remaining:
38
+ max_div = max(
39
+ [self._similarity(d["text"], s["text"]) for s in selected],
40
+ default=0,
41
+ )
42
+ mmr_score = (
43
+ self.lambda_param * d["score"]
44
+ - (1 - self.lambda_param) * max_div
45
+ )
46
+ mmr_scores.append(mmr_score)
47
+ best = remaining[int(np.argmax(mmr_scores))]
48
+ selected.append(best)
49
+ remaining.remove(best)
50
+
51
+ return selected
52
+
53
+ def _crossencoder_rerank(self, query: str, docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
54
+ """Adds a small random perturbation to simulate crossencoder reranking."""
55
+ for d in docs:
56
+ d["score"] += np.random.uniform(0, 0.1)
57
+ return sorted(docs, key=lambda d: d["score"], reverse=True)
58
+
59
+ def _similarity(self, a: str, b: str) -> float:
60
+ """Dummy similarity function between two strings."""
61
+ # Deterministic pseudo-similarity based on hash
62
+ return abs(hash(a + b)) % 100 / 100.0
@@ -0,0 +1,165 @@
1
+ from typing import List, Dict, Any, Optional
2
+ import numpy as np
3
+ from .embeddings import Embeddings
4
+
5
+ # Optional imports
6
+ try:
7
+ import faiss
8
+ except ImportError:
9
+ faiss = None
10
+
11
+ try:
12
+ import chromadb
13
+ except ImportError:
14
+ chromadb = None
15
+
16
+ try:
17
+ from sklearn.neighbors import BallTree
18
+ except ImportError:
19
+ BallTree = None
20
+
21
+ try:
22
+ from rank_bm25 import BM25Okapi
23
+ except ImportError:
24
+ BM25Okapi = None
25
+
26
+
27
+ class Retriever:
28
+ """
29
+ Multi-backend retriever supporting:
30
+ - "numpy" : basic cosine similarity (dense)
31
+ - "faiss" : high-performance dense retriever
32
+ - "chroma" : persistent vector DB
33
+ - "sklearn": BallTree (cosine or Euclidean)
34
+ - "bm25" : lexical retriever using Rank-BM25
35
+
36
+ Example:
37
+ retriever = Retriever(embedder, documents=["A", "B", "C"], backend="bm25")
38
+ results = retriever.retrieve("example query", top_k=3)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ embedder: Optional[Embeddings] = None,
44
+ documents: Optional[List[str]] = None,
45
+ embeddings: Optional[np.ndarray] = None,
46
+ backend: str = "numpy",
47
+ ):
48
+ self.embedder = embedder
49
+ self.documents = documents or []
50
+ self.backend = backend.lower()
51
+ self.embeddings = None
52
+ self.index = None
53
+ self.client = None
54
+ self.bm25 = None
55
+
56
+ # Initialize embeddings for dense backends
57
+ if self.backend not in ["bm25"]:
58
+ if embeddings is not None:
59
+ self.embeddings = np.array(embeddings)
60
+ elif self.documents and self.embedder:
61
+ self.embeddings = self.embedder.encode(self.documents)
62
+ else:
63
+ self.embeddings = np.zeros((0, getattr(self.embedder, "dim", 768)))
64
+
65
+ # Normalize for cosine
66
+ if self.embeddings.size > 0:
67
+ self.embeddings = self._normalize(self.embeddings)
68
+
69
+ # Initialize backend
70
+ self._init_backend()
71
+
72
+ # ------------------------
73
+ # Backend Initialization
74
+ # ------------------------
75
+ def _init_backend(self):
76
+ if self.backend == "faiss":
77
+ if faiss is None:
78
+ raise ImportError("faiss not installed. Run `pip install faiss-cpu`.")
79
+ self.index = faiss.IndexFlatIP(self.embedder.dim)
80
+ self.index.add(self.embeddings.astype("float32"))
81
+
82
+ elif self.backend == "chroma":
83
+ if chromadb is None:
84
+ raise ImportError("chromadb not installed. Run `pip install chromadb`.")
85
+ self.client = chromadb.Client()
86
+ self.collection = self.client.create_collection(name="ragmint_retriever")
87
+ for i, doc in enumerate(self.documents):
88
+ self.collection.add(
89
+ ids=[str(i)],
90
+ documents=[doc],
91
+ embeddings=[self.embeddings[i].tolist()],
92
+ )
93
+
94
+ elif self.backend == "sklearn":
95
+ if BallTree is None:
96
+ raise ImportError("scikit-learn not installed. Run `pip install scikit-learn`.")
97
+ self.index = BallTree(self.embeddings)
98
+
99
+ elif self.backend == "bm25":
100
+ if BM25Okapi is None:
101
+ raise ImportError("rank-bm25 not installed. Run `pip install rank-bm25`.")
102
+ tokenized_corpus = [doc.lower().split() for doc in self.documents]
103
+ self.bm25 = BM25Okapi(tokenized_corpus)
104
+
105
+ elif self.backend != "numpy":
106
+ raise ValueError(f"Unsupported retriever backend: {self.backend}")
107
+
108
+ # ------------------------
109
+ # Retrieval
110
+ # ------------------------
111
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
112
+ if len(self.documents) == 0:
113
+ return [{"text": "", "score": 0.0}]
114
+
115
+ # BM25 retrieval (lexical)
116
+ if self.backend == "bm25":
117
+ tokenized_query = query.lower().split()
118
+ scores = self.bm25.get_scores(tokenized_query)
119
+ top_indices = np.argsort(scores)[::-1][:top_k]
120
+ return [
121
+ {"text": self.documents[i], "score": float(scores[i])}
122
+ for i in top_indices
123
+ ]
124
+
125
+ # Dense retrieval (others)
126
+ if self.embeddings is None or self.embeddings.size == 0:
127
+ return [{"text": "", "score": 0.0}]
128
+
129
+ query_vec = self.embedder.encode([query])[0]
130
+ query_vec = self._normalize(query_vec)
131
+
132
+ if self.backend == "numpy":
133
+ scores = np.dot(self.embeddings, query_vec)
134
+ top_indices = np.argsort(scores)[::-1][:top_k]
135
+ return [{"text": self.documents[i], "score": float(scores[i])} for i in top_indices]
136
+
137
+ elif self.backend == "faiss":
138
+ query_vec = np.expand_dims(query_vec.astype("float32"), axis=0)
139
+ scores, indices = self.index.search(query_vec, top_k)
140
+ return [{"text": self.documents[int(i)], "score": float(scores[0][j])} for j, i in enumerate(indices[0])]
141
+
142
+ elif self.backend == "chroma":
143
+ results = self.collection.query(query_texts=[query], n_results=top_k)
144
+ docs = results["documents"][0]
145
+ scores = results["distances"][0]
146
+ return [{"text": d, "score": 1 - s} for d, s in zip(docs, scores)]
147
+
148
+ elif self.backend == "sklearn":
149
+ distances, indices = self.index.query([query_vec], k=top_k)
150
+ scores = 1 - distances[0]
151
+ return [{"text": self.documents[int(i)], "score": float(scores[j])} for j, i in enumerate(indices[0])]
152
+
153
+ else:
154
+ raise ValueError(f"Unknown backend: {self.backend}")
155
+
156
+ # ------------------------
157
+ # Utils
158
+ # ------------------------
159
+ @staticmethod
160
+ def _normalize(vectors: np.ndarray) -> np.ndarray:
161
+ if vectors.ndim == 1:
162
+ norm = np.linalg.norm(vectors)
163
+ return vectors / norm if norm > 0 else vectors
164
+ norms = np.linalg.norm(vectors, axis=1, keepdims=True)
165
+ return np.divide(vectors, norms, out=np.zeros_like(vectors), where=norms != 0)
File without changes
@@ -0,0 +1,14 @@
1
+ [
2
+ {
3
+ "query": "What is Retrieval-Augmented Generation?",
4
+ "expected_answer": "A technique that combines information retrieval with language generation to improve factual accuracy."
5
+ },
6
+ {
7
+ "query": "What is the role of embeddings in a RAG system?",
8
+ "expected_answer": "They represent text as numerical vectors for similarity-based retrieval."
9
+ },
10
+ {
11
+ "query": "What is Maximal Marginal Relevance used for?",
12
+ "expected_answer": "To select diverse and relevant documents during reranking."
13
+ }
14
+ ]