isage-rag-benchmark 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_rag_benchmark-0.1.0.1.dist-info/METADATA +63 -0
- isage_rag_benchmark-0.1.0.1.dist-info/RECORD +48 -0
- isage_rag_benchmark-0.1.0.1.dist-info/WHEEL +5 -0
- isage_rag_benchmark-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_rag_benchmark-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark_rag/__init__.py +16 -0
- sage/benchmark_rag/_version.py +4 -0
- sage/benchmark_rag/config/config_bm25s.yaml +51 -0
- sage/benchmark_rag/config/config_dense_milvus.yaml +61 -0
- sage/benchmark_rag/config/config_hf.yaml +43 -0
- sage/benchmark_rag/config/config_mixed.yaml +53 -0
- sage/benchmark_rag/config/config_monitoring_demo.yaml +59 -0
- sage/benchmark_rag/config/config_multiplex.yaml +79 -0
- sage/benchmark_rag/config/config_qa_chroma.yaml +51 -0
- sage/benchmark_rag/config/config_ray.yaml +57 -0
- sage/benchmark_rag/config/config_refiner.yaml +75 -0
- sage/benchmark_rag/config/config_rerank.yaml +56 -0
- sage/benchmark_rag/config/config_selfrag.yaml +24 -0
- sage/benchmark_rag/config/config_source.yaml +30 -0
- sage/benchmark_rag/config/config_source_local.yaml +21 -0
- sage/benchmark_rag/config/config_sparse_milvus.yaml +49 -0
- sage/benchmark_rag/evaluation/__init__.py +10 -0
- sage/benchmark_rag/evaluation/benchmark_runner.py +337 -0
- sage/benchmark_rag/evaluation/config/benchmark_config.yaml +35 -0
- sage/benchmark_rag/evaluation/evaluate_results.py +389 -0
- sage/benchmark_rag/implementations/__init__.py +31 -0
- sage/benchmark_rag/implementations/pipelines/__init__.py +24 -0
- sage/benchmark_rag/implementations/pipelines/qa_bm25_retrieval.py +55 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval.py +56 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_chroma.py +71 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_milvus.py +78 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_mixed.py +58 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_ray.py +174 -0
- sage/benchmark_rag/implementations/pipelines/qa_hf_model.py +57 -0
- sage/benchmark_rag/implementations/pipelines/qa_monitoring_demo.py +139 -0
- sage/benchmark_rag/implementations/pipelines/qa_multimodal_fusion.py +318 -0
- sage/benchmark_rag/implementations/pipelines/qa_multiplex.py +91 -0
- sage/benchmark_rag/implementations/pipelines/qa_refiner.py +91 -0
- sage/benchmark_rag/implementations/pipelines/qa_rerank.py +76 -0
- sage/benchmark_rag/implementations/pipelines/qa_sparse_retrieval_milvus.py +76 -0
- sage/benchmark_rag/implementations/pipelines/selfrag.py +226 -0
- sage/benchmark_rag/implementations/tools/__init__.py +17 -0
- sage/benchmark_rag/implementations/tools/build_chroma_index.py +261 -0
- sage/benchmark_rag/implementations/tools/build_milvus_dense_index.py +86 -0
- sage/benchmark_rag/implementations/tools/build_milvus_index.py +59 -0
- sage/benchmark_rag/implementations/tools/build_milvus_sparse_index.py +85 -0
- sage/benchmark_rag/implementations/tools/loaders/document_loaders.py +42 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
pipeline:
|
|
2
|
+
name: "sage-api-operator-operator_test"
|
|
3
|
+
description: "Test pipeline for Sage API Operator"
|
|
4
|
+
version: "1.0.0"
|
|
5
|
+
|
|
6
|
+
source:
|
|
7
|
+
# 数据源类型:'local'(本地 JSONL) 或 'hf'(HuggingFace Dataset)
|
|
8
|
+
type: "hf"
|
|
9
|
+
# 本地 JSONL 文件路径(type=local 时生效)
|
|
10
|
+
data_path: "packages/sage-benchmark/src/sage/data/qa/sample/evaluate.json"
|
|
11
|
+
|
|
12
|
+
# HuggingFace Dataset 参数(type=hf 时生效)
|
|
13
|
+
hf_dataset_name: "RUC-NLPIR/FlashRAG_datasets"
|
|
14
|
+
hf_dataset_config: "asqa"
|
|
15
|
+
hf_split: "dev"
|
|
16
|
+
|
|
17
|
+
retriever:
|
|
18
|
+
# 检索器类型选择: 'chroma' 或 'wiki18_faiss'
|
|
19
|
+
type: "wiki18_faiss" # 使用Wiki18 FAISS检索器
|
|
20
|
+
|
|
21
|
+
# 通用配置
|
|
22
|
+
dimension: 1024 # BGE-M3模型的维度
|
|
23
|
+
top_k: 8
|
|
24
|
+
|
|
25
|
+
# Wiki18 FAISS 专用配置 (type=wiki18_faiss 时生效)
|
|
26
|
+
faiss:
|
|
27
|
+
index_path: ""
|
|
28
|
+
documents_path: ""
|
|
29
|
+
|
|
30
|
+
# 嵌入模型配置
|
|
31
|
+
embedding:
|
|
32
|
+
method: "hf"
|
|
33
|
+
model: "BAAI/bge-m3" # 使用BGE-M3模型进行Wiki18 FAISS检索
|
|
34
|
+
gpu_device: 0 # 明确指定BGE-M3使用GPU 0
|
|
35
|
+
|
|
36
|
+
generator:
|
|
37
|
+
local:
|
|
38
|
+
method: "hf"
|
|
39
|
+
model_name: "meta-llama/Llama-2-13b-chat-hf"
|
|
40
|
+
seed: 42
|
|
41
|
+
|
|
42
|
+
vllm:
|
|
43
|
+
api_key: "token-abc123"
|
|
44
|
+
method: "openai"
|
|
45
|
+
model_name: "meta-llama/Llama-3.1-8B-Instruct"
|
|
46
|
+
base_url: "http://sage2:8000/v1"
|
|
47
|
+
seed: 42
|
|
48
|
+
|
|
49
|
+
remote:
|
|
50
|
+
api_key: ""
|
|
51
|
+
method: "openai"
|
|
52
|
+
model_name: "qwen-turbo-0919"
|
|
53
|
+
base_url: "http://127.0.0.1:8889/v1"
|
|
54
|
+
seed: 42
|
|
55
|
+
|
|
56
|
+
promptor:
|
|
57
|
+
platform: "local"
|
|
58
|
+
|
|
59
|
+
sink:
|
|
60
|
+
platform: "local"
|
|
61
|
+
|
|
62
|
+
refiner:
|
|
63
|
+
base_model_path: "Qwen/Qwen2.5-3B-Instruct"
|
|
64
|
+
query_analysis_module_lora_path: "jinjiajie/Query-Analysis-Qwen2.5-3B-Instruct"
|
|
65
|
+
doc_structuring_module_lora_path: "jinjiajie/Doc-Structuring-Qwen2.5-3B-Instruct"
|
|
66
|
+
global_selection_module_lora_path: "jinjiajie/Global-Selection-Qwen2.5-3B-Instruct"
|
|
67
|
+
score_model_name: "bge-reranker-v2-m3"
|
|
68
|
+
score_model_path: "BAAI/bge-reranker-v2-m3"
|
|
69
|
+
max_model_len: 25000
|
|
70
|
+
budget: 2048
|
|
71
|
+
gpu_device: 0 # vLLM使用GPU 1
|
|
72
|
+
gpu_memory_utilization: 0.5
|
|
73
|
+
|
|
74
|
+
evaluate:
|
|
75
|
+
platform: "local"
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# ChromaDB 专用 QA 配置文件
|
|
2
|
+
# 适用于 qa_openai.py 的简化配置
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
source:
|
|
6
|
+
data_path: "packages/sage-benchmark/src/sage/data/qa/queries.jsonl" # 相对于 SAGE 根目录
|
|
7
|
+
platform: "local"
|
|
8
|
+
|
|
9
|
+
retriever:
|
|
10
|
+
# ChromaDB 专用配置
|
|
11
|
+
dimension: 384 # 修改为与 HuggingFace all-MiniLM-L6-v2 匹配
|
|
12
|
+
top_k: 2
|
|
13
|
+
|
|
14
|
+
chroma:
|
|
15
|
+
persistence_path: "./data/chroma_qa_database"
|
|
16
|
+
collection_name: "qa_knowledge_base"
|
|
17
|
+
use_embedding_query: true
|
|
18
|
+
|
|
19
|
+
# 知识库已预加载,注释掉自动加载
|
|
20
|
+
# knowledge_file: "../../data/qa_knowledge_base.txt"
|
|
21
|
+
|
|
22
|
+
# ChromaDB 元数据配置 (简化格式)
|
|
23
|
+
metadata:
|
|
24
|
+
hnsw:space: "cosine" # 距离度量
|
|
25
|
+
|
|
26
|
+
# 嵌入模型配置
|
|
27
|
+
embedding:
|
|
28
|
+
method: "hf"
|
|
29
|
+
model: "sentence-transformers/all-MiniLM-L6-v2"
|
|
30
|
+
|
|
31
|
+
promptor:
|
|
32
|
+
template: |
|
|
33
|
+
基于以下检索到的相关文档,回答用户问题:
|
|
34
|
+
|
|
35
|
+
相关文档:
|
|
36
|
+
{retrieved_documents}
|
|
37
|
+
|
|
38
|
+
用户问题:{query}
|
|
39
|
+
|
|
40
|
+
请提供准确、有用的回答:
|
|
41
|
+
|
|
42
|
+
generator:
|
|
43
|
+
vllm:
|
|
44
|
+
api_key: ""
|
|
45
|
+
method: "openai"
|
|
46
|
+
model_name: "meta-llama/Llama-2-7b-chat-hf"
|
|
47
|
+
base_url: "http://sage3:8000/v1"
|
|
48
|
+
seed: 42
|
|
49
|
+
|
|
50
|
+
sink:
|
|
51
|
+
enable_log: true
|
|
52
|
+
|
|
53
|
+
reranker:
|
|
54
|
+
platform: "local"
|
|
55
|
+
model_name: "BAAI/bge-reranker-v2-m3"
|
|
56
|
+
topk: 1
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Self-RAG Configuration
|
|
2
|
+
# Based on Self-RAG paper: https://arxiv.org/abs/2310.11511
|
|
3
|
+
|
|
4
|
+
# Data path (Self-RAG dataset format with pre-retrieved documents)
|
|
5
|
+
data_path: "/path/to/selfrag_dataset.jsonl"
|
|
6
|
+
|
|
7
|
+
# Retriever configuration
|
|
8
|
+
retriever:
|
|
9
|
+
top_k: 5 # Number of retrieved documents to use
|
|
10
|
+
|
|
11
|
+
# Promptor configuration
|
|
12
|
+
promptor:
|
|
13
|
+
model_name: "mistralai/Mistral-7B-Instruct-v0.1"
|
|
14
|
+
use_context: true # Whether to include retrieved context in prompt
|
|
15
|
+
|
|
16
|
+
# Generator configuration
|
|
17
|
+
generator:
|
|
18
|
+
model_name: "mistralai/Mistral-7B-Instruct-v0.1"
|
|
19
|
+
gpu_memory_utilization: 0.8
|
|
20
|
+
temperature: 0
|
|
21
|
+
max_tokens: 100
|
|
22
|
+
|
|
23
|
+
# Output path
|
|
24
|
+
output_path: "./output/selfrag_results.json"
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
pipeline:
|
|
2
|
+
name: "qa_source_interactive"
|
|
3
|
+
description: "Interactive terminal QA pipeline with LLM"
|
|
4
|
+
version: "1.0.0"
|
|
5
|
+
|
|
6
|
+
generator:
|
|
7
|
+
local:
|
|
8
|
+
method: "hf"
|
|
9
|
+
model_name: "meta-llama/Llama-2-13b-chat-hf"
|
|
10
|
+
seed: 42
|
|
11
|
+
|
|
12
|
+
vllm:
|
|
13
|
+
api_key: ""
|
|
14
|
+
method: "openai"
|
|
15
|
+
model_name: "qwen-turbo-2025-02-11"
|
|
16
|
+
base_url: "http://127.0.0.1:8889/v1"
|
|
17
|
+
seed: 42
|
|
18
|
+
|
|
19
|
+
remote:
|
|
20
|
+
api_key: ""
|
|
21
|
+
method: "openai"
|
|
22
|
+
model_name: "qwen-turbo-0919"
|
|
23
|
+
base_url: "http://127.0.0.1:8889/v1"
|
|
24
|
+
seed: 42
|
|
25
|
+
|
|
26
|
+
promptor:
|
|
27
|
+
platform: "local"
|
|
28
|
+
|
|
29
|
+
sink:
|
|
30
|
+
platform: "local"
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
pipeline:
|
|
2
|
+
name: "qa_source_interactive_local"
|
|
3
|
+
description: "Interactive terminal QA pipeline with local LLM"
|
|
4
|
+
version: "1.0.0"
|
|
5
|
+
|
|
6
|
+
generator:
|
|
7
|
+
local:
|
|
8
|
+
method: "hf"
|
|
9
|
+
model_name: "microsoft/DialoGPT-medium"
|
|
10
|
+
seed: 42
|
|
11
|
+
max_length: 512
|
|
12
|
+
temperature: 0.7
|
|
13
|
+
|
|
14
|
+
promptor:
|
|
15
|
+
platform: "local"
|
|
16
|
+
template: |
|
|
17
|
+
User: {user_query}
|
|
18
|
+
Assistant:
|
|
19
|
+
|
|
20
|
+
sink:
|
|
21
|
+
platform: "local"
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Milvus 稠密向量检索配置(适配 MilvusDenseRetriever)
|
|
2
|
+
|
|
3
|
+
source:
|
|
4
|
+
data_path: "./packages/sage-benchmark/src/sage/data/qa/queries.jsonl"
|
|
5
|
+
platform: "local"
|
|
6
|
+
|
|
7
|
+
retriever:
|
|
8
|
+
preload_knowledge_file: "./packages/sage-benchmark/src/sage/data/qa/qa_knowledge_base.txt"
|
|
9
|
+
|
|
10
|
+
# 通用参数
|
|
11
|
+
top_k: 3 # 返回文档数量
|
|
12
|
+
|
|
13
|
+
# Milvus 后端(稠密检索)
|
|
14
|
+
milvus_sparse:
|
|
15
|
+
# 本地 Milvus Lite(推荐用于快速试用)
|
|
16
|
+
persistence_path: "./data/milvus_qa_sparse.db"
|
|
17
|
+
|
|
18
|
+
# 远程 Milvus(如需远程,请注释上面的 persistence_path,改为如下配置)
|
|
19
|
+
# host: "127.0.0.1"
|
|
20
|
+
# port: 19530
|
|
21
|
+
# force_http: true
|
|
22
|
+
|
|
23
|
+
collection_name: "qa_sparse_collection"
|
|
24
|
+
search_type: "sparse" # 稀疏检索
|
|
25
|
+
|
|
26
|
+
# # 知识文件(可选):提供后将自动按段落读取并入库
|
|
27
|
+
# knowledge_file: "./packages/sage-benchmark/src/sage/data/qa/qa_knowledge_base.txt"
|
|
28
|
+
|
|
29
|
+
promptor:
|
|
30
|
+
template: |
|
|
31
|
+
基于以下检索到的相关文档,回答用户问题:
|
|
32
|
+
|
|
33
|
+
相关文档:
|
|
34
|
+
{retrieved_documents}
|
|
35
|
+
|
|
36
|
+
用户问题:{query}
|
|
37
|
+
|
|
38
|
+
请提供准确、有用的回答:
|
|
39
|
+
|
|
40
|
+
generator:
|
|
41
|
+
vllm:
|
|
42
|
+
api_key: ""
|
|
43
|
+
method: "openai"
|
|
44
|
+
model_name: "meta-llama/Llama-2-7b-chat-hf"
|
|
45
|
+
base_url: "http://sage3:8000/v1"
|
|
46
|
+
seed: 42
|
|
47
|
+
|
|
48
|
+
sink:
|
|
49
|
+
enable_log: true
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Evaluation framework for RAG benchmarking.
|
|
2
|
+
|
|
3
|
+
This module provides tools for running RAG experiments and evaluating results:
|
|
4
|
+
- Pipeline experiment framework
|
|
5
|
+
- Multiple evaluation metrics (Accuracy, F1, Exact Match)
|
|
6
|
+
- Batch processing support
|
|
7
|
+
- Result analysis tools
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
__all__: list[str] = []
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Universal Benchmark Runner for RAG Pipelines
|
|
3
|
+
|
|
4
|
+
This module provides a generic framework for benchmarking any RAG pipeline implementation.
|
|
5
|
+
It handles:
|
|
6
|
+
- Batch processing of large datasets
|
|
7
|
+
- Running any pipeline from implementations/pipelines/
|
|
8
|
+
- Collecting and saving results
|
|
9
|
+
- Performance metrics tracking
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import importlib
|
|
13
|
+
import json
|
|
14
|
+
import time
|
|
15
|
+
from datetime import datetime
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import yaml # type: ignore[import-untyped]
|
|
20
|
+
from dotenv import load_dotenv
|
|
21
|
+
from sage.common.config.output_paths import get_output_file
|
|
22
|
+
from sage.common.core import BatchFunction, MapFunction
|
|
23
|
+
from sage.common.utils.logging.custom_logger import CustomLogger
|
|
24
|
+
from sage.kernel.api.local_environment import LocalEnvironment
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_config(path: str) -> dict:
|
|
28
|
+
"""Load YAML configuration file."""
|
|
29
|
+
with open(path, encoding="utf-8") as f:
|
|
30
|
+
return yaml.safe_load(f)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BatchDataLoader(BatchFunction):
|
|
34
|
+
"""
|
|
35
|
+
Batch data loader for benchmark datasets.
|
|
36
|
+
|
|
37
|
+
Supports loading large datasets in batches for efficient processing.
|
|
38
|
+
Compatible with Self-RAG dataset format and other QA datasets.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, config: dict):
|
|
42
|
+
self.config = config
|
|
43
|
+
data_path = config.get("data_path")
|
|
44
|
+
if not data_path:
|
|
45
|
+
raise ValueError("data_path is required in config")
|
|
46
|
+
self.data_path: str = data_path
|
|
47
|
+
self.max_samples = config.get("max_samples", None)
|
|
48
|
+
|
|
49
|
+
# Load data
|
|
50
|
+
data = self._load_data()
|
|
51
|
+
if self.max_samples:
|
|
52
|
+
data = data[: self.max_samples]
|
|
53
|
+
|
|
54
|
+
self.batch_size = config.get("batch_size", len(data))
|
|
55
|
+
self.current_batch = 0
|
|
56
|
+
self.total_batches = (len(data) + self.batch_size - 1) // self.batch_size
|
|
57
|
+
self._data = data
|
|
58
|
+
|
|
59
|
+
def _load_data(self) -> list[dict[str, Any]]:
|
|
60
|
+
"""Load dataset from JSONL file."""
|
|
61
|
+
data = []
|
|
62
|
+
with open(self.data_path, encoding="utf-8") as f:
|
|
63
|
+
for line in f:
|
|
64
|
+
if line.strip():
|
|
65
|
+
data.append(json.loads(line))
|
|
66
|
+
|
|
67
|
+
# Limit samples if max_samples is set
|
|
68
|
+
if self.max_samples and self.max_samples > 0:
|
|
69
|
+
print(
|
|
70
|
+
f"📊 Limiting dataset to {self.max_samples} samples (total available: {len(data)})"
|
|
71
|
+
)
|
|
72
|
+
data = data[: self.max_samples]
|
|
73
|
+
|
|
74
|
+
return data
|
|
75
|
+
|
|
76
|
+
def execute(self) -> dict[str, Any] | None:
|
|
77
|
+
"""Return next batch of data."""
|
|
78
|
+
if self.current_batch >= self.total_batches:
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
start_idx = self.current_batch * self.batch_size
|
|
82
|
+
end_idx = min(start_idx + self.batch_size, len(self._data))
|
|
83
|
+
batch = self._data[start_idx:end_idx]
|
|
84
|
+
|
|
85
|
+
result = {
|
|
86
|
+
"batch_data": batch,
|
|
87
|
+
"batch_id": self.current_batch,
|
|
88
|
+
"total_batches": self.total_batches,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
self.current_batch += 1
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class PipelineRunner(MapFunction):
|
|
96
|
+
"""
|
|
97
|
+
Generic pipeline runner that can execute any RAG pipeline.
|
|
98
|
+
|
|
99
|
+
Dynamically loads and runs pipeline implementations from
|
|
100
|
+
implementations/pipelines/ directory.
|
|
101
|
+
|
|
102
|
+
Expected pipeline interface:
|
|
103
|
+
- process_item(item: Dict, config: Dict) -> Dict
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, config: dict):
|
|
107
|
+
self.config = config
|
|
108
|
+
self.pipeline_name = config.get("pipeline_name")
|
|
109
|
+
self.pipeline_config = config.get("pipeline_config", {})
|
|
110
|
+
|
|
111
|
+
# Dynamically import pipeline module
|
|
112
|
+
self.pipeline_module = self._load_pipeline()
|
|
113
|
+
|
|
114
|
+
def _load_pipeline(self):
|
|
115
|
+
"""Dynamically load pipeline implementation."""
|
|
116
|
+
module_path = f"sage.benchmark_rag.implementations.pipelines.{self.pipeline_name}"
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
module = importlib.import_module(module_path)
|
|
120
|
+
|
|
121
|
+
# Verify the module has process_item function
|
|
122
|
+
if not hasattr(module, "process_item"):
|
|
123
|
+
raise ImportError(
|
|
124
|
+
f"Pipeline module {module_path} must have a 'process_item' function "
|
|
125
|
+
f"with signature: process_item(item: Dict, config: Dict) -> Dict"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return module
|
|
129
|
+
except ImportError as e:
|
|
130
|
+
raise ImportError(
|
|
131
|
+
f"Failed to load pipeline '{self.pipeline_name}': {e}\n"
|
|
132
|
+
f"Make sure the pipeline exists in implementations/pipelines/ "
|
|
133
|
+
f"and has a 'process_item' function"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def execute(self, batch_data: dict[str, Any]) -> dict[str, Any]:
|
|
137
|
+
"""Execute pipeline on a batch of data."""
|
|
138
|
+
batch = batch_data["batch_data"]
|
|
139
|
+
|
|
140
|
+
# Process each item in the batch
|
|
141
|
+
results = []
|
|
142
|
+
for i, item in enumerate(batch):
|
|
143
|
+
try:
|
|
144
|
+
# Run the pipeline on this item
|
|
145
|
+
result = self.pipeline_module.process_item(item, self.pipeline_config)
|
|
146
|
+
results.append(result)
|
|
147
|
+
|
|
148
|
+
# Progress indicator
|
|
149
|
+
if (i + 1) % 10 == 0:
|
|
150
|
+
print(f" Processed {i + 1}/{len(batch)} items in batch...")
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
# Log error but continue processing
|
|
154
|
+
print(f"⚠️ Error processing item {item.get('id', 'unknown')}: {e}")
|
|
155
|
+
results.append(
|
|
156
|
+
{
|
|
157
|
+
"id": item.get("id", "unknown"),
|
|
158
|
+
"question": item.get("question", ""),
|
|
159
|
+
"error": str(e),
|
|
160
|
+
"ground_truth": item.get("answers", []),
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return {
|
|
165
|
+
"results": results,
|
|
166
|
+
"batch_id": batch_data["batch_id"],
|
|
167
|
+
"total_batches": batch_data["total_batches"],
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class ResultsCollector(MapFunction):
|
|
172
|
+
"""
|
|
173
|
+
Collects and saves benchmark results.
|
|
174
|
+
|
|
175
|
+
Handles incremental or final saving of results with metadata.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
def __init__(self, config: dict):
|
|
179
|
+
self.config = config
|
|
180
|
+
default_output = get_output_file("benchmark_results.json", "benchmarks")
|
|
181
|
+
self.output_path = config.get("output_path", str(default_output))
|
|
182
|
+
self.save_mode = config.get("save_mode", "incremental")
|
|
183
|
+
self.all_results: list[dict[str, Any]] = []
|
|
184
|
+
self.start_time = time.time()
|
|
185
|
+
|
|
186
|
+
def execute(self, batch_result: dict[str, Any]) -> dict[str, Any]:
|
|
187
|
+
"""Collect results from a batch."""
|
|
188
|
+
results = batch_result["results"]
|
|
189
|
+
batch_id = batch_result["batch_id"]
|
|
190
|
+
total_batches = batch_result["total_batches"]
|
|
191
|
+
|
|
192
|
+
self.all_results.extend(results)
|
|
193
|
+
|
|
194
|
+
print(
|
|
195
|
+
f"✅ Processed batch {batch_id + 1}/{total_batches}, "
|
|
196
|
+
f"Total results: {len(self.all_results)}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Save based on mode
|
|
200
|
+
if self.save_mode == "incremental":
|
|
201
|
+
self._save_results(batch_id + 1, total_batches)
|
|
202
|
+
elif self.save_mode == "final" and batch_id + 1 == total_batches:
|
|
203
|
+
self._save_results(batch_id + 1, total_batches)
|
|
204
|
+
|
|
205
|
+
return batch_result
|
|
206
|
+
|
|
207
|
+
def _save_results(self, current_batch: int, total_batches: int):
|
|
208
|
+
"""Save results to file with metadata."""
|
|
209
|
+
elapsed_time = time.time() - self.start_time
|
|
210
|
+
|
|
211
|
+
output = {
|
|
212
|
+
"metadata": {
|
|
213
|
+
"pipeline_name": self.config.get("pipeline_name", "unknown"),
|
|
214
|
+
"timestamp": datetime.now().isoformat(),
|
|
215
|
+
"total_samples": len(self.all_results),
|
|
216
|
+
"completed_batches": f"{current_batch}/{total_batches}",
|
|
217
|
+
"elapsed_time_seconds": round(elapsed_time, 2),
|
|
218
|
+
"config": self.config,
|
|
219
|
+
},
|
|
220
|
+
"results": self.all_results,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
# Ensure output directory exists
|
|
224
|
+
output_path = Path(self.output_path)
|
|
225
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
226
|
+
|
|
227
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
228
|
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
|
229
|
+
|
|
230
|
+
print(f"💾 Results saved to: {self.output_path}")
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def run_benchmark(config: dict) -> None:
|
|
234
|
+
"""
|
|
235
|
+
Run benchmark with the specified configuration.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
config: Configuration dictionary with:
|
|
239
|
+
- data: Data loading configuration
|
|
240
|
+
- pipeline: Pipeline configuration
|
|
241
|
+
- output: Output configuration
|
|
242
|
+
"""
|
|
243
|
+
print("=" * 60)
|
|
244
|
+
print("🚀 Starting RAG Benchmark")
|
|
245
|
+
print("=" * 60)
|
|
246
|
+
|
|
247
|
+
pipeline_name = config.get("pipeline", {}).get("pipeline_name", "unknown")
|
|
248
|
+
data_path = config.get("data", {}).get("data_path", "unknown")
|
|
249
|
+
|
|
250
|
+
print(f"📊 Pipeline: {pipeline_name}")
|
|
251
|
+
print(f"📁 Dataset: {data_path}")
|
|
252
|
+
print(f"💾 Output: {config.get('output', {}).get('output_path', 'default')}")
|
|
253
|
+
print("=" * 60)
|
|
254
|
+
|
|
255
|
+
env = LocalEnvironment("benchmark_pipeline")
|
|
256
|
+
|
|
257
|
+
# Build benchmark pipeline
|
|
258
|
+
(
|
|
259
|
+
env.from_source(BatchDataLoader, config["data"])
|
|
260
|
+
.map(PipelineRunner, config["pipeline"])
|
|
261
|
+
.sink(ResultsCollector, {**config["output"], **config["pipeline"]})
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
env.submit()
|
|
265
|
+
env.close()
|
|
266
|
+
|
|
267
|
+
print("=" * 60)
|
|
268
|
+
print("✅ Benchmark completed!")
|
|
269
|
+
print("=" * 60)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def main():
|
|
273
|
+
"""Main entry point for benchmark runner."""
|
|
274
|
+
import argparse
|
|
275
|
+
|
|
276
|
+
parser = argparse.ArgumentParser(description="Run RAG pipeline benchmarks")
|
|
277
|
+
parser.add_argument(
|
|
278
|
+
"--config",
|
|
279
|
+
type=str,
|
|
280
|
+
default=None,
|
|
281
|
+
help="Path to benchmark configuration YAML file",
|
|
282
|
+
)
|
|
283
|
+
parser.add_argument(
|
|
284
|
+
"--pipeline",
|
|
285
|
+
type=str,
|
|
286
|
+
default=None,
|
|
287
|
+
help="Pipeline name (e.g., 'selfrag', 'qa_dense_retrieval_milvus')",
|
|
288
|
+
)
|
|
289
|
+
parser.add_argument(
|
|
290
|
+
"--data",
|
|
291
|
+
type=str,
|
|
292
|
+
default=None,
|
|
293
|
+
help="Path to dataset JSONL file",
|
|
294
|
+
)
|
|
295
|
+
parser.add_argument(
|
|
296
|
+
"--output",
|
|
297
|
+
type=str,
|
|
298
|
+
default=None,
|
|
299
|
+
help="Path to output results file",
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
args = parser.parse_args()
|
|
303
|
+
|
|
304
|
+
# Load configuration
|
|
305
|
+
if args.config:
|
|
306
|
+
config = load_config(args.config)
|
|
307
|
+
else:
|
|
308
|
+
# Use default config path
|
|
309
|
+
current_dir = Path(__file__).parent
|
|
310
|
+
default_config = current_dir / "config" / "benchmark_config.yaml"
|
|
311
|
+
|
|
312
|
+
if default_config.exists():
|
|
313
|
+
config = load_config(str(default_config))
|
|
314
|
+
else:
|
|
315
|
+
raise ValueError(
|
|
316
|
+
"No configuration file specified and default config not found. "
|
|
317
|
+
"Use --config to specify a configuration file."
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Override with command-line arguments
|
|
321
|
+
if args.pipeline:
|
|
322
|
+
config["pipeline"]["pipeline_name"] = args.pipeline
|
|
323
|
+
if args.data:
|
|
324
|
+
config["data"]["data_path"] = args.data
|
|
325
|
+
if args.output:
|
|
326
|
+
config["output"]["output_path"] = args.output
|
|
327
|
+
|
|
328
|
+
# Setup
|
|
329
|
+
CustomLogger.disable_global_console_debug()
|
|
330
|
+
load_dotenv(override=False)
|
|
331
|
+
|
|
332
|
+
# Run benchmark
|
|
333
|
+
run_benchmark(config)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
if __name__ == "__main__":
|
|
337
|
+
main()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Universal Benchmark Configuration
|
|
2
|
+
# This is a template that can be used for any RAG pipeline benchmark
|
|
3
|
+
|
|
4
|
+
# Data configuration
|
|
5
|
+
data:
|
|
6
|
+
data_path: "data/selfrag/selfrag_dev_sample_100.jsonl" # Path to dataset
|
|
7
|
+
batch_size: 100 # Number of samples per batch
|
|
8
|
+
max_samples: 100 # Limit to first N samples (null = use all data)
|
|
9
|
+
# Useful for quick testing or limiting dataset size
|
|
10
|
+
|
|
11
|
+
# Pipeline configuration
|
|
12
|
+
pipeline:
|
|
13
|
+
pipeline_name: "selfrag" # Name of pipeline in implementations/pipelines/
|
|
14
|
+
pipeline_config: # Pipeline-specific configuration
|
|
15
|
+
# This section is passed to the pipeline implementation
|
|
16
|
+
# Structure depends on the specific pipeline
|
|
17
|
+
top_k: 5
|
|
18
|
+
model_name: "mistralai/Mistral-7B-Instruct-v0.1"
|
|
19
|
+
gpu_memory_utilization: 0.8
|
|
20
|
+
temperature: 0.0
|
|
21
|
+
max_tokens: 100
|
|
22
|
+
|
|
23
|
+
# Output configuration
|
|
24
|
+
output:
|
|
25
|
+
output_path: "artifacts/benchmarks/benchmark_results.json"
|
|
26
|
+
save_mode: "incremental" # "incremental" or "final"
|
|
27
|
+
|
|
28
|
+
# Optional: Evaluation metrics configuration
|
|
29
|
+
# evaluation:
|
|
30
|
+
# metrics:
|
|
31
|
+
# - "exact_match"
|
|
32
|
+
# - "f1_score"
|
|
33
|
+
# - "bleu"
|
|
34
|
+
# evaluator_config:
|
|
35
|
+
# # Metric-specific configuration
|