isage-middleware 0.2.4.3__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. isage_middleware-0.2.4.3.dist-info/METADATA +266 -0
  2. isage_middleware-0.2.4.3.dist-info/RECORD +94 -0
  3. isage_middleware-0.2.4.3.dist-info/WHEEL +5 -0
  4. isage_middleware-0.2.4.3.dist-info/top_level.txt +1 -0
  5. sage/middleware/__init__.py +59 -0
  6. sage/middleware/_version.py +6 -0
  7. sage/middleware/components/__init__.py +30 -0
  8. sage/middleware/components/extensions_compat.py +141 -0
  9. sage/middleware/components/sage_db/__init__.py +116 -0
  10. sage/middleware/components/sage_db/backend.py +136 -0
  11. sage/middleware/components/sage_db/service.py +15 -0
  12. sage/middleware/components/sage_flow/__init__.py +76 -0
  13. sage/middleware/components/sage_flow/python/__init__.py +14 -0
  14. sage/middleware/components/sage_flow/python/micro_service/__init__.py +4 -0
  15. sage/middleware/components/sage_flow/python/micro_service/sage_flow_service.py +88 -0
  16. sage/middleware/components/sage_flow/python/sage_flow.py +30 -0
  17. sage/middleware/components/sage_flow/service.py +14 -0
  18. sage/middleware/components/sage_mem/__init__.py +83 -0
  19. sage/middleware/components/sage_sias/__init__.py +59 -0
  20. sage/middleware/components/sage_sias/continual_learner.py +184 -0
  21. sage/middleware/components/sage_sias/coreset_selector.py +302 -0
  22. sage/middleware/components/sage_sias/types.py +94 -0
  23. sage/middleware/components/sage_tsdb/__init__.py +81 -0
  24. sage/middleware/components/sage_tsdb/python/__init__.py +21 -0
  25. sage/middleware/components/sage_tsdb/python/_sage_tsdb.pyi +17 -0
  26. sage/middleware/components/sage_tsdb/python/algorithms/__init__.py +17 -0
  27. sage/middleware/components/sage_tsdb/python/algorithms/base.py +51 -0
  28. sage/middleware/components/sage_tsdb/python/algorithms/out_of_order_join.py +248 -0
  29. sage/middleware/components/sage_tsdb/python/algorithms/window_aggregator.py +296 -0
  30. sage/middleware/components/sage_tsdb/python/micro_service/__init__.py +7 -0
  31. sage/middleware/components/sage_tsdb/python/micro_service/sage_tsdb_service.py +365 -0
  32. sage/middleware/components/sage_tsdb/python/sage_tsdb.py +523 -0
  33. sage/middleware/components/sage_tsdb/service.py +17 -0
  34. sage/middleware/components/vector_stores/__init__.py +25 -0
  35. sage/middleware/components/vector_stores/chroma.py +483 -0
  36. sage/middleware/components/vector_stores/chroma_adapter.py +185 -0
  37. sage/middleware/components/vector_stores/milvus.py +677 -0
  38. sage/middleware/operators/__init__.py +56 -0
  39. sage/middleware/operators/agent/__init__.py +24 -0
  40. sage/middleware/operators/agent/planning/__init__.py +5 -0
  41. sage/middleware/operators/agent/planning/llm_adapter.py +41 -0
  42. sage/middleware/operators/agent/planning/planner_adapter.py +98 -0
  43. sage/middleware/operators/agent/planning/router.py +107 -0
  44. sage/middleware/operators/agent/runtime.py +296 -0
  45. sage/middleware/operators/agentic/__init__.py +41 -0
  46. sage/middleware/operators/agentic/config.py +254 -0
  47. sage/middleware/operators/agentic/planning_operator.py +125 -0
  48. sage/middleware/operators/agentic/refined_searcher.py +132 -0
  49. sage/middleware/operators/agentic/runtime.py +241 -0
  50. sage/middleware/operators/agentic/timing_operator.py +125 -0
  51. sage/middleware/operators/agentic/tool_selection_operator.py +127 -0
  52. sage/middleware/operators/context/__init__.py +17 -0
  53. sage/middleware/operators/context/critic_evaluation.py +16 -0
  54. sage/middleware/operators/context/model_context.py +565 -0
  55. sage/middleware/operators/context/quality_label.py +12 -0
  56. sage/middleware/operators/context/search_query_results.py +61 -0
  57. sage/middleware/operators/context/search_result.py +42 -0
  58. sage/middleware/operators/context/search_session.py +79 -0
  59. sage/middleware/operators/filters/__init__.py +26 -0
  60. sage/middleware/operators/filters/context_sink.py +387 -0
  61. sage/middleware/operators/filters/context_source.py +376 -0
  62. sage/middleware/operators/filters/evaluate_filter.py +83 -0
  63. sage/middleware/operators/filters/tool_filter.py +74 -0
  64. sage/middleware/operators/llm/__init__.py +18 -0
  65. sage/middleware/operators/llm/sagellm_generator.py +432 -0
  66. sage/middleware/operators/rag/__init__.py +147 -0
  67. sage/middleware/operators/rag/arxiv.py +331 -0
  68. sage/middleware/operators/rag/chunk.py +13 -0
  69. sage/middleware/operators/rag/document_loaders.py +23 -0
  70. sage/middleware/operators/rag/evaluate.py +658 -0
  71. sage/middleware/operators/rag/generator.py +340 -0
  72. sage/middleware/operators/rag/index_builder/__init__.py +48 -0
  73. sage/middleware/operators/rag/index_builder/builder.py +363 -0
  74. sage/middleware/operators/rag/index_builder/manifest.py +101 -0
  75. sage/middleware/operators/rag/index_builder/storage.py +131 -0
  76. sage/middleware/operators/rag/pipeline.py +46 -0
  77. sage/middleware/operators/rag/profiler.py +59 -0
  78. sage/middleware/operators/rag/promptor.py +400 -0
  79. sage/middleware/operators/rag/refiner.py +231 -0
  80. sage/middleware/operators/rag/reranker.py +364 -0
  81. sage/middleware/operators/rag/retriever.py +1308 -0
  82. sage/middleware/operators/rag/searcher.py +37 -0
  83. sage/middleware/operators/rag/types.py +28 -0
  84. sage/middleware/operators/rag/writer.py +80 -0
  85. sage/middleware/operators/tools/__init__.py +71 -0
  86. sage/middleware/operators/tools/arxiv_paper_searcher.py +175 -0
  87. sage/middleware/operators/tools/arxiv_searcher.py +102 -0
  88. sage/middleware/operators/tools/duckduckgo_searcher.py +105 -0
  89. sage/middleware/operators/tools/image_captioner.py +104 -0
  90. sage/middleware/operators/tools/nature_news_fetcher.py +224 -0
  91. sage/middleware/operators/tools/searcher_tool.py +514 -0
  92. sage/middleware/operators/tools/text_detector.py +185 -0
  93. sage/middleware/operators/tools/url_text_extractor.py +104 -0
  94. sage/middleware/py.typed +2 -0
@@ -0,0 +1,658 @@
1
+ import re
2
+ import string
3
+ from collections import Counter
4
+
5
+ from rouge import Rouge
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+ from sage.common.core.functions import MapFunction as MapOperator
10
+ from sage.kernel.runtime.communication.packet import StopSignal
11
+
12
+ # =============================================================================
13
+ # RECOMP-style Answer Normalization (标准化答案文本)
14
+ # =============================================================================
15
+
16
+
17
+ def normalize_answer(s: str) -> str:
18
+ """RECOMP 风格的答案标准化
19
+
20
+ 步骤:
21
+ 1. 转小写
22
+ 2. 移除标点符号
23
+ 3. 移除冠词 (a, an, the)
24
+ 4. 修复空白字符
25
+
26
+ Args:
27
+ s: 原始答案文本
28
+
29
+ Returns:
30
+ 标准化后的答案文本
31
+ """
32
+
33
+ def remove_articles(text: str) -> str:
34
+ return re.sub(r"\b(a|an|the)\b", " ", text)
35
+
36
+ def white_space_fix(text: str) -> str:
37
+ return " ".join(text.split())
38
+
39
+ def remove_punc(text: str) -> str:
40
+ exclude = set(string.punctuation)
41
+ return "".join(ch for ch in text if ch not in exclude)
42
+
43
+ def lower(text: str) -> str:
44
+ return text.lower()
45
+
46
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
47
+
48
+
49
+ def get_normalized_tokens(s: str) -> list[str]:
50
+ """获取标准化后的 token 列表
51
+
52
+ Args:
53
+ s: 原始文本
54
+
55
+ Returns:
56
+ 标准化后的 token 列表
57
+ """
58
+ if not s:
59
+ return []
60
+ return normalize_answer(s).split()
61
+
62
+
63
+ def answer_extract(pred: str) -> str:
64
+ """提取答案文本
65
+
66
+ 支持 "answer is" 前缀格式的答案提取。
67
+
68
+ Args:
69
+ pred: 预测文本
70
+
71
+ Returns:
72
+ 提取后的答案文本
73
+ """
74
+ prefix = "answer is "
75
+ if prefix in pred.lower():
76
+ idx = pred.lower().rfind(prefix)
77
+ return pred[idx + len(prefix) :].strip()
78
+ return pred.strip()
79
+
80
+
81
+ def _get_results_collector():
82
+ """
83
+ 延迟导入 ResultsCollector 以避免循环依赖
84
+
85
+ Returns:
86
+ ResultsCollector 实例,如果不可用则返回 None
87
+ """
88
+ try:
89
+ from sage.common.utils.results_collector import ResultsCollector
90
+
91
+ return ResultsCollector()
92
+ except ImportError:
93
+ return None
94
+
95
+
96
+ class MetricsAggregator:
97
+ """全局指标聚合器,用于收集和计算平均指标"""
98
+
99
+ _instance = None
100
+
101
+ def __new__(cls):
102
+ if cls._instance is None:
103
+ cls._instance = super().__new__(cls)
104
+ cls._instance.reset()
105
+ return cls._instance
106
+
107
+ def reset(self):
108
+ """重置所有统计数据"""
109
+ self.metrics = {
110
+ "f1_scores": [],
111
+ "em_scores": [], # Exact Match scores
112
+ "token_counts": [],
113
+ "retrieve_times": [],
114
+ "refine_times": [],
115
+ "generate_times": [],
116
+ "total_latencies": [],
117
+ "compression_rates": [],
118
+ }
119
+ self.sample_count = 0
120
+
121
+ def add_f1(self, score):
122
+ self.metrics["f1_scores"].append(score)
123
+
124
+ def add_em(self, score):
125
+ """添加 Exact Match 分数"""
126
+ self.metrics["em_scores"].append(score)
127
+
128
+ def add_token_count(self, count):
129
+ self.metrics["token_counts"].append(count)
130
+
131
+ def add_latency(self, retrieve, refine, generate):
132
+ self.metrics["retrieve_times"].append(retrieve)
133
+ self.metrics["refine_times"].append(refine)
134
+ self.metrics["generate_times"].append(generate)
135
+ self.metrics["total_latencies"].append(retrieve + refine + generate)
136
+ self.sample_count += 1
137
+
138
+ def add_compression_rate(self, rate):
139
+ self.metrics["compression_rates"].append(rate)
140
+
141
+ def print_summary(self):
142
+ """打印汇总统计信息"""
143
+ if self.sample_count == 0:
144
+ print("\n" + "=" * 80)
145
+ print("No samples processed")
146
+ print("=" * 80)
147
+ return
148
+
149
+ print("\n" + "=" * 80)
150
+ print(f"SUMMARY STATISTICS ({self.sample_count} samples)")
151
+ print("=" * 80)
152
+
153
+ # Exact Match Score
154
+ if self.metrics["em_scores"]:
155
+ avg_em = sum(self.metrics["em_scores"]) / len(self.metrics["em_scores"])
156
+ print(f"\033[92m[Average EM Score] : {avg_em:.4f}\033[0m")
157
+
158
+ # F1 Score
159
+ if self.metrics["f1_scores"]:
160
+ avg_f1 = sum(self.metrics["f1_scores"]) / len(self.metrics["f1_scores"])
161
+ print(f"\033[92m[Average F1 Score] : {avg_f1:.4f}\033[0m")
162
+
163
+ # Token Count
164
+ if self.metrics["token_counts"]:
165
+ avg_tokens = sum(self.metrics["token_counts"]) / len(self.metrics["token_counts"])
166
+ print(f"\033[92m[Average Token Count] : {avg_tokens:.0f}\033[0m")
167
+
168
+ # Latency
169
+ if self.metrics["retrieve_times"]:
170
+ avg_retrieve = sum(self.metrics["retrieve_times"]) / len(self.metrics["retrieve_times"])
171
+ avg_refine = sum(self.metrics["refine_times"]) / len(self.metrics["refine_times"])
172
+ avg_generate = sum(self.metrics["generate_times"]) / len(self.metrics["generate_times"])
173
+ avg_total = sum(self.metrics["total_latencies"]) / len(self.metrics["total_latencies"])
174
+
175
+ print(f"\033[92m[Average Retrieve Time] : {avg_retrieve:.2f}s\033[0m")
176
+ print(f"\033[92m[Average Refine Time] : {avg_refine:.2f}s\033[0m")
177
+ print(f"\033[92m[Average Generate Time] : {avg_generate:.2f}s\033[0m")
178
+ avg_min = avg_total / 60
179
+ print(f"\033[92m[Average Total Latency] : {avg_total:.2f}s ({avg_min:.2f}m)\033[0m")
180
+
181
+ # Compression Rate
182
+ if self.metrics["compression_rates"]:
183
+ valid_rates = [r for r in self.metrics["compression_rates"] if r > 0]
184
+ if valid_rates:
185
+ avg_compression = sum(valid_rates) / len(valid_rates)
186
+ print(f"\033[92m[Average Compression Rate]: {avg_compression:.2f}×\033[0m")
187
+
188
+ print("=" * 80 + "\n")
189
+
190
+
191
+ class F1Evaluate(MapOperator):
192
+ """F1分数评估器(RECOMP 标准)
193
+
194
+ 使用 RECOMP 风格的答案标准化进行 F1 分数计算。
195
+ 标准化步骤:转小写、移除标点、移除冠词、修复空白。
196
+
197
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
198
+ """
199
+
200
+ def __init__(self, config=None, **kwargs):
201
+ super().__init__(**kwargs)
202
+ self.aggregator = MetricsAggregator()
203
+ # 是否提取 "answer is" 前缀后的答案
204
+ self.extract_answer = config.get("extract_answer", False) if config else False
205
+
206
+ def _f1_score(self, pred: str, ref: str) -> float:
207
+ """计算 F1 分数(RECOMP 标准)
208
+
209
+ 使用标准化后的 token 进行计算。
210
+
211
+ Args:
212
+ pred: 预测答案
213
+ ref: 参考答案
214
+
215
+ Returns:
216
+ F1 分数
217
+ """
218
+ gold_toks = get_normalized_tokens(ref)
219
+ pred_toks = get_normalized_tokens(pred)
220
+
221
+ common = Counter(gold_toks) & Counter(pred_toks)
222
+ num_same = sum(common.values())
223
+
224
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
225
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
226
+ return float(gold_toks == pred_toks)
227
+
228
+ if num_same == 0:
229
+ return 0.0
230
+
231
+ precision = 1.0 * num_same / len(pred_toks)
232
+ recall = 1.0 * num_same / len(gold_toks)
233
+ f1 = (2 * precision * recall) / (precision + recall)
234
+ return f1
235
+
236
+ def execute(self, data):
237
+ # Handle StopSignal - 不输出,让 CompressionRateEvaluate 最后统一输出
238
+ if isinstance(data, StopSignal):
239
+ return data
240
+
241
+ golds = data.get("references", [])
242
+ pred = data.get("generated", "")
243
+
244
+ # 可选:提取 "answer is" 后的答案
245
+ if self.extract_answer:
246
+ pred = answer_extract(pred)
247
+
248
+ best = max((self._f1_score(pred, g) for g in golds), default=0.0) if golds else 0.0
249
+
250
+ # Add to aggregator
251
+ self.aggregator.add_f1(best)
252
+
253
+ # Add to ResultsCollector (if available)
254
+ collector = _get_results_collector()
255
+ if collector is not None:
256
+ sample_id = data.get("sample_id", data.get("_sample_idx"))
257
+ collector.update_sample(sample_id, f1=best)
258
+
259
+ print(f"\033[93m[F1] : {best:.4f}\033[0m")
260
+ return data
261
+
262
+
263
+ class EMEvaluate(MapOperator):
264
+ """Exact Match 评估器(RECOMP 标准)
265
+
266
+ 使用 RECOMP 风格的答案标准化进行精确匹配计算。
267
+ 标准化步骤:转小写、移除标点、移除冠词、修复空白。
268
+
269
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
270
+ """
271
+
272
+ def __init__(self, config=None, **kwargs):
273
+ super().__init__(**kwargs)
274
+ self.aggregator = MetricsAggregator()
275
+ # 是否提取 "answer is" 前缀后的答案
276
+ self.extract_answer = config.get("extract_answer", False) if config else False
277
+
278
+ def _exact_match(self, pred: str, gold: str) -> int:
279
+ """计算 Exact Match(RECOMP 标准)
280
+
281
+ 使用标准化后的文本进行精确匹配。
282
+
283
+ Args:
284
+ pred: 预测答案
285
+ gold: 参考答案
286
+
287
+ Returns:
288
+ 1 如果匹配,否则 0
289
+ """
290
+ return int(normalize_answer(pred) == normalize_answer(gold))
291
+
292
+ def execute(self, data):
293
+ # Handle StopSignal - 不输出,让 CompressionRateEvaluate 最后统一输出
294
+ if isinstance(data, StopSignal):
295
+ return data
296
+
297
+ golds = data.get("references", [])
298
+ pred = data.get("generated", "")
299
+
300
+ # 可选:提取 "answer is" 后的答案
301
+ if self.extract_answer:
302
+ pred = answer_extract(pred)
303
+
304
+ best = max((self._exact_match(pred, g) for g in golds), default=0) if golds else 0
305
+
306
+ # Add to aggregator
307
+ self.aggregator.add_em(best)
308
+
309
+ print(f"\033[93m[EM] : {best}\033[0m")
310
+ return data
311
+
312
+
313
+ class RecallEvaluate(MapOperator):
314
+ """Recall评估器
315
+
316
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
317
+ """
318
+
319
+ def _get_tokens(self, text: str):
320
+ return text.lower().split()
321
+
322
+ def _recall(self, pred: str, ref: str):
323
+ r = Counter(self._get_tokens(ref))
324
+ p = Counter(self._get_tokens(pred))
325
+ if not r:
326
+ return 0.0
327
+ common = r & p
328
+ return float(sum(common.values()) / sum(r.values()))
329
+
330
+ def execute(self, data: dict):
331
+ golds = data.get("references", [])
332
+ pred = data.get("generated", "")
333
+ best = max(self._recall(pred, g) for g in golds) if golds else 0.0
334
+ print(f"\033[93m[Recall] : {best:.4f}\033[0m")
335
+ return data
336
+
337
+
338
+ class BertRecallEvaluate(MapOperator):
339
+ """BERT Recall评估器
340
+
341
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
342
+ """
343
+
344
+ def __init__(self, config=None, **kwargs):
345
+ super().__init__(**kwargs)
346
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
347
+ self.model = AutoModel.from_pretrained("bert-base-uncased")
348
+
349
+ def execute(self, data: dict):
350
+ golds = data.get("references", [])
351
+ pred = data.get("generated", "")
352
+ scores = []
353
+ for g in golds:
354
+ encs = self.tokenizer([pred, g], return_tensors="pt", padding=True)
355
+ embs = self.model(**encs).last_hidden_state.mean(dim=1).detach().numpy()
356
+ # Convert to numpy arrays explicitly for cosine_similarity
357
+ emb_pred = embs[0:1] # Shape: (1, embedding_dim)
358
+ emb_gold = embs[1:2] # Shape: (1, embedding_dim)
359
+ similarity = cosine_similarity(emb_pred, emb_gold)
360
+ scores.append(float(similarity[0][0]))
361
+ best = max(scores) if scores else 0.0
362
+ print(f"\033[93m[BertRecall] : {best:.4f}\033[0m")
363
+ return data
364
+
365
+
366
+ class RougeLEvaluate(MapOperator):
367
+ """ROUGE-L评估器
368
+
369
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
370
+ """
371
+
372
+ def __init__(self, config=None, **kwargs):
373
+ super().__init__(**kwargs)
374
+ self.rouge = Rouge()
375
+
376
+ def execute(self, data: dict):
377
+ golds = data.get("references", [])
378
+ pred = data.get("generated", "")
379
+ scores = []
380
+ for g in golds:
381
+ # rouge.get_scores returns a list with one dict
382
+ rouge_result = self.rouge.get_scores(pred, g)
383
+ if rouge_result and isinstance(rouge_result, list):
384
+ scores.append(rouge_result[0]["rouge-l"]["f"])
385
+ best = max(scores) if scores else 0.0
386
+ print(f"\033[93m[ROUGE-L] : {best:.4f}\033[0m")
387
+ return data
388
+
389
+
390
+ class BRSEvaluate(MapOperator):
391
+ """BRS评估器
392
+
393
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
394
+ """
395
+
396
+ def execute(self, data: dict):
397
+ golds = data.get("references", [])
398
+ pred = data.get("generated", "")
399
+ scores = [(len(set(pred) & set(g)) / len(set(g))) if g else 0.0 for g in golds]
400
+ best = max(scores) if scores else 0.0
401
+ print(f"\033[93m[BRS] : {best:.4f}\033[0m")
402
+ return data
403
+
404
+
405
+ class AccuracyEvaluate(MapOperator):
406
+ """准确率评估器
407
+
408
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
409
+ """
410
+
411
+ def _normalize_text(self, text: str) -> str:
412
+ """标准化文本用于比较"""
413
+ return text.lower().strip()
414
+
415
+ def execute(self, data: dict):
416
+ golds = data.get("references", [])
417
+ pred = data.get("generated", "")
418
+
419
+ if not golds or not pred:
420
+ print("\033[93m[Acc] : 0.0000\033[0m")
421
+ return data
422
+
423
+ pred_norm = self._normalize_text(pred)
424
+
425
+ # 准确率:检查预测答案是否与任一参考答案匹配(完全匹配或关键词匹配)
426
+ correct = False
427
+ for gold in golds:
428
+ gold_norm = self._normalize_text(gold)
429
+ # 检查是否有关键词匹配
430
+ gold_words = set(gold_norm.split())
431
+ pred_words = set(pred_norm.split())
432
+ # 如果预测答案包含参考答案中的重要词汇,认为是正确的
433
+ if gold_words and len(gold_words & pred_words) / len(gold_words) >= 0.3:
434
+ correct = True
435
+ break
436
+
437
+ print(f"\033[93m[Acc] : {float(correct):.4f}\033[0m")
438
+ return data
439
+
440
+
441
+ class TokenCountEvaluate(MapOperator):
442
+ """Token计数评估器
443
+
444
+ 统计送入生成器的最终prompt的token数量(使用真实tokenizer)
445
+ 优先级:compressed_context(压缩后)> refining_results > retrieval_results(原始)
446
+
447
+ 输入数据格式:{"query": str, "compressed_context": str, "refining_results": List[str], ...} 或
448
+ {"query": str, "retrieval_results": List[Dict], ...}
449
+ """
450
+
451
+ def __init__(self, config=None, **kwargs):
452
+ super().__init__(**kwargs)
453
+ self.aggregator = MetricsAggregator()
454
+ # 使用与REFORM相同的tokenizer以保持一致性
455
+ try:
456
+ from transformers import AutoTokenizer
457
+
458
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
459
+ except Exception:
460
+ self.tokenizer = None
461
+
462
+ def execute(self, data):
463
+ # Handle StopSignal
464
+ if isinstance(data, StopSignal):
465
+ return data
466
+
467
+ # 优先使用 compressed_context(最终送入生成器的文本)
468
+ context = data.get("compressed_context")
469
+ if context:
470
+ # 使用真实tokenizer计算token数
471
+ if self.tokenizer:
472
+ total_tokens = len(self.tokenizer.encode(context))
473
+ else:
474
+ total_tokens = len(context.split())
475
+ else:
476
+ # 回退到旧的计算方式
477
+ docs = data.get("refining_results") or data.get("retrieval_results", [])
478
+ total_tokens = 0
479
+ if docs:
480
+ for doc in docs:
481
+ if isinstance(doc, dict):
482
+ text = doc.get("text", str(doc))
483
+ elif isinstance(doc, str):
484
+ text = doc
485
+ else:
486
+ text = str(doc)
487
+
488
+ if self.tokenizer:
489
+ total_tokens += len(self.tokenizer.encode(text))
490
+ else:
491
+ total_tokens += len(text.split())
492
+
493
+ # Add to aggregator
494
+ self.aggregator.add_token_count(total_tokens)
495
+
496
+ # Add to ResultsCollector (if available)
497
+ collector = _get_results_collector()
498
+ if collector is not None:
499
+ sample_id = data.get("sample_id", data.get("_sample_idx"))
500
+ collector.update_sample(sample_id, token_count=total_tokens)
501
+
502
+ print(f"\033[93m[Token Count] : {total_tokens}\033[0m")
503
+ return data
504
+
505
+
506
+ class LatencyEvaluate(MapOperator):
507
+ """延迟评估器
508
+
509
+ 输入数据格式:
510
+ {"query": str, "retrieve_time": float, "refine_time": float,
511
+ "generate_time": float, ...}
512
+ """
513
+
514
+ def __init__(self, config=None, **kwargs):
515
+ super().__init__(**kwargs)
516
+ self.aggregator = MetricsAggregator()
517
+
518
+ def execute(self, data):
519
+ # Handle StopSignal - 不输出,让 CompressionRateEvaluate 最后统一输出
520
+ if isinstance(data, StopSignal):
521
+ return data
522
+
523
+ retrieve_time = data.get("retrieve_time", 0)
524
+ refine_time = data.get("refine_time", 0.0)
525
+ generate_time = data.get("generate_time", 0.0)
526
+ total_lat = retrieve_time + refine_time + generate_time
527
+
528
+ # Add to aggregator
529
+ self.aggregator.add_latency(retrieve_time, refine_time, generate_time)
530
+
531
+ # Add to ResultsCollector (if available)
532
+ collector = _get_results_collector()
533
+ if collector is not None:
534
+ sample_id = data.get("sample_id", data.get("_sample_idx"))
535
+ collector.update_sample(
536
+ sample_id,
537
+ retrieve_time=retrieve_time,
538
+ refine_time=refine_time,
539
+ generate_time=generate_time,
540
+ total_time=total_lat,
541
+ )
542
+
543
+ print(f"\033[93m[Retrieve Time] : {retrieve_time:.2f}s\033[0m")
544
+ print(f"\033[93m[Refine Time] : {refine_time:.2f}s\033[0m")
545
+ print(f"\033[93m[Generate Time] : {generate_time:.2f}s\033[0m")
546
+ print(f"\033[93m[Total Latency] : {total_lat:.2f}s\033[0m")
547
+ return data
548
+
549
+
550
+ class ContextRecallEvaluate(MapOperator):
551
+ """上下文召回率评估器
552
+
553
+ 输入数据格式:{"query": str, "results": List[Any], "generated": str, "references": List[str]}
554
+ """
555
+
556
+ def _normalize_text(self, text: str) -> str:
557
+ """标准化文本用于比较"""
558
+ return text.lower().strip()
559
+
560
+ def execute(self, data: dict):
561
+ golds = data.get("references", [])
562
+ pred = data.get("generated", "")
563
+
564
+ if not golds or not pred:
565
+ print("\033[93m[Context Recall] : 0.0000\033[0m")
566
+ return data
567
+
568
+ pred_norm = self._normalize_text(pred)
569
+ pred_words = set(pred_norm.split())
570
+
571
+ # 计算有多少参考答案的关键词在生成答案中被提及
572
+ total_recall = 0.0
573
+ for gold in golds:
574
+ gold_norm = self._normalize_text(gold)
575
+ gold_words = set(gold_norm.split())
576
+ if gold_words:
577
+ # 计算当前参考答案的recall
578
+ matched_words = len(gold_words & pred_words)
579
+ recall = matched_words / len(gold_words)
580
+ total_recall = max(total_recall, recall) # 取最大值
581
+
582
+ print(f"\033[93m[Context Recall] : {total_recall:.4f}\033[0m")
583
+ return data
584
+
585
+
586
+ class CompressionRateEvaluate(MapOperator):
587
+ """计算文档压缩率
588
+
589
+ 压缩率 = 原始文档token数 / 压缩后文档token数
590
+
591
+ 输入数据格式:
592
+ {"query": str, "retrieval_results": List[Dict],
593
+ "refining_results": List[str], ...}
594
+
595
+ Args:
596
+ retrieval_results: 原始检索的文档(用于计算原始token数)
597
+ refining_results: 压缩后的文档文本(用于计算压缩后token数)
598
+ """
599
+
600
+ def __init__(self, config=None, **kwargs):
601
+ super().__init__(**kwargs)
602
+ self.aggregator = MetricsAggregator()
603
+
604
+ def _count_tokens(self, docs):
605
+ """计算文档列表的总token数"""
606
+ if not docs:
607
+ return 0
608
+ # 处理不同格式的文档
609
+ total = 0
610
+ for doc in docs:
611
+ if isinstance(doc, dict):
612
+ # Dict格式:提取text字段
613
+ text = doc.get("text", doc.get("content", str(doc)))
614
+ total += len(text.split())
615
+ elif isinstance(doc, str):
616
+ # 字符串格式
617
+ total += len(doc.split())
618
+ else:
619
+ total += len(str(doc).split())
620
+ return total
621
+
622
+ def execute(self, data):
623
+ # Handle StopSignal - 在最后输出完整汇总统计
624
+ if isinstance(data, StopSignal):
625
+ print("\n") # 添加空行分隔
626
+ self.aggregator.print_summary()
627
+ return data
628
+
629
+ # 获取原始检索文档的token数
630
+ retrieved_docs = data.get("retrieval_results", [])
631
+ retrieved_tokens = self._count_tokens(retrieved_docs)
632
+
633
+ # 获取压缩后文档的token数
634
+ refined_docs = data.get("refining_results", [])
635
+ refined_tokens = self._count_tokens(refined_docs)
636
+
637
+ # 计算压缩率
638
+ if refined_tokens > 0 and retrieved_tokens > 0:
639
+ compression_rate = retrieved_tokens / refined_tokens
640
+ else:
641
+ compression_rate = 0.0
642
+
643
+ # Add to aggregator
644
+ self.aggregator.add_compression_rate(compression_rate)
645
+
646
+ # Add to ResultsCollector (if available)
647
+ collector = _get_results_collector()
648
+ if collector is not None:
649
+ sample_id = data.get("sample_id", data.get("_sample_idx"))
650
+ collector.update_sample(
651
+ sample_id,
652
+ compression_rate=compression_rate,
653
+ original_tokens=retrieved_tokens,
654
+ compressed_tokens=refined_tokens,
655
+ )
656
+
657
+ print(f"\033[93m[Compression Rate] : {compression_rate:.2f}×\033[0m")
658
+ return data