ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/evaluate/agent_simulation/__init__.py +87 -0
- ragbits/evaluate/agent_simulation/context.py +118 -0
- ragbits/evaluate/agent_simulation/conversation.py +333 -0
- ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
- ragbits/evaluate/agent_simulation/logger.py +165 -0
- ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
- ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
- ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
- ragbits/evaluate/agent_simulation/models.py +37 -0
- ragbits/evaluate/agent_simulation/results.py +200 -0
- ragbits/evaluate/agent_simulation/scenarios.py +129 -0
- ragbits/evaluate/agent_simulation/simulation.py +243 -0
- ragbits/evaluate/cli.py +150 -0
- ragbits/evaluate/config.py +11 -0
- ragbits/evaluate/dataloaders/__init__.py +3 -0
- ragbits/evaluate/dataloaders/base.py +95 -0
- ragbits/evaluate/dataloaders/document_search.py +61 -0
- ragbits/evaluate/dataloaders/exceptions.py +25 -0
- ragbits/evaluate/dataloaders/gaia.py +78 -0
- ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
- ragbits/evaluate/dataloaders/human_eval.py +70 -0
- ragbits/evaluate/dataloaders/question_answer.py +56 -0
- ragbits/evaluate/dataset_generator/pipeline.py +4 -4
- ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
- ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
- ragbits/evaluate/evaluator.py +178 -50
- ragbits/evaluate/factories/__init__.py +42 -0
- ragbits/evaluate/metrics/__init__.py +2 -23
- ragbits/evaluate/metrics/base.py +40 -17
- ragbits/evaluate/metrics/document_search.py +40 -23
- ragbits/evaluate/metrics/gaia.py +84 -0
- ragbits/evaluate/metrics/hotpot_qa.py +51 -0
- ragbits/evaluate/metrics/human_eval.py +105 -0
- ragbits/evaluate/metrics/question_answer.py +222 -0
- ragbits/evaluate/optimizer.py +138 -86
- ragbits/evaluate/pipelines/__init__.py +37 -0
- ragbits/evaluate/pipelines/base.py +34 -10
- ragbits/evaluate/pipelines/document_search.py +72 -67
- ragbits/evaluate/pipelines/gaia.py +249 -0
- ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
- ragbits/evaluate/pipelines/human_eval.py +323 -0
- ragbits/evaluate/pipelines/question_answer.py +96 -0
- ragbits/evaluate/utils.py +86 -59
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
- ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
- ragbits/evaluate/callbacks/base.py +0 -22
- ragbits/evaluate/callbacks/neptune.py +0 -26
- ragbits/evaluate/loaders/__init__.py +0 -21
- ragbits/evaluate/loaders/base.py +0 -24
- ragbits/evaluate/loaders/hf.py +0 -25
- ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
- /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Callable, Iterable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import cast
|
|
8
|
+
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from ragbits.agents import Agent
|
|
12
|
+
from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata, Usage
|
|
13
|
+
from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GaiaData(EvaluationData):
|
|
17
|
+
"""
|
|
18
|
+
Represents a single GAIA task.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
task_id: str
|
|
22
|
+
question: str
|
|
23
|
+
level: int
|
|
24
|
+
reference_answer: str
|
|
25
|
+
file_name: str | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class GaiaResult(EvaluationResult):
|
|
30
|
+
"""
|
|
31
|
+
Represents the result of evaluating a single GAIA task.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
task_id: str
|
|
35
|
+
level: int
|
|
36
|
+
question: str
|
|
37
|
+
reference_answer: str
|
|
38
|
+
predicted_result: str
|
|
39
|
+
task_success: bool
|
|
40
|
+
tool_triggered: bool
|
|
41
|
+
num_tool_calls: int
|
|
42
|
+
tool_error_count: int
|
|
43
|
+
total_latency_ms: int
|
|
44
|
+
tool_names: list[str] | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GaiaPipeline(
|
|
48
|
+
EvaluationPipeline[Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT], GaiaData, GaiaResult]
|
|
49
|
+
):
|
|
50
|
+
"""GAIA evaluation pipeline for question answering models/agents."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
evaluation_target: Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT],
|
|
55
|
+
*,
|
|
56
|
+
system_prompt: str | None = None,
|
|
57
|
+
per_example_log_file: Path | None = None,
|
|
58
|
+
extended_logs: bool = False,
|
|
59
|
+
parse_answer_fn: Callable[[str], str] | None = None,
|
|
60
|
+
) -> None:
|
|
61
|
+
super().__init__(evaluation_target=evaluation_target)
|
|
62
|
+
self.system_prompt = system_prompt
|
|
63
|
+
self.per_example_log_file = per_example_log_file
|
|
64
|
+
self.extended_logs = extended_logs
|
|
65
|
+
self.parse_answer_fn = parse_answer_fn
|
|
66
|
+
self._init_log_file()
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def from_config(cls, config: dict) -> Self:
|
|
70
|
+
"""Create pipeline from config.
|
|
71
|
+
Attempts Agent first, falls back to raw LLM construction.
|
|
72
|
+
"""
|
|
73
|
+
if "evaluation_target" not in config:
|
|
74
|
+
try:
|
|
75
|
+
config["evaluation_target"] = Agent.from_config(config)
|
|
76
|
+
except Exception:
|
|
77
|
+
config["evaluation_target"] = LLM.from_config(config)
|
|
78
|
+
return super().from_config(config)
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def _count_tool_errors(tool_calls: list) -> int:
|
|
82
|
+
"""Count tool errors from tool_calls."""
|
|
83
|
+
tool_error_count = 0
|
|
84
|
+
for call in tool_calls:
|
|
85
|
+
try:
|
|
86
|
+
if isinstance(call.result, dict) and "error" in call.result:
|
|
87
|
+
tool_error_count += 1
|
|
88
|
+
except Exception as exc:
|
|
89
|
+
logging.getLogger(__name__).debug("Error while parsing tool call result: %s", exc)
|
|
90
|
+
return tool_error_count
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def _extract_tool_names(tool_calls: list) -> list[str] | None:
|
|
94
|
+
"""Extract tool names from tool_calls."""
|
|
95
|
+
if not tool_calls:
|
|
96
|
+
return None
|
|
97
|
+
tool_names = []
|
|
98
|
+
for call in tool_calls:
|
|
99
|
+
try:
|
|
100
|
+
name = getattr(call, "name", None)
|
|
101
|
+
if name is None and isinstance(call, dict):
|
|
102
|
+
name = call.get("name")
|
|
103
|
+
if name:
|
|
104
|
+
tool_names.append(str(name))
|
|
105
|
+
except Exception as exc:
|
|
106
|
+
logging.getLogger(__name__).debug("Tool name extraction error: %s", exc)
|
|
107
|
+
return tool_names
|
|
108
|
+
|
|
109
|
+
async def __call__(self, data: Iterable[GaiaData]) -> Iterable[GaiaResult]:
|
|
110
|
+
"""Generate answer completions per task and evaluate them.
|
|
111
|
+
Returns list of `GaiaResult`, one per input task.
|
|
112
|
+
"""
|
|
113
|
+
results: list[GaiaResult] = []
|
|
114
|
+
|
|
115
|
+
for row in data:
|
|
116
|
+
start_time = time.perf_counter()
|
|
117
|
+
|
|
118
|
+
prompt_input = row.question
|
|
119
|
+
debug_traces: list[dict | None] | None = [] if self.extended_logs else None
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
if self.extended_logs:
|
|
123
|
+
content, dbg = await self._generate_with_debug(prompt_input)
|
|
124
|
+
if debug_traces is not None:
|
|
125
|
+
debug_traces.append(dbg)
|
|
126
|
+
tool_calls = cast(list, (dbg or {}).get("tool_calls") or [])
|
|
127
|
+
else:
|
|
128
|
+
content, _, tool_calls = await self._generate_answer(prompt_input)
|
|
129
|
+
|
|
130
|
+
except Exception as generation_exc:
|
|
131
|
+
content = ""
|
|
132
|
+
tool_calls = []
|
|
133
|
+
err_msg = f"GenerationError: {generation_exc.__class__.__name__}: {generation_exc}"
|
|
134
|
+
if self.extended_logs and debug_traces is not None:
|
|
135
|
+
debug_traces.append({"error": err_msg})
|
|
136
|
+
|
|
137
|
+
end_time = time.perf_counter()
|
|
138
|
+
|
|
139
|
+
# Compute metrics
|
|
140
|
+
predicted_raw = str(content).strip()
|
|
141
|
+
predicted = self._parse_answer(predicted_raw)
|
|
142
|
+
reference = (row.reference_answer or "").strip()
|
|
143
|
+
task_success = self._normalize(predicted) == self._normalize(reference) if reference else False
|
|
144
|
+
|
|
145
|
+
tool_triggered = bool(tool_calls)
|
|
146
|
+
num_tool_calls = len(tool_calls)
|
|
147
|
+
tool_error_count = GaiaPipeline._count_tool_errors(tool_calls)
|
|
148
|
+
tool_names = GaiaPipeline._extract_tool_names(tool_calls)
|
|
149
|
+
total_latency_ms = int((end_time - start_time) * 1000)
|
|
150
|
+
|
|
151
|
+
result = GaiaResult(
|
|
152
|
+
task_id=row.task_id,
|
|
153
|
+
level=row.level,
|
|
154
|
+
question=row.question,
|
|
155
|
+
reference_answer=row.reference_answer,
|
|
156
|
+
predicted_result=content,
|
|
157
|
+
task_success=task_success,
|
|
158
|
+
tool_triggered=tool_triggered,
|
|
159
|
+
num_tool_calls=num_tool_calls,
|
|
160
|
+
tool_error_count=tool_error_count,
|
|
161
|
+
total_latency_ms=total_latency_ms,
|
|
162
|
+
tool_names=tool_names,
|
|
163
|
+
)
|
|
164
|
+
results.append(result)
|
|
165
|
+
ext_log_str = (
|
|
166
|
+
json.dumps(debug_traces, ensure_ascii=False, default=str)
|
|
167
|
+
if (self.extended_logs and debug_traces is not None)
|
|
168
|
+
else None
|
|
169
|
+
)
|
|
170
|
+
self._log_example(row, result, ext_log_str)
|
|
171
|
+
|
|
172
|
+
return results
|
|
173
|
+
|
|
174
|
+
def _init_log_file(self) -> None:
|
|
175
|
+
"""Ensure the per-example log file exists if logging is enabled."""
|
|
176
|
+
if self.per_example_log_file is None:
|
|
177
|
+
return
|
|
178
|
+
self.per_example_log_file.parent.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
with open(self.per_example_log_file, "w", encoding="utf-8") as _:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
def _log_example(self, row: GaiaData, result: GaiaResult, extended_log: str | None = None) -> None:
|
|
183
|
+
"""Append a single NDJSON record for debugging if enabled."""
|
|
184
|
+
if self.per_example_log_file is None:
|
|
185
|
+
return
|
|
186
|
+
# per-task tool frequency map from tool names
|
|
187
|
+
tool_frequency_usage: dict[str, int] = {}
|
|
188
|
+
if result.tool_names:
|
|
189
|
+
for name in result.tool_names:
|
|
190
|
+
tool_frequency_usage[name] = tool_frequency_usage.get(name, 0) + 1
|
|
191
|
+
record: dict[str, object] = {
|
|
192
|
+
"task_id": row.task_id,
|
|
193
|
+
"level": row.level,
|
|
194
|
+
"question": row.question,
|
|
195
|
+
"predicted": str(result.predicted_result),
|
|
196
|
+
"predicted_extracted": self._parse_answer(str(result.predicted_result)),
|
|
197
|
+
"reference": row.reference_answer,
|
|
198
|
+
"task_success": result.task_success,
|
|
199
|
+
"tool_triggered": result.tool_triggered,
|
|
200
|
+
"num_tool_calls": result.num_tool_calls,
|
|
201
|
+
"tool_error_count": result.tool_error_count,
|
|
202
|
+
"total_latency_ms": result.total_latency_ms,
|
|
203
|
+
"tool_frequency_usage": tool_frequency_usage,
|
|
204
|
+
}
|
|
205
|
+
record["extended_debug_logging"] = extended_log or "[]"
|
|
206
|
+
with open(self.per_example_log_file, "a", encoding="utf-8") as f:
|
|
207
|
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
208
|
+
|
|
209
|
+
async def _generate_answer(self, prompt: str) -> tuple[str, Usage, list]:
|
|
210
|
+
"""Generate final answer from Agent or raw LLM and capture usage and tool calls."""
|
|
211
|
+
target = self.evaluation_target
|
|
212
|
+
if isinstance(target, Agent):
|
|
213
|
+
res = await target.run(prompt)
|
|
214
|
+
return str(res.content), res.usage, (res.tool_calls or [])
|
|
215
|
+
|
|
216
|
+
resp = cast(LLMResponseWithMetadata[str], await target.generate_with_metadata(prompt))
|
|
217
|
+
return str(resp.content), (resp.usage or Usage()), []
|
|
218
|
+
|
|
219
|
+
async def _generate_with_debug(self, prompt: str) -> tuple[str, dict | None]:
|
|
220
|
+
"""Generate answer and capture tool/history/usage for logging (as raw content)."""
|
|
221
|
+
target = self.evaluation_target
|
|
222
|
+
if isinstance(target, Agent):
|
|
223
|
+
res = await target.run(prompt)
|
|
224
|
+
dbg = {
|
|
225
|
+
"history": res.history,
|
|
226
|
+
"tool_calls": res.tool_calls,
|
|
227
|
+
"usage": res.usage,
|
|
228
|
+
"metadata": res.metadata,
|
|
229
|
+
}
|
|
230
|
+
return str(res.content), dbg
|
|
231
|
+
resp = await target.generate(prompt)
|
|
232
|
+
return str(resp), None
|
|
233
|
+
|
|
234
|
+
@staticmethod
|
|
235
|
+
def _normalize(text: str) -> str:
|
|
236
|
+
"""Basic normalization for answer equality checks: lowercase, strip spaces."""
|
|
237
|
+
return "".join(ch.lower() for ch in text.strip() if not ch.isspace())
|
|
238
|
+
|
|
239
|
+
def _parse_answer(self, text: str) -> str:
|
|
240
|
+
"""Optionally parse final answer from text using provided function.
|
|
241
|
+
If no parser provided, returns the original text.
|
|
242
|
+
"""
|
|
243
|
+
if self.parse_answer_fn is None:
|
|
244
|
+
return text
|
|
245
|
+
try:
|
|
246
|
+
return self.parse_answer_fn(text)
|
|
247
|
+
except Exception as exc:
|
|
248
|
+
logging.getLogger(__name__).debug("Answer parse error: %s", exc)
|
|
249
|
+
return text
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Callable, Iterable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from ragbits.agents import Agent
|
|
10
|
+
from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata, Usage
|
|
11
|
+
from ragbits.document_search import DocumentSearch
|
|
12
|
+
from ragbits.document_search.documents.document import DocumentMeta
|
|
13
|
+
from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HotpotQAData(EvaluationData):
|
|
17
|
+
"""
|
|
18
|
+
Represents a single HotpotQA example.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
id: str
|
|
22
|
+
question: str
|
|
23
|
+
reference_answer: str
|
|
24
|
+
qtype: str
|
|
25
|
+
level: str
|
|
26
|
+
reference_context: list[str]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class HotpotQAResult(EvaluationResult):
|
|
31
|
+
"""
|
|
32
|
+
Represents the result of evaluating a single HotpotQA example.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
id: str
|
|
36
|
+
predicted_result: str
|
|
37
|
+
reference_answer: str
|
|
38
|
+
question: str
|
|
39
|
+
qtype: str
|
|
40
|
+
level: str
|
|
41
|
+
predicted_parsed: str | None = None
|
|
42
|
+
reference_normalized: str | None = None
|
|
43
|
+
em_value: float = 0.0
|
|
44
|
+
f1_value: float = 0.0
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class HotpotQAPipeline(
|
|
48
|
+
EvaluationPipeline[
|
|
49
|
+
Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT],
|
|
50
|
+
HotpotQAData,
|
|
51
|
+
HotpotQAResult,
|
|
52
|
+
]
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
HotpotQA evaluation pipeline with simple RAG ingestion per batch and multi-hop retrieval.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
evaluation_target: Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT],
|
|
61
|
+
*,
|
|
62
|
+
retriever: DocumentSearch,
|
|
63
|
+
hops: int = 3,
|
|
64
|
+
per_example_log_file: Path | None = None,
|
|
65
|
+
extended_logs: bool = False,
|
|
66
|
+
parse_answer_fn: Callable[[str], str] | None = None,
|
|
67
|
+
question_generation_prompt_fn: Callable[[str, str], str] | None = None,
|
|
68
|
+
retrieval_k: int = 3,
|
|
69
|
+
element_max_chars: int = 500,
|
|
70
|
+
hop_context_max_chars: int = 1200,
|
|
71
|
+
) -> None:
|
|
72
|
+
super().__init__(evaluation_target=evaluation_target)
|
|
73
|
+
self.retriever = retriever
|
|
74
|
+
self.hops = max(1, min(hops, 5))
|
|
75
|
+
self.per_example_log_file = per_example_log_file
|
|
76
|
+
self.extended_logs = extended_logs
|
|
77
|
+
self.parse_answer_fn = parse_answer_fn
|
|
78
|
+
self.question_generation_prompt_fn = question_generation_prompt_fn
|
|
79
|
+
self.retrieval_k = max(1, int(retrieval_k))
|
|
80
|
+
self.element_max_chars = max(50, int(element_max_chars))
|
|
81
|
+
self.hop_context_max_chars = max(100, int(hop_context_max_chars))
|
|
82
|
+
self._init_log_file()
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def from_config(cls, config: dict) -> Self:
|
|
86
|
+
"""Create pipeline from config.
|
|
87
|
+
Attempts Agent first, falls back to raw LLM construction.
|
|
88
|
+
"""
|
|
89
|
+
if "evaluation_target" not in config:
|
|
90
|
+
try:
|
|
91
|
+
config["evaluation_target"] = Agent.from_config(config)
|
|
92
|
+
except Exception:
|
|
93
|
+
config["evaluation_target"] = LLM.from_config(config)
|
|
94
|
+
config["retriever"] = DocumentSearch.from_config(config["document_search"])
|
|
95
|
+
config["hops"] = int(config.get("hops", 3))
|
|
96
|
+
config["retrieval_k"] = int(config.get("retrieval_k", 3))
|
|
97
|
+
config["element_max_chars"] = int(config.get("element_max_chars", 500))
|
|
98
|
+
config["hop_context_max_chars"] = int(config.get("hop_context_max_chars", 1200))
|
|
99
|
+
return super().from_config(config)
|
|
100
|
+
|
|
101
|
+
async def _ingest_documents(self, data: Iterable[HotpotQAData]) -> None:
|
|
102
|
+
"""Ingest all documents from the data."""
|
|
103
|
+
documents: list[DocumentMeta] = []
|
|
104
|
+
for row in data:
|
|
105
|
+
for content in row.reference_context:
|
|
106
|
+
documents.append(DocumentMeta.from_literal(content))
|
|
107
|
+
await self.retriever.ingest(documents)
|
|
108
|
+
|
|
109
|
+
async def _perform_multihop_retrieval(self, example: HotpotQAData) -> tuple[str, list[dict]]:
|
|
110
|
+
"""Perform multi-hop retrieval and return accumulated context and hop logs."""
|
|
111
|
+
accumulated_context: list[str] = []
|
|
112
|
+
hop_logs: list[dict] = []
|
|
113
|
+
last_query = example.question
|
|
114
|
+
for hop_idx in range(self.hops):
|
|
115
|
+
elements = await self.retriever.search(last_query)
|
|
116
|
+
text_parts: list[str] = []
|
|
117
|
+
consumed = 0
|
|
118
|
+
for element in elements[: self.retrieval_k]:
|
|
119
|
+
content = getattr(element, "content", "") if hasattr(element, "content") else ""
|
|
120
|
+
if isinstance(content, str) and content:
|
|
121
|
+
snippet = content.strip().replace("\n\n", "\n")[: self.element_max_chars]
|
|
122
|
+
budget = max(0, self.hop_context_max_chars - consumed)
|
|
123
|
+
take = snippet[:budget]
|
|
124
|
+
if take:
|
|
125
|
+
text_parts.append(take)
|
|
126
|
+
consumed += len(take)
|
|
127
|
+
if consumed >= self.hop_context_max_chars:
|
|
128
|
+
break
|
|
129
|
+
hop_text = "\n\n".join(text_parts)
|
|
130
|
+
hop_logs.append({"hop": hop_idx + 1, "question": last_query, "retrieved": hop_text})
|
|
131
|
+
if hop_text:
|
|
132
|
+
accumulated_context.append(hop_text)
|
|
133
|
+
# generate a new question for the next hop
|
|
134
|
+
if hop_idx < self.hops - 1:
|
|
135
|
+
last_query = await self._generate_next_question(
|
|
136
|
+
original_question=example.question,
|
|
137
|
+
accumulated_context="\n\n".join(accumulated_context),
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
break
|
|
141
|
+
return "\n\n".join(accumulated_context), hop_logs
|
|
142
|
+
|
|
143
|
+
async def _answer_with_retrieval(self, example: HotpotQAData) -> tuple[str, Usage, list, list[dict], dict]:
|
|
144
|
+
"""Answer a question with multi-hop retrieval."""
|
|
145
|
+
full_context, hop_logs = await self._perform_multihop_retrieval(example)
|
|
146
|
+
prompt_input = example.question if not full_context else f"{example.question}\n\nContext:\n{full_context}"
|
|
147
|
+
|
|
148
|
+
if self.extended_logs:
|
|
149
|
+
content, dbg = await self._generate_with_debug(prompt_input)
|
|
150
|
+
usage = cast(Usage, (dbg or {}).get("usage") or Usage())
|
|
151
|
+
tool_calls = cast(list, (dbg or {}).get("tool_calls") or [])
|
|
152
|
+
metadata = cast(dict, (dbg or {}).get("metadata") or {})
|
|
153
|
+
else:
|
|
154
|
+
content, usage, tool_calls = await self._generate_answer(prompt_input)
|
|
155
|
+
metadata = {}
|
|
156
|
+
|
|
157
|
+
return str(content), usage, tool_calls, hop_logs, metadata
|
|
158
|
+
|
|
159
|
+
async def __call__(self, data: Iterable[HotpotQAData]) -> Iterable[HotpotQAResult]:
|
|
160
|
+
"""Ingest contexts, perform multi-hop retrieval, and answer HotpotQA questions."""
|
|
161
|
+
data_list = list(data)
|
|
162
|
+
await self._ingest_documents(data_list)
|
|
163
|
+
|
|
164
|
+
results: list[HotpotQAResult] = []
|
|
165
|
+
for row in data_list:
|
|
166
|
+
try:
|
|
167
|
+
predicted_text, usage, tool_calls, hop_logs, metadata = await self._answer_with_retrieval(row)
|
|
168
|
+
except Exception:
|
|
169
|
+
predicted_text = ""
|
|
170
|
+
usage = Usage()
|
|
171
|
+
tool_calls = []
|
|
172
|
+
hop_logs = []
|
|
173
|
+
metadata = {}
|
|
174
|
+
predicted_extracted = self._parse_answer(predicted_text)
|
|
175
|
+
ref_norm = self._normalize(row.reference_answer)
|
|
176
|
+
|
|
177
|
+
# Compute normalized fields and sample metrics once
|
|
178
|
+
em = 1.0 if self._normalize(predicted_extracted) == ref_norm else 0.0
|
|
179
|
+
f1 = self._f1(self._normalize(predicted_extracted), ref_norm)
|
|
180
|
+
|
|
181
|
+
result = HotpotQAResult(
|
|
182
|
+
id=row.id,
|
|
183
|
+
predicted_result=predicted_text,
|
|
184
|
+
reference_answer=row.reference_answer,
|
|
185
|
+
question=row.question,
|
|
186
|
+
qtype=row.qtype,
|
|
187
|
+
level=row.level,
|
|
188
|
+
predicted_parsed=self._normalize(predicted_extracted),
|
|
189
|
+
reference_normalized=ref_norm,
|
|
190
|
+
em_value=float(em),
|
|
191
|
+
f1_value=float(f1),
|
|
192
|
+
)
|
|
193
|
+
results.append(result)
|
|
194
|
+
|
|
195
|
+
ext_log_str = None
|
|
196
|
+
if self.extended_logs:
|
|
197
|
+
ext_log_str = json.dumps(
|
|
198
|
+
[
|
|
199
|
+
{
|
|
200
|
+
"usage": usage,
|
|
201
|
+
"tool_calls": tool_calls,
|
|
202
|
+
"hops": hop_logs,
|
|
203
|
+
"metadata": metadata,
|
|
204
|
+
}
|
|
205
|
+
],
|
|
206
|
+
ensure_ascii=False,
|
|
207
|
+
default=str,
|
|
208
|
+
)
|
|
209
|
+
self._log_example(
|
|
210
|
+
row=row,
|
|
211
|
+
predicted_text=predicted_text,
|
|
212
|
+
predicted_extracted=predicted_extracted,
|
|
213
|
+
em=result.em_value,
|
|
214
|
+
f1=result.f1_value,
|
|
215
|
+
extended_log=ext_log_str,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return results
|
|
219
|
+
|
|
220
|
+
def _init_log_file(self) -> None:
|
|
221
|
+
"""Ensure the per-example log file exists if logging is enabled."""
|
|
222
|
+
if self.per_example_log_file is None:
|
|
223
|
+
return
|
|
224
|
+
self.per_example_log_file.parent.mkdir(parents=True, exist_ok=True)
|
|
225
|
+
with open(self.per_example_log_file, "w", encoding="utf-8") as _:
|
|
226
|
+
pass
|
|
227
|
+
|
|
228
|
+
def _log_example(
|
|
229
|
+
self,
|
|
230
|
+
*,
|
|
231
|
+
row: HotpotQAData,
|
|
232
|
+
predicted_text: str,
|
|
233
|
+
predicted_extracted: str,
|
|
234
|
+
em: float,
|
|
235
|
+
f1: float,
|
|
236
|
+
extended_log: str | None = None,
|
|
237
|
+
) -> None:
|
|
238
|
+
"""Append a single NDJSON record for debugging if enabled."""
|
|
239
|
+
if self.per_example_log_file is None:
|
|
240
|
+
return
|
|
241
|
+
record: dict[str, object] = {
|
|
242
|
+
"id": row.id,
|
|
243
|
+
"question": row.question,
|
|
244
|
+
"reference": row.reference_answer,
|
|
245
|
+
"predicted": predicted_text,
|
|
246
|
+
"predicted_extracted": predicted_extracted,
|
|
247
|
+
"type": row.qtype,
|
|
248
|
+
"level": row.level,
|
|
249
|
+
"em": float(em),
|
|
250
|
+
"f1": float(f1),
|
|
251
|
+
}
|
|
252
|
+
record["extended_debug_logging"] = extended_log or "[]"
|
|
253
|
+
with open(self.per_example_log_file, "a", encoding="utf-8") as file:
|
|
254
|
+
file.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
255
|
+
|
|
256
|
+
def _parse_answer(self, text: str) -> str:
|
|
257
|
+
"""Optionally parse final answer from text using provided function.
|
|
258
|
+
If no parser provided, returns the original text.
|
|
259
|
+
"""
|
|
260
|
+
if self.parse_answer_fn is None:
|
|
261
|
+
return text
|
|
262
|
+
try:
|
|
263
|
+
return self.parse_answer_fn(text)
|
|
264
|
+
except Exception as exc:
|
|
265
|
+
import logging as _logging
|
|
266
|
+
|
|
267
|
+
_logging.getLogger(__name__).debug("Answer parse error: %s", exc)
|
|
268
|
+
return text
|
|
269
|
+
|
|
270
|
+
async def _generate_answer(self, prompt: str) -> tuple[str, Usage, list]:
|
|
271
|
+
"""Generate final answer from Agent or raw LLM and capture usage and tool calls."""
|
|
272
|
+
target = self.evaluation_target
|
|
273
|
+
if isinstance(target, Agent):
|
|
274
|
+
res = await target.run(prompt)
|
|
275
|
+
return str(res.content), res.usage, (res.tool_calls or [])
|
|
276
|
+
|
|
277
|
+
resp = cast(LLMResponseWithMetadata[str], await target.generate_with_metadata(prompt))
|
|
278
|
+
return str(resp.content), (resp.usage or Usage()), []
|
|
279
|
+
|
|
280
|
+
async def _generate_with_debug(self, prompt: str) -> tuple[str, dict | None]:
|
|
281
|
+
"""Generate answer and capture tool/history/usage for logging (as raw content)."""
|
|
282
|
+
target = self.evaluation_target
|
|
283
|
+
if isinstance(target, Agent):
|
|
284
|
+
res = await target.run(prompt)
|
|
285
|
+
dbg = {
|
|
286
|
+
"history": res.history,
|
|
287
|
+
"tool_calls": res.tool_calls,
|
|
288
|
+
"usage": res.usage,
|
|
289
|
+
"metadata": res.metadata,
|
|
290
|
+
}
|
|
291
|
+
return str(res.content), dbg
|
|
292
|
+
resp = await target.generate(prompt)
|
|
293
|
+
return str(resp), None
|
|
294
|
+
|
|
295
|
+
async def _generate_next_question(self, original_question: str, accumulated_context: str) -> str:
|
|
296
|
+
"""Generate a new follow-up question based on the original question and accumulated context."""
|
|
297
|
+
if self.question_generation_prompt_fn is None:
|
|
298
|
+
# default: simple concatenation
|
|
299
|
+
return f"{original_question}\n\nContext so far:\n{accumulated_context}"
|
|
300
|
+
|
|
301
|
+
question_generation_prompt = self.question_generation_prompt_fn(original_question, accumulated_context)
|
|
302
|
+
|
|
303
|
+
target = self.evaluation_target
|
|
304
|
+
if isinstance(target, Agent):
|
|
305
|
+
resp = await target.llm.generate(question_generation_prompt)
|
|
306
|
+
return str(resp).strip()
|
|
307
|
+
resp = await target.generate(question_generation_prompt)
|
|
308
|
+
return str(resp).strip()
|
|
309
|
+
|
|
310
|
+
@staticmethod
|
|
311
|
+
def _normalize(text: str) -> str:
|
|
312
|
+
"""Basic normalization for answer equality checks: lowercase, strip spaces."""
|
|
313
|
+
return "".join(ch.lower() for ch in (text or "").strip() if not ch.isspace())
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
def _f1(prediction: str, ground_truth: str) -> float:
|
|
317
|
+
import re as _re
|
|
318
|
+
from collections import Counter as _Counter
|
|
319
|
+
|
|
320
|
+
def tokens(value: str) -> list[str]:
|
|
321
|
+
value = (value or "").lower()
|
|
322
|
+
value = _re.sub(r"[^a-z0-9\s]", " ", value)
|
|
323
|
+
value = _re.sub(r"\b(a|an|the)\b", " ", value)
|
|
324
|
+
value = _re.sub(r"\s+", " ", value).strip()
|
|
325
|
+
return value.split()
|
|
326
|
+
|
|
327
|
+
pred_tokens = tokens(prediction)
|
|
328
|
+
gt_tokens = tokens(ground_truth)
|
|
329
|
+
if not pred_tokens and not gt_tokens:
|
|
330
|
+
return 1.0
|
|
331
|
+
if not pred_tokens or not gt_tokens:
|
|
332
|
+
return 0.0
|
|
333
|
+
|
|
334
|
+
pred_counts = _Counter(pred_tokens)
|
|
335
|
+
gt_counts = _Counter(gt_tokens)
|
|
336
|
+
common = sum((pred_counts & gt_counts).values())
|
|
337
|
+
if common == 0:
|
|
338
|
+
return 0.0
|
|
339
|
+
|
|
340
|
+
precision = common / len(pred_tokens)
|
|
341
|
+
recall = common / len(gt_tokens)
|
|
342
|
+
return 2 * precision * recall / (precision + recall)
|