visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_dataset_auto
|
|
12
|
+
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
13
|
+
from visual_rag import VisualEmbedder
|
|
14
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _maybe_load_dotenv() -> None:
|
|
20
|
+
try:
|
|
21
|
+
from dotenv import load_dotenv
|
|
22
|
+
except ImportError:
|
|
23
|
+
return
|
|
24
|
+
if Path(".env").exists():
|
|
25
|
+
load_dotenv(".env")
|
|
26
|
+
|
|
27
|
+
def _torch_dtype_to_str(dtype) -> str:
|
|
28
|
+
if dtype is None:
|
|
29
|
+
return "auto"
|
|
30
|
+
s = str(dtype)
|
|
31
|
+
return s.replace("torch.", "")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _parse_torch_dtype(dtype_str: str):
|
|
35
|
+
if dtype_str == "auto":
|
|
36
|
+
return None
|
|
37
|
+
import torch
|
|
38
|
+
|
|
39
|
+
mapping = {
|
|
40
|
+
"float32": torch.float32,
|
|
41
|
+
"float16": torch.float16,
|
|
42
|
+
"bfloat16": torch.bfloat16,
|
|
43
|
+
}
|
|
44
|
+
return mapping[dtype_str]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _infer_collection_vector_dtype(*, client, collection_name: str) -> Optional[str]:
|
|
48
|
+
try:
|
|
49
|
+
info = client.get_collection(collection_name)
|
|
50
|
+
except Exception:
|
|
51
|
+
return None
|
|
52
|
+
vectors = getattr(getattr(getattr(info, "config", None), "params", None), "vectors", None)
|
|
53
|
+
if not vectors:
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
initial = None
|
|
57
|
+
if isinstance(vectors, dict):
|
|
58
|
+
initial = vectors.get("initial")
|
|
59
|
+
else:
|
|
60
|
+
try:
|
|
61
|
+
initial = vectors.get("initial")
|
|
62
|
+
except Exception:
|
|
63
|
+
initial = None
|
|
64
|
+
|
|
65
|
+
dt = getattr(initial, "datatype", None) if initial is not None else None
|
|
66
|
+
if dt is None:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
s = str(dt).lower()
|
|
70
|
+
if "float16" in s:
|
|
71
|
+
return "float16"
|
|
72
|
+
if "float32" in s:
|
|
73
|
+
return "float32"
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _evaluate(
|
|
78
|
+
*,
|
|
79
|
+
retriever: MultiVectorRetriever,
|
|
80
|
+
queries: List,
|
|
81
|
+
qrels: Dict[str, Dict[str, int]],
|
|
82
|
+
top_k: int,
|
|
83
|
+
mode: str,
|
|
84
|
+
stage1_mode: str,
|
|
85
|
+
prefetch_k: int,
|
|
86
|
+
max_queries: int = 0,
|
|
87
|
+
precomputed_query_embeddings: Optional[List[np.ndarray]] = None,
|
|
88
|
+
) -> Dict[str, float]:
|
|
89
|
+
ndcg10: List[float] = []
|
|
90
|
+
mrr10: List[float] = []
|
|
91
|
+
recall10: List[float] = []
|
|
92
|
+
recall5: List[float] = []
|
|
93
|
+
latencies_ms: List[float] = []
|
|
94
|
+
|
|
95
|
+
if max_queries and max_queries > 0:
|
|
96
|
+
queries = queries[:max_queries]
|
|
97
|
+
|
|
98
|
+
iterator = queries
|
|
99
|
+
try:
|
|
100
|
+
from tqdm import tqdm
|
|
101
|
+
|
|
102
|
+
iterator = tqdm(queries, desc=f"🔎 Evaluating ({mode})", unit="q")
|
|
103
|
+
except ImportError:
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
for idx, q in enumerate(iterator):
|
|
107
|
+
start = time.time()
|
|
108
|
+
if precomputed_query_embeddings is None:
|
|
109
|
+
results = retriever.search(
|
|
110
|
+
query=q.text,
|
|
111
|
+
top_k=top_k,
|
|
112
|
+
mode=mode,
|
|
113
|
+
prefetch_k=prefetch_k,
|
|
114
|
+
stage1_mode=stage1_mode,
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
query_embedding = precomputed_query_embeddings[idx]
|
|
118
|
+
if mode == "single_full":
|
|
119
|
+
results = retriever._single_stage.search(
|
|
120
|
+
query_embedding=query_embedding,
|
|
121
|
+
top_k=top_k,
|
|
122
|
+
strategy="multi_vector",
|
|
123
|
+
)
|
|
124
|
+
elif mode == "two_stage":
|
|
125
|
+
results = retriever._two_stage.search(
|
|
126
|
+
query_embedding=query_embedding,
|
|
127
|
+
top_k=top_k,
|
|
128
|
+
prefetch_k=prefetch_k,
|
|
129
|
+
stage1_mode=stage1_mode,
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(f"Unsupported mode for precomputed embeddings: {mode}")
|
|
133
|
+
latencies_ms.append((time.time() - start) * 1000.0)
|
|
134
|
+
|
|
135
|
+
ranking = [str(r["id"]) for r in results]
|
|
136
|
+
rels = qrels.get(q.query_id, {})
|
|
137
|
+
|
|
138
|
+
ndcg10.append(ndcg_at_k(ranking, rels, k=10))
|
|
139
|
+
mrr10.append(mrr_at_k(ranking, rels, k=10))
|
|
140
|
+
recall5.append(recall_at_k(ranking, rels, k=5))
|
|
141
|
+
recall10.append(recall_at_k(ranking, rels, k=10))
|
|
142
|
+
|
|
143
|
+
return {
|
|
144
|
+
"ndcg@10": float(np.mean(ndcg10)),
|
|
145
|
+
"mrr@10": float(np.mean(mrr10)),
|
|
146
|
+
"recall@5": float(np.mean(recall5)),
|
|
147
|
+
"recall@10": float(np.mean(recall10)),
|
|
148
|
+
"avg_latency_ms": float(np.mean(latencies_ms)),
|
|
149
|
+
"p95_latency_ms": float(np.percentile(latencies_ms, 95)),
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def main() -> None:
|
|
154
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
|
155
|
+
|
|
156
|
+
parser = argparse.ArgumentParser()
|
|
157
|
+
parser.add_argument("--dataset", type=str, default="vidore/tatdqa_test")
|
|
158
|
+
parser.add_argument("--collection", type=str, default="vidore_tatdqa_test")
|
|
159
|
+
parser.add_argument("--model", type=str, default="vidore/colSmol-500M")
|
|
160
|
+
parser.add_argument(
|
|
161
|
+
"--torch-dtype",
|
|
162
|
+
type=str,
|
|
163
|
+
default="auto",
|
|
164
|
+
choices=["auto", "float32", "float16", "bfloat16"],
|
|
165
|
+
help="Torch dtype for model weights (default: auto; inferred from collection vector dtype when possible).",
|
|
166
|
+
)
|
|
167
|
+
parser.add_argument("--top-k", type=int, default=10)
|
|
168
|
+
parser.add_argument("--mode", type=str, default="two_stage", choices=["single_full", "two_stage"])
|
|
169
|
+
parser.add_argument(
|
|
170
|
+
"--stage1-mode",
|
|
171
|
+
type=str,
|
|
172
|
+
default="tokens_vs_tiles",
|
|
173
|
+
choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
|
|
174
|
+
)
|
|
175
|
+
parser.add_argument(
|
|
176
|
+
"--prefetch-ks",
|
|
177
|
+
type=str,
|
|
178
|
+
default="20,50,100,200,400",
|
|
179
|
+
help="Comma-separated list of prefetch_k values (only used for mode=two_stage).",
|
|
180
|
+
)
|
|
181
|
+
parser.add_argument("--prefer-grpc", action="store_true")
|
|
182
|
+
parser.add_argument("--out-dir", type=str, default="results/sweeps")
|
|
183
|
+
parser.add_argument(
|
|
184
|
+
"--max-queries",
|
|
185
|
+
type=int,
|
|
186
|
+
default=0,
|
|
187
|
+
help="Limit number of queries for a quick smoke test (0 = all).",
|
|
188
|
+
)
|
|
189
|
+
parser.add_argument(
|
|
190
|
+
"--sample-queries",
|
|
191
|
+
type=int,
|
|
192
|
+
default=0,
|
|
193
|
+
help="Sample N queries for sweeps (0 = disable).",
|
|
194
|
+
)
|
|
195
|
+
parser.add_argument(
|
|
196
|
+
"--sample-strategy",
|
|
197
|
+
type=str,
|
|
198
|
+
default="head",
|
|
199
|
+
choices=["head", "random"],
|
|
200
|
+
help="How to sample queries when --sample-queries is set.",
|
|
201
|
+
)
|
|
202
|
+
parser.add_argument(
|
|
203
|
+
"--sample-seed",
|
|
204
|
+
type=int,
|
|
205
|
+
default=42,
|
|
206
|
+
help="Random seed for --sample-strategy random.",
|
|
207
|
+
)
|
|
208
|
+
parser.add_argument(
|
|
209
|
+
"--query-batch-size",
|
|
210
|
+
type=int,
|
|
211
|
+
default=32,
|
|
212
|
+
help="Batch size for embedding queries. Set 0 to disable pre-embedding and embed per query.",
|
|
213
|
+
)
|
|
214
|
+
args = parser.parse_args()
|
|
215
|
+
|
|
216
|
+
_maybe_load_dotenv()
|
|
217
|
+
|
|
218
|
+
if not os.getenv("QDRANT_URL"):
|
|
219
|
+
raise ValueError("QDRANT_URL not set. Add it to .env or export it in your shell.")
|
|
220
|
+
|
|
221
|
+
logger.info(f"Dataset: {args.dataset}")
|
|
222
|
+
logger.info(f"Collection: {args.collection}")
|
|
223
|
+
logger.info(f"Model: {args.model}")
|
|
224
|
+
logger.info(f"Mode: {args.mode}")
|
|
225
|
+
if args.mode == "two_stage":
|
|
226
|
+
logger.info(f"Stage1 mode: {args.stage1_mode}")
|
|
227
|
+
logger.info(f"Prefetch ks: {args.prefetch_ks}")
|
|
228
|
+
if args.max_queries:
|
|
229
|
+
logger.info(f"Max queries (smoke test): {args.max_queries}")
|
|
230
|
+
|
|
231
|
+
from qdrant_client import QdrantClient
|
|
232
|
+
|
|
233
|
+
client = QdrantClient(
|
|
234
|
+
url=os.getenv("QDRANT_URL"),
|
|
235
|
+
api_key=os.getenv("QDRANT_API_KEY"),
|
|
236
|
+
prefer_grpc=args.prefer_grpc,
|
|
237
|
+
check_compatibility=False,
|
|
238
|
+
timeout=120,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
requested_torch_dtype = args.torch_dtype
|
|
242
|
+
if requested_torch_dtype == "auto":
|
|
243
|
+
vdt = _infer_collection_vector_dtype(client=client, collection_name=args.collection)
|
|
244
|
+
if vdt == "float16":
|
|
245
|
+
requested_torch_dtype = "float16"
|
|
246
|
+
elif vdt == "float32":
|
|
247
|
+
requested_torch_dtype = "float32"
|
|
248
|
+
|
|
249
|
+
torch_dtype = _parse_torch_dtype(requested_torch_dtype)
|
|
250
|
+
|
|
251
|
+
corpus, queries, qrels, protocol = load_vidore_dataset_auto(args.dataset)
|
|
252
|
+
del corpus
|
|
253
|
+
logger.info(f"Loaded protocol={protocol}, queries={len(queries)}")
|
|
254
|
+
|
|
255
|
+
if args.max_queries and args.max_queries > 0:
|
|
256
|
+
queries = queries[: args.max_queries]
|
|
257
|
+
if args.sample_queries and args.sample_queries > 0:
|
|
258
|
+
if args.sample_strategy == "head":
|
|
259
|
+
queries = queries[: args.sample_queries]
|
|
260
|
+
else:
|
|
261
|
+
rng = np.random.default_rng(int(args.sample_seed))
|
|
262
|
+
n = min(int(args.sample_queries), len(queries))
|
|
263
|
+
idxs = rng.choice(len(queries), size=n, replace=False).tolist()
|
|
264
|
+
queries = [queries[i] for i in idxs]
|
|
265
|
+
logger.info(f"Eval queries: {len(queries)}")
|
|
266
|
+
|
|
267
|
+
embedder = VisualEmbedder(model_name=args.model, torch_dtype=torch_dtype)
|
|
268
|
+
logger.info(f"Effective torch dtype: {_torch_dtype_to_str(embedder.torch_dtype)}")
|
|
269
|
+
|
|
270
|
+
retriever = MultiVectorRetriever(
|
|
271
|
+
collection_name=args.collection,
|
|
272
|
+
embedder=embedder,
|
|
273
|
+
prefer_grpc=args.prefer_grpc,
|
|
274
|
+
qdrant_client=client,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
precomputed_query_embeddings: Optional[List[np.ndarray]] = None
|
|
278
|
+
if args.query_batch_size and args.query_batch_size > 0:
|
|
279
|
+
texts = [q.text for q in queries]
|
|
280
|
+
logger.info(f"Pre-embedding {len(texts)} queries (batch={args.query_batch_size})...")
|
|
281
|
+
q_tensors = embedder.embed_queries(texts, batch_size=args.query_batch_size, show_progress=True)
|
|
282
|
+
precomputed_query_embeddings = [t.detach().cpu().float().numpy() for t in q_tensors]
|
|
283
|
+
try:
|
|
284
|
+
import torch
|
|
285
|
+
|
|
286
|
+
if torch.backends.mps.is_available():
|
|
287
|
+
torch.mps.empty_cache()
|
|
288
|
+
except Exception:
|
|
289
|
+
pass
|
|
290
|
+
|
|
291
|
+
out_dir = Path(args.out_dir) / args.collection
|
|
292
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
293
|
+
|
|
294
|
+
if args.mode == "single_full":
|
|
295
|
+
metrics = _evaluate(
|
|
296
|
+
retriever=retriever,
|
|
297
|
+
queries=queries,
|
|
298
|
+
qrels=qrels,
|
|
299
|
+
top_k=args.top_k,
|
|
300
|
+
mode="single_full",
|
|
301
|
+
stage1_mode=args.stage1_mode,
|
|
302
|
+
prefetch_k=0,
|
|
303
|
+
max_queries=args.max_queries,
|
|
304
|
+
precomputed_query_embeddings=precomputed_query_embeddings,
|
|
305
|
+
)
|
|
306
|
+
out_path = out_dir / f"{protocol}__single_full__top{args.top_k}.json"
|
|
307
|
+
with open(out_path, "w") as f:
|
|
308
|
+
json.dump(
|
|
309
|
+
{
|
|
310
|
+
"dataset": args.dataset,
|
|
311
|
+
"protocol": protocol,
|
|
312
|
+
"collection": args.collection,
|
|
313
|
+
"model": args.model,
|
|
314
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
315
|
+
"mode": "single_full",
|
|
316
|
+
"top_k": args.top_k,
|
|
317
|
+
"max_queries": args.max_queries,
|
|
318
|
+
"sample_queries": args.sample_queries,
|
|
319
|
+
"sample_strategy": args.sample_strategy if args.sample_queries else None,
|
|
320
|
+
"sample_seed": args.sample_seed if args.sample_queries and args.sample_strategy == "random" else None,
|
|
321
|
+
"metrics": metrics,
|
|
322
|
+
},
|
|
323
|
+
f,
|
|
324
|
+
indent=2,
|
|
325
|
+
)
|
|
326
|
+
print(out_path)
|
|
327
|
+
print(json.dumps(metrics, indent=2))
|
|
328
|
+
return
|
|
329
|
+
|
|
330
|
+
prefetch_ks = [int(x.strip()) for x in args.prefetch_ks.split(",") if x.strip()]
|
|
331
|
+
for k in prefetch_ks:
|
|
332
|
+
metrics = _evaluate(
|
|
333
|
+
retriever=retriever,
|
|
334
|
+
queries=queries,
|
|
335
|
+
qrels=qrels,
|
|
336
|
+
top_k=args.top_k,
|
|
337
|
+
mode="two_stage",
|
|
338
|
+
stage1_mode=args.stage1_mode,
|
|
339
|
+
prefetch_k=k,
|
|
340
|
+
max_queries=args.max_queries,
|
|
341
|
+
precomputed_query_embeddings=precomputed_query_embeddings,
|
|
342
|
+
)
|
|
343
|
+
out_path = out_dir / f"{protocol}__two_stage__{args.stage1_mode}__prefetch{k}__top{args.top_k}.json"
|
|
344
|
+
with open(out_path, "w") as f:
|
|
345
|
+
json.dump(
|
|
346
|
+
{
|
|
347
|
+
"dataset": args.dataset,
|
|
348
|
+
"protocol": protocol,
|
|
349
|
+
"collection": args.collection,
|
|
350
|
+
"model": args.model,
|
|
351
|
+
"mode": "two_stage",
|
|
352
|
+
"stage1_mode": args.stage1_mode,
|
|
353
|
+
"prefetch_k": k,
|
|
354
|
+
"top_k": args.top_k,
|
|
355
|
+
"torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
|
|
356
|
+
"max_queries": args.max_queries,
|
|
357
|
+
"sample_queries": args.sample_queries,
|
|
358
|
+
"sample_strategy": args.sample_strategy if args.sample_queries else None,
|
|
359
|
+
"sample_seed": args.sample_seed if args.sample_queries and args.sample_strategy == "random" else None,
|
|
360
|
+
"metrics": metrics,
|
|
361
|
+
},
|
|
362
|
+
f,
|
|
363
|
+
indent=2,
|
|
364
|
+
)
|
|
365
|
+
print(out_path)
|
|
366
|
+
print(json.dumps(metrics, indent=2))
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
if __name__ == "__main__":
|
|
370
|
+
main()
|
|
371
|
+
|
|
372
|
+
|
demo/__init__.py
ADDED
demo/app.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Main entry point for the Visual RAG Toolkit demo application."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
ROOT_DIR = Path(__file__).parent.parent
|
|
7
|
+
sys.path.insert(0, str(ROOT_DIR))
|
|
8
|
+
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
load_dotenv(ROOT_DIR / ".env")
|
|
11
|
+
|
|
12
|
+
import streamlit as st
|
|
13
|
+
|
|
14
|
+
st.set_page_config(
|
|
15
|
+
page_title="Visual RAG Toolkit",
|
|
16
|
+
page_icon="🔬",
|
|
17
|
+
layout="wide",
|
|
18
|
+
initial_sidebar_state="expanded",
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from demo.ui.header import render_header
|
|
22
|
+
from demo.ui.sidebar import render_sidebar
|
|
23
|
+
from demo.ui.upload import render_upload_tab
|
|
24
|
+
from demo.ui.playground import render_playground_tab
|
|
25
|
+
from demo.ui.benchmark import render_benchmark_tab
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def main():
|
|
29
|
+
render_header()
|
|
30
|
+
render_sidebar()
|
|
31
|
+
|
|
32
|
+
tab_upload, tab_playground, tab_benchmark = st.tabs(["📤 Upload", "🎮 Playground", "📊 Benchmarking"])
|
|
33
|
+
|
|
34
|
+
with tab_upload:
|
|
35
|
+
render_upload_tab()
|
|
36
|
+
|
|
37
|
+
with tab_playground:
|
|
38
|
+
render_playground_tab()
|
|
39
|
+
|
|
40
|
+
with tab_benchmark:
|
|
41
|
+
render_benchmark_tab()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
main()
|