ragmint 0.2.3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragmint/app.py +512 -0
- ragmint/autotuner.py +201 -17
- ragmint/core/chunking.py +68 -4
- ragmint/core/embeddings.py +46 -10
- ragmint/core/evaluation.py +33 -14
- ragmint/core/pipeline.py +34 -10
- ragmint/core/retriever.py +152 -20
- ragmint/experiments/validation_qa.json +1 -14
- ragmint/explainer.py +47 -20
- ragmint/integrations/__init__.py +0 -0
- ragmint/integrations/config_adapter.py +96 -0
- ragmint/integrations/langchain_prebuilder.py +99 -0
- ragmint/leaderboard.py +41 -35
- ragmint/qa_generator.py +190 -0
- ragmint/tests/test_autotuner.py +52 -30
- ragmint/tests/test_config_adapter.py +39 -0
- ragmint/tests/test_embeddings.py +46 -0
- ragmint/tests/test_explainer.py +28 -12
- ragmint/tests/test_integration_autotuner_ragmint.py +39 -52
- ragmint/tests/test_langchain_prebuilder.py +82 -0
- ragmint/tests/test_leaderboard.py +78 -25
- ragmint/tests/test_pipeline.py +3 -2
- ragmint/tests/test_qa_generator.py +66 -0
- ragmint/tests/test_retriever.py +3 -2
- ragmint/tests/test_tuner.py +1 -1
- ragmint/tuner.py +109 -22
- ragmint-0.4.6.data/data/README.md +485 -0
- ragmint-0.4.6.dist-info/METADATA +530 -0
- ragmint-0.4.6.dist-info/RECORD +48 -0
- ragmint/tests/test_explainer_integration.py +0 -18
- ragmint-0.2.3.data/data/README.md +0 -284
- ragmint-0.2.3.dist-info/METADATA +0 -312
- ragmint-0.2.3.dist-info/RECORD +0 -40
- {ragmint-0.2.3.data → ragmint-0.4.6.data}/data/LICENSE +0 -0
- {ragmint-0.2.3.dist-info → ragmint-0.4.6.dist-info}/WHEEL +0 -0
- {ragmint-0.2.3.dist-info → ragmint-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {ragmint-0.2.3.dist-info → ragmint-0.4.6.dist-info}/top_level.txt +0 -0
ragmint/qa_generator.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batched Validation QA Generator for Ragmint (Functional Version)
|
|
3
|
+
|
|
4
|
+
Generates a JSON QA dataset from a large corpus using an LLM.
|
|
5
|
+
Processes documents in batches to avoid token limits and API errors.
|
|
6
|
+
Uses topic-aware dynamic question count estimation.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import json
|
|
12
|
+
import math
|
|
13
|
+
import time
|
|
14
|
+
import argparse
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
import numpy as np
|
|
18
|
+
from sklearn.cluster import KMeans
|
|
19
|
+
from sentence_transformers import SentenceTransformer
|
|
20
|
+
|
|
21
|
+
# --- Load .env if available ---
|
|
22
|
+
load_dotenv()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# ---------- Utility functions ----------
|
|
26
|
+
|
|
27
|
+
def extract_json_from_markdown(text: str):
|
|
28
|
+
"""Extract JSON from a markdown-style code block."""
|
|
29
|
+
match = re.search(r"```(?:json)?\s*(\[\s*[\s\S]*?\s*\])\s*```", text, re.MULTILINE)
|
|
30
|
+
if match:
|
|
31
|
+
json_str = match.group(1)
|
|
32
|
+
return json.loads(json_str)
|
|
33
|
+
else:
|
|
34
|
+
return json.loads(text.strip())
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def read_corpus(docs_path: str):
|
|
38
|
+
"""Load all text documents from a folder."""
|
|
39
|
+
docs = []
|
|
40
|
+
for file in Path(docs_path).glob("**/*.txt"):
|
|
41
|
+
with open(file, "r", encoding="utf-8") as f:
|
|
42
|
+
text = f.read().strip()
|
|
43
|
+
if text:
|
|
44
|
+
docs.append({"filename": file.name, "text": text})
|
|
45
|
+
return docs
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def determine_question_count(text: str, embedder, min_q=3, max_q=25):
|
|
49
|
+
"""Estimate number of questions dynamically based on text length and topic diversity."""
|
|
50
|
+
sentences = [s.strip() for s in text.split('.') if len(s.strip().split()) > 3]
|
|
51
|
+
word_count = len(text.split())
|
|
52
|
+
|
|
53
|
+
if word_count == 0:
|
|
54
|
+
return min_q
|
|
55
|
+
|
|
56
|
+
base_q = math.log1p(word_count / 150)
|
|
57
|
+
|
|
58
|
+
# Topic diversity via clustering
|
|
59
|
+
n_sent = len(sentences)
|
|
60
|
+
if n_sent < 5:
|
|
61
|
+
topic_factor = 1.0
|
|
62
|
+
else:
|
|
63
|
+
try:
|
|
64
|
+
emb = embedder.encode(sentences, normalize_embeddings=True)
|
|
65
|
+
n_clusters = min(max(2, n_sent // 10), 8)
|
|
66
|
+
km = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
|
|
67
|
+
labels = km.fit_predict(emb)
|
|
68
|
+
topic_factor = len(set(labels)) / n_clusters
|
|
69
|
+
except Exception as e:
|
|
70
|
+
print(f"[WARN] Clustering failed ({type(e).__name__}): {e}")
|
|
71
|
+
topic_factor = 1.0
|
|
72
|
+
|
|
73
|
+
score = base_q * (1 + 0.8 * topic_factor)
|
|
74
|
+
question_count = round(min_q + score)
|
|
75
|
+
return int(max(min_q, min(question_count, max_q)))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def setup_llm(llm_model="gemini-2.5-flash-lite"):
|
|
79
|
+
"""Configure Gemini or Claude based on available environment keys."""
|
|
80
|
+
google_key = os.getenv("GOOGLE_API_KEY")
|
|
81
|
+
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
|
82
|
+
|
|
83
|
+
if google_key:
|
|
84
|
+
import google.generativeai as genai
|
|
85
|
+
genai.configure(api_key=google_key)
|
|
86
|
+
llm = genai.GenerativeModel(llm_model)
|
|
87
|
+
return llm, "gemini"
|
|
88
|
+
|
|
89
|
+
elif anthropic_key:
|
|
90
|
+
from anthropic import Anthropic
|
|
91
|
+
llm = Anthropic(api_key=anthropic_key)
|
|
92
|
+
return llm, "claude"
|
|
93
|
+
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError("Set ANTHROPIC_API_KEY or GOOGLE_API_KEY in your environment.")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def generate_qa_for_batch(batch, llm, backend, embedder, min_q=3, max_q=25):
|
|
99
|
+
"""Send one LLM call for a batch of documents."""
|
|
100
|
+
prompt_texts = []
|
|
101
|
+
for doc in batch:
|
|
102
|
+
n_questions = determine_question_count(doc["text"], embedder, min_q, max_q)
|
|
103
|
+
prompt_texts.append(
|
|
104
|
+
f"Document: {doc['text'][:1000]}\n"
|
|
105
|
+
f"Generate {n_questions} factual question-answer pairs in JSON format."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
prompt = "\n\n".join(prompt_texts)
|
|
109
|
+
prompt += "\n\nReturn a single JSON array of objects like:\n" \
|
|
110
|
+
'[{"query": "string", "expected_answer": "string"}]'
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
if backend == "gemini":
|
|
114
|
+
response = llm.generate_content(prompt)
|
|
115
|
+
text_out = getattr(response, "text", None)
|
|
116
|
+
if not text_out and hasattr(response, "candidates"):
|
|
117
|
+
text_out = response.candidates[0].content.parts[0].text
|
|
118
|
+
return extract_json_from_markdown(text_out)
|
|
119
|
+
|
|
120
|
+
elif backend == "claude":
|
|
121
|
+
response = llm.messages.create(
|
|
122
|
+
model="claude-3-opus-20240229",
|
|
123
|
+
messages=[{"role": "user", "content": prompt}],
|
|
124
|
+
max_tokens=2000,
|
|
125
|
+
)
|
|
126
|
+
return json.loads(response.content[0].text)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
print(f"[WARN] Failed to parse batch: {e}")
|
|
130
|
+
return []
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def save_json(output_path, data):
|
|
134
|
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
135
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
136
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
137
|
+
print(f"[INFO] Saved {len(data)} QAs → {output_path}")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def generate_validation_qa(
|
|
141
|
+
docs_path="data/docs",
|
|
142
|
+
output_path="experiments/validation_qa.json",
|
|
143
|
+
llm_model="gemini-2.5-flash-lite",
|
|
144
|
+
batch_size=5,
|
|
145
|
+
sleep_between_batches=2,
|
|
146
|
+
min_q=3,
|
|
147
|
+
max_q=25,
|
|
148
|
+
):
|
|
149
|
+
"""Main pipeline to generate QAs."""
|
|
150
|
+
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
|
151
|
+
llm, backend = setup_llm(llm_model)
|
|
152
|
+
all_qa = []
|
|
153
|
+
|
|
154
|
+
corpus = read_corpus(docs_path)
|
|
155
|
+
print(f"[INFO] Loaded {len(corpus)} documents from {docs_path}")
|
|
156
|
+
|
|
157
|
+
for i in range(0, len(corpus), batch_size):
|
|
158
|
+
batch = corpus[i: i + batch_size]
|
|
159
|
+
batch_qa = generate_qa_for_batch(batch, llm, backend, embedder, min_q, max_q)
|
|
160
|
+
all_qa.extend(batch_qa)
|
|
161
|
+
print(f"[INFO] Batch {i // batch_size + 1}: {len(batch_qa)} QAs (Total: {len(all_qa)})")
|
|
162
|
+
time.sleep(sleep_between_batches)
|
|
163
|
+
|
|
164
|
+
save_json(output_path, all_qa)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# ---------- CLI entry point ----------
|
|
168
|
+
|
|
169
|
+
def main():
|
|
170
|
+
parser = argparse.ArgumentParser(description="Generate validation QA dataset for Ragmint.")
|
|
171
|
+
parser.add_argument("--docs_path", type=str, default="data/docs")
|
|
172
|
+
parser.add_argument("--output", type=str, default="experiments/validation_qa.json")
|
|
173
|
+
parser.add_argument("--batch_size", type=int, default=5)
|
|
174
|
+
parser.add_argument("--sleep", type=int, default=2)
|
|
175
|
+
parser.add_argument("--min_q", type=int, default=3)
|
|
176
|
+
parser.add_argument("--max_q", type=int, default=25)
|
|
177
|
+
args = parser.parse_args()
|
|
178
|
+
|
|
179
|
+
generate_validation_qa(
|
|
180
|
+
docs_path=args.docs_path,
|
|
181
|
+
output_path=args.output,
|
|
182
|
+
batch_size=args.batch_size,
|
|
183
|
+
sleep_between_batches=args.sleep,
|
|
184
|
+
min_q=args.min_q,
|
|
185
|
+
max_q=args.max_q,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
if __name__ == "__main__":
|
|
190
|
+
main()
|
ragmint/tests/test_autotuner.py
CHANGED
|
@@ -1,42 +1,64 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
1
3
|
import pytest
|
|
2
4
|
from ragmint.autotuner import AutoRAGTuner
|
|
3
5
|
|
|
4
6
|
|
|
5
|
-
def
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
7
|
+
def setup_docs(tmp_path):
|
|
8
|
+
"""Create a temporary corpus with multiple text files for testing."""
|
|
9
|
+
corpus = tmp_path / "corpus"
|
|
10
|
+
corpus.mkdir()
|
|
11
|
+
(corpus / "short_doc.txt").write_text("AI is changing the world.")
|
|
12
|
+
(corpus / "long_doc.txt").write_text("Machine learning enables RAG pipelines to optimize retrievals. " * 50)
|
|
13
|
+
return str(corpus)
|
|
11
14
|
|
|
12
15
|
|
|
13
|
-
def
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
assert rec["embedding_model"] == "SentenceTransformers"
|
|
16
|
+
def test_analyze_corpus(tmp_path):
|
|
17
|
+
"""Ensure AutoRAGTuner analyzes corpus correctly."""
|
|
18
|
+
docs_path = setup_docs(tmp_path)
|
|
19
|
+
tuner = AutoRAGTuner(docs_path)
|
|
20
|
+
stats = tuner.corpus_stats
|
|
19
21
|
|
|
22
|
+
assert stats["num_docs"] == 2, "Should detect all documents"
|
|
23
|
+
assert stats["size"] > 0, "Corpus size should be positive"
|
|
24
|
+
assert stats["avg_len"] > 0, "Average document length should be computed"
|
|
20
25
|
|
|
21
|
-
def test_autorag_recommend_large():
|
|
22
|
-
"""Large corpus should trigger FAISS + InstructorXL."""
|
|
23
|
-
tuner = AutoRAGTuner({"size": 50000, "avg_len": 300})
|
|
24
|
-
rec = tuner.recommend()
|
|
25
|
-
assert rec["retriever"] == "FAISS"
|
|
26
|
-
assert rec["embedding_model"] == "InstructorXL"
|
|
27
26
|
|
|
27
|
+
@pytest.mark.parametrize("size,expected_retriever", [
|
|
28
|
+
(10_000, "Chroma"),
|
|
29
|
+
(500_000, "FAISS"),
|
|
30
|
+
(1_000, "BM25"),
|
|
31
|
+
])
|
|
32
|
+
def test_recommendation_logic(tmp_path, monkeypatch, size, expected_retriever):
|
|
33
|
+
"""Validate retriever recommendation based on corpus size and chunk suggestion."""
|
|
34
|
+
docs_path = setup_docs(tmp_path)
|
|
35
|
+
tuner = AutoRAGTuner(docs_path)
|
|
28
36
|
|
|
29
|
-
|
|
30
|
-
""
|
|
31
|
-
tuner = AutoRAGTuner({"size": 12000, "avg_len": 250})
|
|
37
|
+
# Mock corpus stats manually
|
|
38
|
+
tuner.corpus_stats = {"size": size, "avg_len": 300, "num_docs": 10}
|
|
32
39
|
|
|
33
|
-
#
|
|
34
|
-
|
|
35
|
-
def mock_eval(config, data):
|
|
36
|
-
return {"faithfulness": 0.9, "latency": 0.01}
|
|
37
|
-
monkeypatch.setattr(autotuner, "evaluate_config", mock_eval)
|
|
40
|
+
# Provide mandatory num_chunk_pairs
|
|
41
|
+
rec = tuner.recommend(num_chunk_pairs=3)
|
|
38
42
|
|
|
39
|
-
|
|
40
|
-
assert "
|
|
41
|
-
assert "
|
|
42
|
-
assert
|
|
43
|
+
assert "retriever" in rec and "embedding_model" in rec
|
|
44
|
+
assert rec["retriever"] == expected_retriever, f"Expected {expected_retriever}"
|
|
45
|
+
assert rec["chunk_size"] > 0 and rec["overlap"] >= 0
|
|
46
|
+
assert "chunk_candidates" in rec, "Should include suggested chunk pairs"
|
|
47
|
+
assert len(rec["chunk_candidates"]) == 3, "Should generate correct number of chunk pairs"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def test_invalid_corpus_path(tmp_path):
|
|
51
|
+
"""Should handle missing directories gracefully."""
|
|
52
|
+
missing_path = tmp_path / "nonexistent"
|
|
53
|
+
tuner = AutoRAGTuner(str(missing_path))
|
|
54
|
+
assert tuner.corpus_stats["size"] == 0
|
|
55
|
+
assert tuner.corpus_stats["num_docs"] == 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_suggest_chunk_sizes_requires_num_pairs(tmp_path):
|
|
59
|
+
"""Ensure suggest_chunk_sizes raises error if num_pairs is not provided."""
|
|
60
|
+
docs_path = setup_docs(tmp_path)
|
|
61
|
+
tuner = AutoRAGTuner(docs_path)
|
|
62
|
+
|
|
63
|
+
with pytest.raises(ValueError):
|
|
64
|
+
tuner.suggest_chunk_sizes()
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from ragmint.integrations.config_adapter import LangchainConfigAdapter
|
|
3
|
+
|
|
4
|
+
def test_default_conversion():
|
|
5
|
+
"""Test that default config values are applied correctly."""
|
|
6
|
+
cfg = {
|
|
7
|
+
"retriever": "FAISS",
|
|
8
|
+
"embedding_model": "all-MiniLM-L6-v2",
|
|
9
|
+
"chunk_size": 500,
|
|
10
|
+
"overlap": 100
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
adapter = LangchainConfigAdapter(cfg)
|
|
14
|
+
result = adapter.to_standard_config()
|
|
15
|
+
|
|
16
|
+
assert result["retriever"].lower() == "faiss"
|
|
17
|
+
assert result["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
|
18
|
+
assert result["chunk_size"] == 500
|
|
19
|
+
assert result["overlap"] == 100
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_missing_fields_are_defaulted():
|
|
23
|
+
"""Ensure missing optional fields (e.g. chunk params) are filled in."""
|
|
24
|
+
cfg = {"retriever": "BM25", "embedding_model": "all-MiniLM-L6-v2"}
|
|
25
|
+
adapter = LangchainConfigAdapter(cfg)
|
|
26
|
+
result = adapter.to_standard_config()
|
|
27
|
+
|
|
28
|
+
assert "chunk_size" in result
|
|
29
|
+
assert "overlap" in result
|
|
30
|
+
assert result["chunk_size"] > 0
|
|
31
|
+
assert result["overlap"] >= 0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_validation_of_invalid_retriever():
|
|
35
|
+
"""Ensure invalid retriever names raise an informative error."""
|
|
36
|
+
cfg = {"retriever": "InvalidBackend", "embedding_model": "all-MiniLM-L6-v2"}
|
|
37
|
+
|
|
38
|
+
with pytest.raises(ValueError, match="Unsupported retriever backend"):
|
|
39
|
+
LangchainConfigAdapter(cfg).to_standard_config()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pytest
|
|
3
|
+
from ragmint.core.embeddings import Embeddings
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_dummy_backend_output_shape():
|
|
7
|
+
model = Embeddings(backend="dummy")
|
|
8
|
+
texts = ["hello", "world"]
|
|
9
|
+
embeddings = model.encode(texts)
|
|
10
|
+
|
|
11
|
+
# Expect 2x768 array
|
|
12
|
+
assert isinstance(embeddings, np.ndarray)
|
|
13
|
+
assert embeddings.shape == (2, 768)
|
|
14
|
+
assert embeddings.dtype == np.float32
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def test_dummy_backend_single_string():
|
|
18
|
+
model = Embeddings(backend="dummy")
|
|
19
|
+
text = "test"
|
|
20
|
+
embeddings = model.encode(text)
|
|
21
|
+
|
|
22
|
+
assert embeddings.shape == (1, 768)
|
|
23
|
+
assert isinstance(embeddings, np.ndarray)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
'''@pytest.mark.skipif(
|
|
27
|
+
not hasattr(__import__('importlib').util.find_spec("sentence_transformers"), "loader"),
|
|
28
|
+
reason="sentence-transformers not installed"
|
|
29
|
+
)
|
|
30
|
+
def test_huggingface_backend_output_shape():
|
|
31
|
+
model = Embeddings(backend="huggingface", model_name="all-MiniLM-L6-v2")
|
|
32
|
+
texts = ["This is a test.", "Another sentence."]
|
|
33
|
+
embeddings = model.encode(texts)
|
|
34
|
+
|
|
35
|
+
# Expect 2x384 for MiniLM-L6-v2
|
|
36
|
+
assert isinstance(embeddings, np.ndarray)
|
|
37
|
+
assert embeddings.ndim == 2
|
|
38
|
+
assert embeddings.shape[0] == len(texts)
|
|
39
|
+
assert embeddings.dtype == np.float32
|
|
40
|
+
'''
|
|
41
|
+
|
|
42
|
+
def test_invalid_backend():
|
|
43
|
+
try:
|
|
44
|
+
Embeddings(backend="unknown")
|
|
45
|
+
except ValueError as e:
|
|
46
|
+
assert "Unsupported embedding backend" in str(e)
|
ragmint/tests/test_explainer.py
CHANGED
|
@@ -1,20 +1,36 @@
|
|
|
1
1
|
import pytest
|
|
2
|
+
import sys
|
|
3
|
+
import types
|
|
2
4
|
from ragmint.explainer import explain_results
|
|
3
5
|
|
|
4
6
|
|
|
5
|
-
def
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
7
|
+
def test_explain_results_with_claude(monkeypatch):
|
|
8
|
+
"""Claude explanation should use Anthropic API path when ANTHROPIC_API_KEY is set."""
|
|
9
|
+
monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key")
|
|
10
|
+
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
|
11
|
+
|
|
12
|
+
# Create a fake anthropic module with the required interface
|
|
13
|
+
mock_anthropic = types.ModuleType("anthropic")
|
|
14
|
+
|
|
15
|
+
class MockContent:
|
|
16
|
+
text = "Claude: The best configuration performs well due to optimized chunk size."
|
|
17
|
+
|
|
18
|
+
class MockMessages:
|
|
19
|
+
def create(self, *args, **kwargs):
|
|
20
|
+
return type("MockResponse", (), {"content": [MockContent()]})()
|
|
21
|
+
|
|
22
|
+
class MockClient:
|
|
23
|
+
def __init__(self, api_key):
|
|
24
|
+
self.messages = MockMessages()
|
|
25
|
+
|
|
26
|
+
mock_anthropic.Anthropic = MockClient
|
|
27
|
+
sys.modules["anthropic"] = mock_anthropic # Inject fake module
|
|
28
|
+
|
|
29
|
+
best = {"retriever": "Chroma", "metric": 0.9}
|
|
30
|
+
all_results = [{"retriever": "FAISS", "metric": 0.85}]
|
|
31
|
+
corpus_stats = {"size": 10000, "avg_len": 400, "num_docs": 20}
|
|
12
32
|
|
|
33
|
+
result = explain_results(best, all_results, corpus_stats, model="claude-3-opus-20240229")
|
|
13
34
|
|
|
14
|
-
def test_explain_results_claude():
|
|
15
|
-
"""Claude explanation should contain model-specific phrasing."""
|
|
16
|
-
config_a = {"retriever": "FAISS"}
|
|
17
|
-
config_b = {"retriever": "Chroma"}
|
|
18
|
-
result = explain_results(config_a, config_b, model="claude")
|
|
19
35
|
assert isinstance(result, str)
|
|
20
36
|
assert "Claude" in result or "claude" in result
|
|
@@ -1,60 +1,47 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
1
3
|
import pytest
|
|
2
|
-
from ragmint.tuner import RAGMint
|
|
3
4
|
from ragmint.autotuner import AutoRAGTuner
|
|
5
|
+
from ragmint.tuner import RAGMint
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
def
|
|
7
|
-
"""
|
|
8
|
-
Smoke test for integration between AutoRAGTuner and RAGMint.
|
|
9
|
-
Ensures end-to-end flow runs without real retrievers or embeddings.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
# --- Mock corpus and validation data ---
|
|
8
|
+
def setup_docs(tmp_path):
|
|
9
|
+
"""Create a temporary corpus for integration testing."""
|
|
13
10
|
corpus = tmp_path / "docs"
|
|
14
11
|
corpus.mkdir()
|
|
15
|
-
(corpus / "doc1.txt").write_text("This
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
12
|
+
(corpus / "doc1.txt").write_text("This document discusses Artificial Intelligence and Machine Learning.")
|
|
13
|
+
(corpus / "doc2.txt").write_text("Retrieval-Augmented Generation combines retrievers and LLMs effectively.")
|
|
14
|
+
return str(corpus)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_validation_file(tmp_path):
|
|
18
|
+
"""Create a temporary validation QA dataset."""
|
|
19
|
+
data = [
|
|
20
|
+
{"question": "What is AI?", "answer": "Artificial Intelligence"},
|
|
21
|
+
{"question": "Define RAG", "answer": "Retrieval-Augmented Generation"},
|
|
22
|
+
]
|
|
23
|
+
file = tmp_path / "validation_qa.json"
|
|
24
|
+
with open(file, "w", encoding="utf-8") as f:
|
|
25
|
+
json.dump(data, f)
|
|
26
|
+
return str(file)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_autotune_integration(tmp_path):
|
|
30
|
+
"""Test that AutoRAGTuner can fully run a RAGMint optimization."""
|
|
31
|
+
docs_path = setup_docs(tmp_path)
|
|
32
|
+
val_file = setup_validation_file(tmp_path)
|
|
33
|
+
|
|
34
|
+
tuner = AutoRAGTuner(docs_path)
|
|
35
|
+
best, results = tuner.auto_tune(
|
|
36
|
+
validation_set=val_file,
|
|
37
|
+
metric="faithfulness",
|
|
38
|
+
trials=2,
|
|
39
|
+
search_type="random",
|
|
40
40
|
)
|
|
41
41
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
assert "retriever" in
|
|
47
|
-
assert "
|
|
48
|
-
|
|
49
|
-
tuning_results = tuner.auto_tune(validation_data)
|
|
50
|
-
assert "results" in tuning_results
|
|
51
|
-
assert isinstance(tuning_results["results"], dict)
|
|
52
|
-
|
|
53
|
-
# --- Run RAGMint optimization flow (mocked) ---
|
|
54
|
-
best_config, results = ragmint.optimize(validation_set=validation_data, trials=2)
|
|
55
|
-
assert isinstance(best_config, dict)
|
|
56
|
-
assert "score" in best_config
|
|
57
|
-
assert isinstance(results, list)
|
|
58
|
-
|
|
59
|
-
# --- Integration Success ---
|
|
60
|
-
print(f"Integration OK: AutoRAG recommended {recommendation}, RAGMint best {best_config}")
|
|
42
|
+
# Assertions on the results
|
|
43
|
+
assert isinstance(best, dict), "Best configuration should be a dict"
|
|
44
|
+
assert isinstance(results, list), "Results should be a list"
|
|
45
|
+
assert len(results) > 0, "Optimization should produce results"
|
|
46
|
+
assert "retriever" in best and "embedding_model" in best
|
|
47
|
+
assert best.get("faithfulness", 0.0) >= 0.0, "Metric value should be non-negative"
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import MagicMock, patch
|
|
3
|
+
from ragmint.integrations.langchain_prebuilder import LangchainPrebuilder
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.fixture
|
|
7
|
+
def sample_docs():
|
|
8
|
+
"""Small sample corpus for testing."""
|
|
9
|
+
return ["AI is transforming the world.", "RAG pipelines improve retrieval."]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def sample_config():
|
|
14
|
+
"""Default configuration for tests."""
|
|
15
|
+
return {
|
|
16
|
+
"retriever": "faiss",
|
|
17
|
+
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
|
18
|
+
"chunk_size": 200,
|
|
19
|
+
"overlap": 50,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@patch("ragmint.integrations.langchain_prebuilder.HuggingFaceEmbeddings", autospec=True)
|
|
24
|
+
@patch("ragmint.integrations.langchain_prebuilder.RecursiveCharacterTextSplitter", autospec=True)
|
|
25
|
+
def test_prepare_creates_components(mock_splitter, mock_embedder, sample_config, sample_docs):
|
|
26
|
+
"""Ensure prepare() builds retriever and embedding components properly."""
|
|
27
|
+
mock_splitter.return_value.create_documents.return_value = ["doc1", "doc2"]
|
|
28
|
+
mock_embedder.return_value = MagicMock()
|
|
29
|
+
|
|
30
|
+
# Patch FAISS to avoid building a real index
|
|
31
|
+
with patch("ragmint.integrations.langchain_prebuilder.FAISS", autospec=True) as mock_faiss:
|
|
32
|
+
mock_db = MagicMock()
|
|
33
|
+
mock_faiss.from_documents.return_value = mock_db
|
|
34
|
+
mock_db.as_retriever.return_value = "mock_retriever"
|
|
35
|
+
|
|
36
|
+
builder = LangchainPrebuilder(sample_config)
|
|
37
|
+
retriever, embeddings = builder.prepare(sample_docs)
|
|
38
|
+
|
|
39
|
+
assert retriever == "mock_retriever"
|
|
40
|
+
assert embeddings == mock_embedder.return_value
|
|
41
|
+
|
|
42
|
+
mock_splitter.assert_called_once()
|
|
43
|
+
mock_embedder.assert_called_once_with(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
44
|
+
mock_faiss.from_documents.assert_called_once()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@pytest.mark.parametrize("backend", ["faiss", "chroma", "bm25"])
|
|
48
|
+
def test_build_retriever_backends(sample_config, sample_docs, backend):
|
|
49
|
+
"""Check retriever creation for each backend."""
|
|
50
|
+
cfg = dict(sample_config)
|
|
51
|
+
cfg["retriever"] = backend
|
|
52
|
+
|
|
53
|
+
builder = LangchainPrebuilder(cfg)
|
|
54
|
+
|
|
55
|
+
# Mock embeddings + docs
|
|
56
|
+
fake_embeddings = MagicMock()
|
|
57
|
+
fake_docs = ["d1", "d2"]
|
|
58
|
+
|
|
59
|
+
with patch("ragmint.integrations.langchain_prebuilder.FAISS.from_documents", return_value=MagicMock()) as mock_faiss, \
|
|
60
|
+
patch("ragmint.integrations.langchain_prebuilder.Chroma.from_documents", return_value=MagicMock()) as mock_chroma, \
|
|
61
|
+
patch("ragmint.integrations.langchain_prebuilder.BM25Retriever.from_texts", return_value=MagicMock()) as mock_bm25:
|
|
62
|
+
retriever = builder._build_retriever(fake_docs, fake_embeddings)
|
|
63
|
+
|
|
64
|
+
# Validate retriever creation per backend
|
|
65
|
+
if backend == "faiss":
|
|
66
|
+
mock_faiss.assert_called_once()
|
|
67
|
+
elif backend == "chroma":
|
|
68
|
+
mock_chroma.assert_called_once()
|
|
69
|
+
elif backend == "bm25":
|
|
70
|
+
mock_bm25.assert_called_once()
|
|
71
|
+
|
|
72
|
+
assert retriever is not None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_invalid_backend_raises(sample_config):
|
|
76
|
+
"""Ensure ValueError is raised for unsupported retriever."""
|
|
77
|
+
cfg = dict(sample_config)
|
|
78
|
+
cfg["retriever"] = "invalid"
|
|
79
|
+
|
|
80
|
+
builder = LangchainPrebuilder(cfg)
|
|
81
|
+
with pytest.raises(ValueError, match="Unsupported retriever backend"):
|
|
82
|
+
builder._build_retriever(["doc"], MagicMock())
|