ragmint 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ragmint might be problematic. Click here for more details.
- ragmint/__init__.py +0 -0
- ragmint/__main__.py +28 -0
- ragmint/autotuner.py +138 -0
- ragmint/core/__init__.py +0 -0
- ragmint/core/chunking.py +86 -0
- ragmint/core/embeddings.py +55 -0
- ragmint/core/evaluation.py +38 -0
- ragmint/core/pipeline.py +62 -0
- ragmint/core/reranker.py +62 -0
- ragmint/core/retriever.py +165 -0
- ragmint/experiments/__init__.py +0 -0
- ragmint/experiments/validation_qa.json +14 -0
- ragmint/explainer.py +63 -0
- ragmint/integrations/__init__.py +0 -0
- ragmint/integrations/config_adapter.py +96 -0
- ragmint/integrations/langchain_prebuilder.py +99 -0
- ragmint/leaderboard.py +45 -0
- ragmint/optimization/__init__.py +0 -0
- ragmint/optimization/search.py +48 -0
- ragmint/tests/__init__.py +0 -0
- ragmint/tests/conftest.py +16 -0
- ragmint/tests/test_autotuner.py +51 -0
- ragmint/tests/test_config_adapter.py +39 -0
- ragmint/tests/test_embeddings.py +46 -0
- ragmint/tests/test_explainer.py +20 -0
- ragmint/tests/test_explainer_integration.py +18 -0
- ragmint/tests/test_integration_autotuner_ragmint.py +47 -0
- ragmint/tests/test_langchain_prebuilder.py +82 -0
- ragmint/tests/test_leaderboard.py +39 -0
- ragmint/tests/test_pipeline.py +20 -0
- ragmint/tests/test_retriever.py +15 -0
- ragmint/tests/test_search.py +17 -0
- ragmint/tests/test_tuner.py +71 -0
- ragmint/tuner.py +189 -0
- ragmint/utils/__init__.py +0 -0
- ragmint/utils/caching.py +37 -0
- ragmint/utils/data_loader.py +65 -0
- ragmint/utils/logger.py +36 -0
- ragmint/utils/metrics.py +27 -0
- ragmint-0.3.1.data/data/LICENSE +19 -0
- ragmint-0.3.1.data/data/README.md +397 -0
- ragmint-0.3.1.dist-info/METADATA +441 -0
- ragmint-0.3.1.dist-info/RECORD +46 -0
- ragmint-0.3.1.dist-info/WHEEL +5 -0
- ragmint-0.3.1.dist-info/licenses/LICENSE +19 -0
- ragmint-0.3.1.dist-info/top_level.txt +1 -0
ragmint/__init__.py
ADDED
|
File without changes
|
ragmint/__main__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from ragmint.tuner import RAGMint
|
|
3
|
+
|
|
4
|
+
def main():
|
|
5
|
+
# Dynamically resolve the path to the installed ragmint package
|
|
6
|
+
base_dir = Path(__file__).resolve().parent
|
|
7
|
+
|
|
8
|
+
docs_path = base_dir / "experiments" / "corpus"
|
|
9
|
+
validation_file = base_dir / "experiments" / "validation_qa.json"
|
|
10
|
+
|
|
11
|
+
rag = RAGMint(
|
|
12
|
+
docs_path=str(docs_path),
|
|
13
|
+
retrievers=["faiss"],
|
|
14
|
+
embeddings=["openai/text-embedding-3-small"],
|
|
15
|
+
rerankers=["mmr"],
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
best, results = rag.optimize(
|
|
19
|
+
validation_set=str(validation_file),
|
|
20
|
+
metric="faithfulness",
|
|
21
|
+
search_type="bayesian",
|
|
22
|
+
trials=10,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
print("Best config found:\n", best)
|
|
26
|
+
|
|
27
|
+
if __name__ == "__main__":
|
|
28
|
+
main()
|
ragmint/autotuner.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Auto-RAG Tuner
|
|
3
|
+
--------------
|
|
4
|
+
Automatically recommends and optimizes RAG configurations based on corpus statistics.
|
|
5
|
+
Integrates with RAGMint to perform full end-to-end tuning.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import logging
|
|
10
|
+
from statistics import mean
|
|
11
|
+
from typing import Dict, Any, Tuple, List
|
|
12
|
+
|
|
13
|
+
from .tuner import RAGMint
|
|
14
|
+
from .core.evaluation import evaluate_config
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AutoRAGTuner:
|
|
20
|
+
def __init__(self, docs_path: str):
|
|
21
|
+
"""
|
|
22
|
+
AutoRAGTuner automatically analyzes a corpus and runs an optimized RAG tuning pipeline.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
docs_path (str): Path to the directory containing documents (.txt, .md, .rst)
|
|
26
|
+
"""
|
|
27
|
+
self.docs_path = docs_path
|
|
28
|
+
self.corpus_stats = self._analyze_corpus()
|
|
29
|
+
|
|
30
|
+
# -----------------------------
|
|
31
|
+
# Corpus Analysis
|
|
32
|
+
# -----------------------------
|
|
33
|
+
def _analyze_corpus(self) -> Dict[str, Any]:
|
|
34
|
+
"""Compute corpus size, average length, and number of documents."""
|
|
35
|
+
docs = []
|
|
36
|
+
total_chars = 0
|
|
37
|
+
num_docs = 0
|
|
38
|
+
|
|
39
|
+
if not os.path.exists(self.docs_path):
|
|
40
|
+
logging.warning(f"⚠️ Corpus path not found: {self.docs_path}")
|
|
41
|
+
return {"size": 0, "avg_len": 0, "num_docs": 0}
|
|
42
|
+
|
|
43
|
+
for file in os.listdir(self.docs_path):
|
|
44
|
+
if file.endswith((".txt", ".md", ".rst")):
|
|
45
|
+
with open(os.path.join(self.docs_path, file), "r", encoding="utf-8") as f:
|
|
46
|
+
content = f.read()
|
|
47
|
+
docs.append(content)
|
|
48
|
+
total_chars += len(content)
|
|
49
|
+
num_docs += 1
|
|
50
|
+
|
|
51
|
+
avg_len = int(mean([len(d) for d in docs])) if docs else 0
|
|
52
|
+
stats = {"size": total_chars, "avg_len": avg_len, "num_docs": num_docs}
|
|
53
|
+
logging.info(f"📊 Corpus stats: {stats}")
|
|
54
|
+
return stats
|
|
55
|
+
|
|
56
|
+
# -----------------------------
|
|
57
|
+
# Recommendation Logic
|
|
58
|
+
# -----------------------------
|
|
59
|
+
def recommend(self) -> Dict[str, Any]:
|
|
60
|
+
"""Recommend retriever, embedding, and chunking based on corpus stats."""
|
|
61
|
+
size = self.corpus_stats.get("size", 0)
|
|
62
|
+
avg_len = self.corpus_stats.get("avg_len", 0)
|
|
63
|
+
num_docs = self.corpus_stats.get("num_docs", 0)
|
|
64
|
+
|
|
65
|
+
# Heuristic-based tuning
|
|
66
|
+
# Determine chunking heuristics first
|
|
67
|
+
if avg_len < 200:
|
|
68
|
+
chunk_size, overlap = 300, 50
|
|
69
|
+
elif avg_len < 500:
|
|
70
|
+
chunk_size, overlap = 500, 100
|
|
71
|
+
else:
|
|
72
|
+
chunk_size, overlap = 800, 150
|
|
73
|
+
|
|
74
|
+
# Determine retriever–embedding based on corpus size
|
|
75
|
+
if size <= 2000:
|
|
76
|
+
retriever = "BM25"
|
|
77
|
+
embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
|
78
|
+
elif size <= 10000:
|
|
79
|
+
retriever = "Chroma"
|
|
80
|
+
embedding_model = "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
|
81
|
+
else:
|
|
82
|
+
retriever = "FAISS"
|
|
83
|
+
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
|
84
|
+
|
|
85
|
+
strategy = "fixed" if avg_len < 400 else "sentence"
|
|
86
|
+
|
|
87
|
+
recommendation = {
|
|
88
|
+
"retriever": retriever,
|
|
89
|
+
"embedding_model": embedding_model,
|
|
90
|
+
"chunk_size": chunk_size,
|
|
91
|
+
"overlap": overlap,
|
|
92
|
+
"strategy": strategy,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
logging.info(f"🔮 AutoRAG Recommendation: {recommendation}")
|
|
96
|
+
return recommendation
|
|
97
|
+
|
|
98
|
+
# -----------------------------
|
|
99
|
+
# Full Auto-Tuning
|
|
100
|
+
# -----------------------------
|
|
101
|
+
def auto_tune(
|
|
102
|
+
self,
|
|
103
|
+
validation_set: str = None,
|
|
104
|
+
metric: str = "faithfulness",
|
|
105
|
+
trials: int = 5,
|
|
106
|
+
search_type: str = "random",
|
|
107
|
+
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
|
108
|
+
"""
|
|
109
|
+
Run a full automatic optimization using RAGMint.
|
|
110
|
+
|
|
111
|
+
Automatically:
|
|
112
|
+
- Recommends initial config (retriever, embedding, chunking)
|
|
113
|
+
- Launches RAGMint optimization trials
|
|
114
|
+
- Returns best configuration and results
|
|
115
|
+
"""
|
|
116
|
+
rec = self.recommend()
|
|
117
|
+
|
|
118
|
+
logging.info("🚀 Launching full AutoRAG optimization with RAGMint")
|
|
119
|
+
|
|
120
|
+
tuner = RAGMint(
|
|
121
|
+
docs_path=self.docs_path,
|
|
122
|
+
retrievers=[rec["retriever"]],
|
|
123
|
+
embeddings=[rec["embedding_model"]],
|
|
124
|
+
rerankers=["mmr"],
|
|
125
|
+
chunk_sizes=[rec["chunk_size"]],
|
|
126
|
+
overlaps=[rec["overlap"]],
|
|
127
|
+
strategies=[rec["strategy"]],
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
best, results = tuner.optimize(
|
|
131
|
+
validation_set=validation_set,
|
|
132
|
+
metric=metric,
|
|
133
|
+
trials=trials,
|
|
134
|
+
search_type=search_type,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
logging.info(f"🏁 AutoRAG tuning complete. Best: {best}")
|
|
138
|
+
return best, results
|
ragmint/core/__init__.py
ADDED
|
File without changes
|
ragmint/core/chunking.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import tiktoken
|
|
6
|
+
except ImportError:
|
|
7
|
+
tiktoken = None
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import nltk
|
|
11
|
+
nltk.download("punkt", quiet=True)
|
|
12
|
+
from nltk.tokenize import sent_tokenize
|
|
13
|
+
except ImportError:
|
|
14
|
+
sent_tokenize = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Chunker:
|
|
18
|
+
"""
|
|
19
|
+
Handles text chunking strategies:
|
|
20
|
+
- fixed: character-based
|
|
21
|
+
- token: token-based (requires tiktoken)
|
|
22
|
+
- sentence: splits by full sentences (requires nltk)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, chunk_size: int = 500, overlap: int = 100, strategy: str = "fixed"):
|
|
26
|
+
self.chunk_size = chunk_size
|
|
27
|
+
self.overlap = overlap
|
|
28
|
+
self.strategy = strategy
|
|
29
|
+
|
|
30
|
+
def chunk_text(self, text: str) -> List[str]:
|
|
31
|
+
"""Dispatches to the correct chunking strategy."""
|
|
32
|
+
if self.strategy == "token" and tiktoken:
|
|
33
|
+
return self._chunk_by_tokens(text)
|
|
34
|
+
elif self.strategy == "sentence" and sent_tokenize:
|
|
35
|
+
return self._chunk_by_sentences(text)
|
|
36
|
+
else:
|
|
37
|
+
return self._chunk_fixed(text)
|
|
38
|
+
|
|
39
|
+
# -------------------------------
|
|
40
|
+
# Fixed-length (default)
|
|
41
|
+
# -------------------------------
|
|
42
|
+
def _chunk_fixed(self, text: str) -> List[str]:
|
|
43
|
+
chunks = []
|
|
44
|
+
start = 0
|
|
45
|
+
while start < len(text):
|
|
46
|
+
end = start + self.chunk_size
|
|
47
|
+
chunks.append(text[start:end])
|
|
48
|
+
start += self.chunk_size - self.overlap
|
|
49
|
+
return chunks
|
|
50
|
+
|
|
51
|
+
# -------------------------------
|
|
52
|
+
# Token-based (for LLM embedding)
|
|
53
|
+
# -------------------------------
|
|
54
|
+
def _chunk_by_tokens(self, text: str) -> List[str]:
|
|
55
|
+
if not tiktoken:
|
|
56
|
+
raise ImportError("tiktoken is required for token-based chunking.")
|
|
57
|
+
enc = tiktoken.get_encoding("cl100k_base")
|
|
58
|
+
tokens = enc.encode(text)
|
|
59
|
+
|
|
60
|
+
chunks = []
|
|
61
|
+
for i in range(0, len(tokens), self.chunk_size - self.overlap):
|
|
62
|
+
chunk_tokens = tokens[i:i + self.chunk_size]
|
|
63
|
+
chunks.append(enc.decode(chunk_tokens))
|
|
64
|
+
return chunks
|
|
65
|
+
|
|
66
|
+
# -------------------------------
|
|
67
|
+
# Sentence-based
|
|
68
|
+
# -------------------------------
|
|
69
|
+
def _chunk_by_sentences(self, text: str) -> List[str]:
|
|
70
|
+
if not sent_tokenize:
|
|
71
|
+
raise ImportError("nltk is required for sentence-based chunking.")
|
|
72
|
+
sentences = sent_tokenize(text)
|
|
73
|
+
chunks = []
|
|
74
|
+
current_chunk = ""
|
|
75
|
+
|
|
76
|
+
for sentence in sentences:
|
|
77
|
+
if len(current_chunk) + len(sentence) <= self.chunk_size:
|
|
78
|
+
current_chunk += " " + sentence
|
|
79
|
+
else:
|
|
80
|
+
chunks.append(current_chunk.strip())
|
|
81
|
+
current_chunk = sentence
|
|
82
|
+
|
|
83
|
+
if current_chunk:
|
|
84
|
+
chunks.append(current_chunk.strip())
|
|
85
|
+
|
|
86
|
+
return chunks
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from dotenv import load_dotenv
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from sentence_transformers import SentenceTransformer
|
|
6
|
+
except ImportError:
|
|
7
|
+
SentenceTransformer = None
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Embeddings:
|
|
11
|
+
"""
|
|
12
|
+
Wrapper for embedding backends: HuggingFace (SentenceTransformers) or Dummy.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
model = Embeddings("huggingface", model_name="all-MiniLM-L6-v2")
|
|
16
|
+
embeddings = model.encode(["example text"])
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, backend: str = "huggingface", model_name: str = None):
|
|
20
|
+
load_dotenv()
|
|
21
|
+
self.backend = backend.lower()
|
|
22
|
+
self.model_name = model_name or "all-MiniLM-L6-v2"
|
|
23
|
+
|
|
24
|
+
if self.backend == "huggingface":
|
|
25
|
+
if SentenceTransformer is None:
|
|
26
|
+
raise ImportError("Please install `sentence-transformers` to use HuggingFace embeddings.")
|
|
27
|
+
self.model = SentenceTransformer(self.model_name)
|
|
28
|
+
self.dim = self.model.get_sentence_embedding_dimension()
|
|
29
|
+
|
|
30
|
+
elif self.backend == "dummy":
|
|
31
|
+
self.model = None
|
|
32
|
+
self.dim = 768 # Default embedding dimension for dummy backend
|
|
33
|
+
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Unsupported embedding backend: {backend}")
|
|
36
|
+
|
|
37
|
+
def encode(self, texts):
|
|
38
|
+
if isinstance(texts, str):
|
|
39
|
+
texts = [texts]
|
|
40
|
+
|
|
41
|
+
if self.backend == "huggingface":
|
|
42
|
+
embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
|
43
|
+
|
|
44
|
+
elif self.backend == "dummy":
|
|
45
|
+
# Return a NumPy array of shape (len(texts), dim)
|
|
46
|
+
embeddings = np.random.rand(len(texts), self.dim).astype(np.float32)
|
|
47
|
+
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Unknown embedding backend: {self.backend}")
|
|
50
|
+
|
|
51
|
+
# ✅ Always ensure NumPy array output
|
|
52
|
+
if not isinstance(embeddings, np.ndarray):
|
|
53
|
+
embeddings = np.array(embeddings, dtype=np.float32)
|
|
54
|
+
|
|
55
|
+
return embeddings
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Dict, Any
|
|
3
|
+
from difflib import SequenceMatcher
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Evaluator:
|
|
7
|
+
"""
|
|
8
|
+
Simple evaluation of generated answers:
|
|
9
|
+
- Faithfulness (similarity between answer and context)
|
|
10
|
+
- Latency
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def evaluate(self, query: str, answer: str, context: str) -> Dict[str, Any]:
|
|
17
|
+
start = time.time()
|
|
18
|
+
faithfulness = self._similarity(answer, context)
|
|
19
|
+
latency = time.time() - start
|
|
20
|
+
|
|
21
|
+
return {
|
|
22
|
+
"faithfulness": faithfulness,
|
|
23
|
+
"latency": latency,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def _similarity(self, a: str, b: str) -> float:
|
|
27
|
+
return SequenceMatcher(None, a, b).ratio()
|
|
28
|
+
|
|
29
|
+
def evaluate_config(config, validation_data):
|
|
30
|
+
evaluator = Evaluator()
|
|
31
|
+
results = []
|
|
32
|
+
for sample in validation_data:
|
|
33
|
+
query = sample.get("query", "")
|
|
34
|
+
answer = sample.get("answer", "")
|
|
35
|
+
context = sample.get("context", "")
|
|
36
|
+
results.append(evaluator.evaluate(query, answer, context))
|
|
37
|
+
return results
|
|
38
|
+
|
ragmint/core/pipeline.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
from .retriever import Retriever
|
|
3
|
+
from .reranker import Reranker
|
|
4
|
+
from .evaluation import Evaluator
|
|
5
|
+
from .chunking import Chunker
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RAGPipeline:
|
|
9
|
+
"""
|
|
10
|
+
Core Retrieval-Augmented Generation pipeline.
|
|
11
|
+
Retrieves, reranks, and evaluates a query given the configured backends.
|
|
12
|
+
Supports text chunking for optimal retrieval performance.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
retriever: Retriever,
|
|
18
|
+
reranker: Reranker,
|
|
19
|
+
evaluator: Evaluator,
|
|
20
|
+
chunk_size: int = 500,
|
|
21
|
+
overlap: int = 100,
|
|
22
|
+
chunking_strategy: str = "fixed"
|
|
23
|
+
):
|
|
24
|
+
self.retriever = retriever
|
|
25
|
+
self.reranker = reranker
|
|
26
|
+
self.evaluator = evaluator
|
|
27
|
+
|
|
28
|
+
# Initialize chunker for preprocessing
|
|
29
|
+
self.chunker = Chunker(chunk_size=chunk_size, overlap=overlap, strategy=chunking_strategy)
|
|
30
|
+
|
|
31
|
+
def preprocess_docs(self, documents):
|
|
32
|
+
"""Applies the selected chunking strategy to the document set."""
|
|
33
|
+
all_chunks = []
|
|
34
|
+
for doc in documents:
|
|
35
|
+
chunks = self.chunker.chunk_text(doc)
|
|
36
|
+
all_chunks.extend(chunks)
|
|
37
|
+
return all_chunks
|
|
38
|
+
|
|
39
|
+
def run(self, query: str, top_k: int = 5, use_chunking: bool = True) -> Dict[str, Any]:
|
|
40
|
+
# Optional preprocessing step
|
|
41
|
+
if use_chunking and hasattr(self.retriever, "documents") and self.retriever.documents:
|
|
42
|
+
self.retriever.documents = self.preprocess_docs(self.retriever.documents)
|
|
43
|
+
|
|
44
|
+
# Retrieve documents
|
|
45
|
+
retrieved_docs = self.retriever.retrieve(query, top_k=top_k)
|
|
46
|
+
|
|
47
|
+
# Rerank
|
|
48
|
+
reranked_docs = self.reranker.rerank(query, retrieved_docs)
|
|
49
|
+
|
|
50
|
+
# Construct pseudo-answer
|
|
51
|
+
answer = reranked_docs[0]["text"] if reranked_docs else ""
|
|
52
|
+
context = "\n".join([d["text"] for d in reranked_docs])
|
|
53
|
+
|
|
54
|
+
# Evaluate
|
|
55
|
+
metrics = self.evaluator.evaluate(query, answer, context)
|
|
56
|
+
|
|
57
|
+
return {
|
|
58
|
+
"query": query,
|
|
59
|
+
"answer": answer,
|
|
60
|
+
"docs": reranked_docs,
|
|
61
|
+
"metrics": metrics,
|
|
62
|
+
}
|
ragmint/core/reranker.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import List, Dict, Any
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Reranker:
|
|
6
|
+
"""
|
|
7
|
+
Supports:
|
|
8
|
+
- MMR (Maximal Marginal Relevance)
|
|
9
|
+
- Dummy CrossEncoder (for demonstration)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, mode: str = "mmr", lambda_param: float = 0.5, seed: int = 42):
|
|
13
|
+
self.mode = mode
|
|
14
|
+
self.lambda_param = lambda_param
|
|
15
|
+
np.random.seed(seed)
|
|
16
|
+
|
|
17
|
+
def rerank(self, query: str, docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
18
|
+
if not docs:
|
|
19
|
+
return []
|
|
20
|
+
|
|
21
|
+
if self.mode == "crossencoder":
|
|
22
|
+
return self._crossencoder_rerank(query, docs)
|
|
23
|
+
return self._mmr_rerank(query, docs)
|
|
24
|
+
|
|
25
|
+
def _mmr_rerank(self, query: str, docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
26
|
+
"""Perform MMR reranking using dummy similarity scores."""
|
|
27
|
+
selected = []
|
|
28
|
+
remaining = docs.copy()
|
|
29
|
+
|
|
30
|
+
while remaining and len(selected) < len(docs):
|
|
31
|
+
if not selected:
|
|
32
|
+
# pick doc with highest base score
|
|
33
|
+
best = max(remaining, key=lambda d: d["score"])
|
|
34
|
+
else:
|
|
35
|
+
# MMR balancing between relevance and diversity
|
|
36
|
+
mmr_scores = []
|
|
37
|
+
for d in remaining:
|
|
38
|
+
max_div = max(
|
|
39
|
+
[self._similarity(d["text"], s["text"]) for s in selected],
|
|
40
|
+
default=0,
|
|
41
|
+
)
|
|
42
|
+
mmr_score = (
|
|
43
|
+
self.lambda_param * d["score"]
|
|
44
|
+
- (1 - self.lambda_param) * max_div
|
|
45
|
+
)
|
|
46
|
+
mmr_scores.append(mmr_score)
|
|
47
|
+
best = remaining[int(np.argmax(mmr_scores))]
|
|
48
|
+
selected.append(best)
|
|
49
|
+
remaining.remove(best)
|
|
50
|
+
|
|
51
|
+
return selected
|
|
52
|
+
|
|
53
|
+
def _crossencoder_rerank(self, query: str, docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
54
|
+
"""Adds a small random perturbation to simulate crossencoder reranking."""
|
|
55
|
+
for d in docs:
|
|
56
|
+
d["score"] += np.random.uniform(0, 0.1)
|
|
57
|
+
return sorted(docs, key=lambda d: d["score"], reverse=True)
|
|
58
|
+
|
|
59
|
+
def _similarity(self, a: str, b: str) -> float:
|
|
60
|
+
"""Dummy similarity function between two strings."""
|
|
61
|
+
# Deterministic pseudo-similarity based on hash
|
|
62
|
+
return abs(hash(a + b)) % 100 / 100.0
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import List, Dict, Any, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .embeddings import Embeddings
|
|
4
|
+
|
|
5
|
+
# Optional imports
|
|
6
|
+
try:
|
|
7
|
+
import faiss
|
|
8
|
+
except ImportError:
|
|
9
|
+
faiss = None
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import chromadb
|
|
13
|
+
except ImportError:
|
|
14
|
+
chromadb = None
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from sklearn.neighbors import BallTree
|
|
18
|
+
except ImportError:
|
|
19
|
+
BallTree = None
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from rank_bm25 import BM25Okapi
|
|
23
|
+
except ImportError:
|
|
24
|
+
BM25Okapi = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Retriever:
|
|
28
|
+
"""
|
|
29
|
+
Multi-backend retriever supporting:
|
|
30
|
+
- "numpy" : basic cosine similarity (dense)
|
|
31
|
+
- "faiss" : high-performance dense retriever
|
|
32
|
+
- "chroma" : persistent vector DB
|
|
33
|
+
- "sklearn": BallTree (cosine or Euclidean)
|
|
34
|
+
- "bm25" : lexical retriever using Rank-BM25
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
retriever = Retriever(embedder, documents=["A", "B", "C"], backend="bm25")
|
|
38
|
+
results = retriever.retrieve("example query", top_k=3)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
embedder: Optional[Embeddings] = None,
|
|
44
|
+
documents: Optional[List[str]] = None,
|
|
45
|
+
embeddings: Optional[np.ndarray] = None,
|
|
46
|
+
backend: str = "numpy",
|
|
47
|
+
):
|
|
48
|
+
self.embedder = embedder
|
|
49
|
+
self.documents = documents or []
|
|
50
|
+
self.backend = backend.lower()
|
|
51
|
+
self.embeddings = None
|
|
52
|
+
self.index = None
|
|
53
|
+
self.client = None
|
|
54
|
+
self.bm25 = None
|
|
55
|
+
|
|
56
|
+
# Initialize embeddings for dense backends
|
|
57
|
+
if self.backend not in ["bm25"]:
|
|
58
|
+
if embeddings is not None:
|
|
59
|
+
self.embeddings = np.array(embeddings)
|
|
60
|
+
elif self.documents and self.embedder:
|
|
61
|
+
self.embeddings = self.embedder.encode(self.documents)
|
|
62
|
+
else:
|
|
63
|
+
self.embeddings = np.zeros((0, getattr(self.embedder, "dim", 768)))
|
|
64
|
+
|
|
65
|
+
# Normalize for cosine
|
|
66
|
+
if self.embeddings.size > 0:
|
|
67
|
+
self.embeddings = self._normalize(self.embeddings)
|
|
68
|
+
|
|
69
|
+
# Initialize backend
|
|
70
|
+
self._init_backend()
|
|
71
|
+
|
|
72
|
+
# ------------------------
|
|
73
|
+
# Backend Initialization
|
|
74
|
+
# ------------------------
|
|
75
|
+
def _init_backend(self):
|
|
76
|
+
if self.backend == "faiss":
|
|
77
|
+
if faiss is None:
|
|
78
|
+
raise ImportError("faiss not installed. Run `pip install faiss-cpu`.")
|
|
79
|
+
self.index = faiss.IndexFlatIP(self.embedder.dim)
|
|
80
|
+
self.index.add(self.embeddings.astype("float32"))
|
|
81
|
+
|
|
82
|
+
elif self.backend == "chroma":
|
|
83
|
+
if chromadb is None:
|
|
84
|
+
raise ImportError("chromadb not installed. Run `pip install chromadb`.")
|
|
85
|
+
self.client = chromadb.Client()
|
|
86
|
+
self.collection = self.client.create_collection(name="ragmint_retriever")
|
|
87
|
+
for i, doc in enumerate(self.documents):
|
|
88
|
+
self.collection.add(
|
|
89
|
+
ids=[str(i)],
|
|
90
|
+
documents=[doc],
|
|
91
|
+
embeddings=[self.embeddings[i].tolist()],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
elif self.backend == "sklearn":
|
|
95
|
+
if BallTree is None:
|
|
96
|
+
raise ImportError("scikit-learn not installed. Run `pip install scikit-learn`.")
|
|
97
|
+
self.index = BallTree(self.embeddings)
|
|
98
|
+
|
|
99
|
+
elif self.backend == "bm25":
|
|
100
|
+
if BM25Okapi is None:
|
|
101
|
+
raise ImportError("rank-bm25 not installed. Run `pip install rank-bm25`.")
|
|
102
|
+
tokenized_corpus = [doc.lower().split() for doc in self.documents]
|
|
103
|
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
|
104
|
+
|
|
105
|
+
elif self.backend != "numpy":
|
|
106
|
+
raise ValueError(f"Unsupported retriever backend: {self.backend}")
|
|
107
|
+
|
|
108
|
+
# ------------------------
|
|
109
|
+
# Retrieval
|
|
110
|
+
# ------------------------
|
|
111
|
+
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
112
|
+
if len(self.documents) == 0:
|
|
113
|
+
return [{"text": "", "score": 0.0}]
|
|
114
|
+
|
|
115
|
+
# BM25 retrieval (lexical)
|
|
116
|
+
if self.backend == "bm25":
|
|
117
|
+
tokenized_query = query.lower().split()
|
|
118
|
+
scores = self.bm25.get_scores(tokenized_query)
|
|
119
|
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
|
120
|
+
return [
|
|
121
|
+
{"text": self.documents[i], "score": float(scores[i])}
|
|
122
|
+
for i in top_indices
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
# Dense retrieval (others)
|
|
126
|
+
if self.embeddings is None or self.embeddings.size == 0:
|
|
127
|
+
return [{"text": "", "score": 0.0}]
|
|
128
|
+
|
|
129
|
+
query_vec = self.embedder.encode([query])[0]
|
|
130
|
+
query_vec = self._normalize(query_vec)
|
|
131
|
+
|
|
132
|
+
if self.backend == "numpy":
|
|
133
|
+
scores = np.dot(self.embeddings, query_vec)
|
|
134
|
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
|
135
|
+
return [{"text": self.documents[i], "score": float(scores[i])} for i in top_indices]
|
|
136
|
+
|
|
137
|
+
elif self.backend == "faiss":
|
|
138
|
+
query_vec = np.expand_dims(query_vec.astype("float32"), axis=0)
|
|
139
|
+
scores, indices = self.index.search(query_vec, top_k)
|
|
140
|
+
return [{"text": self.documents[int(i)], "score": float(scores[0][j])} for j, i in enumerate(indices[0])]
|
|
141
|
+
|
|
142
|
+
elif self.backend == "chroma":
|
|
143
|
+
results = self.collection.query(query_texts=[query], n_results=top_k)
|
|
144
|
+
docs = results["documents"][0]
|
|
145
|
+
scores = results["distances"][0]
|
|
146
|
+
return [{"text": d, "score": 1 - s} for d, s in zip(docs, scores)]
|
|
147
|
+
|
|
148
|
+
elif self.backend == "sklearn":
|
|
149
|
+
distances, indices = self.index.query([query_vec], k=top_k)
|
|
150
|
+
scores = 1 - distances[0]
|
|
151
|
+
return [{"text": self.documents[int(i)], "score": float(scores[j])} for j, i in enumerate(indices[0])]
|
|
152
|
+
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"Unknown backend: {self.backend}")
|
|
155
|
+
|
|
156
|
+
# ------------------------
|
|
157
|
+
# Utils
|
|
158
|
+
# ------------------------
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _normalize(vectors: np.ndarray) -> np.ndarray:
|
|
161
|
+
if vectors.ndim == 1:
|
|
162
|
+
norm = np.linalg.norm(vectors)
|
|
163
|
+
return vectors / norm if norm > 0 else vectors
|
|
164
|
+
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
|
165
|
+
return np.divide(vectors, norms, out=np.zeros_like(vectors), where=norms != 0)
|
|
File without changes
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
[
|
|
2
|
+
{
|
|
3
|
+
"query": "What is Retrieval-Augmented Generation?",
|
|
4
|
+
"expected_answer": "A technique that combines information retrieval with language generation to improve factual accuracy."
|
|
5
|
+
},
|
|
6
|
+
{
|
|
7
|
+
"query": "What is the role of embeddings in a RAG system?",
|
|
8
|
+
"expected_answer": "They represent text as numerical vectors for similarity-based retrieval."
|
|
9
|
+
},
|
|
10
|
+
{
|
|
11
|
+
"query": "What is Maximal Marginal Relevance used for?",
|
|
12
|
+
"expected_answer": "To select diverse and relevant documents during reranking."
|
|
13
|
+
}
|
|
14
|
+
]
|