ragmint 0.2.3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragmint/app.py +512 -0
- ragmint/autotuner.py +201 -17
- ragmint/core/chunking.py +68 -4
- ragmint/core/embeddings.py +46 -10
- ragmint/core/evaluation.py +33 -14
- ragmint/core/pipeline.py +34 -10
- ragmint/core/retriever.py +152 -20
- ragmint/experiments/validation_qa.json +1 -14
- ragmint/explainer.py +47 -20
- ragmint/integrations/__init__.py +0 -0
- ragmint/integrations/config_adapter.py +96 -0
- ragmint/integrations/langchain_prebuilder.py +99 -0
- ragmint/leaderboard.py +41 -35
- ragmint/qa_generator.py +190 -0
- ragmint/tests/test_autotuner.py +52 -30
- ragmint/tests/test_config_adapter.py +39 -0
- ragmint/tests/test_embeddings.py +46 -0
- ragmint/tests/test_explainer.py +28 -12
- ragmint/tests/test_integration_autotuner_ragmint.py +39 -52
- ragmint/tests/test_langchain_prebuilder.py +82 -0
- ragmint/tests/test_leaderboard.py +78 -25
- ragmint/tests/test_pipeline.py +3 -2
- ragmint/tests/test_qa_generator.py +66 -0
- ragmint/tests/test_retriever.py +3 -2
- ragmint/tests/test_tuner.py +1 -1
- ragmint/tuner.py +109 -22
- ragmint-0.4.6.data/data/README.md +485 -0
- ragmint-0.4.6.dist-info/METADATA +530 -0
- ragmint-0.4.6.dist-info/RECORD +48 -0
- ragmint/tests/test_explainer_integration.py +0 -18
- ragmint-0.2.3.data/data/README.md +0 -284
- ragmint-0.2.3.dist-info/METADATA +0 -312
- ragmint-0.2.3.dist-info/RECORD +0 -40
- {ragmint-0.2.3.data → ragmint-0.4.6.data}/data/LICENSE +0 -0
- {ragmint-0.2.3.dist-info → ragmint-0.4.6.dist-info}/WHEEL +0 -0
- {ragmint-0.2.3.dist-info → ragmint-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {ragmint-0.2.3.dist-info → ragmint-0.4.6.dist-info}/top_level.txt +0 -0
ragmint/app.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAGMint Dashboard
|
|
3
|
+
-----------------
|
|
4
|
+
Gradio UI for AutoRAG / RAGMint:
|
|
5
|
+
- Upload corpus files
|
|
6
|
+
- Run recommend() (autotuner quick suggestion)
|
|
7
|
+
- Run full optimize() using grid/random/bayesian
|
|
8
|
+
- View leaderboard entries (local JSONL)
|
|
9
|
+
- Request LLM explanation for the best run
|
|
10
|
+
- Simple analytics: score histogram, latency summary, runs over time
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
import json
|
|
16
|
+
import time
|
|
17
|
+
from typing import List, Dict, Any
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import matplotlib.pyplot as plt
|
|
20
|
+
import gradio as gr
|
|
21
|
+
from dotenv import set_key, load_dotenv
|
|
22
|
+
from ragmint.autotuner import AutoRAGTuner
|
|
23
|
+
from ragmint.tuner import RAGMint
|
|
24
|
+
from ragmint.leaderboard import Leaderboard
|
|
25
|
+
from ragmint.explainer import explain_results
|
|
26
|
+
from matplotlib.ticker import MultipleLocator
|
|
27
|
+
import yaml
|
|
28
|
+
|
|
29
|
+
# ----------------------------------------------------------------------
|
|
30
|
+
# CONFIGURATION
|
|
31
|
+
# ----------------------------------------------------------------------
|
|
32
|
+
DATA_DIR = "data/docs"
|
|
33
|
+
LEADERBOARD_PATH = "data/leaderboard.jsonl"
|
|
34
|
+
LOGO_PATH = "src/ragmint/assets/img/ragmint_logo.png"
|
|
35
|
+
ENV_PATH = ".env"
|
|
36
|
+
|
|
37
|
+
BG_COLOR = "#F7F4ED" # soft beige background
|
|
38
|
+
PRIMARY_GREEN = "#1D5C39" # brand green
|
|
39
|
+
|
|
40
|
+
os.makedirs(DATA_DIR, exist_ok=True)
|
|
41
|
+
os.makedirs(os.path.dirname(LEADERBOARD_PATH) or ".", exist_ok=True)
|
|
42
|
+
leaderboard = Leaderboard(storage_path=LEADERBOARD_PATH)
|
|
43
|
+
|
|
44
|
+
load_dotenv(ENV_PATH)
|
|
45
|
+
|
|
46
|
+
# ----------------------------------------------------------------------
|
|
47
|
+
# UTILITY FUNCTIONS
|
|
48
|
+
# ----------------------------------------------------------------------
|
|
49
|
+
def save_uploaded_files(files):
|
|
50
|
+
saved_files = []
|
|
51
|
+
for f in files:
|
|
52
|
+
# f is a gradio.files.NamedString
|
|
53
|
+
src_path = f.name # full path to the temp file
|
|
54
|
+
filename = os.path.basename(src_path) # extract just the name
|
|
55
|
+
dest_path = os.path.join(DATA_DIR, filename)
|
|
56
|
+
with open(src_path, "rb") as src, open(dest_path, "wb") as dst:
|
|
57
|
+
dst.write(src.read())
|
|
58
|
+
saved_files.append(filename)
|
|
59
|
+
return saved_files
|
|
60
|
+
|
|
61
|
+
def handle_validation_upload(file):
|
|
62
|
+
if not file:
|
|
63
|
+
return "⚠️ Please select a file first."
|
|
64
|
+
dest_path = os.path.join(DATA_DIR, "validation_qa.json")
|
|
65
|
+
with open(file.name, "rb") as src, open(dest_path, "wb") as dst:
|
|
66
|
+
dst.write(src.read())
|
|
67
|
+
return "✅ Validation file saved as validation_qa.json"
|
|
68
|
+
|
|
69
|
+
# --- Validation toggle ---
|
|
70
|
+
def toggle_validation_inputs(choice):
|
|
71
|
+
return (
|
|
72
|
+
gr.update(visible=(choice == "Upload JSON")),
|
|
73
|
+
gr.update(visible=(choice == "HuggingFace Dataset"), interactive=True),
|
|
74
|
+
f"Selected: {choice}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# --- API key save function with type ---
|
|
78
|
+
def save_api_key(api_key: str, provider: str):
|
|
79
|
+
if not api_key:
|
|
80
|
+
return "⚠️ No API key provided."
|
|
81
|
+
if provider == "Google":
|
|
82
|
+
set_key(ENV_PATH, "GOOGLE_API_KEY", api_key)
|
|
83
|
+
elif provider == "Anthropic":
|
|
84
|
+
set_key(ENV_PATH, "ANTHROPIC_API_KEY", api_key)
|
|
85
|
+
return f"✅ {provider} API key saved to .env."
|
|
86
|
+
|
|
87
|
+
def read_leaderboard_df():
|
|
88
|
+
if not os.path.exists(LEADERBOARD_PATH) or os.path.getsize(LEADERBOARD_PATH) == 0:
|
|
89
|
+
return pd.DataFrame()
|
|
90
|
+
return pd.read_json(LEADERBOARD_PATH, lines=True)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def plot_score_scatter(results: List[Dict[str, Any]], best_index: int = None):
|
|
94
|
+
if not results:
|
|
95
|
+
fig, ax = plt.subplots()
|
|
96
|
+
ax.text(0.5, 0.5, "No results yet", ha="center", va="center")
|
|
97
|
+
return fig
|
|
98
|
+
|
|
99
|
+
scores = [r.get("faithfulness", r.get("score", 0)) for r in results]
|
|
100
|
+
fig, ax = plt.subplots()
|
|
101
|
+
ax.scatter(range(len(scores)), scores, color="#1D5C39", label="Trials", alpha=0.7)
|
|
102
|
+
|
|
103
|
+
if best_index is not None and 0 <= best_index < len(scores):
|
|
104
|
+
ax.scatter(best_index, scores[best_index], color="gold", s=120, edgecolor="black", label="Best Run")
|
|
105
|
+
|
|
106
|
+
ax.set_title("Trial Scores", color="#1D5C39")
|
|
107
|
+
ax.set_xlabel("Trial #")
|
|
108
|
+
ax.set_ylabel("Faithfulness")
|
|
109
|
+
ax.legend()
|
|
110
|
+
|
|
111
|
+
# Force integer steps on the X axis
|
|
112
|
+
ax.xaxis.set_major_locator(MultipleLocator(1))
|
|
113
|
+
ax.set_xlim(-0.5, len(scores) - 0.5) # better spacing on the ends
|
|
114
|
+
|
|
115
|
+
plt.tight_layout()
|
|
116
|
+
return fig
|
|
117
|
+
|
|
118
|
+
def export_best_config(best_json_str: str):
|
|
119
|
+
"""
|
|
120
|
+
Export best configuration as config.yaml with only selected fields.
|
|
121
|
+
"""
|
|
122
|
+
try:
|
|
123
|
+
best = json.loads(best_json_str)
|
|
124
|
+
allowed_fields = ["retriever", "embedding_model", "reranker", "chunk_size", "overlap", "strategy"]
|
|
125
|
+
config = {k: best[k] for k in allowed_fields if k in best}
|
|
126
|
+
|
|
127
|
+
if not config:
|
|
128
|
+
return "⚠️ No valid configuration fields found."
|
|
129
|
+
|
|
130
|
+
config_path = os.path.join(DATA_DIR, "config.yaml")
|
|
131
|
+
with open(config_path, "w") as f:
|
|
132
|
+
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
|
133
|
+
|
|
134
|
+
return f"✅ Exported configuration to {config_path}"
|
|
135
|
+
|
|
136
|
+
except Exception as e:
|
|
137
|
+
return f"⚠️ Error exporting config: {str(e)}"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# ----------------------------------------------------------------------
|
|
141
|
+
# ACTION HANDLERS
|
|
142
|
+
# ----------------------------------------------------------------------
|
|
143
|
+
def handle_upload(files):
|
|
144
|
+
if not files:
|
|
145
|
+
return "No files provided."
|
|
146
|
+
saved = save_uploaded_files(files)
|
|
147
|
+
return f"✅ Saved {len(saved)} files to {DATA_DIR}."
|
|
148
|
+
|
|
149
|
+
def do_auto_tune(
|
|
150
|
+
embedding_model: str,
|
|
151
|
+
num_chunk_pairs: int,
|
|
152
|
+
search_type: str,
|
|
153
|
+
trials: int,
|
|
154
|
+
validation_choice: str,
|
|
155
|
+
hf_dataset: str = None
|
|
156
|
+
):
|
|
157
|
+
tuner = AutoRAGTuner(docs_path=DATA_DIR)
|
|
158
|
+
rec = tuner.recommend(embedding_model=embedding_model, num_chunk_pairs=num_chunk_pairs)
|
|
159
|
+
num_chunk_pairs = int(num_chunk_pairs)
|
|
160
|
+
|
|
161
|
+
# FIX: use num_chunk_pairs instead of None
|
|
162
|
+
chunk_candidates = tuner.suggest_chunk_sizes(
|
|
163
|
+
model_name=rec["embedding_model"],
|
|
164
|
+
num_pairs=num_chunk_pairs,
|
|
165
|
+
step=20
|
|
166
|
+
)
|
|
167
|
+
chunk_sizes = sorted({c for c, _ in chunk_candidates})
|
|
168
|
+
overlaps = sorted({o for _, o in chunk_candidates})
|
|
169
|
+
|
|
170
|
+
rag = RAGMint(
|
|
171
|
+
docs_path=DATA_DIR,
|
|
172
|
+
retrievers=[rec["retriever"]],
|
|
173
|
+
embeddings=[rec["embedding_model"]],
|
|
174
|
+
rerankers=["mmr"],
|
|
175
|
+
chunk_sizes=chunk_sizes,
|
|
176
|
+
overlaps=overlaps,
|
|
177
|
+
strategies=[rec["strategy"]],
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
start_time = time.time()
|
|
181
|
+
|
|
182
|
+
validation_set = None
|
|
183
|
+
if validation_choice == "Upload JSON":
|
|
184
|
+
validation_path = os.path.join(DATA_DIR, "validation_qa.json")
|
|
185
|
+
if os.path.exists(validation_path):
|
|
186
|
+
validation_set = validation_path
|
|
187
|
+
elif validation_choice == "HuggingFace Dataset" and hf_dataset:
|
|
188
|
+
validation_set = hf_dataset.strip()
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
best, results = rag.optimize(
|
|
192
|
+
validation_set=validation_set,
|
|
193
|
+
metric="faithfulness",
|
|
194
|
+
search_type=search_type,
|
|
195
|
+
trials=trials,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
elapsed = time.time() - start_time
|
|
199
|
+
|
|
200
|
+
run_id = f"run_{int(time.time())}"
|
|
201
|
+
corpus_stats = {
|
|
202
|
+
"num_docs": len(rag.documents),
|
|
203
|
+
"avg_len": sum(len(d.split()) for d in rag.documents) / max(1, len(rag.documents)),
|
|
204
|
+
"corpus_size": sum(len(d) for d in rag.documents),
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
leaderboard.upload(
|
|
208
|
+
run_id=run_id,
|
|
209
|
+
best_config=best,
|
|
210
|
+
best_score=best.get("faithfulness", best.get("score", 0.0)),
|
|
211
|
+
all_results=results,
|
|
212
|
+
documents=os.listdir(DATA_DIR),
|
|
213
|
+
model=best.get("embedding_model", rec["embedding_model"]),
|
|
214
|
+
corpus_stats=corpus_stats,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# --- Plot scatter ---
|
|
218
|
+
best_index = next((i for i, r in enumerate(results)
|
|
219
|
+
if r.get("faithfulness") == best.get("faithfulness")), None)
|
|
220
|
+
fig = plot_score_scatter(results, best_index)
|
|
221
|
+
|
|
222
|
+
# --- Inline explanation (string/markdown) ---
|
|
223
|
+
explanation = explain_results(best, results, corpus_stats=corpus_stats)
|
|
224
|
+
|
|
225
|
+
# Return exactly three outputs matching your Gradio components:
|
|
226
|
+
# 1) best configuration JSON (string)
|
|
227
|
+
# 2) Matplotlib figure
|
|
228
|
+
# 3) explanation markdown/string
|
|
229
|
+
return (
|
|
230
|
+
json.dumps(best, indent=2),
|
|
231
|
+
fig,
|
|
232
|
+
explanation
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
except Exception as e:
|
|
236
|
+
# Return a placeholder figure if there’s an error and match output types
|
|
237
|
+
fig, ax = plt.subplots()
|
|
238
|
+
ax.text(0.5, 0.5, str(e), ha="center", va="center", color="red")
|
|
239
|
+
ax.axis("off")
|
|
240
|
+
# 1) error string for best_json 2) placeholder fig 3) explanation/error markdown
|
|
241
|
+
return "Error during tuning", fig, f"⚠️ {str(e)}"
|
|
242
|
+
|
|
243
|
+
def show_leaderboard_table():
|
|
244
|
+
df = read_leaderboard_df()
|
|
245
|
+
if df.empty:
|
|
246
|
+
return "No runs yet.", ""
|
|
247
|
+
table = df[["run_id", "timestamp", "best_score", "model", "best_config"]].sort_values(
|
|
248
|
+
"best_score", ascending=False
|
|
249
|
+
)
|
|
250
|
+
return table, df.to_json(orient="records", indent=2)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def do_explain(run_id: str, llm_model: str = "gemini-2.5-flash-lite"):
|
|
254
|
+
entry = leaderboard.all_results()
|
|
255
|
+
matched = [r for r in entry if r["run_id"] == run_id]
|
|
256
|
+
if not matched:
|
|
257
|
+
return f"Run {run_id} not found."
|
|
258
|
+
record = matched[0]
|
|
259
|
+
best = record["best_config"]
|
|
260
|
+
all_results = record["all_results"]
|
|
261
|
+
corpus_stats = record.get("corpus_stats", {})
|
|
262
|
+
return explain_results(best, all_results, corpus_stats=corpus_stats, model=llm_model)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def analytics_overview():
|
|
266
|
+
df = read_leaderboard_df()
|
|
267
|
+
if df.empty:
|
|
268
|
+
return "No data yet."
|
|
269
|
+
top_score = df["best_score"].max()
|
|
270
|
+
runs = len(df)
|
|
271
|
+
latencies = []
|
|
272
|
+
for row in df["all_results"]:
|
|
273
|
+
for r in row:
|
|
274
|
+
if isinstance(r, dict) and "latency" in r:
|
|
275
|
+
latencies.append(r["latency"])
|
|
276
|
+
avg_latency = sum(latencies) / len(latencies) if latencies else None
|
|
277
|
+
summary = {
|
|
278
|
+
"num_runs": runs,
|
|
279
|
+
"top_score": float(top_score),
|
|
280
|
+
"avg_trial_latency": float(avg_latency) if avg_latency else None,
|
|
281
|
+
}
|
|
282
|
+
return json.dumps(summary, indent=2)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# ----------------------------------------------------------------------
|
|
286
|
+
# CUSTOM STYLING
|
|
287
|
+
# ----------------------------------------------------------------------
|
|
288
|
+
|
|
289
|
+
custom_css = f"""
|
|
290
|
+
|
|
291
|
+
#logo {{
|
|
292
|
+
display: flex;
|
|
293
|
+
align-items: center;
|
|
294
|
+
justify-content: center; /* center horizontally */
|
|
295
|
+
padding: 0; /* remove padding */
|
|
296
|
+
margin: 0; /* remove margin */
|
|
297
|
+
box-shadow: none; /* remove shadow */
|
|
298
|
+
border: none; /* remove any border */
|
|
299
|
+
}}
|
|
300
|
+
|
|
301
|
+
#logo img {{
|
|
302
|
+
height: 80px;
|
|
303
|
+
width: auto;
|
|
304
|
+
}}
|
|
305
|
+
|
|
306
|
+
#logo button {{
|
|
307
|
+
background-color: rgba(29, 92, 57, 0.1); !important;
|
|
308
|
+
color: white !important;
|
|
309
|
+
border-radius: 12px !important;
|
|
310
|
+
font-weight: 600 !important;
|
|
311
|
+
}}
|
|
312
|
+
|
|
313
|
+
.custom-summary {{
|
|
314
|
+
background-color: #f4f4f4; /* clean light grey background */
|
|
315
|
+
border: 1px solid rgba(0, 0, 0, 0.08); /* subtle border */
|
|
316
|
+
border-radius: 16px;
|
|
317
|
+
padding: 16px;
|
|
318
|
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* soft depth */
|
|
319
|
+
transition: all 0.3s ease-in-out;
|
|
320
|
+
}}
|
|
321
|
+
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# ----------------------------------------------------------------------
|
|
326
|
+
# BUILD GRADIO APP
|
|
327
|
+
# ----------------------------------------------------------------------
|
|
328
|
+
|
|
329
|
+
with gr.Blocks(css=custom_css, theme=gr.themes.Ocean()) as demo:
|
|
330
|
+
with gr.Row(elem_id="logo"):
|
|
331
|
+
gr.Image(value=LOGO_PATH, show_label=False, interactive=False, elem_id="logo_img")
|
|
332
|
+
|
|
333
|
+
gr.Markdown(f"# Ragmint - RAG Automated Tuning")
|
|
334
|
+
gr.Markdown("Auto-tune your RAG pipeline, benchmark performance, visualize results, and get AI-driven insights.")
|
|
335
|
+
|
|
336
|
+
# --- Corpus & API Key Upload ---
|
|
337
|
+
with gr.Tab("📂 Configuration"):
|
|
338
|
+
with gr.Row():
|
|
339
|
+
# --- LEFT SIDE: Main Steps ---
|
|
340
|
+
with gr.Column(scale=3):
|
|
341
|
+
gr.Markdown("### 1️⃣ Upload Corpus Files")
|
|
342
|
+
uploader = gr.File(label="Upload corpus files", file_count="multiple")
|
|
343
|
+
upload_btn = gr.Button("Upload Files", variant="primary")
|
|
344
|
+
|
|
345
|
+
gr.Markdown("### 2️⃣ Add LLM Key")
|
|
346
|
+
api_provider = gr.Radio(
|
|
347
|
+
label="Provider",
|
|
348
|
+
choices=["Google", "Anthropic"],
|
|
349
|
+
value="Google"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
api_key_input = gr.Textbox(
|
|
353
|
+
label="API Key",
|
|
354
|
+
placeholder="Paste your API key here",
|
|
355
|
+
type="password"
|
|
356
|
+
)
|
|
357
|
+
save_api_btn = gr.Button("Save API Key", variant="primary")
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
gr.Markdown("### 3️⃣ Validation Dataset (Optional)")
|
|
362
|
+
validation_source = gr.Radio(
|
|
363
|
+
label="Validation Source",
|
|
364
|
+
choices=["Default File","Upload JSON","HuggingFace Dataset"],
|
|
365
|
+
value="Default File"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
with gr.Row(visible=False) as validation_upload_row:
|
|
369
|
+
validation_file = gr.File(
|
|
370
|
+
label="Upload validation_qa.json",
|
|
371
|
+
file_count="single",
|
|
372
|
+
interactive=True
|
|
373
|
+
)
|
|
374
|
+
upload_validation_btn = gr.Button("Upload Validation File", variant="primary")
|
|
375
|
+
|
|
376
|
+
validation_hf_dataset = gr.Textbox(
|
|
377
|
+
label="HuggingFace Dataset Name",
|
|
378
|
+
placeholder="e.g. squad, hotpot_qa, or your own dataset",
|
|
379
|
+
interactive=True,
|
|
380
|
+
visible=False
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
# --- RIGHT SIDE: Status Summary ---
|
|
386
|
+
with gr.Column(scale=1, elem_classes=["custom-summary"]):
|
|
387
|
+
gr.Markdown("### ⚙️ Configuration")
|
|
388
|
+
gr.Markdown("Monitor your current setup below:")
|
|
389
|
+
upload_status = gr.Textbox(label="File Upload Status", interactive=False)
|
|
390
|
+
save_status = gr.Textbox(label="API Key Status", interactive=False)
|
|
391
|
+
validation_status = gr.Textbox(label="Validation Selection", interactive=False)
|
|
392
|
+
|
|
393
|
+
# --- Event bindings ---
|
|
394
|
+
upload_btn.click(fn=handle_upload, inputs=[uploader], outputs=[upload_status])
|
|
395
|
+
save_api_btn.click(
|
|
396
|
+
fn=save_api_key,
|
|
397
|
+
inputs=[api_key_input, api_provider],
|
|
398
|
+
outputs=[save_status]
|
|
399
|
+
)
|
|
400
|
+
upload_validation_btn.click(fn=handle_validation_upload, inputs=[validation_file], outputs=[validation_status])
|
|
401
|
+
validation_source.change(
|
|
402
|
+
fn=toggle_validation_inputs,
|
|
403
|
+
inputs=[validation_source],
|
|
404
|
+
outputs=[validation_upload_row, validation_hf_dataset, validation_status]
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# --- Unified AutoTune ---
|
|
408
|
+
with gr.Tab("🤖 AutoTune"):
|
|
409
|
+
# 🌿 Custom CSS for dashboard-like cards
|
|
410
|
+
gr.HTML("""
|
|
411
|
+
<style>
|
|
412
|
+
.card {
|
|
413
|
+
background-color: #f4f4f4; /* clean light grey background */
|
|
414
|
+
border: 1px solid rgba(0, 0, 0, 0.08); /* subtle border */
|
|
415
|
+
border-radius: 16px;
|
|
416
|
+
padding: 16px;
|
|
417
|
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* soft depth */
|
|
418
|
+
transition: all 0.3s ease-in-out;
|
|
419
|
+
}
|
|
420
|
+
.card:hover {
|
|
421
|
+
box-shadow: 0 4px 12px rgba(0,0,0,0.08);
|
|
422
|
+
}
|
|
423
|
+
.section-title {
|
|
424
|
+
font-size: 1.1rem;
|
|
425
|
+
font-weight: 600;
|
|
426
|
+
color: #1D5C39;
|
|
427
|
+
margin-bottom: 8px;
|
|
428
|
+
}
|
|
429
|
+
</style>
|
|
430
|
+
""")
|
|
431
|
+
|
|
432
|
+
# ⚙️ Settings
|
|
433
|
+
with gr.Column(elem_classes=["card"]):
|
|
434
|
+
gr.Markdown("⚙️ <span class='section-title'>AutoTuner Settings</span>")
|
|
435
|
+
with gr.Accordion("Advanced Settings", open=False):
|
|
436
|
+
embed_model = gr.Textbox(
|
|
437
|
+
value="sentence-transformers/all-MiniLM-L6-v2",
|
|
438
|
+
label="Embedding Model",
|
|
439
|
+
info="Model used to generate text embeddings for retrieval."
|
|
440
|
+
)
|
|
441
|
+
num_pairs = gr.Number(
|
|
442
|
+
value=5,
|
|
443
|
+
label="Chunk Candidates",
|
|
444
|
+
info="Number of chunk size-overlap pairs to test."
|
|
445
|
+
)
|
|
446
|
+
search_type = gr.Dropdown(
|
|
447
|
+
choices=["random", "grid", "bayesian"],
|
|
448
|
+
value="grid",
|
|
449
|
+
label="Search Type",
|
|
450
|
+
info="Method used for optimization search over hyperparameters."
|
|
451
|
+
)
|
|
452
|
+
trials = gr.Slider(
|
|
453
|
+
minimum=1, maximum=50, step=1, value=5,
|
|
454
|
+
label="Trials",
|
|
455
|
+
info="Number of trials to run during optimization."
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
autotune_btn = gr.Button("🚀 Run AutoTune", elem_id="autotune_btn", variant="primary")
|
|
459
|
+
|
|
460
|
+
# 🏆 Best Configuration
|
|
461
|
+
with gr.Row(elem_classes=["card"]):
|
|
462
|
+
with gr.Column():
|
|
463
|
+
gr.Markdown("🏆 <span class='section-title'>Best Configuration</span>")
|
|
464
|
+
best_json = gr.Textbox(label="", interactive=False, lines=10)
|
|
465
|
+
|
|
466
|
+
# 📊 Trial Scores
|
|
467
|
+
with gr.Row(elem_classes=["card"]):
|
|
468
|
+
with gr.Column():
|
|
469
|
+
gr.Markdown("📊 <span class='section-title'>Trial Scores</span>")
|
|
470
|
+
score_plot = gr.Plot(label="")
|
|
471
|
+
|
|
472
|
+
# 💡 Explanation
|
|
473
|
+
with gr.Row(elem_classes=["card"]):
|
|
474
|
+
with gr.Column():
|
|
475
|
+
gr.Markdown("💡 <span class='section-title'>Explanation</span>")
|
|
476
|
+
explanation_md = gr.Markdown(label="")
|
|
477
|
+
|
|
478
|
+
export_btn = gr.Button("Export Best Configuration", variant="primary", visible=False)
|
|
479
|
+
export_status = gr.Textbox(label="Export Status", interactive=False, visible=False)
|
|
480
|
+
|
|
481
|
+
# 🔗 Connect button logic
|
|
482
|
+
autotune_btn.click(
|
|
483
|
+
fn=do_auto_tune,
|
|
484
|
+
inputs=[embed_model, num_pairs, search_type, trials, validation_source, validation_hf_dataset],
|
|
485
|
+
outputs=[best_json, score_plot, explanation_md],
|
|
486
|
+
show_progress=True
|
|
487
|
+
).then(
|
|
488
|
+
fn=lambda: (gr.update(visible=True), gr.update(visible=True)),
|
|
489
|
+
inputs=None,
|
|
490
|
+
outputs=[export_btn, export_status]
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
export_btn.click(
|
|
494
|
+
fn=export_best_config,
|
|
495
|
+
inputs=[best_json],
|
|
496
|
+
outputs=[export_status]
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
with gr.Tab("🏆 Leaderboard"):
|
|
500
|
+
show_btn = gr.Button("Refresh",variant="primary")
|
|
501
|
+
lb_table = gr.Dataframe(label="Leaderboard", interactive=False)
|
|
502
|
+
lb_json = gr.Textbox(label="Raw JSON", interactive=False)
|
|
503
|
+
show_btn.click(fn=show_leaderboard_table, outputs=[lb_table, lb_json])
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
gr.Markdown(
|
|
507
|
+
f"<center><p font-size:0.9em;'>"
|
|
508
|
+
"Built with ❤️ using RAGMint · © 2025 andyolivers.com</p></center>"
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
if __name__ == "__main__":
|
|
512
|
+
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
|