isage-rag-benchmark 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_rag_benchmark-0.1.0.1.dist-info/METADATA +63 -0
- isage_rag_benchmark-0.1.0.1.dist-info/RECORD +48 -0
- isage_rag_benchmark-0.1.0.1.dist-info/WHEEL +5 -0
- isage_rag_benchmark-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_rag_benchmark-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark_rag/__init__.py +16 -0
- sage/benchmark_rag/_version.py +4 -0
- sage/benchmark_rag/config/config_bm25s.yaml +51 -0
- sage/benchmark_rag/config/config_dense_milvus.yaml +61 -0
- sage/benchmark_rag/config/config_hf.yaml +43 -0
- sage/benchmark_rag/config/config_mixed.yaml +53 -0
- sage/benchmark_rag/config/config_monitoring_demo.yaml +59 -0
- sage/benchmark_rag/config/config_multiplex.yaml +79 -0
- sage/benchmark_rag/config/config_qa_chroma.yaml +51 -0
- sage/benchmark_rag/config/config_ray.yaml +57 -0
- sage/benchmark_rag/config/config_refiner.yaml +75 -0
- sage/benchmark_rag/config/config_rerank.yaml +56 -0
- sage/benchmark_rag/config/config_selfrag.yaml +24 -0
- sage/benchmark_rag/config/config_source.yaml +30 -0
- sage/benchmark_rag/config/config_source_local.yaml +21 -0
- sage/benchmark_rag/config/config_sparse_milvus.yaml +49 -0
- sage/benchmark_rag/evaluation/__init__.py +10 -0
- sage/benchmark_rag/evaluation/benchmark_runner.py +337 -0
- sage/benchmark_rag/evaluation/config/benchmark_config.yaml +35 -0
- sage/benchmark_rag/evaluation/evaluate_results.py +389 -0
- sage/benchmark_rag/implementations/__init__.py +31 -0
- sage/benchmark_rag/implementations/pipelines/__init__.py +24 -0
- sage/benchmark_rag/implementations/pipelines/qa_bm25_retrieval.py +55 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval.py +56 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_chroma.py +71 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_milvus.py +78 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_mixed.py +58 -0
- sage/benchmark_rag/implementations/pipelines/qa_dense_retrieval_ray.py +174 -0
- sage/benchmark_rag/implementations/pipelines/qa_hf_model.py +57 -0
- sage/benchmark_rag/implementations/pipelines/qa_monitoring_demo.py +139 -0
- sage/benchmark_rag/implementations/pipelines/qa_multimodal_fusion.py +318 -0
- sage/benchmark_rag/implementations/pipelines/qa_multiplex.py +91 -0
- sage/benchmark_rag/implementations/pipelines/qa_refiner.py +91 -0
- sage/benchmark_rag/implementations/pipelines/qa_rerank.py +76 -0
- sage/benchmark_rag/implementations/pipelines/qa_sparse_retrieval_milvus.py +76 -0
- sage/benchmark_rag/implementations/pipelines/selfrag.py +226 -0
- sage/benchmark_rag/implementations/tools/__init__.py +17 -0
- sage/benchmark_rag/implementations/tools/build_chroma_index.py +261 -0
- sage/benchmark_rag/implementations/tools/build_milvus_dense_index.py +86 -0
- sage/benchmark_rag/implementations/tools/build_milvus_index.py +59 -0
- sage/benchmark_rag/implementations/tools/build_milvus_sparse_index.py +85 -0
- sage/benchmark_rag/implementations/tools/loaders/document_loaders.py +42 -0
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
import string
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from sage.common.config.output_paths import get_output_file
|
|
11
|
+
|
|
12
|
+
# ============================================================================
|
|
13
|
+
# 文本标准化模块
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
# # 英文常见停顿词/停用词列表
|
|
17
|
+
STOP_WORDS: set[str] = set()
|
|
18
|
+
# 'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
|
|
19
|
+
# 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did',
|
|
20
|
+
# 'will', 'would', 'could', 'should', 'may', 'might', 'must', 'can', 'shall',
|
|
21
|
+
# 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them',
|
|
22
|
+
# 'my', 'your', 'his', 'her', 'its', 'our', 'their', 'mine', 'yours', 'ours', 'theirs',
|
|
23
|
+
# 'this', 'that', 'these', 'those', 'here', 'there', 'where', 'when', 'why', 'how',
|
|
24
|
+
# 'what', 'which', 'who', 'whom', 'whose', 'if', 'then', 'else', 'so', 'as', 'than',
|
|
25
|
+
# 'not', 'no', 'yes', 'all', 'any', 'some', 'each', 'every', 'other', 'another',
|
|
26
|
+
# 'more', 'most', 'less', 'least', 'much', 'many', 'few', 'little', 'very', 'quite',
|
|
27
|
+
# 'just', 'only', 'also', 'too', 'even', 'still', 'yet', 'already', 'again',
|
|
28
|
+
# 'up', 'down', 'out', 'off', 'over', 'under', 'above', 'below', 'through', 'between',
|
|
29
|
+
# 'into', 'onto', 'from', 'within', 'without', 'during', 'before', 'after', 'since', 'until'
|
|
30
|
+
# }
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def normalize_text_basic(text: str) -> str:
|
|
34
|
+
"""
|
|
35
|
+
基础文本标准化(用于简单匹配)
|
|
36
|
+
Args:
|
|
37
|
+
text: 原始文本
|
|
38
|
+
Returns:
|
|
39
|
+
标准化后的文本
|
|
40
|
+
"""
|
|
41
|
+
# 移除数字标记 (1., 2., 3., etc.)
|
|
42
|
+
text = re.sub(r"\d+\.\s*", "", text)
|
|
43
|
+
# 移除换行符
|
|
44
|
+
text = text.replace("\n", " ")
|
|
45
|
+
# 移除多余空格并转为小写
|
|
46
|
+
text = " ".join(text.split()).lower().strip()
|
|
47
|
+
return text
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def normalize_text_advanced(text: str) -> str:
|
|
51
|
+
"""
|
|
52
|
+
高级文本标准化(用于精确匹配,移除停用词和标点)
|
|
53
|
+
Args:
|
|
54
|
+
text: 原始文本
|
|
55
|
+
Returns:
|
|
56
|
+
标准化后的文本
|
|
57
|
+
"""
|
|
58
|
+
# 转为小写
|
|
59
|
+
# text = text.lower()
|
|
60
|
+
|
|
61
|
+
# 移除标点符号
|
|
62
|
+
text = "".join(ch for ch in text if ch not in string.punctuation)
|
|
63
|
+
|
|
64
|
+
# 移除articles (a, an, the)
|
|
65
|
+
text = re.sub(r"\b(a|an|the)\b", " ", text)
|
|
66
|
+
|
|
67
|
+
# 移除停用词
|
|
68
|
+
words = text.split()
|
|
69
|
+
words = [word for word in words if word not in STOP_WORDS]
|
|
70
|
+
|
|
71
|
+
# 标准化空格
|
|
72
|
+
text = " ".join(words)
|
|
73
|
+
|
|
74
|
+
return text
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# ============================================================================
|
|
78
|
+
# 评估指标计算模块
|
|
79
|
+
# ============================================================================
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def compute_f1(prediction: str, ground_truth: str) -> float:
|
|
83
|
+
"""计算F1分数"""
|
|
84
|
+
pred_tokens = normalize_text_advanced(prediction).split()
|
|
85
|
+
gt_tokens = normalize_text_advanced(ground_truth).split()
|
|
86
|
+
|
|
87
|
+
if not pred_tokens or not gt_tokens:
|
|
88
|
+
return 0.0
|
|
89
|
+
|
|
90
|
+
common = Counter(pred_tokens) & Counter(gt_tokens)
|
|
91
|
+
num_same = sum(common.values())
|
|
92
|
+
|
|
93
|
+
if num_same == 0:
|
|
94
|
+
return 0.0
|
|
95
|
+
|
|
96
|
+
precision = num_same / len(pred_tokens)
|
|
97
|
+
recall = num_same / len(gt_tokens)
|
|
98
|
+
f1 = (2 * precision * recall) / (precision + recall)
|
|
99
|
+
return f1
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def compute_exact_match(prediction: str, ground_truth: str) -> int:
|
|
103
|
+
"""计算精确匹配分数"""
|
|
104
|
+
return int(normalize_text_advanced(prediction) == normalize_text_advanced(ground_truth))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def compute_accuracy_single(prediction: str, ground_truths: list[str]) -> float:
|
|
108
|
+
"""
|
|
109
|
+
计算单个预测的accuracy分数,使用多种匹配策略
|
|
110
|
+
Args:
|
|
111
|
+
prediction: 模型预测结果
|
|
112
|
+
ground_truths: 正确答案列表
|
|
113
|
+
Returns:
|
|
114
|
+
accuracy分数 (0.0 或 1.0)
|
|
115
|
+
"""
|
|
116
|
+
# 基础标准化
|
|
117
|
+
norm_pred = normalize_text_advanced(prediction)
|
|
118
|
+
|
|
119
|
+
for gt in ground_truths:
|
|
120
|
+
norm_gt = normalize_text_advanced(gt)
|
|
121
|
+
|
|
122
|
+
if norm_gt in norm_pred:
|
|
123
|
+
return 1.0
|
|
124
|
+
|
|
125
|
+
# pred_words = set(normalize_text_advanced(prediction).split())
|
|
126
|
+
# gt_words = set(normalize_text_advanced(gt).split())
|
|
127
|
+
|
|
128
|
+
# if not pred_words or not gt_words:
|
|
129
|
+
# continue
|
|
130
|
+
|
|
131
|
+
# # 如果ground truth的所有关键词都在prediction中
|
|
132
|
+
# if gt_words.issubset(pred_words):
|
|
133
|
+
# return 1.0
|
|
134
|
+
|
|
135
|
+
return 0.0
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def evaluate_predictions(
|
|
139
|
+
predictions: list[str], ground_truths: list[list[str]], metric: str = "accuracy"
|
|
140
|
+
) -> dict[str, float]:
|
|
141
|
+
"""
|
|
142
|
+
评估预测结果
|
|
143
|
+
Args:
|
|
144
|
+
predictions: 预测结果列表
|
|
145
|
+
ground_truths: 正确答案列表的列表
|
|
146
|
+
metric: 评估指标 ("accuracy", "f1", "exact_match", "all")
|
|
147
|
+
Returns:
|
|
148
|
+
评估结果字典
|
|
149
|
+
"""
|
|
150
|
+
results: dict[str, float] = {}
|
|
151
|
+
|
|
152
|
+
if metric in ["accuracy", "all"]:
|
|
153
|
+
accuracy_scores = [
|
|
154
|
+
compute_accuracy_single(pred, truths)
|
|
155
|
+
for pred, truths in zip(predictions, ground_truths)
|
|
156
|
+
]
|
|
157
|
+
results["accuracy"] = float(100 * np.mean(accuracy_scores))
|
|
158
|
+
|
|
159
|
+
if metric in ["f1", "all"]:
|
|
160
|
+
f1_scores = []
|
|
161
|
+
for pred, truths in zip(predictions, ground_truths):
|
|
162
|
+
# 对每个ground truth计算F1,取最大值
|
|
163
|
+
f1_max = max([compute_f1(pred, gt) for gt in truths]) if truths else 0.0
|
|
164
|
+
f1_scores.append(f1_max)
|
|
165
|
+
results["f1"] = float(100 * np.mean(f1_scores))
|
|
166
|
+
|
|
167
|
+
if metric in ["exact_match", "all"]:
|
|
168
|
+
em_scores = []
|
|
169
|
+
for pred, truths in zip(predictions, ground_truths):
|
|
170
|
+
# 对每个ground truth计算EM,取最大值
|
|
171
|
+
em_max = max([compute_exact_match(pred, gt) for gt in truths]) if truths else 0
|
|
172
|
+
em_scores.append(em_max)
|
|
173
|
+
results["exact_match"] = float(100 * np.mean(em_scores))
|
|
174
|
+
|
|
175
|
+
return results
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def load_results(file_path: str) -> dict[str, Any]:
|
|
179
|
+
"""加载推理结果文件"""
|
|
180
|
+
with open(file_path, encoding="utf-8") as f:
|
|
181
|
+
return json.load(f)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def calculate_overall_scores(results_data: dict[str, Any], metric: str = "all") -> dict[str, Any]:
|
|
185
|
+
"""
|
|
186
|
+
计算整体评估分数(不输出每个样本的详细分数)
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
results_data: 推理结果数据
|
|
190
|
+
metric: 评估指标
|
|
191
|
+
Returns:
|
|
192
|
+
包含整体分数的数据
|
|
193
|
+
"""
|
|
194
|
+
results = results_data["results"]
|
|
195
|
+
|
|
196
|
+
# 提取预测和真实答案
|
|
197
|
+
# 兼容不同的字段名称
|
|
198
|
+
predictions = []
|
|
199
|
+
ground_truths = []
|
|
200
|
+
|
|
201
|
+
for item in results:
|
|
202
|
+
# 预测结果字段
|
|
203
|
+
pred = item.get("prediction") or item.get("model_output", "")
|
|
204
|
+
predictions.append(pred)
|
|
205
|
+
|
|
206
|
+
# 真实答案字段
|
|
207
|
+
gt = item.get("ground_truth", [])
|
|
208
|
+
if isinstance(gt, str):
|
|
209
|
+
gt = [gt]
|
|
210
|
+
ground_truths.append(gt)
|
|
211
|
+
|
|
212
|
+
# 计算整体指标
|
|
213
|
+
overall_scores = evaluate_predictions(predictions, ground_truths, metric)
|
|
214
|
+
|
|
215
|
+
# 构建评估结果(只包含整体指标)
|
|
216
|
+
evaluation_result = {
|
|
217
|
+
"metadata": results_data.get("metadata", {}),
|
|
218
|
+
"overall_scores": overall_scores,
|
|
219
|
+
"summary": {
|
|
220
|
+
"total_samples": len(results),
|
|
221
|
+
"evaluation_metric": metric,
|
|
222
|
+
},
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
return evaluation_result
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def analyze_retrieval_quality(
|
|
229
|
+
evaluation_result: dict[str, Any], results_data: dict[str, Any]
|
|
230
|
+
) -> dict[str, Any]:
|
|
231
|
+
"""
|
|
232
|
+
分析检索质量
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
evaluation_result: 评估结果数据
|
|
236
|
+
results_data: 原始结果数据
|
|
237
|
+
Returns:
|
|
238
|
+
检索质量分析结果
|
|
239
|
+
"""
|
|
240
|
+
detailed_results = results_data["results"]
|
|
241
|
+
|
|
242
|
+
# 统计检索相关信息
|
|
243
|
+
total_samples = len(detailed_results)
|
|
244
|
+
samples_with_context = 0
|
|
245
|
+
context_lengths = []
|
|
246
|
+
|
|
247
|
+
# 分析检索上下文与答案的相关性
|
|
248
|
+
context_relevance_scores = []
|
|
249
|
+
|
|
250
|
+
for item in detailed_results:
|
|
251
|
+
# 兼容不同的字段名称
|
|
252
|
+
contexts = item.get("retrieved_docs") or item.get("retrieved_context", [])
|
|
253
|
+
|
|
254
|
+
if contexts:
|
|
255
|
+
samples_with_context += 1
|
|
256
|
+
context_lengths.append(len(contexts))
|
|
257
|
+
|
|
258
|
+
# 简单的相关性分析:检查真实答案是否出现在检索的上下文中
|
|
259
|
+
ground_truth = item.get("ground_truth", [])
|
|
260
|
+
if isinstance(ground_truth, str):
|
|
261
|
+
ground_truth = [ground_truth]
|
|
262
|
+
|
|
263
|
+
# 对每个真实答案检查是否在上下文中
|
|
264
|
+
found_in_context = False
|
|
265
|
+
for gt in ground_truth:
|
|
266
|
+
gt_normalized = normalize_text_basic(gt)
|
|
267
|
+
for context in contexts:
|
|
268
|
+
# 处理不同的上下文格式
|
|
269
|
+
if isinstance(context, dict):
|
|
270
|
+
context_text = context.get("text", "")
|
|
271
|
+
else:
|
|
272
|
+
context_text = str(context)
|
|
273
|
+
|
|
274
|
+
context_normalized = normalize_text_basic(context_text)
|
|
275
|
+
if gt_normalized in context_normalized:
|
|
276
|
+
found_in_context = True
|
|
277
|
+
break
|
|
278
|
+
if found_in_context:
|
|
279
|
+
break
|
|
280
|
+
|
|
281
|
+
context_relevance_scores.append(1.0 if found_in_context else 0.0)
|
|
282
|
+
|
|
283
|
+
retrieval_analysis = {
|
|
284
|
+
"total_samples": total_samples,
|
|
285
|
+
"samples_with_context": samples_with_context,
|
|
286
|
+
"context_coverage": (samples_with_context / total_samples if total_samples > 0 else 0.0),
|
|
287
|
+
"avg_context_count": np.mean(context_lengths) if context_lengths else 0.0,
|
|
288
|
+
"context_relevance_rate": (
|
|
289
|
+
np.mean(context_relevance_scores) if context_relevance_scores else 0.0
|
|
290
|
+
),
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
return retrieval_analysis
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def print_evaluation_summary(evaluation_result: dict[str, Any]):
|
|
297
|
+
"""打印评估结果摘要"""
|
|
298
|
+
metadata = evaluation_result.get("metadata", {})
|
|
299
|
+
scores = evaluation_result["overall_scores"]
|
|
300
|
+
summary = evaluation_result["summary"]
|
|
301
|
+
|
|
302
|
+
print("\n" + "=" * 60)
|
|
303
|
+
print("📊 评估结果摘要")
|
|
304
|
+
print("=" * 60)
|
|
305
|
+
|
|
306
|
+
# 打印配置信息
|
|
307
|
+
if metadata:
|
|
308
|
+
print("🔧 配置信息:")
|
|
309
|
+
if "pipeline_name" in metadata:
|
|
310
|
+
print(f" Pipeline: {metadata['pipeline_name']}")
|
|
311
|
+
if "timestamp" in metadata:
|
|
312
|
+
print(f" 时间: {metadata['timestamp']}")
|
|
313
|
+
if "config" in metadata:
|
|
314
|
+
config = metadata["config"]
|
|
315
|
+
if "pipeline" in config and "pipeline_config" in config["pipeline"]:
|
|
316
|
+
pipeline_config = config["pipeline"]["pipeline_config"]
|
|
317
|
+
if "model_name" in pipeline_config:
|
|
318
|
+
print(f" 模型: {pipeline_config['model_name']}")
|
|
319
|
+
if "top_k" in pipeline_config:
|
|
320
|
+
print(f" Top-K: {pipeline_config['top_k']}")
|
|
321
|
+
if "total_samples" in metadata:
|
|
322
|
+
print(f" 样本数: {metadata['total_samples']}")
|
|
323
|
+
elif "summary" in metadata and "total_samples" in summary:
|
|
324
|
+
print(f" 样本数: {summary['total_samples']}")
|
|
325
|
+
|
|
326
|
+
print(f"\n📊 总样本数: {summary['total_samples']}")
|
|
327
|
+
print(f"📏 评估指标: {summary['evaluation_metric']}")
|
|
328
|
+
|
|
329
|
+
print("\n📈 整体性能指标:")
|
|
330
|
+
for metric, score in scores.items():
|
|
331
|
+
print(f" {metric.upper()}: {score:.2f}%")
|
|
332
|
+
|
|
333
|
+
# 添加检索质量分析
|
|
334
|
+
if "retrieval_analysis" in evaluation_result:
|
|
335
|
+
retrieval_stats = evaluation_result["retrieval_analysis"]
|
|
336
|
+
print("\n🔍 检索质量分析:")
|
|
337
|
+
print(f" 上下文覆盖率: {100 * retrieval_stats['context_coverage']:.2f}%")
|
|
338
|
+
print(f" 平均检索数量: {retrieval_stats['avg_context_count']:.2f}")
|
|
339
|
+
print(f" 上下文相关性: {100 * retrieval_stats['context_relevance_rate']:.2f}%")
|
|
340
|
+
|
|
341
|
+
print("=" * 60)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def main():
|
|
345
|
+
parser = argparse.ArgumentParser(description="评估RAG推理结果")
|
|
346
|
+
parser.add_argument("--results", "-r", type=str, required=True, help="推理结果文件路径")
|
|
347
|
+
parser.add_argument(
|
|
348
|
+
"--metric",
|
|
349
|
+
choices=["accuracy", "f1", "exact_match", "all"],
|
|
350
|
+
default="all",
|
|
351
|
+
help="评估指标",
|
|
352
|
+
)
|
|
353
|
+
parser.add_argument("--output", "-o", type=str, help="输出评估结果文件路径")
|
|
354
|
+
|
|
355
|
+
args = parser.parse_args()
|
|
356
|
+
|
|
357
|
+
# 加载推理结果
|
|
358
|
+
print(f"📥 正在加载推理结果: {args.results}")
|
|
359
|
+
results_data = load_results(args.results)
|
|
360
|
+
|
|
361
|
+
# 计算整体评估分数
|
|
362
|
+
print(f"🔄 正在计算评估指标: {args.metric}")
|
|
363
|
+
evaluation_result = calculate_overall_scores(results_data, args.metric)
|
|
364
|
+
|
|
365
|
+
# 分析检索质量(如果有检索上下文)
|
|
366
|
+
print("🔍 正在分析检索质量...")
|
|
367
|
+
retrieval_analysis = analyze_retrieval_quality(evaluation_result, results_data)
|
|
368
|
+
evaluation_result["retrieval_analysis"] = retrieval_analysis
|
|
369
|
+
|
|
370
|
+
# 打印摘要
|
|
371
|
+
print_evaluation_summary(evaluation_result)
|
|
372
|
+
|
|
373
|
+
# 保存评估结果
|
|
374
|
+
if args.output:
|
|
375
|
+
output_path = args.output
|
|
376
|
+
else:
|
|
377
|
+
# 生成默认输出文件名
|
|
378
|
+
input_path = Path(args.results)
|
|
379
|
+
output_filename = f"evaluation_{input_path.stem}.json"
|
|
380
|
+
output_path = get_output_file(output_filename, "benchmarks")
|
|
381
|
+
|
|
382
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
383
|
+
json.dump(evaluation_result, f, indent=2, ensure_ascii=False)
|
|
384
|
+
|
|
385
|
+
print(f"\n✅ 评估结果已保存到: {output_path}")
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
if __name__ == "__main__":
|
|
389
|
+
main()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAGE RAG Examples - 检索增强生成示例
|
|
3
|
+
|
|
4
|
+
这个模块包含了各种 RAG (Retrieval-Augmented Generation) 的实现示例。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# 导入所有RAG示例(这里需要根据实际文件调整)
|
|
8
|
+
# from . import rag_simple
|
|
9
|
+
# from . import qa_dense_retrieval
|
|
10
|
+
|
|
11
|
+
"""RAG implementations for benchmarking.
|
|
12
|
+
|
|
13
|
+
This module contains various RAG implementation approaches for performance comparison:
|
|
14
|
+
|
|
15
|
+
Pipelines (pipelines/):
|
|
16
|
+
- Dense retrieval (ChromaDB, Milvus, FAISS)
|
|
17
|
+
- Sparse retrieval (BM25, Milvus sparse)
|
|
18
|
+
- Hybrid retrieval (dense + sparse)
|
|
19
|
+
- Multimodal fusion (text + image + video)
|
|
20
|
+
- Reranking strategies
|
|
21
|
+
- Query refinement
|
|
22
|
+
|
|
23
|
+
Tools (tools/):
|
|
24
|
+
- Index building utilities (ChromaDB, Milvus)
|
|
25
|
+
- Document loaders
|
|
26
|
+
- Data preparation scripts
|
|
27
|
+
|
|
28
|
+
See subdirectory READMEs for detailed usage.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
__all__: list[str] = []
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""RAG Pipeline Implementations.
|
|
2
|
+
|
|
3
|
+
This module contains various RAG pipeline implementations for benchmarking:
|
|
4
|
+
|
|
5
|
+
Retrieval Methods:
|
|
6
|
+
- Dense retrieval (embedding-based)
|
|
7
|
+
- Sparse retrieval (BM25, sparse vectors)
|
|
8
|
+
- Hybrid retrieval (combining dense + sparse)
|
|
9
|
+
|
|
10
|
+
Vector Databases:
|
|
11
|
+
- Milvus (dense, sparse, hybrid)
|
|
12
|
+
- ChromaDB (local vector database)
|
|
13
|
+
- FAISS (efficient similarity search)
|
|
14
|
+
|
|
15
|
+
Advanced Features:
|
|
16
|
+
- Multimodal fusion (text + image + video)
|
|
17
|
+
- Reranking strategies
|
|
18
|
+
- Query refinement
|
|
19
|
+
- Distributed processing (Ray)
|
|
20
|
+
|
|
21
|
+
Each pipeline can be run independently for testing or used in benchmark experiments.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
__all__: list[str] = []
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from sage.common.utils.config.loader import load_config
|
|
5
|
+
from sage.kernel.api.local_environment import LocalEnvironment
|
|
6
|
+
from sage.libs.foundation.io.sink import TerminalSink
|
|
7
|
+
from sage.libs.foundation.io.source import FileSource
|
|
8
|
+
from sage.middleware.operators.rag import OpenAIGenerator, QAPromptor
|
|
9
|
+
|
|
10
|
+
# from sage.middleware.operators.rag import BM25sRetriever # 这个类不存在
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def pipeline_run():
|
|
14
|
+
"""创建并运行数据处理管道"""
|
|
15
|
+
# 检查是否在测试模式下运行
|
|
16
|
+
if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
|
|
17
|
+
print("🧪 Test mode detected - qa_bm25_retrieval example")
|
|
18
|
+
print("✅ Test passed: Example structure validated (BM25sRetriever not available)")
|
|
19
|
+
return
|
|
20
|
+
|
|
21
|
+
env = LocalEnvironment()
|
|
22
|
+
# env.set_memory(config=None)
|
|
23
|
+
# 构建数据处理流程
|
|
24
|
+
query_stream = env.from_source(FileSource, config["source"])
|
|
25
|
+
# query_and_chunks_stream = query_stream.map(BM25sRetriever, config["retriever"]) # 不可用
|
|
26
|
+
query_and_chunks_stream = query_stream # 跳过检索步骤
|
|
27
|
+
prompt_stream = query_and_chunks_stream.map(QAPromptor, config["promptor"])
|
|
28
|
+
response_stream = prompt_stream.map(OpenAIGenerator, config["generator"]["vllm"])
|
|
29
|
+
response_stream.sink(TerminalSink, config["sink"])
|
|
30
|
+
# 提交管道并运行
|
|
31
|
+
env.submit()
|
|
32
|
+
# 启动管道
|
|
33
|
+
|
|
34
|
+
# time.sleep(100) # 等待管道运行
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
if __name__ == "__main__":
|
|
38
|
+
import os
|
|
39
|
+
|
|
40
|
+
# 检查是否在测试模式下运行
|
|
41
|
+
if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
|
|
42
|
+
print("🧪 Test mode detected - qa_bm25_retrieval example")
|
|
43
|
+
print("✅ Test passed: Example structure validated (BM25sRetriever not available)")
|
|
44
|
+
sys.exit(0)
|
|
45
|
+
|
|
46
|
+
# 加载配置并初始化日志
|
|
47
|
+
config_path = os.path.join(os.path.dirname(__file__), "..", "config", "config_bm25s.yaml")
|
|
48
|
+
if not os.path.exists(config_path):
|
|
49
|
+
print(f"❌ Configuration file not found: {config_path}")
|
|
50
|
+
print("Please create the configuration file first.")
|
|
51
|
+
sys.exit(1)
|
|
52
|
+
|
|
53
|
+
config = load_config(config_path)
|
|
54
|
+
# 初始化内存并运行管道
|
|
55
|
+
pipeline_run()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from sage.common.utils.config.loader import load_config
|
|
6
|
+
from sage.kernel.api.local_environment import LocalEnvironment
|
|
7
|
+
from sage.libs.foundation.io.sink import TerminalSink
|
|
8
|
+
from sage.libs.foundation.io.source import FileSource
|
|
9
|
+
from sage.middleware.operators.rag import OpenAIGenerator, QAPromptor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def pipeline_run():
|
|
13
|
+
"""创建并运行数据处理管道"""
|
|
14
|
+
# 检查是否在测试模式下运行
|
|
15
|
+
if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
|
|
16
|
+
print("🧪 Test mode detected - qa_dense_retrieval example")
|
|
17
|
+
print("✅ Test passed: Example structure validated")
|
|
18
|
+
return
|
|
19
|
+
|
|
20
|
+
# env = LocalBatchEnvironment() #DEBUG and Batch -- Client 拥有后续程序的全部handler(包括JM)
|
|
21
|
+
env = LocalEnvironment(
|
|
22
|
+
"JM-IP"
|
|
23
|
+
) # Deployment to JM. -- Client 不拥有后续程序的全部handler(包括JM)
|
|
24
|
+
|
|
25
|
+
# Batch Environment.
|
|
26
|
+
|
|
27
|
+
(
|
|
28
|
+
env.from_source(FileSource, config["source"]) # 处理且处理一整个file 一次。
|
|
29
|
+
# .map(MilvusDenseRetriever, config["retriever"]) # 需要配置文件
|
|
30
|
+
.map(QAPromptor, config["promptor"])
|
|
31
|
+
.map(OpenAIGenerator, config["generator"]["vllm"])
|
|
32
|
+
.sink(TerminalSink, config["sink"]) # TM (JVM) --> 会打印在某一台机器的console里
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
env.submit()
|
|
36
|
+
time.sleep(5)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if __name__ == "__main__":
|
|
40
|
+
import os
|
|
41
|
+
|
|
42
|
+
# 检查是否在测试模式下运行
|
|
43
|
+
if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
|
|
44
|
+
print("🧪 Test mode detected - qa_dense_retrieval example")
|
|
45
|
+
print("✅ Test passed: Example structure validated")
|
|
46
|
+
sys.exit(0)
|
|
47
|
+
|
|
48
|
+
# 加载配置
|
|
49
|
+
config_path = os.path.join(os.path.dirname(__file__), "..", "config", "config.yaml")
|
|
50
|
+
if not os.path.exists(config_path):
|
|
51
|
+
print(f"❌ Configuration file not found: {config_path}")
|
|
52
|
+
print("Please create the configuration file first.")
|
|
53
|
+
sys.exit(1)
|
|
54
|
+
|
|
55
|
+
config = load_config(config_path)
|
|
56
|
+
pipeline_run()
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from sage.common.utils.config.loader import load_config
|
|
4
|
+
from sage.kernel.api.local_environment import LocalEnvironment
|
|
5
|
+
from sage.libs.foundation.io.batch import JSONLBatch
|
|
6
|
+
from sage.libs.foundation.io.sink import TerminalSink
|
|
7
|
+
from sage.middleware.operators.rag import ChromaRetriever, OpenAIGenerator, QAPromptor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def pipeline_run(config: dict) -> None:
|
|
11
|
+
"""
|
|
12
|
+
创建并运行 ChromaDB 专用 RAG 数据处理管道
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
config (dict): 包含各模块配置的配置字典。
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
print("=== 启动基于 ChromaDB 的 RAG 问答系统 ===")
|
|
19
|
+
print("配置信息:")
|
|
20
|
+
print(f" - 源文件: {config['source']['data_path']}")
|
|
21
|
+
print(f" - 向量维度: {config['retriever']['dimension']}")
|
|
22
|
+
print(f" - Top-K: {config['retriever']['top_k']}")
|
|
23
|
+
print(f" - 集合名称: {config['retriever']['chroma']['collection_name']}")
|
|
24
|
+
print(f" - 嵌入模型: {config['retriever']['embedding']['method']}")
|
|
25
|
+
|
|
26
|
+
env = LocalEnvironment()
|
|
27
|
+
|
|
28
|
+
(
|
|
29
|
+
env.from_batch(JSONLBatch, config["source"])
|
|
30
|
+
.map(ChromaRetriever, config["retriever"])
|
|
31
|
+
.map(QAPromptor, config["promptor"])
|
|
32
|
+
.map(OpenAIGenerator, config["generator"]["vllm"])
|
|
33
|
+
.sink(TerminalSink, config["sink"])
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
print("正在提交并运行管道...")
|
|
37
|
+
env.submit(autostop=True)
|
|
38
|
+
env.close()
|
|
39
|
+
print("=== RAG 问答系统运行完成 ===")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
if __name__ == "__main__":
|
|
43
|
+
# CustomLogger.disable_global_console_debug()
|
|
44
|
+
import sys
|
|
45
|
+
|
|
46
|
+
# 检查是否在测试模式下运行
|
|
47
|
+
if os.getenv("SAGE_EXAMPLES_MODE") == "test" or os.getenv("SAGE_TEST_MODE") == "true":
|
|
48
|
+
print("🧪 Test mode detected - qa_dense_retrieval_chroma example")
|
|
49
|
+
print("✅ Test passed: Example structure validated")
|
|
50
|
+
sys.exit(0)
|
|
51
|
+
|
|
52
|
+
config_path = "./examples/config/config_qa_chroma.yaml"
|
|
53
|
+
if not os.path.exists(config_path):
|
|
54
|
+
print(f"配置文件不存在: {config_path}")
|
|
55
|
+
print("Please create the configuration file first.")
|
|
56
|
+
sys.exit(1)
|
|
57
|
+
|
|
58
|
+
config = load_config(config_path)
|
|
59
|
+
|
|
60
|
+
print(config)
|
|
61
|
+
|
|
62
|
+
# 检查知识库文件(如果配置了)
|
|
63
|
+
knowledge_file = config["retriever"]["chroma"].get("knowledge_file")
|
|
64
|
+
if knowledge_file:
|
|
65
|
+
if not os.path.exists(knowledge_file):
|
|
66
|
+
print(f"警告:知识库文件不存在: {knowledge_file}")
|
|
67
|
+
print("请确保知识库文件存在于指定路径")
|
|
68
|
+
else:
|
|
69
|
+
print(f"找到知识库文件: {knowledge_file}")
|
|
70
|
+
|
|
71
|
+
pipeline_run(config)
|