ragmint 0.2.1__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ragmint/app.py +512 -0
  2. ragmint/autotuner.py +201 -17
  3. ragmint/core/chunking.py +68 -4
  4. ragmint/core/embeddings.py +46 -10
  5. ragmint/core/evaluation.py +33 -14
  6. ragmint/core/pipeline.py +34 -10
  7. ragmint/core/retriever.py +152 -20
  8. ragmint/experiments/validation_qa.json +1 -14
  9. ragmint/explainer.py +47 -20
  10. ragmint/integrations/__init__.py +0 -0
  11. ragmint/integrations/config_adapter.py +96 -0
  12. ragmint/integrations/langchain_prebuilder.py +99 -0
  13. ragmint/leaderboard.py +41 -35
  14. ragmint/qa_generator.py +190 -0
  15. ragmint/tests/test_autotuner.py +52 -30
  16. ragmint/tests/test_config_adapter.py +39 -0
  17. ragmint/tests/test_embeddings.py +46 -0
  18. ragmint/tests/test_explainer.py +28 -12
  19. ragmint/tests/test_integration_autotuner_ragmint.py +39 -52
  20. ragmint/tests/test_langchain_prebuilder.py +82 -0
  21. ragmint/tests/test_leaderboard.py +78 -25
  22. ragmint/tests/test_pipeline.py +3 -2
  23. ragmint/tests/test_qa_generator.py +66 -0
  24. ragmint/tests/test_retriever.py +3 -2
  25. ragmint/tests/test_tuner.py +1 -1
  26. ragmint/tuner.py +109 -22
  27. ragmint-0.4.6.data/data/README.md +485 -0
  28. ragmint-0.4.6.dist-info/METADATA +530 -0
  29. ragmint-0.4.6.dist-info/RECORD +48 -0
  30. ragmint-0.4.6.dist-info/licenses/LICENSE +19 -0
  31. ragmint/tests/test_explainer_integration.py +0 -18
  32. ragmint-0.2.1.dist-info/METADATA +0 -27
  33. ragmint-0.2.1.dist-info/RECORD +0 -38
  34. {ragmint-0.2.1.dist-info/licenses → ragmint-0.4.6.data/data}/LICENSE +0 -0
  35. {ragmint-0.2.1.dist-info → ragmint-0.4.6.dist-info}/WHEEL +0 -0
  36. {ragmint-0.2.1.dist-info → ragmint-0.4.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,190 @@
1
+ """
2
+ Batched Validation QA Generator for Ragmint (Functional Version)
3
+
4
+ Generates a JSON QA dataset from a large corpus using an LLM.
5
+ Processes documents in batches to avoid token limits and API errors.
6
+ Uses topic-aware dynamic question count estimation.
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import json
12
+ import math
13
+ import time
14
+ import argparse
15
+ from pathlib import Path
16
+ from dotenv import load_dotenv
17
+ import numpy as np
18
+ from sklearn.cluster import KMeans
19
+ from sentence_transformers import SentenceTransformer
20
+
21
+ # --- Load .env if available ---
22
+ load_dotenv()
23
+
24
+
25
+ # ---------- Utility functions ----------
26
+
27
+ def extract_json_from_markdown(text: str):
28
+ """Extract JSON from a markdown-style code block."""
29
+ match = re.search(r"```(?:json)?\s*(\[\s*[\s\S]*?\s*\])\s*```", text, re.MULTILINE)
30
+ if match:
31
+ json_str = match.group(1)
32
+ return json.loads(json_str)
33
+ else:
34
+ return json.loads(text.strip())
35
+
36
+
37
+ def read_corpus(docs_path: str):
38
+ """Load all text documents from a folder."""
39
+ docs = []
40
+ for file in Path(docs_path).glob("**/*.txt"):
41
+ with open(file, "r", encoding="utf-8") as f:
42
+ text = f.read().strip()
43
+ if text:
44
+ docs.append({"filename": file.name, "text": text})
45
+ return docs
46
+
47
+
48
+ def determine_question_count(text: str, embedder, min_q=3, max_q=25):
49
+ """Estimate number of questions dynamically based on text length and topic diversity."""
50
+ sentences = [s.strip() for s in text.split('.') if len(s.strip().split()) > 3]
51
+ word_count = len(text.split())
52
+
53
+ if word_count == 0:
54
+ return min_q
55
+
56
+ base_q = math.log1p(word_count / 150)
57
+
58
+ # Topic diversity via clustering
59
+ n_sent = len(sentences)
60
+ if n_sent < 5:
61
+ topic_factor = 1.0
62
+ else:
63
+ try:
64
+ emb = embedder.encode(sentences, normalize_embeddings=True)
65
+ n_clusters = min(max(2, n_sent // 10), 8)
66
+ km = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
67
+ labels = km.fit_predict(emb)
68
+ topic_factor = len(set(labels)) / n_clusters
69
+ except Exception as e:
70
+ print(f"[WARN] Clustering failed ({type(e).__name__}): {e}")
71
+ topic_factor = 1.0
72
+
73
+ score = base_q * (1 + 0.8 * topic_factor)
74
+ question_count = round(min_q + score)
75
+ return int(max(min_q, min(question_count, max_q)))
76
+
77
+
78
+ def setup_llm(llm_model="gemini-2.5-flash-lite"):
79
+ """Configure Gemini or Claude based on available environment keys."""
80
+ google_key = os.getenv("GOOGLE_API_KEY")
81
+ anthropic_key = os.getenv("ANTHROPIC_API_KEY")
82
+
83
+ if google_key:
84
+ import google.generativeai as genai
85
+ genai.configure(api_key=google_key)
86
+ llm = genai.GenerativeModel(llm_model)
87
+ return llm, "gemini"
88
+
89
+ elif anthropic_key:
90
+ from anthropic import Anthropic
91
+ llm = Anthropic(api_key=anthropic_key)
92
+ return llm, "claude"
93
+
94
+ else:
95
+ raise ValueError("Set ANTHROPIC_API_KEY or GOOGLE_API_KEY in your environment.")
96
+
97
+
98
+ def generate_qa_for_batch(batch, llm, backend, embedder, min_q=3, max_q=25):
99
+ """Send one LLM call for a batch of documents."""
100
+ prompt_texts = []
101
+ for doc in batch:
102
+ n_questions = determine_question_count(doc["text"], embedder, min_q, max_q)
103
+ prompt_texts.append(
104
+ f"Document: {doc['text'][:1000]}\n"
105
+ f"Generate {n_questions} factual question-answer pairs in JSON format."
106
+ )
107
+
108
+ prompt = "\n\n".join(prompt_texts)
109
+ prompt += "\n\nReturn a single JSON array of objects like:\n" \
110
+ '[{"query": "string", "expected_answer": "string"}]'
111
+
112
+ try:
113
+ if backend == "gemini":
114
+ response = llm.generate_content(prompt)
115
+ text_out = getattr(response, "text", None)
116
+ if not text_out and hasattr(response, "candidates"):
117
+ text_out = response.candidates[0].content.parts[0].text
118
+ return extract_json_from_markdown(text_out)
119
+
120
+ elif backend == "claude":
121
+ response = llm.messages.create(
122
+ model="claude-3-opus-20240229",
123
+ messages=[{"role": "user", "content": prompt}],
124
+ max_tokens=2000,
125
+ )
126
+ return json.loads(response.content[0].text)
127
+
128
+ except Exception as e:
129
+ print(f"[WARN] Failed to parse batch: {e}")
130
+ return []
131
+
132
+
133
+ def save_json(output_path, data):
134
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
135
+ with open(output_path, "w", encoding="utf-8") as f:
136
+ json.dump(data, f, indent=2, ensure_ascii=False)
137
+ print(f"[INFO] Saved {len(data)} QAs → {output_path}")
138
+
139
+
140
+ def generate_validation_qa(
141
+ docs_path="data/docs",
142
+ output_path="experiments/validation_qa.json",
143
+ llm_model="gemini-2.5-flash-lite",
144
+ batch_size=5,
145
+ sleep_between_batches=2,
146
+ min_q=3,
147
+ max_q=25,
148
+ ):
149
+ """Main pipeline to generate QAs."""
150
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
151
+ llm, backend = setup_llm(llm_model)
152
+ all_qa = []
153
+
154
+ corpus = read_corpus(docs_path)
155
+ print(f"[INFO] Loaded {len(corpus)} documents from {docs_path}")
156
+
157
+ for i in range(0, len(corpus), batch_size):
158
+ batch = corpus[i: i + batch_size]
159
+ batch_qa = generate_qa_for_batch(batch, llm, backend, embedder, min_q, max_q)
160
+ all_qa.extend(batch_qa)
161
+ print(f"[INFO] Batch {i // batch_size + 1}: {len(batch_qa)} QAs (Total: {len(all_qa)})")
162
+ time.sleep(sleep_between_batches)
163
+
164
+ save_json(output_path, all_qa)
165
+
166
+
167
+ # ---------- CLI entry point ----------
168
+
169
+ def main():
170
+ parser = argparse.ArgumentParser(description="Generate validation QA dataset for Ragmint.")
171
+ parser.add_argument("--docs_path", type=str, default="data/docs")
172
+ parser.add_argument("--output", type=str, default="experiments/validation_qa.json")
173
+ parser.add_argument("--batch_size", type=int, default=5)
174
+ parser.add_argument("--sleep", type=int, default=2)
175
+ parser.add_argument("--min_q", type=int, default=3)
176
+ parser.add_argument("--max_q", type=int, default=25)
177
+ args = parser.parse_args()
178
+
179
+ generate_validation_qa(
180
+ docs_path=args.docs_path,
181
+ output_path=args.output,
182
+ batch_size=args.batch_size,
183
+ sleep_between_batches=args.sleep,
184
+ min_q=args.min_q,
185
+ max_q=args.max_q,
186
+ )
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()
@@ -1,42 +1,64 @@
1
+ import os
2
+ import json
1
3
  import pytest
2
4
  from ragmint.autotuner import AutoRAGTuner
3
5
 
4
6
 
5
- def test_autorag_recommend_small():
6
- """Small corpus should trigger BM25 + OpenAI."""
7
- tuner = AutoRAGTuner({"size": 500, "avg_len": 150})
8
- rec = tuner.recommend()
9
- assert rec["retriever"] == "BM25"
10
- assert rec["embedding_model"] == "OpenAI"
7
+ def setup_docs(tmp_path):
8
+ """Create a temporary corpus with multiple text files for testing."""
9
+ corpus = tmp_path / "corpus"
10
+ corpus.mkdir()
11
+ (corpus / "short_doc.txt").write_text("AI is changing the world.")
12
+ (corpus / "long_doc.txt").write_text("Machine learning enables RAG pipelines to optimize retrievals. " * 50)
13
+ return str(corpus)
11
14
 
12
15
 
13
- def test_autorag_recommend_medium():
14
- """Medium corpus should trigger Chroma + SentenceTransformers."""
15
- tuner = AutoRAGTuner({"size": 5000, "avg_len": 200})
16
- rec = tuner.recommend()
17
- assert rec["retriever"] == "Chroma"
18
- assert rec["embedding_model"] == "SentenceTransformers"
16
+ def test_analyze_corpus(tmp_path):
17
+ """Ensure AutoRAGTuner analyzes corpus correctly."""
18
+ docs_path = setup_docs(tmp_path)
19
+ tuner = AutoRAGTuner(docs_path)
20
+ stats = tuner.corpus_stats
19
21
 
22
+ assert stats["num_docs"] == 2, "Should detect all documents"
23
+ assert stats["size"] > 0, "Corpus size should be positive"
24
+ assert stats["avg_len"] > 0, "Average document length should be computed"
20
25
 
21
- def test_autorag_recommend_large():
22
- """Large corpus should trigger FAISS + InstructorXL."""
23
- tuner = AutoRAGTuner({"size": 50000, "avg_len": 300})
24
- rec = tuner.recommend()
25
- assert rec["retriever"] == "FAISS"
26
- assert rec["embedding_model"] == "InstructorXL"
27
26
 
27
+ @pytest.mark.parametrize("size,expected_retriever", [
28
+ (10_000, "Chroma"),
29
+ (500_000, "FAISS"),
30
+ (1_000, "BM25"),
31
+ ])
32
+ def test_recommendation_logic(tmp_path, monkeypatch, size, expected_retriever):
33
+ """Validate retriever recommendation based on corpus size and chunk suggestion."""
34
+ docs_path = setup_docs(tmp_path)
35
+ tuner = AutoRAGTuner(docs_path)
28
36
 
29
- def test_autorag_auto_tune(monkeypatch):
30
- """Test auto_tune with a mock validation dataset."""
31
- tuner = AutoRAGTuner({"size": 12000, "avg_len": 250})
37
+ # Mock corpus stats manually
38
+ tuner.corpus_stats = {"size": size, "avg_len": 300, "num_docs": 10}
32
39
 
33
- # Monkeypatch evaluate_config inside autotuner
34
- import ragmint.autotuner as autotuner
35
- def mock_eval(config, data):
36
- return {"faithfulness": 0.9, "latency": 0.01}
37
- monkeypatch.setattr(autotuner, "evaluate_config", mock_eval)
40
+ # Provide mandatory num_chunk_pairs
41
+ rec = tuner.recommend(num_chunk_pairs=3)
38
42
 
39
- result = tuner.auto_tune([{"question": "What is AI?", "answer": "Artificial Intelligence"}])
40
- assert "recommended" in result
41
- assert "results" in result
42
- assert isinstance(result["results"], dict)
43
+ assert "retriever" in rec and "embedding_model" in rec
44
+ assert rec["retriever"] == expected_retriever, f"Expected {expected_retriever}"
45
+ assert rec["chunk_size"] > 0 and rec["overlap"] >= 0
46
+ assert "chunk_candidates" in rec, "Should include suggested chunk pairs"
47
+ assert len(rec["chunk_candidates"]) == 3, "Should generate correct number of chunk pairs"
48
+
49
+
50
+ def test_invalid_corpus_path(tmp_path):
51
+ """Should handle missing directories gracefully."""
52
+ missing_path = tmp_path / "nonexistent"
53
+ tuner = AutoRAGTuner(str(missing_path))
54
+ assert tuner.corpus_stats["size"] == 0
55
+ assert tuner.corpus_stats["num_docs"] == 0
56
+
57
+
58
+ def test_suggest_chunk_sizes_requires_num_pairs(tmp_path):
59
+ """Ensure suggest_chunk_sizes raises error if num_pairs is not provided."""
60
+ docs_path = setup_docs(tmp_path)
61
+ tuner = AutoRAGTuner(docs_path)
62
+
63
+ with pytest.raises(ValueError):
64
+ tuner.suggest_chunk_sizes()
@@ -0,0 +1,39 @@
1
+ import pytest
2
+ from ragmint.integrations.config_adapter import LangchainConfigAdapter
3
+
4
+ def test_default_conversion():
5
+ """Test that default config values are applied correctly."""
6
+ cfg = {
7
+ "retriever": "FAISS",
8
+ "embedding_model": "all-MiniLM-L6-v2",
9
+ "chunk_size": 500,
10
+ "overlap": 100
11
+ }
12
+
13
+ adapter = LangchainConfigAdapter(cfg)
14
+ result = adapter.to_standard_config()
15
+
16
+ assert result["retriever"].lower() == "faiss"
17
+ assert result["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
18
+ assert result["chunk_size"] == 500
19
+ assert result["overlap"] == 100
20
+
21
+
22
+ def test_missing_fields_are_defaulted():
23
+ """Ensure missing optional fields (e.g. chunk params) are filled in."""
24
+ cfg = {"retriever": "BM25", "embedding_model": "all-MiniLM-L6-v2"}
25
+ adapter = LangchainConfigAdapter(cfg)
26
+ result = adapter.to_standard_config()
27
+
28
+ assert "chunk_size" in result
29
+ assert "overlap" in result
30
+ assert result["chunk_size"] > 0
31
+ assert result["overlap"] >= 0
32
+
33
+
34
+ def test_validation_of_invalid_retriever():
35
+ """Ensure invalid retriever names raise an informative error."""
36
+ cfg = {"retriever": "InvalidBackend", "embedding_model": "all-MiniLM-L6-v2"}
37
+
38
+ with pytest.raises(ValueError, match="Unsupported retriever backend"):
39
+ LangchainConfigAdapter(cfg).to_standard_config()
@@ -0,0 +1,46 @@
1
+ import numpy as np
2
+ import pytest
3
+ from ragmint.core.embeddings import Embeddings
4
+
5
+
6
+ def test_dummy_backend_output_shape():
7
+ model = Embeddings(backend="dummy")
8
+ texts = ["hello", "world"]
9
+ embeddings = model.encode(texts)
10
+
11
+ # Expect 2x768 array
12
+ assert isinstance(embeddings, np.ndarray)
13
+ assert embeddings.shape == (2, 768)
14
+ assert embeddings.dtype == np.float32
15
+
16
+
17
+ def test_dummy_backend_single_string():
18
+ model = Embeddings(backend="dummy")
19
+ text = "test"
20
+ embeddings = model.encode(text)
21
+
22
+ assert embeddings.shape == (1, 768)
23
+ assert isinstance(embeddings, np.ndarray)
24
+
25
+
26
+ '''@pytest.mark.skipif(
27
+ not hasattr(__import__('importlib').util.find_spec("sentence_transformers"), "loader"),
28
+ reason="sentence-transformers not installed"
29
+ )
30
+ def test_huggingface_backend_output_shape():
31
+ model = Embeddings(backend="huggingface", model_name="all-MiniLM-L6-v2")
32
+ texts = ["This is a test.", "Another sentence."]
33
+ embeddings = model.encode(texts)
34
+
35
+ # Expect 2x384 for MiniLM-L6-v2
36
+ assert isinstance(embeddings, np.ndarray)
37
+ assert embeddings.ndim == 2
38
+ assert embeddings.shape[0] == len(texts)
39
+ assert embeddings.dtype == np.float32
40
+ '''
41
+
42
+ def test_invalid_backend():
43
+ try:
44
+ Embeddings(backend="unknown")
45
+ except ValueError as e:
46
+ assert "Unsupported embedding backend" in str(e)
@@ -1,20 +1,36 @@
1
1
  import pytest
2
+ import sys
3
+ import types
2
4
  from ragmint.explainer import explain_results
3
5
 
4
6
 
5
- def test_explain_results_gemini():
6
- """Gemini explanation should contain model-specific phrasing."""
7
- config_a = {"retriever": "FAISS", "embedding_model": "OpenAI"}
8
- config_b = {"retriever": "Chroma", "embedding_model": "SentenceTransformers"}
9
- result = explain_results(config_a, config_b, model="gemini")
10
- assert isinstance(result, str)
11
- assert "Gemini" in result or "gemini" in result
7
+ def test_explain_results_with_claude(monkeypatch):
8
+ """Claude explanation should use Anthropic API path when ANTHROPIC_API_KEY is set."""
9
+ monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key")
10
+ monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
11
+
12
+ # Create a fake anthropic module with the required interface
13
+ mock_anthropic = types.ModuleType("anthropic")
14
+
15
+ class MockContent:
16
+ text = "Claude: The best configuration performs well due to optimized chunk size."
17
+
18
+ class MockMessages:
19
+ def create(self, *args, **kwargs):
20
+ return type("MockResponse", (), {"content": [MockContent()]})()
21
+
22
+ class MockClient:
23
+ def __init__(self, api_key):
24
+ self.messages = MockMessages()
25
+
26
+ mock_anthropic.Anthropic = MockClient
27
+ sys.modules["anthropic"] = mock_anthropic # Inject fake module
28
+
29
+ best = {"retriever": "Chroma", "metric": 0.9}
30
+ all_results = [{"retriever": "FAISS", "metric": 0.85}]
31
+ corpus_stats = {"size": 10000, "avg_len": 400, "num_docs": 20}
12
32
 
33
+ result = explain_results(best, all_results, corpus_stats, model="claude-3-opus-20240229")
13
34
 
14
- def test_explain_results_claude():
15
- """Claude explanation should contain model-specific phrasing."""
16
- config_a = {"retriever": "FAISS"}
17
- config_b = {"retriever": "Chroma"}
18
- result = explain_results(config_a, config_b, model="claude")
19
35
  assert isinstance(result, str)
20
36
  assert "Claude" in result or "claude" in result
@@ -1,60 +1,47 @@
1
+ import os
2
+ import json
1
3
  import pytest
2
- from ragmint.tuner import RAGMint
3
4
  from ragmint.autotuner import AutoRAGTuner
5
+ from ragmint.tuner import RAGMint
4
6
 
5
7
 
6
- def test_integration_ragmint_autotune(monkeypatch, tmp_path):
7
- """
8
- Smoke test for integration between AutoRAGTuner and RAGMint.
9
- Ensures end-to-end flow runs without real retrievers or embeddings.
10
- """
11
-
12
- # --- Mock corpus and validation data ---
8
+ def setup_docs(tmp_path):
9
+ """Create a temporary corpus for integration testing."""
13
10
  corpus = tmp_path / "docs"
14
11
  corpus.mkdir()
15
- (corpus / "doc1.txt").write_text("This is an AI document.")
16
- validation_data = [{"question": "What is AI?", "answer": "Artificial Intelligence"}]
17
-
18
- # --- Mock RAGMint.optimize() to avoid real model work ---
19
- def mock_optimize(self, validation_set=None, metric="faithfulness", trials=2):
20
- return (
21
- {"retriever": "FAISS", "embedding_model": "OpenAI", "score": 0.88},
22
- [{"trial": 1, "score": 0.88}],
23
- )
24
-
25
- monkeypatch.setattr(RAGMint, "optimize", mock_optimize)
26
-
27
- # --- Mock evaluation used by AutoRAGTuner ---
28
- def mock_evaluate_config(config, data):
29
- return {"faithfulness": 0.9, "latency": 0.01}
30
-
31
- import ragmint.autotuner as autotuner
32
- monkeypatch.setattr(autotuner, "evaluate_config", mock_evaluate_config)
33
-
34
- # --- Create AutoRAGTuner and RAGMint instances ---
35
- ragmint = RAGMint(
36
- docs_path=str(corpus),
37
- retrievers=["faiss", "chroma"],
38
- embeddings=["text-embedding-3-small"],
39
- rerankers=["mmr"],
12
+ (corpus / "doc1.txt").write_text("This document discusses Artificial Intelligence and Machine Learning.")
13
+ (corpus / "doc2.txt").write_text("Retrieval-Augmented Generation combines retrievers and LLMs effectively.")
14
+ return str(corpus)
15
+
16
+
17
+ def setup_validation_file(tmp_path):
18
+ """Create a temporary validation QA dataset."""
19
+ data = [
20
+ {"question": "What is AI?", "answer": "Artificial Intelligence"},
21
+ {"question": "Define RAG", "answer": "Retrieval-Augmented Generation"},
22
+ ]
23
+ file = tmp_path / "validation_qa.json"
24
+ with open(file, "w", encoding="utf-8") as f:
25
+ json.dump(data, f)
26
+ return str(file)
27
+
28
+
29
+ def test_autotune_integration(tmp_path):
30
+ """Test that AutoRAGTuner can fully run a RAGMint optimization."""
31
+ docs_path = setup_docs(tmp_path)
32
+ val_file = setup_validation_file(tmp_path)
33
+
34
+ tuner = AutoRAGTuner(docs_path)
35
+ best, results = tuner.auto_tune(
36
+ validation_set=val_file,
37
+ metric="faithfulness",
38
+ trials=2,
39
+ search_type="random",
40
40
  )
41
41
 
42
- tuner = AutoRAGTuner({"size": 2000, "avg_len": 150})
43
-
44
- # --- Run Auto-Tune and RAG Optimization ---
45
- recommendation = tuner.recommend()
46
- assert "retriever" in recommendation
47
- assert "embedding_model" in recommendation
48
-
49
- tuning_results = tuner.auto_tune(validation_data)
50
- assert "results" in tuning_results
51
- assert isinstance(tuning_results["results"], dict)
52
-
53
- # --- Run RAGMint optimization flow (mocked) ---
54
- best_config, results = ragmint.optimize(validation_set=validation_data, trials=2)
55
- assert isinstance(best_config, dict)
56
- assert "score" in best_config
57
- assert isinstance(results, list)
58
-
59
- # --- Integration Success ---
60
- print(f"Integration OK: AutoRAG recommended {recommendation}, RAGMint best {best_config}")
42
+ # Assertions on the results
43
+ assert isinstance(best, dict), "Best configuration should be a dict"
44
+ assert isinstance(results, list), "Results should be a list"
45
+ assert len(results) > 0, "Optimization should produce results"
46
+ assert "retriever" in best and "embedding_model" in best
47
+ assert best.get("faithfulness", 0.0) >= 0.0, "Metric value should be non-negative"
@@ -0,0 +1,82 @@
1
+ import pytest
2
+ from unittest.mock import MagicMock, patch
3
+ from ragmint.integrations.langchain_prebuilder import LangchainPrebuilder
4
+
5
+
6
+ @pytest.fixture
7
+ def sample_docs():
8
+ """Small sample corpus for testing."""
9
+ return ["AI is transforming the world.", "RAG pipelines improve retrieval."]
10
+
11
+
12
+ @pytest.fixture
13
+ def sample_config():
14
+ """Default configuration for tests."""
15
+ return {
16
+ "retriever": "faiss",
17
+ "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
18
+ "chunk_size": 200,
19
+ "overlap": 50,
20
+ }
21
+
22
+
23
+ @patch("ragmint.integrations.langchain_prebuilder.HuggingFaceEmbeddings", autospec=True)
24
+ @patch("ragmint.integrations.langchain_prebuilder.RecursiveCharacterTextSplitter", autospec=True)
25
+ def test_prepare_creates_components(mock_splitter, mock_embedder, sample_config, sample_docs):
26
+ """Ensure prepare() builds retriever and embedding components properly."""
27
+ mock_splitter.return_value.create_documents.return_value = ["doc1", "doc2"]
28
+ mock_embedder.return_value = MagicMock()
29
+
30
+ # Patch FAISS to avoid building a real index
31
+ with patch("ragmint.integrations.langchain_prebuilder.FAISS", autospec=True) as mock_faiss:
32
+ mock_db = MagicMock()
33
+ mock_faiss.from_documents.return_value = mock_db
34
+ mock_db.as_retriever.return_value = "mock_retriever"
35
+
36
+ builder = LangchainPrebuilder(sample_config)
37
+ retriever, embeddings = builder.prepare(sample_docs)
38
+
39
+ assert retriever == "mock_retriever"
40
+ assert embeddings == mock_embedder.return_value
41
+
42
+ mock_splitter.assert_called_once()
43
+ mock_embedder.assert_called_once_with(model_name="sentence-transformers/all-MiniLM-L6-v2")
44
+ mock_faiss.from_documents.assert_called_once()
45
+
46
+
47
+ @pytest.mark.parametrize("backend", ["faiss", "chroma", "bm25"])
48
+ def test_build_retriever_backends(sample_config, sample_docs, backend):
49
+ """Check retriever creation for each backend."""
50
+ cfg = dict(sample_config)
51
+ cfg["retriever"] = backend
52
+
53
+ builder = LangchainPrebuilder(cfg)
54
+
55
+ # Mock embeddings + docs
56
+ fake_embeddings = MagicMock()
57
+ fake_docs = ["d1", "d2"]
58
+
59
+ with patch("ragmint.integrations.langchain_prebuilder.FAISS.from_documents", return_value=MagicMock()) as mock_faiss, \
60
+ patch("ragmint.integrations.langchain_prebuilder.Chroma.from_documents", return_value=MagicMock()) as mock_chroma, \
61
+ patch("ragmint.integrations.langchain_prebuilder.BM25Retriever.from_texts", return_value=MagicMock()) as mock_bm25:
62
+ retriever = builder._build_retriever(fake_docs, fake_embeddings)
63
+
64
+ # Validate retriever creation per backend
65
+ if backend == "faiss":
66
+ mock_faiss.assert_called_once()
67
+ elif backend == "chroma":
68
+ mock_chroma.assert_called_once()
69
+ elif backend == "bm25":
70
+ mock_bm25.assert_called_once()
71
+
72
+ assert retriever is not None
73
+
74
+
75
+ def test_invalid_backend_raises(sample_config):
76
+ """Ensure ValueError is raised for unsupported retriever."""
77
+ cfg = dict(sample_config)
78
+ cfg["retriever"] = "invalid"
79
+
80
+ builder = LangchainPrebuilder(cfg)
81
+ with pytest.raises(ValueError, match="Unsupported retriever backend"):
82
+ builder._build_retriever(["doc"], MagicMock())