isage-rag-benchmark 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. isage_rag_benchmark-0.1.0.1.dist-info/METADATA +63 -0
  2. isage_rag_benchmark-0.1.0.1.dist-info/RECORD +48 -0
  3. isage_rag_benchmark-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_rag_benchmark-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  5. isage_rag_benchmark-0.1.0.1.dist-info/top_level.txt +1 -0
  6. sage/__init__.py +0 -0
  7. sage/benchmark_rag/__init__.py +16 -0
  8. sage/benchmark_rag/_version.py +4 -0
  9. sage/benchmark_rag/config/config_bm25s.yaml +51 -0
  10. sage/benchmark_rag/config/config_dense_milvus.yaml +61 -0
  11. sage/benchmark_rag/config/config_hf.yaml +43 -0
  12. sage/benchmark_rag/config/config_mixed.yaml +53 -0
  13. sage/benchmark_rag/config/config_monitoring_demo.yaml +59 -0
  14. sage/benchmark_rag/config/config_multiplex.yaml +79 -0
  15. sage/benchmark_rag/config/config_qa_chroma.yaml +51 -0
  16. sage/benchmark_rag/config/config_ray.yaml +57 -0
  17. sage/benchmark_rag/config/config_refiner.yaml +75 -0
  18. sage/benchmark_rag/config/config_rerank.yaml +56 -0
  19. sage/benchmark_rag/config/config_selfrag.yaml +24 -0
  20. sage/benchmark_rag/config/config_source.yaml +30 -0
  21. sage/benchmark_rag/config/config_source_local.yaml +21 -0
  22. sage/benchmark_rag/config/config_sparse_milvus.yaml +49 -0
  23. sage/benchmark_rag/evaluation/__init__.py +10 -0
  24. sage/benchmark_rag/evaluation/benchmark_runner.py +337 -0
  25. sage/benchmark_rag/evaluation/config/benchmark_config.yaml +35 -0
  26. sage/benchmark_rag/evaluation/evaluate_results.py +389 -0
  27. sage/benchmark_rag/implementations/__init__.py +31 -0
  28. sage/benchmark_rag/implementations/pipelines/__init__.py +24 -0
  29. sage/benchmark_rag/implementations/pipelines/qa_bm25_retrieval.py +55 -0
  30. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval.py +56 -0
  31. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_chroma.py +71 -0
  32. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_milvus.py +78 -0
  33. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_mixed.py +58 -0
  34. sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_ray.py +174 -0
  35. sage/benchmark_rag/implementations/pipelines/qa_hf_model.py +57 -0
  36. sage/benchmark_rag/implementations/pipelines/qa_monitoring_demo.py +139 -0
  37. sage/benchmark_rag/implementations/pipelines/qa_multimodal_fusion.py +318 -0
  38. sage/benchmark_rag/implementations/pipelines/qa_multiplex.py +91 -0
  39. sage/benchmark_rag/implementations/pipelines/qa_refiner.py +91 -0
  40. sage/benchmark_rag/implementations/pipelines/qa_rerank.py +76 -0
  41. sage/benchmark_rag/implementations/pipelines/qa_sparse_retrieval_milvus.py +76 -0
  42. sage/benchmark_rag/implementations/pipelines/selfrag.py +226 -0
  43. sage/benchmark_rag/implementations/tools/__init__.py +17 -0
  44. sage/benchmark_rag/implementations/tools/build_chroma_index.py +261 -0
  45. sage/benchmark_rag/implementations/tools/build_milvus_dense_index.py +86 -0
  46. sage/benchmark_rag/implementations/tools/build_milvus_index.py +59 -0
  47. sage/benchmark_rag/implementations/tools/build_milvus_sparse_index.py +85 -0
  48. sage/benchmark_rag/implementations/tools/loaders/document_loaders.py +42 -0
@@ -0,0 +1,75 @@
1
+ pipeline:
2
+ name: "sage-api-operator-operator_test"
3
+ description: "Test pipeline for Sage API Operator"
4
+ version: "1.0.0"
5
+
6
+ source:
7
+ # 数据源类型:'local'(本地 JSONL) 或 'hf'(HuggingFace Dataset)
8
+ type: "hf"
9
+ # 本地 JSONL 文件路径(type=local 时生效)
10
+ data_path: "packages/sage-benchmark/src/sage/data/qa/sample/evaluate.json"
11
+
12
+ # HuggingFace Dataset 参数(type=hf 时生效)
13
+ hf_dataset_name: "RUC-NLPIR/FlashRAG_datasets"
14
+ hf_dataset_config: "asqa"
15
+ hf_split: "dev"
16
+
17
+ retriever:
18
+ # 检索器类型选择: 'chroma' 或 'wiki18_faiss'
19
+ type: "wiki18_faiss" # 使用Wiki18 FAISS检索器
20
+
21
+ # 通用配置
22
+ dimension: 1024 # BGE-M3模型的维度
23
+ top_k: 8
24
+
25
+ # Wiki18 FAISS 专用配置 (type=wiki18_faiss 时生效)
26
+ faiss:
27
+ index_path: ""
28
+ documents_path: ""
29
+
30
+ # 嵌入模型配置
31
+ embedding:
32
+ method: "hf"
33
+ model: "BAAI/bge-m3" # 使用BGE-M3模型进行Wiki18 FAISS检索
34
+ gpu_device: 0 # 明确指定BGE-M3使用GPU 0
35
+
36
+ generator:
37
+ local:
38
+ method: "hf"
39
+ model_name: "meta-llama/Llama-2-13b-chat-hf"
40
+ seed: 42
41
+
42
+ vllm:
43
+ api_key: "token-abc123"
44
+ method: "openai"
45
+ model_name: "meta-llama/Llama-3.1-8B-Instruct"
46
+ base_url: "http://sage2:8000/v1"
47
+ seed: 42
48
+
49
+ remote:
50
+ api_key: ""
51
+ method: "openai"
52
+ model_name: "qwen-turbo-0919"
53
+ base_url: "http://127.0.0.1:8889/v1"
54
+ seed: 42
55
+
56
+ promptor:
57
+ platform: "local"
58
+
59
+ sink:
60
+ platform: "local"
61
+
62
+ refiner:
63
+ base_model_path: "Qwen/Qwen2.5-3B-Instruct"
64
+ query_analysis_module_lora_path: "jinjiajie/Query-Analysis-Qwen2.5-3B-Instruct"
65
+ doc_structuring_module_lora_path: "jinjiajie/Doc-Structuring-Qwen2.5-3B-Instruct"
66
+ global_selection_module_lora_path: "jinjiajie/Global-Selection-Qwen2.5-3B-Instruct"
67
+ score_model_name: "bge-reranker-v2-m3"
68
+ score_model_path: "BAAI/bge-reranker-v2-m3"
69
+ max_model_len: 25000
70
+ budget: 2048
71
+ gpu_device: 0 # vLLM使用GPU 1
72
+ gpu_memory_utilization: 0.5
73
+
74
+ evaluate:
75
+ platform: "local"
@@ -0,0 +1,56 @@
1
+ # ChromaDB 专用 QA 配置文件
2
+ # 适用于 qa_openai.py 的简化配置
3
+
4
+
5
+ source:
6
+ data_path: "packages/sage-benchmark/src/sage/data/qa/queries.jsonl" # 相对于 SAGE 根目录
7
+ platform: "local"
8
+
9
+ retriever:
10
+ # ChromaDB 专用配置
11
+ dimension: 384 # 修改为与 HuggingFace all-MiniLM-L6-v2 匹配
12
+ top_k: 2
13
+
14
+ chroma:
15
+ persistence_path: "./data/chroma_qa_database"
16
+ collection_name: "qa_knowledge_base"
17
+ use_embedding_query: true
18
+
19
+ # 知识库已预加载,注释掉自动加载
20
+ # knowledge_file: "../../data/qa_knowledge_base.txt"
21
+
22
+ # ChromaDB 元数据配置 (简化格式)
23
+ metadata:
24
+ hnsw:space: "cosine" # 距离度量
25
+
26
+ # 嵌入模型配置
27
+ embedding:
28
+ method: "hf"
29
+ model: "sentence-transformers/all-MiniLM-L6-v2"
30
+
31
+ promptor:
32
+ template: |
33
+ 基于以下检索到的相关文档,回答用户问题:
34
+
35
+ 相关文档:
36
+ {retrieved_documents}
37
+
38
+ 用户问题:{query}
39
+
40
+ 请提供准确、有用的回答:
41
+
42
+ generator:
43
+ vllm:
44
+ api_key: ""
45
+ method: "openai"
46
+ model_name: "meta-llama/Llama-2-7b-chat-hf"
47
+ base_url: "http://sage3:8000/v1"
48
+ seed: 42
49
+
50
+ sink:
51
+ enable_log: true
52
+
53
+ reranker:
54
+ platform: "local"
55
+ model_name: "BAAI/bge-reranker-v2-m3"
56
+ topk: 1
@@ -0,0 +1,24 @@
1
+ # Self-RAG Configuration
2
+ # Based on Self-RAG paper: https://arxiv.org/abs/2310.11511
3
+
4
+ # Data path (Self-RAG dataset format with pre-retrieved documents)
5
+ data_path: "/path/to/selfrag_dataset.jsonl"
6
+
7
+ # Retriever configuration
8
+ retriever:
9
+ top_k: 5 # Number of retrieved documents to use
10
+
11
+ # Promptor configuration
12
+ promptor:
13
+ model_name: "mistralai/Mistral-7B-Instruct-v0.1"
14
+ use_context: true # Whether to include retrieved context in prompt
15
+
16
+ # Generator configuration
17
+ generator:
18
+ model_name: "mistralai/Mistral-7B-Instruct-v0.1"
19
+ gpu_memory_utilization: 0.8
20
+ temperature: 0
21
+ max_tokens: 100
22
+
23
+ # Output path
24
+ output_path: "./output/selfrag_results.json"
@@ -0,0 +1,30 @@
1
+ pipeline:
2
+ name: "qa_source_interactive"
3
+ description: "Interactive terminal QA pipeline with LLM"
4
+ version: "1.0.0"
5
+
6
+ generator:
7
+ local:
8
+ method: "hf"
9
+ model_name: "meta-llama/Llama-2-13b-chat-hf"
10
+ seed: 42
11
+
12
+ vllm:
13
+ api_key: ""
14
+ method: "openai"
15
+ model_name: "qwen-turbo-2025-02-11"
16
+ base_url: "http://127.0.0.1:8889/v1"
17
+ seed: 42
18
+
19
+ remote:
20
+ api_key: ""
21
+ method: "openai"
22
+ model_name: "qwen-turbo-0919"
23
+ base_url: "http://127.0.0.1:8889/v1"
24
+ seed: 42
25
+
26
+ promptor:
27
+ platform: "local"
28
+
29
+ sink:
30
+ platform: "local"
@@ -0,0 +1,21 @@
1
+ pipeline:
2
+ name: "qa_source_interactive_local"
3
+ description: "Interactive terminal QA pipeline with local LLM"
4
+ version: "1.0.0"
5
+
6
+ generator:
7
+ local:
8
+ method: "hf"
9
+ model_name: "microsoft/DialoGPT-medium"
10
+ seed: 42
11
+ max_length: 512
12
+ temperature: 0.7
13
+
14
+ promptor:
15
+ platform: "local"
16
+ template: |
17
+ User: {user_query}
18
+ Assistant:
19
+
20
+ sink:
21
+ platform: "local"
@@ -0,0 +1,49 @@
1
+ # Milvus 稠密向量检索配置(适配 MilvusDenseRetriever)
2
+
3
+ source:
4
+ data_path: "./packages/sage-benchmark/src/sage/data/qa/queries.jsonl"
5
+ platform: "local"
6
+
7
+ retriever:
8
+ preload_knowledge_file: "./packages/sage-benchmark/src/sage/data/qa/qa_knowledge_base.txt"
9
+
10
+ # 通用参数
11
+ top_k: 3 # 返回文档数量
12
+
13
+ # Milvus 后端(稠密检索)
14
+ milvus_sparse:
15
+ # 本地 Milvus Lite(推荐用于快速试用)
16
+ persistence_path: "./data/milvus_qa_sparse.db"
17
+
18
+ # 远程 Milvus(如需远程,请注释上面的 persistence_path,改为如下配置)
19
+ # host: "127.0.0.1"
20
+ # port: 19530
21
+ # force_http: true
22
+
23
+ collection_name: "qa_sparse_collection"
24
+ search_type: "sparse" # 稀疏检索
25
+
26
+ # # 知识文件(可选):提供后将自动按段落读取并入库
27
+ # knowledge_file: "./packages/sage-benchmark/src/sage/data/qa/qa_knowledge_base.txt"
28
+
29
+ promptor:
30
+ template: |
31
+ 基于以下检索到的相关文档,回答用户问题:
32
+
33
+ 相关文档:
34
+ {retrieved_documents}
35
+
36
+ 用户问题:{query}
37
+
38
+ 请提供准确、有用的回答:
39
+
40
+ generator:
41
+ vllm:
42
+ api_key: ""
43
+ method: "openai"
44
+ model_name: "meta-llama/Llama-2-7b-chat-hf"
45
+ base_url: "http://sage3:8000/v1"
46
+ seed: 42
47
+
48
+ sink:
49
+ enable_log: true
@@ -0,0 +1,10 @@
1
+ """Evaluation framework for RAG benchmarking.
2
+
3
+ This module provides tools for running RAG experiments and evaluating results:
4
+ - Pipeline experiment framework
5
+ - Multiple evaluation metrics (Accuracy, F1, Exact Match)
6
+ - Batch processing support
7
+ - Result analysis tools
8
+ """
9
+
10
+ __all__: list[str] = []
@@ -0,0 +1,337 @@
1
+ """
2
+ Universal Benchmark Runner for RAG Pipelines
3
+
4
+ This module provides a generic framework for benchmarking any RAG pipeline implementation.
5
+ It handles:
6
+ - Batch processing of large datasets
7
+ - Running any pipeline from implementations/pipelines/
8
+ - Collecting and saving results
9
+ - Performance metrics tracking
10
+ """
11
+
12
+ import importlib
13
+ import json
14
+ import time
15
+ from datetime import datetime
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import yaml # type: ignore[import-untyped]
20
+ from dotenv import load_dotenv
21
+ from sage.common.config.output_paths import get_output_file
22
+ from sage.common.core import BatchFunction, MapFunction
23
+ from sage.common.utils.logging.custom_logger import CustomLogger
24
+ from sage.kernel.api.local_environment import LocalEnvironment
25
+
26
+
27
+ def load_config(path: str) -> dict:
28
+ """Load YAML configuration file."""
29
+ with open(path, encoding="utf-8") as f:
30
+ return yaml.safe_load(f)
31
+
32
+
33
+ class BatchDataLoader(BatchFunction):
34
+ """
35
+ Batch data loader for benchmark datasets.
36
+
37
+ Supports loading large datasets in batches for efficient processing.
38
+ Compatible with Self-RAG dataset format and other QA datasets.
39
+ """
40
+
41
+ def __init__(self, config: dict):
42
+ self.config = config
43
+ data_path = config.get("data_path")
44
+ if not data_path:
45
+ raise ValueError("data_path is required in config")
46
+ self.data_path: str = data_path
47
+ self.max_samples = config.get("max_samples", None)
48
+
49
+ # Load data
50
+ data = self._load_data()
51
+ if self.max_samples:
52
+ data = data[: self.max_samples]
53
+
54
+ self.batch_size = config.get("batch_size", len(data))
55
+ self.current_batch = 0
56
+ self.total_batches = (len(data) + self.batch_size - 1) // self.batch_size
57
+ self._data = data
58
+
59
+ def _load_data(self) -> list[dict[str, Any]]:
60
+ """Load dataset from JSONL file."""
61
+ data = []
62
+ with open(self.data_path, encoding="utf-8") as f:
63
+ for line in f:
64
+ if line.strip():
65
+ data.append(json.loads(line))
66
+
67
+ # Limit samples if max_samples is set
68
+ if self.max_samples and self.max_samples > 0:
69
+ print(
70
+ f"📊 Limiting dataset to {self.max_samples} samples (total available: {len(data)})"
71
+ )
72
+ data = data[: self.max_samples]
73
+
74
+ return data
75
+
76
+ def execute(self) -> dict[str, Any] | None:
77
+ """Return next batch of data."""
78
+ if self.current_batch >= self.total_batches:
79
+ return None
80
+
81
+ start_idx = self.current_batch * self.batch_size
82
+ end_idx = min(start_idx + self.batch_size, len(self._data))
83
+ batch = self._data[start_idx:end_idx]
84
+
85
+ result = {
86
+ "batch_data": batch,
87
+ "batch_id": self.current_batch,
88
+ "total_batches": self.total_batches,
89
+ }
90
+
91
+ self.current_batch += 1
92
+ return result
93
+
94
+
95
+ class PipelineRunner(MapFunction):
96
+ """
97
+ Generic pipeline runner that can execute any RAG pipeline.
98
+
99
+ Dynamically loads and runs pipeline implementations from
100
+ implementations/pipelines/ directory.
101
+
102
+ Expected pipeline interface:
103
+ - process_item(item: Dict, config: Dict) -> Dict
104
+ """
105
+
106
+ def __init__(self, config: dict):
107
+ self.config = config
108
+ self.pipeline_name = config.get("pipeline_name")
109
+ self.pipeline_config = config.get("pipeline_config", {})
110
+
111
+ # Dynamically import pipeline module
112
+ self.pipeline_module = self._load_pipeline()
113
+
114
+ def _load_pipeline(self):
115
+ """Dynamically load pipeline implementation."""
116
+ module_path = f"sage.benchmark_rag.implementations.pipelines.{self.pipeline_name}"
117
+
118
+ try:
119
+ module = importlib.import_module(module_path)
120
+
121
+ # Verify the module has process_item function
122
+ if not hasattr(module, "process_item"):
123
+ raise ImportError(
124
+ f"Pipeline module {module_path} must have a 'process_item' function "
125
+ f"with signature: process_item(item: Dict, config: Dict) -> Dict"
126
+ )
127
+
128
+ return module
129
+ except ImportError as e:
130
+ raise ImportError(
131
+ f"Failed to load pipeline '{self.pipeline_name}': {e}\n"
132
+ f"Make sure the pipeline exists in implementations/pipelines/ "
133
+ f"and has a 'process_item' function"
134
+ )
135
+
136
+ def execute(self, batch_data: dict[str, Any]) -> dict[str, Any]:
137
+ """Execute pipeline on a batch of data."""
138
+ batch = batch_data["batch_data"]
139
+
140
+ # Process each item in the batch
141
+ results = []
142
+ for i, item in enumerate(batch):
143
+ try:
144
+ # Run the pipeline on this item
145
+ result = self.pipeline_module.process_item(item, self.pipeline_config)
146
+ results.append(result)
147
+
148
+ # Progress indicator
149
+ if (i + 1) % 10 == 0:
150
+ print(f" Processed {i + 1}/{len(batch)} items in batch...")
151
+
152
+ except Exception as e:
153
+ # Log error but continue processing
154
+ print(f"⚠️ Error processing item {item.get('id', 'unknown')}: {e}")
155
+ results.append(
156
+ {
157
+ "id": item.get("id", "unknown"),
158
+ "question": item.get("question", ""),
159
+ "error": str(e),
160
+ "ground_truth": item.get("answers", []),
161
+ }
162
+ )
163
+
164
+ return {
165
+ "results": results,
166
+ "batch_id": batch_data["batch_id"],
167
+ "total_batches": batch_data["total_batches"],
168
+ }
169
+
170
+
171
+ class ResultsCollector(MapFunction):
172
+ """
173
+ Collects and saves benchmark results.
174
+
175
+ Handles incremental or final saving of results with metadata.
176
+ """
177
+
178
+ def __init__(self, config: dict):
179
+ self.config = config
180
+ default_output = get_output_file("benchmark_results.json", "benchmarks")
181
+ self.output_path = config.get("output_path", str(default_output))
182
+ self.save_mode = config.get("save_mode", "incremental")
183
+ self.all_results: list[dict[str, Any]] = []
184
+ self.start_time = time.time()
185
+
186
+ def execute(self, batch_result: dict[str, Any]) -> dict[str, Any]:
187
+ """Collect results from a batch."""
188
+ results = batch_result["results"]
189
+ batch_id = batch_result["batch_id"]
190
+ total_batches = batch_result["total_batches"]
191
+
192
+ self.all_results.extend(results)
193
+
194
+ print(
195
+ f"✅ Processed batch {batch_id + 1}/{total_batches}, "
196
+ f"Total results: {len(self.all_results)}"
197
+ )
198
+
199
+ # Save based on mode
200
+ if self.save_mode == "incremental":
201
+ self._save_results(batch_id + 1, total_batches)
202
+ elif self.save_mode == "final" and batch_id + 1 == total_batches:
203
+ self._save_results(batch_id + 1, total_batches)
204
+
205
+ return batch_result
206
+
207
+ def _save_results(self, current_batch: int, total_batches: int):
208
+ """Save results to file with metadata."""
209
+ elapsed_time = time.time() - self.start_time
210
+
211
+ output = {
212
+ "metadata": {
213
+ "pipeline_name": self.config.get("pipeline_name", "unknown"),
214
+ "timestamp": datetime.now().isoformat(),
215
+ "total_samples": len(self.all_results),
216
+ "completed_batches": f"{current_batch}/{total_batches}",
217
+ "elapsed_time_seconds": round(elapsed_time, 2),
218
+ "config": self.config,
219
+ },
220
+ "results": self.all_results,
221
+ }
222
+
223
+ # Ensure output directory exists
224
+ output_path = Path(self.output_path)
225
+ output_path.parent.mkdir(parents=True, exist_ok=True)
226
+
227
+ with open(output_path, "w", encoding="utf-8") as f:
228
+ json.dump(output, f, indent=2, ensure_ascii=False)
229
+
230
+ print(f"💾 Results saved to: {self.output_path}")
231
+
232
+
233
+ def run_benchmark(config: dict) -> None:
234
+ """
235
+ Run benchmark with the specified configuration.
236
+
237
+ Args:
238
+ config: Configuration dictionary with:
239
+ - data: Data loading configuration
240
+ - pipeline: Pipeline configuration
241
+ - output: Output configuration
242
+ """
243
+ print("=" * 60)
244
+ print("🚀 Starting RAG Benchmark")
245
+ print("=" * 60)
246
+
247
+ pipeline_name = config.get("pipeline", {}).get("pipeline_name", "unknown")
248
+ data_path = config.get("data", {}).get("data_path", "unknown")
249
+
250
+ print(f"📊 Pipeline: {pipeline_name}")
251
+ print(f"📁 Dataset: {data_path}")
252
+ print(f"💾 Output: {config.get('output', {}).get('output_path', 'default')}")
253
+ print("=" * 60)
254
+
255
+ env = LocalEnvironment("benchmark_pipeline")
256
+
257
+ # Build benchmark pipeline
258
+ (
259
+ env.from_source(BatchDataLoader, config["data"])
260
+ .map(PipelineRunner, config["pipeline"])
261
+ .sink(ResultsCollector, {**config["output"], **config["pipeline"]})
262
+ )
263
+
264
+ env.submit()
265
+ env.close()
266
+
267
+ print("=" * 60)
268
+ print("✅ Benchmark completed!")
269
+ print("=" * 60)
270
+
271
+
272
+ def main():
273
+ """Main entry point for benchmark runner."""
274
+ import argparse
275
+
276
+ parser = argparse.ArgumentParser(description="Run RAG pipeline benchmarks")
277
+ parser.add_argument(
278
+ "--config",
279
+ type=str,
280
+ default=None,
281
+ help="Path to benchmark configuration YAML file",
282
+ )
283
+ parser.add_argument(
284
+ "--pipeline",
285
+ type=str,
286
+ default=None,
287
+ help="Pipeline name (e.g., 'selfrag', 'qa_dense_retrieval_milvus')",
288
+ )
289
+ parser.add_argument(
290
+ "--data",
291
+ type=str,
292
+ default=None,
293
+ help="Path to dataset JSONL file",
294
+ )
295
+ parser.add_argument(
296
+ "--output",
297
+ type=str,
298
+ default=None,
299
+ help="Path to output results file",
300
+ )
301
+
302
+ args = parser.parse_args()
303
+
304
+ # Load configuration
305
+ if args.config:
306
+ config = load_config(args.config)
307
+ else:
308
+ # Use default config path
309
+ current_dir = Path(__file__).parent
310
+ default_config = current_dir / "config" / "benchmark_config.yaml"
311
+
312
+ if default_config.exists():
313
+ config = load_config(str(default_config))
314
+ else:
315
+ raise ValueError(
316
+ "No configuration file specified and default config not found. "
317
+ "Use --config to specify a configuration file."
318
+ )
319
+
320
+ # Override with command-line arguments
321
+ if args.pipeline:
322
+ config["pipeline"]["pipeline_name"] = args.pipeline
323
+ if args.data:
324
+ config["data"]["data_path"] = args.data
325
+ if args.output:
326
+ config["output"]["output_path"] = args.output
327
+
328
+ # Setup
329
+ CustomLogger.disable_global_console_debug()
330
+ load_dotenv(override=False)
331
+
332
+ # Run benchmark
333
+ run_benchmark(config)
334
+
335
+
336
+ if __name__ == "__main__":
337
+ main()
@@ -0,0 +1,35 @@
1
+ # Universal Benchmark Configuration
2
+ # This is a template that can be used for any RAG pipeline benchmark
3
+
4
+ # Data configuration
5
+ data:
6
+ data_path: "data/selfrag/selfrag_dev_sample_100.jsonl" # Path to dataset
7
+ batch_size: 100 # Number of samples per batch
8
+ max_samples: 100 # Limit to first N samples (null = use all data)
9
+ # Useful for quick testing or limiting dataset size
10
+
11
+ # Pipeline configuration
12
+ pipeline:
13
+ pipeline_name: "selfrag" # Name of pipeline in implementations/pipelines/
14
+ pipeline_config: # Pipeline-specific configuration
15
+ # This section is passed to the pipeline implementation
16
+ # Structure depends on the specific pipeline
17
+ top_k: 5
18
+ model_name: "mistralai/Mistral-7B-Instruct-v0.1"
19
+ gpu_memory_utilization: 0.8
20
+ temperature: 0.0
21
+ max_tokens: 100
22
+
23
+ # Output configuration
24
+ output:
25
+ output_path: "artifacts/benchmarks/benchmark_results.json"
26
+ save_mode: "incremental" # "incremental" or "final"
27
+
28
+ # Optional: Evaluation metrics configuration
29
+ # evaluation:
30
+ # metrics:
31
+ # - "exact_match"
32
+ # - "f1_score"
33
+ # - "bleu"
34
+ # evaluator_config:
35
+ # # Metric-specific configuration