isage-middleware 0.2.4.3__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
- isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
- isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
- isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
- sage/middleware/__init__.py +59 -0
- sage/middleware/_version.py +6 -0
- sage/middleware/components/__init__.py +30 -0
- sage/middleware/components/extensions_compat.py +141 -0
- sage/middleware/components/sage_db/__init__.py +116 -0
- sage/middleware/components/sage_db/backend.py +136 -0
- sage/middleware/components/sage_db/service.py +15 -0
- sage/middleware/components/sage_flow/__init__.py +76 -0
- sage/middleware/components/sage_flow/python/__init__.py +14 -0
- sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
- sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
- sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
- sage/middleware/components/sage_flow/service.py +14 -0
- sage/middleware/components/sage_mem/__init__.py +83 -0
- sage/middleware/components/sage_sias/__init__.py +59 -0
- sage/middleware/components/sage_sias/continual_learner.py +184 -0
- sage/middleware/components/sage_sias/coreset_selector.py +302 -0
- sage/middleware/components/sage_sias/types.py +94 -0
- sage/middleware/components/sage_tsdb/__init__.py +81 -0
- sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
- sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
- sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
- sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
- sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
- sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
- sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
- sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
- sage/middleware/components/sage_tsdb/service.py +17 -0
- sage/middleware/components/vector_stores/__init__.py +25 -0
- sage/middleware/components/vector_stores/chroma.py +483 -0
- sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
- sage/middleware/components/vector_stores/milvus.py +677 -0
- sage/middleware/operators/__init__.py +56 -0
- sage/middleware/operators/agent/__init__.py +24 -0
- sage/middleware/operators/agent/planning/__init__.py +5 -0
- sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
- sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
- sage/middleware/operators/agent/planning/router.py +107 -0
- sage/middleware/operators/agent/runtime.py +296 -0
- sage/middleware/operators/agentic/__init__.py +41 -0
- sage/middleware/operators/agentic/config.py +254 -0
- sage/middleware/operators/agentic/planning_operator.py +125 -0
- sage/middleware/operators/agentic/refined_searcher.py +132 -0
- sage/middleware/operators/agentic/runtime.py +241 -0
- sage/middleware/operators/agentic/timing_operator.py +125 -0
- sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
- sage/middleware/operators/context/__init__.py +17 -0
- sage/middleware/operators/context/critic_evaluation.py +16 -0
- sage/middleware/operators/context/model_context.py +565 -0
- sage/middleware/operators/context/quality_label.py +12 -0
- sage/middleware/operators/context/search_query_results.py +61 -0
- sage/middleware/operators/context/search_result.py +42 -0
- sage/middleware/operators/context/search_session.py +79 -0
- sage/middleware/operators/filters/__init__.py +26 -0
- sage/middleware/operators/filters/context_sink.py +387 -0
- sage/middleware/operators/filters/context_source.py +376 -0
- sage/middleware/operators/filters/evaluate_filter.py +83 -0
- sage/middleware/operators/filters/tool_filter.py +74 -0
- sage/middleware/operators/llm/__init__.py +18 -0
- sage/middleware/operators/llm/sagellm_generator.py +432 -0
- sage/middleware/operators/rag/__init__.py +147 -0
- sage/middleware/operators/rag/arxiv.py +331 -0
- sage/middleware/operators/rag/chunk.py +13 -0
- sage/middleware/operators/rag/document_loaders.py +23 -0
- sage/middleware/operators/rag/evaluate.py +658 -0
- sage/middleware/operators/rag/generator.py +340 -0
- sage/middleware/operators/rag/index_builder/__init__.py +48 -0
- sage/middleware/operators/rag/index_builder/builder.py +363 -0
- sage/middleware/operators/rag/index_builder/manifest.py +101 -0
- sage/middleware/operators/rag/index_builder/storage.py +131 -0
- sage/middleware/operators/rag/pipeline.py +46 -0
- sage/middleware/operators/rag/profiler.py +59 -0
- sage/middleware/operators/rag/promptor.py +400 -0
- sage/middleware/operators/rag/refiner.py +231 -0
- sage/middleware/operators/rag/reranker.py +364 -0
- sage/middleware/operators/rag/retriever.py +1308 -0
- sage/middleware/operators/rag/searcher.py +37 -0
- sage/middleware/operators/rag/types.py +28 -0
- sage/middleware/operators/rag/writer.py +80 -0
- sage/middleware/operators/tools/__init__.py +71 -0
- sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
- sage/middleware/operators/tools/arxiv_searcher.py +102 -0
- sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
- sage/middleware/operators/tools/image_captioner.py +104 -0
- sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
- sage/middleware/operators/tools/searcher_tool.py +514 -0
- sage/middleware/operators/tools/text_detector.py +185 -0
- sage/middleware/operators/tools/url_text_extractor.py +104 -0
- sage/middleware/py.typed +2 -0
|
@@ -0,0 +1,658 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import string
|
|
3
|
+
from collections import Counter
|
|
4
|
+
|
|
5
|
+
from rouge import Rouge
|
|
6
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
7
|
+
from transformers import AutoModel, AutoTokenizer
|
|
8
|
+
|
|
9
|
+
from sage.common.core.functions import MapFunction as MapOperator
|
|
10
|
+
from sage.kernel.runtime.communication.packet import StopSignal
|
|
11
|
+
|
|
12
|
+
# =============================================================================
|
|
13
|
+
# RECOMP-style Answer Normalization (标准化答案文本)
|
|
14
|
+
# =============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def normalize_answer(s: str) -> str:
|
|
18
|
+
"""RECOMP 风格的答案标准化
|
|
19
|
+
|
|
20
|
+
步骤:
|
|
21
|
+
1. 转小写
|
|
22
|
+
2. 移除标点符号
|
|
23
|
+
3. 移除冠词 (a, an, the)
|
|
24
|
+
4. 修复空白字符
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
s: 原始答案文本
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
标准化后的答案文本
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def remove_articles(text: str) -> str:
|
|
34
|
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
|
35
|
+
|
|
36
|
+
def white_space_fix(text: str) -> str:
|
|
37
|
+
return " ".join(text.split())
|
|
38
|
+
|
|
39
|
+
def remove_punc(text: str) -> str:
|
|
40
|
+
exclude = set(string.punctuation)
|
|
41
|
+
return "".join(ch for ch in text if ch not in exclude)
|
|
42
|
+
|
|
43
|
+
def lower(text: str) -> str:
|
|
44
|
+
return text.lower()
|
|
45
|
+
|
|
46
|
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_normalized_tokens(s: str) -> list[str]:
|
|
50
|
+
"""获取标准化后的 token 列表
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
s: 原始文本
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
标准化后的 token 列表
|
|
57
|
+
"""
|
|
58
|
+
if not s:
|
|
59
|
+
return []
|
|
60
|
+
return normalize_answer(s).split()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def answer_extract(pred: str) -> str:
|
|
64
|
+
"""提取答案文本
|
|
65
|
+
|
|
66
|
+
支持 "answer is" 前缀格式的答案提取。
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
pred: 预测文本
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
提取后的答案文本
|
|
73
|
+
"""
|
|
74
|
+
prefix = "answer is "
|
|
75
|
+
if prefix in pred.lower():
|
|
76
|
+
idx = pred.lower().rfind(prefix)
|
|
77
|
+
return pred[idx + len(prefix) :].strip()
|
|
78
|
+
return pred.strip()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _get_results_collector():
|
|
82
|
+
"""
|
|
83
|
+
延迟导入 ResultsCollector 以避免循环依赖
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
ResultsCollector 实例,如果不可用则返回 None
|
|
87
|
+
"""
|
|
88
|
+
try:
|
|
89
|
+
from sage.common.utils.results_collector import ResultsCollector
|
|
90
|
+
|
|
91
|
+
return ResultsCollector()
|
|
92
|
+
except ImportError:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class MetricsAggregator:
|
|
97
|
+
"""全局指标聚合器,用于收集和计算平均指标"""
|
|
98
|
+
|
|
99
|
+
_instance = None
|
|
100
|
+
|
|
101
|
+
def __new__(cls):
|
|
102
|
+
if cls._instance is None:
|
|
103
|
+
cls._instance = super().__new__(cls)
|
|
104
|
+
cls._instance.reset()
|
|
105
|
+
return cls._instance
|
|
106
|
+
|
|
107
|
+
def reset(self):
|
|
108
|
+
"""重置所有统计数据"""
|
|
109
|
+
self.metrics = {
|
|
110
|
+
"f1_scores": [],
|
|
111
|
+
"em_scores": [], # Exact Match scores
|
|
112
|
+
"token_counts": [],
|
|
113
|
+
"retrieve_times": [],
|
|
114
|
+
"refine_times": [],
|
|
115
|
+
"generate_times": [],
|
|
116
|
+
"total_latencies": [],
|
|
117
|
+
"compression_rates": [],
|
|
118
|
+
}
|
|
119
|
+
self.sample_count = 0
|
|
120
|
+
|
|
121
|
+
def add_f1(self, score):
|
|
122
|
+
self.metrics["f1_scores"].append(score)
|
|
123
|
+
|
|
124
|
+
def add_em(self, score):
|
|
125
|
+
"""添加 Exact Match 分数"""
|
|
126
|
+
self.metrics["em_scores"].append(score)
|
|
127
|
+
|
|
128
|
+
def add_token_count(self, count):
|
|
129
|
+
self.metrics["token_counts"].append(count)
|
|
130
|
+
|
|
131
|
+
def add_latency(self, retrieve, refine, generate):
|
|
132
|
+
self.metrics["retrieve_times"].append(retrieve)
|
|
133
|
+
self.metrics["refine_times"].append(refine)
|
|
134
|
+
self.metrics["generate_times"].append(generate)
|
|
135
|
+
self.metrics["total_latencies"].append(retrieve + refine + generate)
|
|
136
|
+
self.sample_count += 1
|
|
137
|
+
|
|
138
|
+
def add_compression_rate(self, rate):
|
|
139
|
+
self.metrics["compression_rates"].append(rate)
|
|
140
|
+
|
|
141
|
+
def print_summary(self):
|
|
142
|
+
"""打印汇总统计信息"""
|
|
143
|
+
if self.sample_count == 0:
|
|
144
|
+
print("\n" + "=" * 80)
|
|
145
|
+
print("No samples processed")
|
|
146
|
+
print("=" * 80)
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
print("\n" + "=" * 80)
|
|
150
|
+
print(f"SUMMARY STATISTICS ({self.sample_count} samples)")
|
|
151
|
+
print("=" * 80)
|
|
152
|
+
|
|
153
|
+
# Exact Match Score
|
|
154
|
+
if self.metrics["em_scores"]:
|
|
155
|
+
avg_em = sum(self.metrics["em_scores"]) / len(self.metrics["em_scores"])
|
|
156
|
+
print(f"\033[92m[Average EM Score] : {avg_em:.4f}\033[0m")
|
|
157
|
+
|
|
158
|
+
# F1 Score
|
|
159
|
+
if self.metrics["f1_scores"]:
|
|
160
|
+
avg_f1 = sum(self.metrics["f1_scores"]) / len(self.metrics["f1_scores"])
|
|
161
|
+
print(f"\033[92m[Average F1 Score] : {avg_f1:.4f}\033[0m")
|
|
162
|
+
|
|
163
|
+
# Token Count
|
|
164
|
+
if self.metrics["token_counts"]:
|
|
165
|
+
avg_tokens = sum(self.metrics["token_counts"]) / len(self.metrics["token_counts"])
|
|
166
|
+
print(f"\033[92m[Average Token Count] : {avg_tokens:.0f}\033[0m")
|
|
167
|
+
|
|
168
|
+
# Latency
|
|
169
|
+
if self.metrics["retrieve_times"]:
|
|
170
|
+
avg_retrieve = sum(self.metrics["retrieve_times"]) / len(self.metrics["retrieve_times"])
|
|
171
|
+
avg_refine = sum(self.metrics["refine_times"]) / len(self.metrics["refine_times"])
|
|
172
|
+
avg_generate = sum(self.metrics["generate_times"]) / len(self.metrics["generate_times"])
|
|
173
|
+
avg_total = sum(self.metrics["total_latencies"]) / len(self.metrics["total_latencies"])
|
|
174
|
+
|
|
175
|
+
print(f"\033[92m[Average Retrieve Time] : {avg_retrieve:.2f}s\033[0m")
|
|
176
|
+
print(f"\033[92m[Average Refine Time] : {avg_refine:.2f}s\033[0m")
|
|
177
|
+
print(f"\033[92m[Average Generate Time] : {avg_generate:.2f}s\033[0m")
|
|
178
|
+
avg_min = avg_total / 60
|
|
179
|
+
print(f"\033[92m[Average Total Latency] : {avg_total:.2f}s ({avg_min:.2f}m)\033[0m")
|
|
180
|
+
|
|
181
|
+
# Compression Rate
|
|
182
|
+
if self.metrics["compression_rates"]:
|
|
183
|
+
valid_rates = [r for r in self.metrics["compression_rates"] if r > 0]
|
|
184
|
+
if valid_rates:
|
|
185
|
+
avg_compression = sum(valid_rates) / len(valid_rates)
|
|
186
|
+
print(f"\033[92m[Average Compression Rate]: {avg_compression:.2f}×\033[0m")
|
|
187
|
+
|
|
188
|
+
print("=" * 80 + "\n")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class F1Evaluate(MapOperator):
|
|
192
|
+
"""F1分数评估器(RECOMP 标准)
|
|
193
|
+
|
|
194
|
+
使用 RECOMP 风格的答案标准化进行 F1 分数计算。
|
|
195
|
+
标准化步骤:转小写、移除标点、移除冠词、修复空白。
|
|
196
|
+
|
|
197
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(self, config=None, **kwargs):
|
|
201
|
+
super().__init__(**kwargs)
|
|
202
|
+
self.aggregator = MetricsAggregator()
|
|
203
|
+
# 是否提取 "answer is" 前缀后的答案
|
|
204
|
+
self.extract_answer = config.get("extract_answer", False) if config else False
|
|
205
|
+
|
|
206
|
+
def _f1_score(self, pred: str, ref: str) -> float:
|
|
207
|
+
"""计算 F1 分数(RECOMP 标准)
|
|
208
|
+
|
|
209
|
+
使用标准化后的 token 进行计算。
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
pred: 预测答案
|
|
213
|
+
ref: 参考答案
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
F1 分数
|
|
217
|
+
"""
|
|
218
|
+
gold_toks = get_normalized_tokens(ref)
|
|
219
|
+
pred_toks = get_normalized_tokens(pred)
|
|
220
|
+
|
|
221
|
+
common = Counter(gold_toks) & Counter(pred_toks)
|
|
222
|
+
num_same = sum(common.values())
|
|
223
|
+
|
|
224
|
+
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
|
225
|
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
|
226
|
+
return float(gold_toks == pred_toks)
|
|
227
|
+
|
|
228
|
+
if num_same == 0:
|
|
229
|
+
return 0.0
|
|
230
|
+
|
|
231
|
+
precision = 1.0 * num_same / len(pred_toks)
|
|
232
|
+
recall = 1.0 * num_same / len(gold_toks)
|
|
233
|
+
f1 = (2 * precision * recall) / (precision + recall)
|
|
234
|
+
return f1
|
|
235
|
+
|
|
236
|
+
def execute(self, data):
|
|
237
|
+
# Handle StopSignal - 不输出,让 CompressionRateEvaluate 最后统一输出
|
|
238
|
+
if isinstance(data, StopSignal):
|
|
239
|
+
return data
|
|
240
|
+
|
|
241
|
+
golds = data.get("references", [])
|
|
242
|
+
pred = data.get("generated", "")
|
|
243
|
+
|
|
244
|
+
# 可选:提取 "answer is" 后的答案
|
|
245
|
+
if self.extract_answer:
|
|
246
|
+
pred = answer_extract(pred)
|
|
247
|
+
|
|
248
|
+
best = max((self._f1_score(pred, g) for g in golds), default=0.0) if golds else 0.0
|
|
249
|
+
|
|
250
|
+
# Add to aggregator
|
|
251
|
+
self.aggregator.add_f1(best)
|
|
252
|
+
|
|
253
|
+
# Add to ResultsCollector (if available)
|
|
254
|
+
collector = _get_results_collector()
|
|
255
|
+
if collector is not None:
|
|
256
|
+
sample_id = data.get("sample_id", data.get("_sample_idx"))
|
|
257
|
+
collector.update_sample(sample_id, f1=best)
|
|
258
|
+
|
|
259
|
+
print(f"\033[93m[F1] : {best:.4f}\033[0m")
|
|
260
|
+
return data
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class EMEvaluate(MapOperator):
|
|
264
|
+
"""Exact Match 评估器(RECOMP 标准)
|
|
265
|
+
|
|
266
|
+
使用 RECOMP 风格的答案标准化进行精确匹配计算。
|
|
267
|
+
标准化步骤:转小写、移除标点、移除冠词、修复空白。
|
|
268
|
+
|
|
269
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(self, config=None, **kwargs):
|
|
273
|
+
super().__init__(**kwargs)
|
|
274
|
+
self.aggregator = MetricsAggregator()
|
|
275
|
+
# 是否提取 "answer is" 前缀后的答案
|
|
276
|
+
self.extract_answer = config.get("extract_answer", False) if config else False
|
|
277
|
+
|
|
278
|
+
def _exact_match(self, pred: str, gold: str) -> int:
|
|
279
|
+
"""计算 Exact Match(RECOMP 标准)
|
|
280
|
+
|
|
281
|
+
使用标准化后的文本进行精确匹配。
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
pred: 预测答案
|
|
285
|
+
gold: 参考答案
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
1 如果匹配,否则 0
|
|
289
|
+
"""
|
|
290
|
+
return int(normalize_answer(pred) == normalize_answer(gold))
|
|
291
|
+
|
|
292
|
+
def execute(self, data):
|
|
293
|
+
# Handle StopSignal - 不输出,让 CompressionRateEvaluate 最后统一输出
|
|
294
|
+
if isinstance(data, StopSignal):
|
|
295
|
+
return data
|
|
296
|
+
|
|
297
|
+
golds = data.get("references", [])
|
|
298
|
+
pred = data.get("generated", "")
|
|
299
|
+
|
|
300
|
+
# 可选:提取 "answer is" 后的答案
|
|
301
|
+
if self.extract_answer:
|
|
302
|
+
pred = answer_extract(pred)
|
|
303
|
+
|
|
304
|
+
best = max((self._exact_match(pred, g) for g in golds), default=0) if golds else 0
|
|
305
|
+
|
|
306
|
+
# Add to aggregator
|
|
307
|
+
self.aggregator.add_em(best)
|
|
308
|
+
|
|
309
|
+
print(f"\033[93m[EM] : {best}\033[0m")
|
|
310
|
+
return data
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class RecallEvaluate(MapOperator):
|
|
314
|
+
"""Recall评估器
|
|
315
|
+
|
|
316
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def _get_tokens(self, text: str):
|
|
320
|
+
return text.lower().split()
|
|
321
|
+
|
|
322
|
+
def _recall(self, pred: str, ref: str):
|
|
323
|
+
r = Counter(self._get_tokens(ref))
|
|
324
|
+
p = Counter(self._get_tokens(pred))
|
|
325
|
+
if not r:
|
|
326
|
+
return 0.0
|
|
327
|
+
common = r & p
|
|
328
|
+
return float(sum(common.values()) / sum(r.values()))
|
|
329
|
+
|
|
330
|
+
def execute(self, data: dict):
|
|
331
|
+
golds = data.get("references", [])
|
|
332
|
+
pred = data.get("generated", "")
|
|
333
|
+
best = max(self._recall(pred, g) for g in golds) if golds else 0.0
|
|
334
|
+
print(f"\033[93m[Recall] : {best:.4f}\033[0m")
|
|
335
|
+
return data
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class BertRecallEvaluate(MapOperator):
|
|
339
|
+
"""BERT Recall评估器
|
|
340
|
+
|
|
341
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
def __init__(self, config=None, **kwargs):
|
|
345
|
+
super().__init__(**kwargs)
|
|
346
|
+
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
|
347
|
+
self.model = AutoModel.from_pretrained("bert-base-uncased")
|
|
348
|
+
|
|
349
|
+
def execute(self, data: dict):
|
|
350
|
+
golds = data.get("references", [])
|
|
351
|
+
pred = data.get("generated", "")
|
|
352
|
+
scores = []
|
|
353
|
+
for g in golds:
|
|
354
|
+
encs = self.tokenizer([pred, g], return_tensors="pt", padding=True)
|
|
355
|
+
embs = self.model(**encs).last_hidden_state.mean(dim=1).detach().numpy()
|
|
356
|
+
# Convert to numpy arrays explicitly for cosine_similarity
|
|
357
|
+
emb_pred = embs[0:1] # Shape: (1, embedding_dim)
|
|
358
|
+
emb_gold = embs[1:2] # Shape: (1, embedding_dim)
|
|
359
|
+
similarity = cosine_similarity(emb_pred, emb_gold)
|
|
360
|
+
scores.append(float(similarity[0][0]))
|
|
361
|
+
best = max(scores) if scores else 0.0
|
|
362
|
+
print(f"\033[93m[BertRecall] : {best:.4f}\033[0m")
|
|
363
|
+
return data
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class RougeLEvaluate(MapOperator):
|
|
367
|
+
"""ROUGE-L评估器
|
|
368
|
+
|
|
369
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
def __init__(self, config=None, **kwargs):
|
|
373
|
+
super().__init__(**kwargs)
|
|
374
|
+
self.rouge = Rouge()
|
|
375
|
+
|
|
376
|
+
def execute(self, data: dict):
|
|
377
|
+
golds = data.get("references", [])
|
|
378
|
+
pred = data.get("generated", "")
|
|
379
|
+
scores = []
|
|
380
|
+
for g in golds:
|
|
381
|
+
# rouge.get_scores returns a list with one dict
|
|
382
|
+
rouge_result = self.rouge.get_scores(pred, g)
|
|
383
|
+
if rouge_result and isinstance(rouge_result, list):
|
|
384
|
+
scores.append(rouge_result[0]["rouge-l"]["f"])
|
|
385
|
+
best = max(scores) if scores else 0.0
|
|
386
|
+
print(f"\033[93m[ROUGE-L] : {best:.4f}\033[0m")
|
|
387
|
+
return data
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class BRSEvaluate(MapOperator):
|
|
391
|
+
"""BRS评估器
|
|
392
|
+
|
|
393
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
def execute(self, data: dict):
|
|
397
|
+
golds = data.get("references", [])
|
|
398
|
+
pred = data.get("generated", "")
|
|
399
|
+
scores = [(len(set(pred) & set(g)) / len(set(g))) if g else 0.0 for g in golds]
|
|
400
|
+
best = max(scores) if scores else 0.0
|
|
401
|
+
print(f"\033[93m[BRS] : {best:.4f}\033[0m")
|
|
402
|
+
return data
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
class AccuracyEvaluate(MapOperator):
|
|
406
|
+
"""准确率评估器
|
|
407
|
+
|
|
408
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
def _normalize_text(self, text: str) -> str:
|
|
412
|
+
"""标准化文本用于比较"""
|
|
413
|
+
return text.lower().strip()
|
|
414
|
+
|
|
415
|
+
def execute(self, data: dict):
|
|
416
|
+
golds = data.get("references", [])
|
|
417
|
+
pred = data.get("generated", "")
|
|
418
|
+
|
|
419
|
+
if not golds or not pred:
|
|
420
|
+
print("\033[93m[Acc] : 0.0000\033[0m")
|
|
421
|
+
return data
|
|
422
|
+
|
|
423
|
+
pred_norm = self._normalize_text(pred)
|
|
424
|
+
|
|
425
|
+
# 准确率:检查预测答案是否与任一参考答案匹配(完全匹配或关键词匹配)
|
|
426
|
+
correct = False
|
|
427
|
+
for gold in golds:
|
|
428
|
+
gold_norm = self._normalize_text(gold)
|
|
429
|
+
# 检查是否有关键词匹配
|
|
430
|
+
gold_words = set(gold_norm.split())
|
|
431
|
+
pred_words = set(pred_norm.split())
|
|
432
|
+
# 如果预测答案包含参考答案中的重要词汇,认为是正确的
|
|
433
|
+
if gold_words and len(gold_words & pred_words) / len(gold_words) >= 0.3:
|
|
434
|
+
correct = True
|
|
435
|
+
break
|
|
436
|
+
|
|
437
|
+
print(f"\033[93m[Acc] : {float(correct):.4f}\033[0m")
|
|
438
|
+
return data
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
class TokenCountEvaluate(MapOperator):
|
|
442
|
+
"""Token计数评估器
|
|
443
|
+
|
|
444
|
+
统计送入生成器的最终prompt的token数量(使用真实tokenizer)
|
|
445
|
+
优先级:compressed_context(压缩后)> refining_results > retrieval_results(原始)
|
|
446
|
+
|
|
447
|
+
输入数据格式:{"query": str, "compressed_context": str, "refining_results": List[str], ...} 或
|
|
448
|
+
{"query": str, "retrieval_results": List[Dict], ...}
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
def __init__(self, config=None, **kwargs):
|
|
452
|
+
super().__init__(**kwargs)
|
|
453
|
+
self.aggregator = MetricsAggregator()
|
|
454
|
+
# 使用与REFORM相同的tokenizer以保持一致性
|
|
455
|
+
try:
|
|
456
|
+
from transformers import AutoTokenizer
|
|
457
|
+
|
|
458
|
+
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
|
459
|
+
except Exception:
|
|
460
|
+
self.tokenizer = None
|
|
461
|
+
|
|
462
|
+
def execute(self, data):
|
|
463
|
+
# Handle StopSignal
|
|
464
|
+
if isinstance(data, StopSignal):
|
|
465
|
+
return data
|
|
466
|
+
|
|
467
|
+
# 优先使用 compressed_context(最终送入生成器的文本)
|
|
468
|
+
context = data.get("compressed_context")
|
|
469
|
+
if context:
|
|
470
|
+
# 使用真实tokenizer计算token数
|
|
471
|
+
if self.tokenizer:
|
|
472
|
+
total_tokens = len(self.tokenizer.encode(context))
|
|
473
|
+
else:
|
|
474
|
+
total_tokens = len(context.split())
|
|
475
|
+
else:
|
|
476
|
+
# 回退到旧的计算方式
|
|
477
|
+
docs = data.get("refining_results") or data.get("retrieval_results", [])
|
|
478
|
+
total_tokens = 0
|
|
479
|
+
if docs:
|
|
480
|
+
for doc in docs:
|
|
481
|
+
if isinstance(doc, dict):
|
|
482
|
+
text = doc.get("text", str(doc))
|
|
483
|
+
elif isinstance(doc, str):
|
|
484
|
+
text = doc
|
|
485
|
+
else:
|
|
486
|
+
text = str(doc)
|
|
487
|
+
|
|
488
|
+
if self.tokenizer:
|
|
489
|
+
total_tokens += len(self.tokenizer.encode(text))
|
|
490
|
+
else:
|
|
491
|
+
total_tokens += len(text.split())
|
|
492
|
+
|
|
493
|
+
# Add to aggregator
|
|
494
|
+
self.aggregator.add_token_count(total_tokens)
|
|
495
|
+
|
|
496
|
+
# Add to ResultsCollector (if available)
|
|
497
|
+
collector = _get_results_collector()
|
|
498
|
+
if collector is not None:
|
|
499
|
+
sample_id = data.get("sample_id", data.get("_sample_idx"))
|
|
500
|
+
collector.update_sample(sample_id, token_count=total_tokens)
|
|
501
|
+
|
|
502
|
+
print(f"\033[93m[Token Count] : {total_tokens}\033[0m")
|
|
503
|
+
return data
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
class LatencyEvaluate(MapOperator):
|
|
507
|
+
"""延迟评估器
|
|
508
|
+
|
|
509
|
+
输入数据格式:
|
|
510
|
+
{"query": str, "retrieve_time": float, "refine_time": float,
|
|
511
|
+
"generate_time": float, ...}
|
|
512
|
+
"""
|
|
513
|
+
|
|
514
|
+
def __init__(self, config=None, **kwargs):
|
|
515
|
+
super().__init__(**kwargs)
|
|
516
|
+
self.aggregator = MetricsAggregator()
|
|
517
|
+
|
|
518
|
+
def execute(self, data):
|
|
519
|
+
# Handle StopSignal - 不输出,让 CompressionRateEvaluate 最后统一输出
|
|
520
|
+
if isinstance(data, StopSignal):
|
|
521
|
+
return data
|
|
522
|
+
|
|
523
|
+
retrieve_time = data.get("retrieve_time", 0)
|
|
524
|
+
refine_time = data.get("refine_time", 0.0)
|
|
525
|
+
generate_time = data.get("generate_time", 0.0)
|
|
526
|
+
total_lat = retrieve_time + refine_time + generate_time
|
|
527
|
+
|
|
528
|
+
# Add to aggregator
|
|
529
|
+
self.aggregator.add_latency(retrieve_time, refine_time, generate_time)
|
|
530
|
+
|
|
531
|
+
# Add to ResultsCollector (if available)
|
|
532
|
+
collector = _get_results_collector()
|
|
533
|
+
if collector is not None:
|
|
534
|
+
sample_id = data.get("sample_id", data.get("_sample_idx"))
|
|
535
|
+
collector.update_sample(
|
|
536
|
+
sample_id,
|
|
537
|
+
retrieve_time=retrieve_time,
|
|
538
|
+
refine_time=refine_time,
|
|
539
|
+
generate_time=generate_time,
|
|
540
|
+
total_time=total_lat,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
print(f"\033[93m[Retrieve Time] : {retrieve_time:.2f}s\033[0m")
|
|
544
|
+
print(f"\033[93m[Refine Time] : {refine_time:.2f}s\033[0m")
|
|
545
|
+
print(f"\033[93m[Generate Time] : {generate_time:.2f}s\033[0m")
|
|
546
|
+
print(f"\033[93m[Total Latency] : {total_lat:.2f}s\033[0m")
|
|
547
|
+
return data
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class ContextRecallEvaluate(MapOperator):
|
|
551
|
+
"""上下文召回率评估器
|
|
552
|
+
|
|
553
|
+
输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
|
|
554
|
+
"""
|
|
555
|
+
|
|
556
|
+
def _normalize_text(self, text: str) -> str:
|
|
557
|
+
"""标准化文本用于比较"""
|
|
558
|
+
return text.lower().strip()
|
|
559
|
+
|
|
560
|
+
def execute(self, data: dict):
|
|
561
|
+
golds = data.get("references", [])
|
|
562
|
+
pred = data.get("generated", "")
|
|
563
|
+
|
|
564
|
+
if not golds or not pred:
|
|
565
|
+
print("\033[93m[Context Recall] : 0.0000\033[0m")
|
|
566
|
+
return data
|
|
567
|
+
|
|
568
|
+
pred_norm = self._normalize_text(pred)
|
|
569
|
+
pred_words = set(pred_norm.split())
|
|
570
|
+
|
|
571
|
+
# 计算有多少参考答案的关键词在生成答案中被提及
|
|
572
|
+
total_recall = 0.0
|
|
573
|
+
for gold in golds:
|
|
574
|
+
gold_norm = self._normalize_text(gold)
|
|
575
|
+
gold_words = set(gold_norm.split())
|
|
576
|
+
if gold_words:
|
|
577
|
+
# 计算当前参考答案的recall
|
|
578
|
+
matched_words = len(gold_words & pred_words)
|
|
579
|
+
recall = matched_words / len(gold_words)
|
|
580
|
+
total_recall = max(total_recall, recall) # 取最大值
|
|
581
|
+
|
|
582
|
+
print(f"\033[93m[Context Recall] : {total_recall:.4f}\033[0m")
|
|
583
|
+
return data
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
class CompressionRateEvaluate(MapOperator):
|
|
587
|
+
"""计算文档压缩率
|
|
588
|
+
|
|
589
|
+
压缩率 = 原始文档token数 / 压缩后文档token数
|
|
590
|
+
|
|
591
|
+
输入数据格式:
|
|
592
|
+
{"query": str, "retrieval_results": List[Dict],
|
|
593
|
+
"refining_results": List[str], ...}
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
retrieval_results: 原始检索的文档(用于计算原始token数)
|
|
597
|
+
refining_results: 压缩后的文档文本(用于计算压缩后token数)
|
|
598
|
+
"""
|
|
599
|
+
|
|
600
|
+
def __init__(self, config=None, **kwargs):
|
|
601
|
+
super().__init__(**kwargs)
|
|
602
|
+
self.aggregator = MetricsAggregator()
|
|
603
|
+
|
|
604
|
+
def _count_tokens(self, docs):
|
|
605
|
+
"""计算文档列表的总token数"""
|
|
606
|
+
if not docs:
|
|
607
|
+
return 0
|
|
608
|
+
# 处理不同格式的文档
|
|
609
|
+
total = 0
|
|
610
|
+
for doc in docs:
|
|
611
|
+
if isinstance(doc, dict):
|
|
612
|
+
# Dict格式:提取text字段
|
|
613
|
+
text = doc.get("text", doc.get("content", str(doc)))
|
|
614
|
+
total += len(text.split())
|
|
615
|
+
elif isinstance(doc, str):
|
|
616
|
+
# 字符串格式
|
|
617
|
+
total += len(doc.split())
|
|
618
|
+
else:
|
|
619
|
+
total += len(str(doc).split())
|
|
620
|
+
return total
|
|
621
|
+
|
|
622
|
+
def execute(self, data):
|
|
623
|
+
# Handle StopSignal - 在最后输出完整汇总统计
|
|
624
|
+
if isinstance(data, StopSignal):
|
|
625
|
+
print("\n") # 添加空行分隔
|
|
626
|
+
self.aggregator.print_summary()
|
|
627
|
+
return data
|
|
628
|
+
|
|
629
|
+
# 获取原始检索文档的token数
|
|
630
|
+
retrieved_docs = data.get("retrieval_results", [])
|
|
631
|
+
retrieved_tokens = self._count_tokens(retrieved_docs)
|
|
632
|
+
|
|
633
|
+
# 获取压缩后文档的token数
|
|
634
|
+
refined_docs = data.get("refining_results", [])
|
|
635
|
+
refined_tokens = self._count_tokens(refined_docs)
|
|
636
|
+
|
|
637
|
+
# 计算压缩率
|
|
638
|
+
if refined_tokens > 0 and retrieved_tokens > 0:
|
|
639
|
+
compression_rate = retrieved_tokens / refined_tokens
|
|
640
|
+
else:
|
|
641
|
+
compression_rate = 0.0
|
|
642
|
+
|
|
643
|
+
# Add to aggregator
|
|
644
|
+
self.aggregator.add_compression_rate(compression_rate)
|
|
645
|
+
|
|
646
|
+
# Add to ResultsCollector (if available)
|
|
647
|
+
collector = _get_results_collector()
|
|
648
|
+
if collector is not None:
|
|
649
|
+
sample_id = data.get("sample_id", data.get("_sample_idx"))
|
|
650
|
+
collector.update_sample(
|
|
651
|
+
sample_id,
|
|
652
|
+
compression_rate=compression_rate,
|
|
653
|
+
original_tokens=retrieved_tokens,
|
|
654
|
+
compressed_tokens=refined_tokens,
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
print(f"\033[93m[Compression Rate] : {compression_rate:.2f}×\033[0m")
|
|
658
|
+
return data
|