ragmint 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ragmint might be problematic. Click here for more details.
- ragmint/__init__.py +0 -0
- ragmint/__main__.py +28 -0
- ragmint/autotuner.py +138 -0
- ragmint/core/__init__.py +0 -0
- ragmint/core/chunking.py +86 -0
- ragmint/core/embeddings.py +55 -0
- ragmint/core/evaluation.py +38 -0
- ragmint/core/pipeline.py +62 -0
- ragmint/core/reranker.py +62 -0
- ragmint/core/retriever.py +165 -0
- ragmint/experiments/__init__.py +0 -0
- ragmint/experiments/validation_qa.json +14 -0
- ragmint/explainer.py +63 -0
- ragmint/integrations/__init__.py +0 -0
- ragmint/integrations/config_adapter.py +96 -0
- ragmint/integrations/langchain_prebuilder.py +99 -0
- ragmint/leaderboard.py +45 -0
- ragmint/optimization/__init__.py +0 -0
- ragmint/optimization/search.py +48 -0
- ragmint/tests/__init__.py +0 -0
- ragmint/tests/conftest.py +16 -0
- ragmint/tests/test_autotuner.py +51 -0
- ragmint/tests/test_config_adapter.py +39 -0
- ragmint/tests/test_embeddings.py +46 -0
- ragmint/tests/test_explainer.py +20 -0
- ragmint/tests/test_explainer_integration.py +18 -0
- ragmint/tests/test_integration_autotuner_ragmint.py +47 -0
- ragmint/tests/test_langchain_prebuilder.py +82 -0
- ragmint/tests/test_leaderboard.py +39 -0
- ragmint/tests/test_pipeline.py +20 -0
- ragmint/tests/test_retriever.py +15 -0
- ragmint/tests/test_search.py +17 -0
- ragmint/tests/test_tuner.py +71 -0
- ragmint/tuner.py +189 -0
- ragmint/utils/__init__.py +0 -0
- ragmint/utils/caching.py +37 -0
- ragmint/utils/data_loader.py +65 -0
- ragmint/utils/logger.py +36 -0
- ragmint/utils/metrics.py +27 -0
- ragmint-0.3.1.data/data/LICENSE +19 -0
- ragmint-0.3.1.data/data/README.md +397 -0
- ragmint-0.3.1.dist-info/METADATA +441 -0
- ragmint-0.3.1.dist-info/RECORD +46 -0
- ragmint-0.3.1.dist-info/WHEEL +5 -0
- ragmint-0.3.1.dist-info/licenses/LICENSE +19 -0
- ragmint-0.3.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import MagicMock, patch
|
|
3
|
+
from ragmint.integrations.langchain_prebuilder import LangchainPrebuilder
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.fixture
|
|
7
|
+
def sample_docs():
|
|
8
|
+
"""Small sample corpus for testing."""
|
|
9
|
+
return ["AI is transforming the world.", "RAG pipelines improve retrieval."]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def sample_config():
|
|
14
|
+
"""Default configuration for tests."""
|
|
15
|
+
return {
|
|
16
|
+
"retriever": "faiss",
|
|
17
|
+
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
|
18
|
+
"chunk_size": 200,
|
|
19
|
+
"overlap": 50,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@patch("ragmint.integrations.langchain_prebuilder.HuggingFaceEmbeddings", autospec=True)
|
|
24
|
+
@patch("ragmint.integrations.langchain_prebuilder.RecursiveCharacterTextSplitter", autospec=True)
|
|
25
|
+
def test_prepare_creates_components(mock_splitter, mock_embedder, sample_config, sample_docs):
|
|
26
|
+
"""Ensure prepare() builds retriever and embedding components properly."""
|
|
27
|
+
mock_splitter.return_value.create_documents.return_value = ["doc1", "doc2"]
|
|
28
|
+
mock_embedder.return_value = MagicMock()
|
|
29
|
+
|
|
30
|
+
# Patch FAISS to avoid building a real index
|
|
31
|
+
with patch("ragmint.integrations.langchain_prebuilder.FAISS", autospec=True) as mock_faiss:
|
|
32
|
+
mock_db = MagicMock()
|
|
33
|
+
mock_faiss.from_documents.return_value = mock_db
|
|
34
|
+
mock_db.as_retriever.return_value = "mock_retriever"
|
|
35
|
+
|
|
36
|
+
builder = LangchainPrebuilder(sample_config)
|
|
37
|
+
retriever, embeddings = builder.prepare(sample_docs)
|
|
38
|
+
|
|
39
|
+
assert retriever == "mock_retriever"
|
|
40
|
+
assert embeddings == mock_embedder.return_value
|
|
41
|
+
|
|
42
|
+
mock_splitter.assert_called_once()
|
|
43
|
+
mock_embedder.assert_called_once_with(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
44
|
+
mock_faiss.from_documents.assert_called_once()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@pytest.mark.parametrize("backend", ["faiss", "chroma", "bm25"])
|
|
48
|
+
def test_build_retriever_backends(sample_config, sample_docs, backend):
|
|
49
|
+
"""Check retriever creation for each backend."""
|
|
50
|
+
cfg = dict(sample_config)
|
|
51
|
+
cfg["retriever"] = backend
|
|
52
|
+
|
|
53
|
+
builder = LangchainPrebuilder(cfg)
|
|
54
|
+
|
|
55
|
+
# Mock embeddings + docs
|
|
56
|
+
fake_embeddings = MagicMock()
|
|
57
|
+
fake_docs = ["d1", "d2"]
|
|
58
|
+
|
|
59
|
+
with patch("ragmint.integrations.langchain_prebuilder.FAISS.from_documents", return_value=MagicMock()) as mock_faiss, \
|
|
60
|
+
patch("ragmint.integrations.langchain_prebuilder.Chroma.from_documents", return_value=MagicMock()) as mock_chroma, \
|
|
61
|
+
patch("ragmint.integrations.langchain_prebuilder.BM25Retriever.from_texts", return_value=MagicMock()) as mock_bm25:
|
|
62
|
+
retriever = builder._build_retriever(fake_docs, fake_embeddings)
|
|
63
|
+
|
|
64
|
+
# Validate retriever creation per backend
|
|
65
|
+
if backend == "faiss":
|
|
66
|
+
mock_faiss.assert_called_once()
|
|
67
|
+
elif backend == "chroma":
|
|
68
|
+
mock_chroma.assert_called_once()
|
|
69
|
+
elif backend == "bm25":
|
|
70
|
+
mock_bm25.assert_called_once()
|
|
71
|
+
|
|
72
|
+
assert retriever is not None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_invalid_backend_raises(sample_config):
|
|
76
|
+
"""Ensure ValueError is raised for unsupported retriever."""
|
|
77
|
+
cfg = dict(sample_config)
|
|
78
|
+
cfg["retriever"] = "invalid"
|
|
79
|
+
|
|
80
|
+
builder = LangchainPrebuilder(cfg)
|
|
81
|
+
with pytest.raises(ValueError, match="Unsupported retriever backend"):
|
|
82
|
+
builder._build_retriever(["doc"], MagicMock())
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import tempfile
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from ragmint.leaderboard import Leaderboard
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_leaderboard_add_and_top(tmp_path):
|
|
8
|
+
"""Ensure local leaderboard persistence works without Supabase."""
|
|
9
|
+
file_path = tmp_path / "leaderboard.jsonl"
|
|
10
|
+
lb = Leaderboard(storage_path=str(file_path))
|
|
11
|
+
|
|
12
|
+
# Add two runs
|
|
13
|
+
lb.upload("run1", {"retriever": "FAISS"}, 0.91)
|
|
14
|
+
lb.upload("run2", {"retriever": "Chroma"}, 0.85)
|
|
15
|
+
|
|
16
|
+
# Verify file content
|
|
17
|
+
assert file_path.exists()
|
|
18
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
19
|
+
lines = [json.loads(line) for line in f]
|
|
20
|
+
assert len(lines) == 2
|
|
21
|
+
|
|
22
|
+
# Get top results
|
|
23
|
+
top = lb.top_results(limit=1)
|
|
24
|
+
assert isinstance(top, list)
|
|
25
|
+
assert len(top) == 1
|
|
26
|
+
assert "score" in top[0]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_leaderboard_append_existing(tmp_path):
|
|
30
|
+
"""Ensure multiple uploads append properly."""
|
|
31
|
+
file_path = tmp_path / "leaderboard.jsonl"
|
|
32
|
+
lb = Leaderboard(storage_path=str(file_path))
|
|
33
|
+
|
|
34
|
+
for i in range(3):
|
|
35
|
+
lb.upload(f"run{i}", {"retriever": "BM25"}, 0.8 + i * 0.05)
|
|
36
|
+
|
|
37
|
+
top = lb.top_results(limit=2)
|
|
38
|
+
assert len(top) == 2
|
|
39
|
+
assert top[0]["score"] >= top[1]["score"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from ragmint.core.pipeline import RAGPipeline
|
|
3
|
+
from ragmint.core.retriever import Retriever
|
|
4
|
+
from ragmint.core.embeddings import Embeddings
|
|
5
|
+
from ragmint.core.reranker import Reranker
|
|
6
|
+
from ragmint.core.evaluation import Evaluator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_pipeline_run():
|
|
10
|
+
docs = ["doc1 text", "doc2 text"]
|
|
11
|
+
embedder = Embeddings(backend="dummy")
|
|
12
|
+
retriever = Retriever(embedder=embedder, documents=docs)
|
|
13
|
+
reranker = Reranker("mmr")
|
|
14
|
+
evaluator = Evaluator()
|
|
15
|
+
pipeline = RAGPipeline(retriever, reranker, evaluator)
|
|
16
|
+
|
|
17
|
+
result = pipeline.run("what is doc1?")
|
|
18
|
+
assert "query" in result
|
|
19
|
+
assert "answer" in result
|
|
20
|
+
assert "metrics" in result
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from ragmint.core.retriever import Retriever
|
|
3
|
+
from ragmint.core.embeddings import Embeddings
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_retrieve_basic():
|
|
7
|
+
docs = ["doc A", "doc B", "doc C"]
|
|
8
|
+
embedder = Embeddings(backend="dummy")
|
|
9
|
+
retriever = Retriever(embedder=embedder, documents=docs)
|
|
10
|
+
|
|
11
|
+
results = retriever.retrieve("sample query", top_k=2)
|
|
12
|
+
assert isinstance(results, list)
|
|
13
|
+
assert len(results) == 2
|
|
14
|
+
assert "text" in results[0]
|
|
15
|
+
assert "score" in results[0]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from ragmint.optimization.search import GridSearch, RandomSearch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_grid_search_iterates():
|
|
5
|
+
space = {"retriever": ["faiss"], "embedding_model": ["openai"], "reranker": ["mmr"]}
|
|
6
|
+
search = GridSearch(space)
|
|
7
|
+
combos = list(search)
|
|
8
|
+
assert len(combos) == 1
|
|
9
|
+
assert "retriever" in combos[0]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_random_search_n_trials():
|
|
13
|
+
space = {"retriever": ["faiss", "bm25"], "embedding_model": ["openai", "st"], "reranker": ["mmr"]}
|
|
14
|
+
search = RandomSearch(space, n_trials=5)
|
|
15
|
+
combos = list(search)
|
|
16
|
+
assert len(combos) == 5
|
|
17
|
+
assert all("retriever" in c for c in combos)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import pytest
|
|
4
|
+
from ragmint.tuner import RAGMint
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def setup_validation_file(tmp_path):
|
|
8
|
+
"""Create a temporary validation QA dataset."""
|
|
9
|
+
data = [
|
|
10
|
+
{"question": "What is AI?", "answer": "Artificial Intelligence"},
|
|
11
|
+
{"question": "Define ML", "answer": "Machine Learning"}
|
|
12
|
+
]
|
|
13
|
+
file = tmp_path / "validation_qa.json"
|
|
14
|
+
with open(file, "w", encoding="utf-8") as f:
|
|
15
|
+
json.dump(data, f)
|
|
16
|
+
return str(file)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def setup_docs(tmp_path):
|
|
20
|
+
"""Create a small document corpus for testing."""
|
|
21
|
+
corpus = tmp_path / "corpus"
|
|
22
|
+
corpus.mkdir()
|
|
23
|
+
(corpus / "doc1.txt").write_text("This is about Artificial Intelligence.")
|
|
24
|
+
(corpus / "doc2.txt").write_text("This text explains Machine Learning.")
|
|
25
|
+
return str(corpus)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.mark.parametrize("validation_mode", [
|
|
29
|
+
None, # Built-in dataset
|
|
30
|
+
"data/custom_eval.json", # Custom dataset path (mocked below)
|
|
31
|
+
])
|
|
32
|
+
def test_optimize_ragmint(tmp_path, validation_mode, monkeypatch):
|
|
33
|
+
"""Test RAGMint.optimize() with different dataset modes."""
|
|
34
|
+
docs_path = setup_docs(tmp_path)
|
|
35
|
+
val_file = setup_validation_file(tmp_path)
|
|
36
|
+
|
|
37
|
+
# If using custom dataset, mock the path
|
|
38
|
+
if validation_mode and "custom_eval" in validation_mode:
|
|
39
|
+
custom_path = tmp_path / "custom_eval.json"
|
|
40
|
+
os.rename(val_file, custom_path)
|
|
41
|
+
validation_mode = str(custom_path)
|
|
42
|
+
|
|
43
|
+
metric = "faithfulness"
|
|
44
|
+
|
|
45
|
+
# Initialize RAGMint
|
|
46
|
+
rag = RAGMint(
|
|
47
|
+
docs_path=docs_path,
|
|
48
|
+
retrievers=["faiss"],
|
|
49
|
+
embeddings=["all-MiniLM-L6-v2"],
|
|
50
|
+
rerankers=["mmr"]
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Run optimization
|
|
54
|
+
best, results = rag.optimize(
|
|
55
|
+
validation_set=validation_mode,
|
|
56
|
+
metric=metric,
|
|
57
|
+
trials=2
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Validate results
|
|
61
|
+
assert isinstance(best, dict), "Best config should be a dict"
|
|
62
|
+
assert isinstance(results, list), "Results should be a list of trials"
|
|
63
|
+
assert len(results) > 0, "Optimization should produce results"
|
|
64
|
+
|
|
65
|
+
# The best result can expose either 'score' or the metric name (e.g. 'faithfulness')
|
|
66
|
+
assert any(k in best for k in ("score", metric)), \
|
|
67
|
+
f"Best config should include either 'score' or '{metric}'"
|
|
68
|
+
|
|
69
|
+
# Ensure the metric value is valid
|
|
70
|
+
assert best.get(metric, best.get("score")) >= 0, \
|
|
71
|
+
f"{metric} score should be non-negative"
|
ragmint/tuner.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Tuple
|
|
5
|
+
from time import perf_counter
|
|
6
|
+
|
|
7
|
+
from .core.pipeline import RAGPipeline
|
|
8
|
+
from .core.embeddings import Embeddings
|
|
9
|
+
from .core.retriever import Retriever
|
|
10
|
+
from .core.reranker import Reranker
|
|
11
|
+
from .core.evaluation import Evaluator
|
|
12
|
+
from .optimization.search import GridSearch, RandomSearch, BayesianSearch
|
|
13
|
+
from .utils.data_loader import load_validation_set
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RAGMint:
|
|
19
|
+
"""
|
|
20
|
+
Main RAG pipeline optimizer and evaluator.
|
|
21
|
+
Runs combinations of retrievers, embeddings, rerankers, and chunking parameters
|
|
22
|
+
to find the best performing RAG configuration.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
docs_path: str,
|
|
28
|
+
retrievers: List[str],
|
|
29
|
+
embeddings: List[str],
|
|
30
|
+
rerankers: List[str],
|
|
31
|
+
chunk_sizes: List[int] = [400, 600],
|
|
32
|
+
overlaps: List[int] = [50, 100],
|
|
33
|
+
strategies: List[str] = ["fixed"],
|
|
34
|
+
):
|
|
35
|
+
self.docs_path = docs_path
|
|
36
|
+
self.retrievers = retrievers
|
|
37
|
+
self.embeddings = embeddings
|
|
38
|
+
self.rerankers = rerankers
|
|
39
|
+
self.chunk_sizes = chunk_sizes
|
|
40
|
+
self.overlaps = overlaps
|
|
41
|
+
self.strategies = strategies
|
|
42
|
+
|
|
43
|
+
self.documents: List[str] = self._load_docs()
|
|
44
|
+
self.embeddings_cache: Dict[str, Any] = {}
|
|
45
|
+
|
|
46
|
+
# -------------------------
|
|
47
|
+
# Document Loading
|
|
48
|
+
# -------------------------
|
|
49
|
+
def _load_docs(self) -> List[str]:
|
|
50
|
+
if not os.path.exists(self.docs_path):
|
|
51
|
+
logging.warning(f"Corpus path not found: {self.docs_path}")
|
|
52
|
+
return []
|
|
53
|
+
|
|
54
|
+
docs = []
|
|
55
|
+
for file in os.listdir(self.docs_path):
|
|
56
|
+
if file.endswith((".txt", ".md", ".rst")):
|
|
57
|
+
with open(os.path.join(self.docs_path, file), "r", encoding="utf-8") as f:
|
|
58
|
+
docs.append(f.read())
|
|
59
|
+
|
|
60
|
+
logging.info(f"📚 Loaded {len(docs)} documents from {self.docs_path}")
|
|
61
|
+
return docs
|
|
62
|
+
|
|
63
|
+
# -------------------------
|
|
64
|
+
# Embedding Cache
|
|
65
|
+
# -------------------------
|
|
66
|
+
def _embed_docs(self, model_name: str) -> Any:
|
|
67
|
+
"""Compute and cache document embeddings."""
|
|
68
|
+
if model_name in self.embeddings_cache:
|
|
69
|
+
return self.embeddings_cache[model_name]
|
|
70
|
+
|
|
71
|
+
model = Embeddings(backend="huggingface", model_name=model_name)
|
|
72
|
+
embeddings = model.encode(self.documents)
|
|
73
|
+
self.embeddings_cache[model_name] = embeddings
|
|
74
|
+
return embeddings
|
|
75
|
+
|
|
76
|
+
# -------------------------
|
|
77
|
+
# Build Pipeline
|
|
78
|
+
# -------------------------
|
|
79
|
+
def _build_pipeline(self, config: Dict[str, str]) -> RAGPipeline:
|
|
80
|
+
"""Builds a pipeline from one configuration."""
|
|
81
|
+
retriever_backend = config["retriever"]
|
|
82
|
+
model_name = config["embedding_model"]
|
|
83
|
+
reranker_name = config["reranker"]
|
|
84
|
+
|
|
85
|
+
# Chunking params (use defaults if missing)
|
|
86
|
+
chunk_size = int(config.get("chunk_size", 500))
|
|
87
|
+
overlap = int(config.get("overlap", 100))
|
|
88
|
+
strategy = config.get("strategy", "fixed")
|
|
89
|
+
|
|
90
|
+
# Load embeddings (cached)
|
|
91
|
+
embeddings = self._embed_docs(model_name)
|
|
92
|
+
embedder = Embeddings(backend="huggingface", model_name=model_name)
|
|
93
|
+
|
|
94
|
+
# Initialize retriever with backend
|
|
95
|
+
logging.info(f"⚙️ Initializing retriever backend: {retriever_backend}")
|
|
96
|
+
retriever = Retriever(
|
|
97
|
+
embedder=embedder,
|
|
98
|
+
documents=self.documents,
|
|
99
|
+
embeddings=embeddings,
|
|
100
|
+
backend=retriever_backend,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
reranker = Reranker(reranker_name)
|
|
104
|
+
evaluator = Evaluator()
|
|
105
|
+
|
|
106
|
+
# ✅ Pass chunking parameters into RAGPipeline
|
|
107
|
+
return RAGPipeline(
|
|
108
|
+
retriever,
|
|
109
|
+
reranker,
|
|
110
|
+
evaluator,
|
|
111
|
+
chunk_size=chunk_size,
|
|
112
|
+
overlap=overlap,
|
|
113
|
+
chunking_strategy=strategy,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# -------------------------
|
|
117
|
+
# Evaluate Configuration
|
|
118
|
+
# -------------------------
|
|
119
|
+
def _evaluate_config(
|
|
120
|
+
self, config: Dict[str, Any], validation: List[Dict[str, str]], metric: str
|
|
121
|
+
) -> Dict[str, float]:
|
|
122
|
+
"""Evaluates a single configuration."""
|
|
123
|
+
pipeline = self._build_pipeline(config)
|
|
124
|
+
scores = []
|
|
125
|
+
start = perf_counter()
|
|
126
|
+
|
|
127
|
+
for sample in validation:
|
|
128
|
+
query = sample.get("question") or sample.get("query") or ""
|
|
129
|
+
result = pipeline.run(query)
|
|
130
|
+
score = result["metrics"].get(metric, 0.0)
|
|
131
|
+
scores.append(score)
|
|
132
|
+
|
|
133
|
+
elapsed = perf_counter() - start
|
|
134
|
+
avg_score = sum(scores) / len(scores) if scores else 0.0
|
|
135
|
+
|
|
136
|
+
return {
|
|
137
|
+
metric: avg_score,
|
|
138
|
+
"latency": elapsed / max(1, len(validation)),
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
# -------------------------
|
|
142
|
+
# Optimize
|
|
143
|
+
# -------------------------
|
|
144
|
+
def optimize(
|
|
145
|
+
self,
|
|
146
|
+
validation_set: str,
|
|
147
|
+
metric: str = "faithfulness",
|
|
148
|
+
search_type: str = "random",
|
|
149
|
+
trials: int = 10,
|
|
150
|
+
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
|
151
|
+
"""Run optimization search over retrievers, embeddings, rerankers, and chunking."""
|
|
152
|
+
validation = load_validation_set(validation_set or "default")
|
|
153
|
+
|
|
154
|
+
# ✅ Add chunking parameters to the search space
|
|
155
|
+
search_space = {
|
|
156
|
+
"retriever": self.retrievers,
|
|
157
|
+
"embedding_model": self.embeddings,
|
|
158
|
+
"reranker": self.rerankers,
|
|
159
|
+
"chunk_size": self.chunk_sizes,
|
|
160
|
+
"overlap": self.overlaps,
|
|
161
|
+
"strategy": self.strategies,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
logging.info(f"🚀 Starting {search_type} optimization with {trials} trials")
|
|
165
|
+
|
|
166
|
+
# Select search strategy
|
|
167
|
+
try:
|
|
168
|
+
if search_type == "grid":
|
|
169
|
+
searcher = GridSearch(search_space)
|
|
170
|
+
elif search_type == "bayesian":
|
|
171
|
+
searcher = BayesianSearch(search_space)
|
|
172
|
+
else:
|
|
173
|
+
searcher = RandomSearch(search_space, n_trials=trials)
|
|
174
|
+
except Exception as e:
|
|
175
|
+
logging.warning(f"⚠️ Fallback to RandomSearch due to missing deps: {e}")
|
|
176
|
+
searcher = RandomSearch(search_space, n_trials=trials)
|
|
177
|
+
|
|
178
|
+
# Run trials
|
|
179
|
+
results = []
|
|
180
|
+
for config in searcher:
|
|
181
|
+
metrics = self._evaluate_config(config, validation, metric)
|
|
182
|
+
result = {**config, **metrics}
|
|
183
|
+
results.append(result)
|
|
184
|
+
logging.info(f"🔹 Tested config: {config} -> {metrics}")
|
|
185
|
+
|
|
186
|
+
best = max(results, key=lambda r: r.get(metric, 0.0)) if results else {}
|
|
187
|
+
logging.info(f"🏆 Best configuration: {best}")
|
|
188
|
+
|
|
189
|
+
return best, results
|
|
File without changes
|
ragmint/utils/caching.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import hashlib
|
|
4
|
+
import pickle
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Cache:
|
|
9
|
+
"""
|
|
10
|
+
Simple file-based cache for embeddings or retrievals.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, cache_dir: str = ".ragmint_cache"):
|
|
14
|
+
self.cache_dir = cache_dir
|
|
15
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
16
|
+
|
|
17
|
+
def _hash_key(self, key: str) -> str:
|
|
18
|
+
return hashlib.md5(key.encode()).hexdigest()
|
|
19
|
+
|
|
20
|
+
def exists(self, key: str) -> bool:
|
|
21
|
+
return os.path.exists(os.path.join(self.cache_dir, self._hash_key(key)))
|
|
22
|
+
|
|
23
|
+
def get(self, key: str) -> Any:
|
|
24
|
+
path = os.path.join(self.cache_dir, self._hash_key(key))
|
|
25
|
+
if not os.path.exists(path):
|
|
26
|
+
return None
|
|
27
|
+
with open(path, "rb") as f:
|
|
28
|
+
return pickle.load(f)
|
|
29
|
+
|
|
30
|
+
def set(self, key: str, value: Any):
|
|
31
|
+
path = os.path.join(self.cache_dir, self._hash_key(key))
|
|
32
|
+
with open(path, "wb") as f:
|
|
33
|
+
pickle.dump(value, f)
|
|
34
|
+
|
|
35
|
+
def clear(self):
|
|
36
|
+
for file in os.listdir(self.cache_dir):
|
|
37
|
+
os.remove(os.path.join(self.cache_dir, file))
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import csv
|
|
3
|
+
from typing import List, Dict
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from datasets import load_dataset
|
|
9
|
+
except ImportError:
|
|
10
|
+
load_dataset = None # optional dependency
|
|
11
|
+
|
|
12
|
+
DEFAULT_VALIDATION_PATH = Path(__file__).parent.parent / "experiments" / "validation_qa.json"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_json(path: str) -> List[Dict]:
|
|
16
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
17
|
+
return json.load(f)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_csv(path: str) -> List[Dict]:
|
|
21
|
+
with open(path, newline="", encoding="utf-8") as csvfile:
|
|
22
|
+
reader = csv.DictReader(csvfile)
|
|
23
|
+
return list(reader)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def save_json(path: str, data: Dict):
|
|
27
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
28
|
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
29
|
+
|
|
30
|
+
def load_validation_set(path: str | None = None) -> List[Dict]:
|
|
31
|
+
"""
|
|
32
|
+
Loads a validation dataset (QA pairs) from:
|
|
33
|
+
- Built-in default JSON file
|
|
34
|
+
- User-provided JSON or CSV
|
|
35
|
+
- Hugging Face dataset by name
|
|
36
|
+
"""
|
|
37
|
+
# Default behavior
|
|
38
|
+
if path is None or path == "default":
|
|
39
|
+
if not DEFAULT_VALIDATION_PATH.exists():
|
|
40
|
+
raise FileNotFoundError(f"Default validation set not found at {DEFAULT_VALIDATION_PATH}")
|
|
41
|
+
return load_json(DEFAULT_VALIDATION_PATH)
|
|
42
|
+
|
|
43
|
+
# Hugging Face dataset
|
|
44
|
+
if not os.path.exists(path) and load_dataset:
|
|
45
|
+
try:
|
|
46
|
+
dataset = load_dataset(path, split="validation")
|
|
47
|
+
data = [
|
|
48
|
+
{"question": q, "answer": a}
|
|
49
|
+
for q, a in zip(dataset["question"], dataset["answers"])
|
|
50
|
+
]
|
|
51
|
+
return data
|
|
52
|
+
except Exception:
|
|
53
|
+
pass # fall through to file loading
|
|
54
|
+
|
|
55
|
+
# Local file
|
|
56
|
+
p = Path(path)
|
|
57
|
+
if not p.exists():
|
|
58
|
+
raise FileNotFoundError(f"Validation file not found: {path}")
|
|
59
|
+
|
|
60
|
+
if p.suffix.lower() == ".json":
|
|
61
|
+
return load_json(path)
|
|
62
|
+
elif p.suffix.lower() in [".csv", ".tsv"]:
|
|
63
|
+
return load_csv(path)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError("Unsupported validation set format. Use JSON, CSV, or a Hugging Face dataset name.")
|
ragmint/utils/logger.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Logger:
|
|
6
|
+
"""
|
|
7
|
+
Centralized logger with optional tqdm integration and color formatting.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, name: str = "ragmint", level: int = logging.INFO):
|
|
11
|
+
self.logger = logging.getLogger(name)
|
|
12
|
+
self.logger.setLevel(level)
|
|
13
|
+
|
|
14
|
+
if not self.logger.handlers:
|
|
15
|
+
handler = logging.StreamHandler()
|
|
16
|
+
formatter = logging.Formatter(
|
|
17
|
+
"\033[96m[%(asctime)s]\033[0m \033[93m%(levelname)s\033[0m: %(message)s",
|
|
18
|
+
"%H:%M:%S",
|
|
19
|
+
)
|
|
20
|
+
handler.setFormatter(formatter)
|
|
21
|
+
self.logger.addHandler(handler)
|
|
22
|
+
|
|
23
|
+
def info(self, msg: str):
|
|
24
|
+
self.logger.info(msg)
|
|
25
|
+
|
|
26
|
+
def warning(self, msg: str):
|
|
27
|
+
self.logger.warning(msg)
|
|
28
|
+
|
|
29
|
+
def error(self, msg: str):
|
|
30
|
+
self.logger.error(msg)
|
|
31
|
+
|
|
32
|
+
def progress(self, iterable, desc="Processing", total=None):
|
|
33
|
+
return tqdm(iterable, desc=desc, total=total)
|
|
34
|
+
|
|
35
|
+
def get_logger(name: str = "ragmint") -> Logger:
|
|
36
|
+
return Logger(name)
|
ragmint/utils/metrics.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
import numpy as np
|
|
3
|
+
from difflib import SequenceMatcher
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def bleu_score(reference: str, candidate: str) -> float:
|
|
7
|
+
"""
|
|
8
|
+
Simple BLEU-like precision approximation.
|
|
9
|
+
"""
|
|
10
|
+
ref_tokens = reference.split()
|
|
11
|
+
cand_tokens = candidate.split()
|
|
12
|
+
if not cand_tokens:
|
|
13
|
+
return 0.0
|
|
14
|
+
|
|
15
|
+
matches = sum(1 for token in cand_tokens if token in ref_tokens)
|
|
16
|
+
return matches / len(cand_tokens)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def rouge_l(reference: str, candidate: str) -> float:
|
|
20
|
+
"""
|
|
21
|
+
Approximation of ROUGE-L using sequence matcher ratio.
|
|
22
|
+
"""
|
|
23
|
+
return SequenceMatcher(None, reference, candidate).ratio()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def mean_score(scores: List[float]) -> float:
|
|
27
|
+
return float(np.mean(scores)) if scores else 0.0
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Apache License
|
|
2
|
+
Version 2.0, January 2004
|
|
3
|
+
http://www.apache.org/licenses/
|
|
4
|
+
|
|
5
|
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
6
|
+
|
|
7
|
+
Copyright 2025 André Oliveira
|
|
8
|
+
|
|
9
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
10
|
+
you may not use this file except in compliance with the License.
|
|
11
|
+
You may obtain a copy of the License at
|
|
12
|
+
|
|
13
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
14
|
+
|
|
15
|
+
Unless required by applicable law or agreed to in writing, software
|
|
16
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
17
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
|
+
See the License for the specific language governing permissions and
|
|
19
|
+
limitations under the License.
|