ragmint 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ragmint might be problematic. Click here for more details.
- ragmint/__init__.py +0 -0
- ragmint/__main__.py +28 -0
- ragmint/autotuner.py +138 -0
- ragmint/core/__init__.py +0 -0
- ragmint/core/chunking.py +86 -0
- ragmint/core/embeddings.py +55 -0
- ragmint/core/evaluation.py +38 -0
- ragmint/core/pipeline.py +62 -0
- ragmint/core/reranker.py +62 -0
- ragmint/core/retriever.py +165 -0
- ragmint/experiments/__init__.py +0 -0
- ragmint/experiments/validation_qa.json +14 -0
- ragmint/explainer.py +63 -0
- ragmint/integrations/__init__.py +0 -0
- ragmint/integrations/config_adapter.py +96 -0
- ragmint/integrations/langchain_prebuilder.py +99 -0
- ragmint/leaderboard.py +45 -0
- ragmint/optimization/__init__.py +0 -0
- ragmint/optimization/search.py +48 -0
- ragmint/tests/__init__.py +0 -0
- ragmint/tests/conftest.py +16 -0
- ragmint/tests/test_autotuner.py +51 -0
- ragmint/tests/test_config_adapter.py +39 -0
- ragmint/tests/test_embeddings.py +46 -0
- ragmint/tests/test_explainer.py +20 -0
- ragmint/tests/test_explainer_integration.py +18 -0
- ragmint/tests/test_integration_autotuner_ragmint.py +47 -0
- ragmint/tests/test_langchain_prebuilder.py +82 -0
- ragmint/tests/test_leaderboard.py +39 -0
- ragmint/tests/test_pipeline.py +20 -0
- ragmint/tests/test_retriever.py +15 -0
- ragmint/tests/test_search.py +17 -0
- ragmint/tests/test_tuner.py +71 -0
- ragmint/tuner.py +189 -0
- ragmint/utils/__init__.py +0 -0
- ragmint/utils/caching.py +37 -0
- ragmint/utils/data_loader.py +65 -0
- ragmint/utils/logger.py +36 -0
- ragmint/utils/metrics.py +27 -0
- ragmint-0.3.1.data/data/LICENSE +19 -0
- ragmint-0.3.1.data/data/README.md +397 -0
- ragmint-0.3.1.dist-info/METADATA +441 -0
- ragmint-0.3.1.dist-info/RECORD +46 -0
- ragmint-0.3.1.dist-info/WHEEL +5 -0
- ragmint-0.3.1.dist-info/licenses/LICENSE +19 -0
- ragmint-0.3.1.dist-info/top_level.txt +1 -0
ragmint/explainer.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Interpretability Layer
|
|
3
|
+
----------------------
|
|
4
|
+
Uses Gemini or Anthropic Claude to explain why one RAG configuration
|
|
5
|
+
outperforms another. Falls back gracefully if no API key is provided.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import json
|
|
10
|
+
from dotenv import load_dotenv
|
|
11
|
+
|
|
12
|
+
# Load environment variables from .env file if available
|
|
13
|
+
load_dotenv()
|
|
14
|
+
|
|
15
|
+
def explain_results(results_a: dict, results_b: dict, model: str = "gemini-2.5-flash-lite") -> str:
|
|
16
|
+
"""
|
|
17
|
+
Generate a natural-language explanation comparing two RAG experiment results.
|
|
18
|
+
Priority:
|
|
19
|
+
1. Anthropic Claude (if ANTHROPIC_API_KEY is set)
|
|
20
|
+
2. Google Gemini (if GOOGLE_API_KEY is set)
|
|
21
|
+
3. Fallback text message
|
|
22
|
+
"""
|
|
23
|
+
prompt = f"""
|
|
24
|
+
You are an AI evaluation expert.
|
|
25
|
+
Compare these two RAG experiment results and explain why one performs better.
|
|
26
|
+
Metrics A: {json.dumps(results_a, indent=2)}
|
|
27
|
+
Metrics B: {json.dumps(results_b, indent=2)}
|
|
28
|
+
Provide a concise, human-friendly explanation and practical improvement tips.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
|
32
|
+
google_key = os.getenv("GOOGLE_API_KEY") # fixed var name
|
|
33
|
+
|
|
34
|
+
# 1️⃣ Try Anthropic Claude first
|
|
35
|
+
if anthropic_key:
|
|
36
|
+
try:
|
|
37
|
+
from anthropic import Anthropic
|
|
38
|
+
client = Anthropic(api_key=anthropic_key)
|
|
39
|
+
response = client.messages.create(
|
|
40
|
+
model="claude-3-opus-20240229",
|
|
41
|
+
max_tokens=300,
|
|
42
|
+
messages=[{"role": "user", "content": prompt}],
|
|
43
|
+
)
|
|
44
|
+
return response.content[0].text
|
|
45
|
+
except Exception as e:
|
|
46
|
+
return f"[Claude unavailable] {e}"
|
|
47
|
+
|
|
48
|
+
# 2️⃣ Fallback to Google Gemini
|
|
49
|
+
elif google_key:
|
|
50
|
+
try:
|
|
51
|
+
import google.generativeai as genai
|
|
52
|
+
genai.configure(api_key=google_key)
|
|
53
|
+
response = genai.GenerativeModel(model).generate_content(prompt)
|
|
54
|
+
return response.text
|
|
55
|
+
except Exception as e:
|
|
56
|
+
return f"[Gemini unavailable] {e}"
|
|
57
|
+
|
|
58
|
+
# 3️⃣ Fallback if neither key is available
|
|
59
|
+
else:
|
|
60
|
+
return (
|
|
61
|
+
"[No LLM available] Please set ANTHROPIC_API_KEY or GOOGLE_API_KEY "
|
|
62
|
+
"to enable interpretability via Claude or Gemini."
|
|
63
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAGMint → LangChain Config Adapter
|
|
3
|
+
----------------------------------
|
|
4
|
+
Takes RAGMint or AutoRAGTuner recommendations and converts them into
|
|
5
|
+
a normalized, pickle-safe configuration that can be used to build
|
|
6
|
+
a LangChain RAG pipeline later.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import pickle
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LangchainConfigAdapter:
|
|
16
|
+
"""
|
|
17
|
+
Converts RAGMint recommendations into LangChain-compatible configs.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
adapter = LangChainConfigAdapter()
|
|
21
|
+
cfg = adapter.prepare(recommendation)
|
|
22
|
+
adapter.save(cfg, "best_config.pkl")
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
DEFAULT_EMBEDDINGS = {
|
|
26
|
+
"OpenAI": "sentence-transformers/all-MiniLM-L6-v2",
|
|
27
|
+
"SentenceTransformers": "sentence-transformers/all-MiniLM-L6-v2",
|
|
28
|
+
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
|
29
|
+
"InstructorXL": "hkunlp/instructor-xl"
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
SUPPORTED_RETRIEVERS = {"faiss", "chroma", "bm25", "numpy", "sklearn"}
|
|
33
|
+
|
|
34
|
+
def __init__(self, recommendation: Dict[str, Any] | None = None):
|
|
35
|
+
self.recommendation = recommendation
|
|
36
|
+
|
|
37
|
+
def prepare(self, recommendation: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
|
38
|
+
recommendation = recommendation or self.recommendation or {}
|
|
39
|
+
"""
|
|
40
|
+
Normalize and validate configuration for LangChain use.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
dict with clean retriever, embedding, and chunking settings.
|
|
44
|
+
"""
|
|
45
|
+
retriever = recommendation.get("retriever", "faiss").lower()
|
|
46
|
+
embedding_model = recommendation.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")
|
|
47
|
+
chunk_size = recommendation.get("chunk_size", 400)
|
|
48
|
+
overlap = recommendation.get("overlap", 100)
|
|
49
|
+
|
|
50
|
+
# Normalize embedding model names
|
|
51
|
+
embedding_model = self.DEFAULT_EMBEDDINGS.get(embedding_model, embedding_model)
|
|
52
|
+
|
|
53
|
+
# Validate retriever backend
|
|
54
|
+
if retriever not in self.SUPPORTED_RETRIEVERS:
|
|
55
|
+
raise ValueError(f"Unsupported retriever backend: {retriever}")
|
|
56
|
+
|
|
57
|
+
config = {
|
|
58
|
+
"retriever": retriever,
|
|
59
|
+
"embedding_model": embedding_model,
|
|
60
|
+
"chunk_size": int(chunk_size),
|
|
61
|
+
"overlap": int(overlap),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
return config
|
|
65
|
+
|
|
66
|
+
def save(self, config: Dict[str, Any], path: str):
|
|
67
|
+
"""
|
|
68
|
+
Save configuration to a pickle file.
|
|
69
|
+
"""
|
|
70
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
71
|
+
with open(path, "wb") as f:
|
|
72
|
+
pickle.dump(config, f)
|
|
73
|
+
print(f"💾 Saved LangChain config → {path}")
|
|
74
|
+
|
|
75
|
+
def load(self, path: str) -> Dict[str, Any]:
|
|
76
|
+
"""
|
|
77
|
+
Load configuration from a pickle file.
|
|
78
|
+
"""
|
|
79
|
+
with open(path, "rb") as f:
|
|
80
|
+
cfg = pickle.load(f)
|
|
81
|
+
print(f"✅ Loaded LangChain config ← {path}")
|
|
82
|
+
return cfg
|
|
83
|
+
|
|
84
|
+
def to_json(self, config: Dict[str, Any], path: str):
|
|
85
|
+
"""
|
|
86
|
+
Save configuration as JSON (for human readability).
|
|
87
|
+
"""
|
|
88
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
89
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
90
|
+
json.dump(config, f, indent=2)
|
|
91
|
+
print(f"📝 Exported LangChain config → {path}")
|
|
92
|
+
|
|
93
|
+
# Alias for backward compatibility
|
|
94
|
+
def to_standard_config(self, recommendation: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
|
95
|
+
"""Alias for backward compatibility with older test suites."""
|
|
96
|
+
return self.prepare(recommendation)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangChain Pre-Build Integration
|
|
3
|
+
-------------------------------
|
|
4
|
+
This module bridges RAGMint's auto-tuning system with LangChain,
|
|
5
|
+
returning retriever and embedding components that can plug directly
|
|
6
|
+
into any LangChain RAG pipeline.
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
from ragmint.integrations.langchain_prebuilder import LangChainPrebuilder
|
|
10
|
+
from langchain.chains import RetrievalQA
|
|
11
|
+
from langchain_openai import ChatOpenAI
|
|
12
|
+
|
|
13
|
+
prebuilder = LangChainPrebuilder(best_cfg)
|
|
14
|
+
retriever, embeddings = prebuilder.prepare(documents)
|
|
15
|
+
|
|
16
|
+
llm = ChatOpenAI(model="gpt-4o-mini")
|
|
17
|
+
qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from typing import List, Tuple, Dict, Any
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
25
|
+
except ImportError:
|
|
26
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
27
|
+
|
|
28
|
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
29
|
+
from langchain_community.vectorstores import FAISS, Chroma
|
|
30
|
+
from langchain_community.retrievers import BM25Retriever
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LangchainPrebuilder:
|
|
34
|
+
"""
|
|
35
|
+
Dynamically builds LangChain retriever and embedding objects
|
|
36
|
+
based on a RAGMint configuration dictionary.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, cfg: Dict[str, Any]):
|
|
40
|
+
"""
|
|
41
|
+
Args:
|
|
42
|
+
cfg (dict): RAGMint configuration with keys:
|
|
43
|
+
- retriever: "faiss" | "chroma" | "bm25"
|
|
44
|
+
- embedding_model: HuggingFace model name
|
|
45
|
+
- chunk_size: int (default=500)
|
|
46
|
+
- overlap: int (default=100)
|
|
47
|
+
"""
|
|
48
|
+
self.cfg = cfg
|
|
49
|
+
self.retriever_backend = cfg.get("retriever", "faiss").lower()
|
|
50
|
+
self.embedding_model = cfg.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")
|
|
51
|
+
self.chunk_size = int(cfg.get("chunk_size", 500))
|
|
52
|
+
self.overlap = int(cfg.get("overlap", 100))
|
|
53
|
+
|
|
54
|
+
def prepare(self, documents: List[str]) -> Tuple[Any, Any]:
|
|
55
|
+
"""
|
|
56
|
+
Prepares LangChain-compatible retriever and embeddings.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
documents (list[str]): Corpus texts
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
(retriever, embeddings): Tuple of initialized LangChain retriever and embedding model
|
|
63
|
+
"""
|
|
64
|
+
# 1️⃣ Split into chunks
|
|
65
|
+
splitter = RecursiveCharacterTextSplitter(
|
|
66
|
+
chunk_size=self.chunk_size,
|
|
67
|
+
chunk_overlap=self.overlap
|
|
68
|
+
)
|
|
69
|
+
docs = splitter.create_documents(documents)
|
|
70
|
+
|
|
71
|
+
# 2️⃣ Create embeddings
|
|
72
|
+
embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)
|
|
73
|
+
|
|
74
|
+
# 3️⃣ Build retriever
|
|
75
|
+
retriever = self._build_retriever(docs, embeddings)
|
|
76
|
+
return retriever, embeddings
|
|
77
|
+
|
|
78
|
+
def _build_retriever(self, docs, embeddings):
|
|
79
|
+
"""Internal helper for building retriever backend."""
|
|
80
|
+
backend = self.retriever_backend
|
|
81
|
+
|
|
82
|
+
if backend == "faiss":
|
|
83
|
+
db = FAISS.from_documents(docs, embeddings)
|
|
84
|
+
return db.as_retriever(search_kwargs={"k": 5})
|
|
85
|
+
|
|
86
|
+
elif backend == "chroma":
|
|
87
|
+
db = Chroma.from_documents(docs, embeddings, collection_name="ragmint_docs")
|
|
88
|
+
return db.as_retriever(search_kwargs={"k": 5})
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
elif backend == "bm25":
|
|
92
|
+
# Support both Document objects and raw text strings
|
|
93
|
+
texts = [getattr(d, "page_content", d) for d in docs]
|
|
94
|
+
retriever = BM25Retriever.from_texts(texts)
|
|
95
|
+
retriever.k = 5
|
|
96
|
+
return retriever
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError(f"Unsupported retriever backend: {backend}")
|
ragmint/leaderboard.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Dict, Any, Optional
|
|
5
|
+
from supabase import create_client
|
|
6
|
+
|
|
7
|
+
class Leaderboard:
|
|
8
|
+
def __init__(self, storage_path: Optional[str] = None):
|
|
9
|
+
self.storage_path = storage_path
|
|
10
|
+
url = os.getenv("SUPABASE_URL")
|
|
11
|
+
key = os.getenv("SUPABASE_KEY")
|
|
12
|
+
self.client = None
|
|
13
|
+
if url and key:
|
|
14
|
+
self.client = create_client(url, key)
|
|
15
|
+
elif not storage_path:
|
|
16
|
+
raise EnvironmentError("Set SUPABASE_URL/SUPABASE_KEY or pass storage_path")
|
|
17
|
+
|
|
18
|
+
def upload(self, run_id: str, config: Dict[str, Any], score: float):
|
|
19
|
+
data = {
|
|
20
|
+
"run_id": run_id,
|
|
21
|
+
"config": config,
|
|
22
|
+
"score": score,
|
|
23
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
24
|
+
}
|
|
25
|
+
if self.client:
|
|
26
|
+
return self.client.table("experiments").insert(data).execute()
|
|
27
|
+
else:
|
|
28
|
+
os.makedirs(os.path.dirname(self.storage_path), exist_ok=True)
|
|
29
|
+
with open(self.storage_path, "a", encoding="utf-8") as f:
|
|
30
|
+
f.write(json.dumps(data) + "\n")
|
|
31
|
+
return data
|
|
32
|
+
|
|
33
|
+
def top_results(self, limit: int = 10):
|
|
34
|
+
if self.client:
|
|
35
|
+
return (
|
|
36
|
+
self.client.table("experiments")
|
|
37
|
+
.select("*")
|
|
38
|
+
.order("score", desc=True)
|
|
39
|
+
.limit(limit)
|
|
40
|
+
.execute()
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
with open(self.storage_path, "r", encoding="utf-8") as f:
|
|
44
|
+
lines = [json.loads(line) for line in f]
|
|
45
|
+
return sorted(lines, key=lambda x: x["score"], reverse=True)[:limit]
|
|
File without changes
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import random
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Dict, List, Iterator, Any
|
|
5
|
+
|
|
6
|
+
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GridSearch:
|
|
10
|
+
def __init__(self, search_space: Dict[str, List[Any]]):
|
|
11
|
+
keys = list(search_space.keys())
|
|
12
|
+
values = list(search_space.values())
|
|
13
|
+
self.combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
|
14
|
+
|
|
15
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
16
|
+
for combo in self.combinations:
|
|
17
|
+
yield combo
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RandomSearch:
|
|
21
|
+
def __init__(self, search_space: Dict[str, List[Any]], n_trials: int = 10):
|
|
22
|
+
self.search_space = search_space
|
|
23
|
+
self.n_trials = n_trials
|
|
24
|
+
|
|
25
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
26
|
+
keys = list(self.search_space.keys())
|
|
27
|
+
for _ in range(self.n_trials):
|
|
28
|
+
yield {k: random.choice(self.search_space[k]) for k in keys}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BayesianSearch:
|
|
32
|
+
def __init__(self, search_space: Dict[str, List[Any]]):
|
|
33
|
+
try:
|
|
34
|
+
import optuna
|
|
35
|
+
self.optuna = optuna
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise RuntimeError("Optuna not installed; use GridSearch or RandomSearch instead.")
|
|
38
|
+
self.search_space = search_space
|
|
39
|
+
|
|
40
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
41
|
+
keys = list(self.search_space.keys())
|
|
42
|
+
|
|
43
|
+
def objective(trial):
|
|
44
|
+
return {k: trial.suggest_categorical(k, self.search_space[k]) for k in keys}
|
|
45
|
+
|
|
46
|
+
# Example static 5-trial yield for compatibility
|
|
47
|
+
for _ in range(5):
|
|
48
|
+
yield {k: random.choice(self.search_space[k]) for k in keys}
|
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# src/ragmint/tests/conftest.py
|
|
2
|
+
import os
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
# Load .env from project root
|
|
7
|
+
load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), "../../../.env"))
|
|
8
|
+
|
|
9
|
+
def pytest_configure(config):
|
|
10
|
+
"""Print which keys are loaded (debug)."""
|
|
11
|
+
google = os.getenv("GEMINI_API_KEY")
|
|
12
|
+
anthropic = os.getenv("ANTHROPIC_API_KEY")
|
|
13
|
+
if google:
|
|
14
|
+
print("✅ GOOGLE_API_KEY loaded")
|
|
15
|
+
if anthropic:
|
|
16
|
+
print("✅ ANTHROPIC_API_KEY loaded")
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import pytest
|
|
4
|
+
from ragmint.autotuner import AutoRAGTuner
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def setup_docs(tmp_path):
|
|
8
|
+
"""Create a temporary corpus with multiple text files for testing."""
|
|
9
|
+
corpus = tmp_path / "corpus"
|
|
10
|
+
corpus.mkdir()
|
|
11
|
+
(corpus / "short_doc.txt").write_text("AI is changing the world.")
|
|
12
|
+
(corpus / "long_doc.txt").write_text("Machine learning enables RAG pipelines to optimize retrievals. " * 50)
|
|
13
|
+
return str(corpus)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_analyze_corpus(tmp_path):
|
|
17
|
+
"""Ensure AutoRAGTuner analyzes corpus correctly."""
|
|
18
|
+
docs_path = setup_docs(tmp_path)
|
|
19
|
+
tuner = AutoRAGTuner(docs_path)
|
|
20
|
+
stats = tuner.corpus_stats
|
|
21
|
+
|
|
22
|
+
assert stats["num_docs"] == 2, "Should detect all documents"
|
|
23
|
+
assert stats["size"] > 0, "Corpus size should be positive"
|
|
24
|
+
assert stats["avg_len"] > 0, "Average document length should be computed"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.mark.parametrize("size,expected_retriever", [
|
|
28
|
+
(10_000, "Chroma"),
|
|
29
|
+
(500_000, "FAISS"),
|
|
30
|
+
(1_000, "BM25"),
|
|
31
|
+
])
|
|
32
|
+
def test_recommendation_logic(tmp_path, monkeypatch, size, expected_retriever):
|
|
33
|
+
"""Validate retriever recommendation based on corpus size."""
|
|
34
|
+
docs_path = setup_docs(tmp_path)
|
|
35
|
+
tuner = AutoRAGTuner(docs_path)
|
|
36
|
+
|
|
37
|
+
# Mock corpus stats manually
|
|
38
|
+
tuner.corpus_stats = {"size": size, "avg_len": 300, "num_docs": 10}
|
|
39
|
+
|
|
40
|
+
rec = tuner.recommend()
|
|
41
|
+
assert "retriever" in rec and "embedding_model" in rec
|
|
42
|
+
assert rec["retriever"] == expected_retriever, f"Expected {expected_retriever}"
|
|
43
|
+
assert rec["chunk_size"] > 0 and rec["overlap"] >= 0
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_invalid_corpus_path(tmp_path):
|
|
47
|
+
"""Should handle missing directories gracefully."""
|
|
48
|
+
missing_path = tmp_path / "nonexistent"
|
|
49
|
+
tuner = AutoRAGTuner(str(missing_path))
|
|
50
|
+
assert tuner.corpus_stats["size"] == 0
|
|
51
|
+
assert tuner.corpus_stats["num_docs"] == 0
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from ragmint.integrations.config_adapter import LangchainConfigAdapter
|
|
3
|
+
|
|
4
|
+
def test_default_conversion():
|
|
5
|
+
"""Test that default config values are applied correctly."""
|
|
6
|
+
cfg = {
|
|
7
|
+
"retriever": "FAISS",
|
|
8
|
+
"embedding_model": "all-MiniLM-L6-v2",
|
|
9
|
+
"chunk_size": 500,
|
|
10
|
+
"overlap": 100
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
adapter = LangchainConfigAdapter(cfg)
|
|
14
|
+
result = adapter.to_standard_config()
|
|
15
|
+
|
|
16
|
+
assert result["retriever"].lower() == "faiss"
|
|
17
|
+
assert result["embedding_model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
|
18
|
+
assert result["chunk_size"] == 500
|
|
19
|
+
assert result["overlap"] == 100
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_missing_fields_are_defaulted():
|
|
23
|
+
"""Ensure missing optional fields (e.g. chunk params) are filled in."""
|
|
24
|
+
cfg = {"retriever": "BM25", "embedding_model": "all-MiniLM-L6-v2"}
|
|
25
|
+
adapter = LangchainConfigAdapter(cfg)
|
|
26
|
+
result = adapter.to_standard_config()
|
|
27
|
+
|
|
28
|
+
assert "chunk_size" in result
|
|
29
|
+
assert "overlap" in result
|
|
30
|
+
assert result["chunk_size"] > 0
|
|
31
|
+
assert result["overlap"] >= 0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_validation_of_invalid_retriever():
|
|
35
|
+
"""Ensure invalid retriever names raise an informative error."""
|
|
36
|
+
cfg = {"retriever": "InvalidBackend", "embedding_model": "all-MiniLM-L6-v2"}
|
|
37
|
+
|
|
38
|
+
with pytest.raises(ValueError, match="Unsupported retriever backend"):
|
|
39
|
+
LangchainConfigAdapter(cfg).to_standard_config()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pytest
|
|
3
|
+
from ragmint.core.embeddings import Embeddings
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_dummy_backend_output_shape():
|
|
7
|
+
model = Embeddings(backend="dummy")
|
|
8
|
+
texts = ["hello", "world"]
|
|
9
|
+
embeddings = model.encode(texts)
|
|
10
|
+
|
|
11
|
+
# Expect 2x768 array
|
|
12
|
+
assert isinstance(embeddings, np.ndarray)
|
|
13
|
+
assert embeddings.shape == (2, 768)
|
|
14
|
+
assert embeddings.dtype == np.float32
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def test_dummy_backend_single_string():
|
|
18
|
+
model = Embeddings(backend="dummy")
|
|
19
|
+
text = "test"
|
|
20
|
+
embeddings = model.encode(text)
|
|
21
|
+
|
|
22
|
+
assert embeddings.shape == (1, 768)
|
|
23
|
+
assert isinstance(embeddings, np.ndarray)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
'''@pytest.mark.skipif(
|
|
27
|
+
not hasattr(__import__('importlib').util.find_spec("sentence_transformers"), "loader"),
|
|
28
|
+
reason="sentence-transformers not installed"
|
|
29
|
+
)
|
|
30
|
+
def test_huggingface_backend_output_shape():
|
|
31
|
+
model = Embeddings(backend="huggingface", model_name="all-MiniLM-L6-v2")
|
|
32
|
+
texts = ["This is a test.", "Another sentence."]
|
|
33
|
+
embeddings = model.encode(texts)
|
|
34
|
+
|
|
35
|
+
# Expect 2x384 for MiniLM-L6-v2
|
|
36
|
+
assert isinstance(embeddings, np.ndarray)
|
|
37
|
+
assert embeddings.ndim == 2
|
|
38
|
+
assert embeddings.shape[0] == len(texts)
|
|
39
|
+
assert embeddings.dtype == np.float32
|
|
40
|
+
'''
|
|
41
|
+
|
|
42
|
+
def test_invalid_backend():
|
|
43
|
+
try:
|
|
44
|
+
Embeddings(backend="unknown")
|
|
45
|
+
except ValueError as e:
|
|
46
|
+
assert "Unsupported embedding backend" in str(e)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from ragmint.explainer import explain_results
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_explain_results_gemini():
|
|
6
|
+
"""Gemini explanation should contain model-specific phrasing."""
|
|
7
|
+
config_a = {"retriever": "FAISS", "embedding_model": "OpenAI"}
|
|
8
|
+
config_b = {"retriever": "Chroma", "embedding_model": "SentenceTransformers"}
|
|
9
|
+
result = explain_results(config_a, config_b, model="gemini")
|
|
10
|
+
assert isinstance(result, str)
|
|
11
|
+
assert "Gemini" in result or "gemini" in result
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_explain_results_claude():
|
|
15
|
+
"""Claude explanation should contain model-specific phrasing."""
|
|
16
|
+
config_a = {"retriever": "FAISS"}
|
|
17
|
+
config_b = {"retriever": "Chroma"}
|
|
18
|
+
result = explain_results(config_a, config_b, model="claude")
|
|
19
|
+
assert isinstance(result, str)
|
|
20
|
+
assert "Claude" in result or "claude" in result
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pytest
|
|
3
|
+
from ragmint.explainer import explain_results
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.mark.integration
|
|
7
|
+
def test_real_gemini_explanation():
|
|
8
|
+
"""Run real Gemini call if GOOGLE_API_KEY is set."""
|
|
9
|
+
if not os.getenv("GEMINI_API_KEY"):
|
|
10
|
+
pytest.skip("GEMINI_API_KEY not set")
|
|
11
|
+
|
|
12
|
+
config_a = {"retriever": "FAISS", "embedding_model": "OpenAI"}
|
|
13
|
+
config_b = {"retriever": "Chroma", "embedding_model": "SentenceTransformers"}
|
|
14
|
+
|
|
15
|
+
result = explain_results(config_a, config_b, model="gemini-1.5-pro")
|
|
16
|
+
assert isinstance(result, str)
|
|
17
|
+
assert len(result) > 0
|
|
18
|
+
print("\n[Gemini explanation]:", result[:200], "...")
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import pytest
|
|
4
|
+
from ragmint.autotuner import AutoRAGTuner
|
|
5
|
+
from ragmint.tuner import RAGMint
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def setup_docs(tmp_path):
|
|
9
|
+
"""Create a temporary corpus for integration testing."""
|
|
10
|
+
corpus = tmp_path / "docs"
|
|
11
|
+
corpus.mkdir()
|
|
12
|
+
(corpus / "doc1.txt").write_text("This document discusses Artificial Intelligence and Machine Learning.")
|
|
13
|
+
(corpus / "doc2.txt").write_text("Retrieval-Augmented Generation combines retrievers and LLMs effectively.")
|
|
14
|
+
return str(corpus)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_validation_file(tmp_path):
|
|
18
|
+
"""Create a temporary validation QA dataset."""
|
|
19
|
+
data = [
|
|
20
|
+
{"question": "What is AI?", "answer": "Artificial Intelligence"},
|
|
21
|
+
{"question": "Define RAG", "answer": "Retrieval-Augmented Generation"},
|
|
22
|
+
]
|
|
23
|
+
file = tmp_path / "validation_qa.json"
|
|
24
|
+
with open(file, "w", encoding="utf-8") as f:
|
|
25
|
+
json.dump(data, f)
|
|
26
|
+
return str(file)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_autotune_integration(tmp_path):
|
|
30
|
+
"""Test that AutoRAGTuner can fully run a RAGMint optimization."""
|
|
31
|
+
docs_path = setup_docs(tmp_path)
|
|
32
|
+
val_file = setup_validation_file(tmp_path)
|
|
33
|
+
|
|
34
|
+
tuner = AutoRAGTuner(docs_path)
|
|
35
|
+
best, results = tuner.auto_tune(
|
|
36
|
+
validation_set=val_file,
|
|
37
|
+
metric="faithfulness",
|
|
38
|
+
trials=2,
|
|
39
|
+
search_type="random",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Assertions on the results
|
|
43
|
+
assert isinstance(best, dict), "Best configuration should be a dict"
|
|
44
|
+
assert isinstance(results, list), "Results should be a list"
|
|
45
|
+
assert len(results) > 0, "Optimization should produce results"
|
|
46
|
+
assert "retriever" in best and "embedding_model" in best
|
|
47
|
+
assert best.get("faithfulness", 0.0) >= 0.0, "Metric value should be non-negative"
|