ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. ragbits/evaluate/agent_simulation/__init__.py +87 -0
  2. ragbits/evaluate/agent_simulation/context.py +118 -0
  3. ragbits/evaluate/agent_simulation/conversation.py +333 -0
  4. ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
  5. ragbits/evaluate/agent_simulation/logger.py +165 -0
  6. ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
  7. ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
  8. ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
  9. ragbits/evaluate/agent_simulation/models.py +37 -0
  10. ragbits/evaluate/agent_simulation/results.py +200 -0
  11. ragbits/evaluate/agent_simulation/scenarios.py +129 -0
  12. ragbits/evaluate/agent_simulation/simulation.py +243 -0
  13. ragbits/evaluate/cli.py +150 -0
  14. ragbits/evaluate/config.py +11 -0
  15. ragbits/evaluate/dataloaders/__init__.py +3 -0
  16. ragbits/evaluate/dataloaders/base.py +95 -0
  17. ragbits/evaluate/dataloaders/document_search.py +61 -0
  18. ragbits/evaluate/dataloaders/exceptions.py +25 -0
  19. ragbits/evaluate/dataloaders/gaia.py +78 -0
  20. ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
  21. ragbits/evaluate/dataloaders/human_eval.py +70 -0
  22. ragbits/evaluate/dataloaders/question_answer.py +56 -0
  23. ragbits/evaluate/dataset_generator/pipeline.py +4 -4
  24. ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
  25. ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
  26. ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
  27. ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
  28. ragbits/evaluate/evaluator.py +178 -50
  29. ragbits/evaluate/factories/__init__.py +42 -0
  30. ragbits/evaluate/metrics/__init__.py +2 -23
  31. ragbits/evaluate/metrics/base.py +40 -17
  32. ragbits/evaluate/metrics/document_search.py +40 -23
  33. ragbits/evaluate/metrics/gaia.py +84 -0
  34. ragbits/evaluate/metrics/hotpot_qa.py +51 -0
  35. ragbits/evaluate/metrics/human_eval.py +105 -0
  36. ragbits/evaluate/metrics/question_answer.py +222 -0
  37. ragbits/evaluate/optimizer.py +138 -86
  38. ragbits/evaluate/pipelines/__init__.py +37 -0
  39. ragbits/evaluate/pipelines/base.py +34 -10
  40. ragbits/evaluate/pipelines/document_search.py +72 -67
  41. ragbits/evaluate/pipelines/gaia.py +249 -0
  42. ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
  43. ragbits/evaluate/pipelines/human_eval.py +323 -0
  44. ragbits/evaluate/pipelines/question_answer.py +96 -0
  45. ragbits/evaluate/utils.py +86 -59
  46. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
  47. ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
  48. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
  49. ragbits/evaluate/callbacks/base.py +0 -22
  50. ragbits/evaluate/callbacks/neptune.py +0 -26
  51. ragbits/evaluate/loaders/__init__.py +0 -21
  52. ragbits/evaluate/loaders/base.py +0 -24
  53. ragbits/evaluate/loaders/hf.py +0 -25
  54. ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
  55. /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
@@ -0,0 +1,249 @@
1
+ import json
2
+ import logging
3
+ import time
4
+ from collections.abc import Callable, Iterable
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import cast
8
+
9
+ from typing_extensions import Self
10
+
11
+ from ragbits.agents import Agent
12
+ from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata, Usage
13
+ from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult
14
+
15
+
16
+ class GaiaData(EvaluationData):
17
+ """
18
+ Represents a single GAIA task.
19
+ """
20
+
21
+ task_id: str
22
+ question: str
23
+ level: int
24
+ reference_answer: str
25
+ file_name: str | None = None
26
+
27
+
28
+ @dataclass
29
+ class GaiaResult(EvaluationResult):
30
+ """
31
+ Represents the result of evaluating a single GAIA task.
32
+ """
33
+
34
+ task_id: str
35
+ level: int
36
+ question: str
37
+ reference_answer: str
38
+ predicted_result: str
39
+ task_success: bool
40
+ tool_triggered: bool
41
+ num_tool_calls: int
42
+ tool_error_count: int
43
+ total_latency_ms: int
44
+ tool_names: list[str] | None = None
45
+
46
+
47
+ class GaiaPipeline(
48
+ EvaluationPipeline[Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT], GaiaData, GaiaResult]
49
+ ):
50
+ """GAIA evaluation pipeline for question answering models/agents."""
51
+
52
+ def __init__(
53
+ self,
54
+ evaluation_target: Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT],
55
+ *,
56
+ system_prompt: str | None = None,
57
+ per_example_log_file: Path | None = None,
58
+ extended_logs: bool = False,
59
+ parse_answer_fn: Callable[[str], str] | None = None,
60
+ ) -> None:
61
+ super().__init__(evaluation_target=evaluation_target)
62
+ self.system_prompt = system_prompt
63
+ self.per_example_log_file = per_example_log_file
64
+ self.extended_logs = extended_logs
65
+ self.parse_answer_fn = parse_answer_fn
66
+ self._init_log_file()
67
+
68
+ @classmethod
69
+ def from_config(cls, config: dict) -> Self:
70
+ """Create pipeline from config.
71
+ Attempts Agent first, falls back to raw LLM construction.
72
+ """
73
+ if "evaluation_target" not in config:
74
+ try:
75
+ config["evaluation_target"] = Agent.from_config(config)
76
+ except Exception:
77
+ config["evaluation_target"] = LLM.from_config(config)
78
+ return super().from_config(config)
79
+
80
+ @staticmethod
81
+ def _count_tool_errors(tool_calls: list) -> int:
82
+ """Count tool errors from tool_calls."""
83
+ tool_error_count = 0
84
+ for call in tool_calls:
85
+ try:
86
+ if isinstance(call.result, dict) and "error" in call.result:
87
+ tool_error_count += 1
88
+ except Exception as exc:
89
+ logging.getLogger(__name__).debug("Error while parsing tool call result: %s", exc)
90
+ return tool_error_count
91
+
92
+ @staticmethod
93
+ def _extract_tool_names(tool_calls: list) -> list[str] | None:
94
+ """Extract tool names from tool_calls."""
95
+ if not tool_calls:
96
+ return None
97
+ tool_names = []
98
+ for call in tool_calls:
99
+ try:
100
+ name = getattr(call, "name", None)
101
+ if name is None and isinstance(call, dict):
102
+ name = call.get("name")
103
+ if name:
104
+ tool_names.append(str(name))
105
+ except Exception as exc:
106
+ logging.getLogger(__name__).debug("Tool name extraction error: %s", exc)
107
+ return tool_names
108
+
109
+ async def __call__(self, data: Iterable[GaiaData]) -> Iterable[GaiaResult]:
110
+ """Generate answer completions per task and evaluate them.
111
+ Returns list of `GaiaResult`, one per input task.
112
+ """
113
+ results: list[GaiaResult] = []
114
+
115
+ for row in data:
116
+ start_time = time.perf_counter()
117
+
118
+ prompt_input = row.question
119
+ debug_traces: list[dict | None] | None = [] if self.extended_logs else None
120
+
121
+ try:
122
+ if self.extended_logs:
123
+ content, dbg = await self._generate_with_debug(prompt_input)
124
+ if debug_traces is not None:
125
+ debug_traces.append(dbg)
126
+ tool_calls = cast(list, (dbg or {}).get("tool_calls") or [])
127
+ else:
128
+ content, _, tool_calls = await self._generate_answer(prompt_input)
129
+
130
+ except Exception as generation_exc:
131
+ content = ""
132
+ tool_calls = []
133
+ err_msg = f"GenerationError: {generation_exc.__class__.__name__}: {generation_exc}"
134
+ if self.extended_logs and debug_traces is not None:
135
+ debug_traces.append({"error": err_msg})
136
+
137
+ end_time = time.perf_counter()
138
+
139
+ # Compute metrics
140
+ predicted_raw = str(content).strip()
141
+ predicted = self._parse_answer(predicted_raw)
142
+ reference = (row.reference_answer or "").strip()
143
+ task_success = self._normalize(predicted) == self._normalize(reference) if reference else False
144
+
145
+ tool_triggered = bool(tool_calls)
146
+ num_tool_calls = len(tool_calls)
147
+ tool_error_count = GaiaPipeline._count_tool_errors(tool_calls)
148
+ tool_names = GaiaPipeline._extract_tool_names(tool_calls)
149
+ total_latency_ms = int((end_time - start_time) * 1000)
150
+
151
+ result = GaiaResult(
152
+ task_id=row.task_id,
153
+ level=row.level,
154
+ question=row.question,
155
+ reference_answer=row.reference_answer,
156
+ predicted_result=content,
157
+ task_success=task_success,
158
+ tool_triggered=tool_triggered,
159
+ num_tool_calls=num_tool_calls,
160
+ tool_error_count=tool_error_count,
161
+ total_latency_ms=total_latency_ms,
162
+ tool_names=tool_names,
163
+ )
164
+ results.append(result)
165
+ ext_log_str = (
166
+ json.dumps(debug_traces, ensure_ascii=False, default=str)
167
+ if (self.extended_logs and debug_traces is not None)
168
+ else None
169
+ )
170
+ self._log_example(row, result, ext_log_str)
171
+
172
+ return results
173
+
174
+ def _init_log_file(self) -> None:
175
+ """Ensure the per-example log file exists if logging is enabled."""
176
+ if self.per_example_log_file is None:
177
+ return
178
+ self.per_example_log_file.parent.mkdir(parents=True, exist_ok=True)
179
+ with open(self.per_example_log_file, "w", encoding="utf-8") as _:
180
+ pass
181
+
182
+ def _log_example(self, row: GaiaData, result: GaiaResult, extended_log: str | None = None) -> None:
183
+ """Append a single NDJSON record for debugging if enabled."""
184
+ if self.per_example_log_file is None:
185
+ return
186
+ # per-task tool frequency map from tool names
187
+ tool_frequency_usage: dict[str, int] = {}
188
+ if result.tool_names:
189
+ for name in result.tool_names:
190
+ tool_frequency_usage[name] = tool_frequency_usage.get(name, 0) + 1
191
+ record: dict[str, object] = {
192
+ "task_id": row.task_id,
193
+ "level": row.level,
194
+ "question": row.question,
195
+ "predicted": str(result.predicted_result),
196
+ "predicted_extracted": self._parse_answer(str(result.predicted_result)),
197
+ "reference": row.reference_answer,
198
+ "task_success": result.task_success,
199
+ "tool_triggered": result.tool_triggered,
200
+ "num_tool_calls": result.num_tool_calls,
201
+ "tool_error_count": result.tool_error_count,
202
+ "total_latency_ms": result.total_latency_ms,
203
+ "tool_frequency_usage": tool_frequency_usage,
204
+ }
205
+ record["extended_debug_logging"] = extended_log or "[]"
206
+ with open(self.per_example_log_file, "a", encoding="utf-8") as f:
207
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
208
+
209
+ async def _generate_answer(self, prompt: str) -> tuple[str, Usage, list]:
210
+ """Generate final answer from Agent or raw LLM and capture usage and tool calls."""
211
+ target = self.evaluation_target
212
+ if isinstance(target, Agent):
213
+ res = await target.run(prompt)
214
+ return str(res.content), res.usage, (res.tool_calls or [])
215
+
216
+ resp = cast(LLMResponseWithMetadata[str], await target.generate_with_metadata(prompt))
217
+ return str(resp.content), (resp.usage or Usage()), []
218
+
219
+ async def _generate_with_debug(self, prompt: str) -> tuple[str, dict | None]:
220
+ """Generate answer and capture tool/history/usage for logging (as raw content)."""
221
+ target = self.evaluation_target
222
+ if isinstance(target, Agent):
223
+ res = await target.run(prompt)
224
+ dbg = {
225
+ "history": res.history,
226
+ "tool_calls": res.tool_calls,
227
+ "usage": res.usage,
228
+ "metadata": res.metadata,
229
+ }
230
+ return str(res.content), dbg
231
+ resp = await target.generate(prompt)
232
+ return str(resp), None
233
+
234
+ @staticmethod
235
+ def _normalize(text: str) -> str:
236
+ """Basic normalization for answer equality checks: lowercase, strip spaces."""
237
+ return "".join(ch.lower() for ch in text.strip() if not ch.isspace())
238
+
239
+ def _parse_answer(self, text: str) -> str:
240
+ """Optionally parse final answer from text using provided function.
241
+ If no parser provided, returns the original text.
242
+ """
243
+ if self.parse_answer_fn is None:
244
+ return text
245
+ try:
246
+ return self.parse_answer_fn(text)
247
+ except Exception as exc:
248
+ logging.getLogger(__name__).debug("Answer parse error: %s", exc)
249
+ return text
@@ -0,0 +1,342 @@
1
+ import json
2
+ from collections.abc import Callable, Iterable
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import cast
6
+
7
+ from typing_extensions import Self
8
+
9
+ from ragbits.agents import Agent
10
+ from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata, Usage
11
+ from ragbits.document_search import DocumentSearch
12
+ from ragbits.document_search.documents.document import DocumentMeta
13
+ from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult
14
+
15
+
16
+ class HotpotQAData(EvaluationData):
17
+ """
18
+ Represents a single HotpotQA example.
19
+ """
20
+
21
+ id: str
22
+ question: str
23
+ reference_answer: str
24
+ qtype: str
25
+ level: str
26
+ reference_context: list[str]
27
+
28
+
29
+ @dataclass
30
+ class HotpotQAResult(EvaluationResult):
31
+ """
32
+ Represents the result of evaluating a single HotpotQA example.
33
+ """
34
+
35
+ id: str
36
+ predicted_result: str
37
+ reference_answer: str
38
+ question: str
39
+ qtype: str
40
+ level: str
41
+ predicted_parsed: str | None = None
42
+ reference_normalized: str | None = None
43
+ em_value: float = 0.0
44
+ f1_value: float = 0.0
45
+
46
+
47
+ class HotpotQAPipeline(
48
+ EvaluationPipeline[
49
+ Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT],
50
+ HotpotQAData,
51
+ HotpotQAResult,
52
+ ]
53
+ ):
54
+ """
55
+ HotpotQA evaluation pipeline with simple RAG ingestion per batch and multi-hop retrieval.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ evaluation_target: Agent[LLMClientOptionsT, None, str] | LLM[LLMClientOptionsT],
61
+ *,
62
+ retriever: DocumentSearch,
63
+ hops: int = 3,
64
+ per_example_log_file: Path | None = None,
65
+ extended_logs: bool = False,
66
+ parse_answer_fn: Callable[[str], str] | None = None,
67
+ question_generation_prompt_fn: Callable[[str, str], str] | None = None,
68
+ retrieval_k: int = 3,
69
+ element_max_chars: int = 500,
70
+ hop_context_max_chars: int = 1200,
71
+ ) -> None:
72
+ super().__init__(evaluation_target=evaluation_target)
73
+ self.retriever = retriever
74
+ self.hops = max(1, min(hops, 5))
75
+ self.per_example_log_file = per_example_log_file
76
+ self.extended_logs = extended_logs
77
+ self.parse_answer_fn = parse_answer_fn
78
+ self.question_generation_prompt_fn = question_generation_prompt_fn
79
+ self.retrieval_k = max(1, int(retrieval_k))
80
+ self.element_max_chars = max(50, int(element_max_chars))
81
+ self.hop_context_max_chars = max(100, int(hop_context_max_chars))
82
+ self._init_log_file()
83
+
84
+ @classmethod
85
+ def from_config(cls, config: dict) -> Self:
86
+ """Create pipeline from config.
87
+ Attempts Agent first, falls back to raw LLM construction.
88
+ """
89
+ if "evaluation_target" not in config:
90
+ try:
91
+ config["evaluation_target"] = Agent.from_config(config)
92
+ except Exception:
93
+ config["evaluation_target"] = LLM.from_config(config)
94
+ config["retriever"] = DocumentSearch.from_config(config["document_search"])
95
+ config["hops"] = int(config.get("hops", 3))
96
+ config["retrieval_k"] = int(config.get("retrieval_k", 3))
97
+ config["element_max_chars"] = int(config.get("element_max_chars", 500))
98
+ config["hop_context_max_chars"] = int(config.get("hop_context_max_chars", 1200))
99
+ return super().from_config(config)
100
+
101
+ async def _ingest_documents(self, data: Iterable[HotpotQAData]) -> None:
102
+ """Ingest all documents from the data."""
103
+ documents: list[DocumentMeta] = []
104
+ for row in data:
105
+ for content in row.reference_context:
106
+ documents.append(DocumentMeta.from_literal(content))
107
+ await self.retriever.ingest(documents)
108
+
109
+ async def _perform_multihop_retrieval(self, example: HotpotQAData) -> tuple[str, list[dict]]:
110
+ """Perform multi-hop retrieval and return accumulated context and hop logs."""
111
+ accumulated_context: list[str] = []
112
+ hop_logs: list[dict] = []
113
+ last_query = example.question
114
+ for hop_idx in range(self.hops):
115
+ elements = await self.retriever.search(last_query)
116
+ text_parts: list[str] = []
117
+ consumed = 0
118
+ for element in elements[: self.retrieval_k]:
119
+ content = getattr(element, "content", "") if hasattr(element, "content") else ""
120
+ if isinstance(content, str) and content:
121
+ snippet = content.strip().replace("\n\n", "\n")[: self.element_max_chars]
122
+ budget = max(0, self.hop_context_max_chars - consumed)
123
+ take = snippet[:budget]
124
+ if take:
125
+ text_parts.append(take)
126
+ consumed += len(take)
127
+ if consumed >= self.hop_context_max_chars:
128
+ break
129
+ hop_text = "\n\n".join(text_parts)
130
+ hop_logs.append({"hop": hop_idx + 1, "question": last_query, "retrieved": hop_text})
131
+ if hop_text:
132
+ accumulated_context.append(hop_text)
133
+ # generate a new question for the next hop
134
+ if hop_idx < self.hops - 1:
135
+ last_query = await self._generate_next_question(
136
+ original_question=example.question,
137
+ accumulated_context="\n\n".join(accumulated_context),
138
+ )
139
+ else:
140
+ break
141
+ return "\n\n".join(accumulated_context), hop_logs
142
+
143
+ async def _answer_with_retrieval(self, example: HotpotQAData) -> tuple[str, Usage, list, list[dict], dict]:
144
+ """Answer a question with multi-hop retrieval."""
145
+ full_context, hop_logs = await self._perform_multihop_retrieval(example)
146
+ prompt_input = example.question if not full_context else f"{example.question}\n\nContext:\n{full_context}"
147
+
148
+ if self.extended_logs:
149
+ content, dbg = await self._generate_with_debug(prompt_input)
150
+ usage = cast(Usage, (dbg or {}).get("usage") or Usage())
151
+ tool_calls = cast(list, (dbg or {}).get("tool_calls") or [])
152
+ metadata = cast(dict, (dbg or {}).get("metadata") or {})
153
+ else:
154
+ content, usage, tool_calls = await self._generate_answer(prompt_input)
155
+ metadata = {}
156
+
157
+ return str(content), usage, tool_calls, hop_logs, metadata
158
+
159
+ async def __call__(self, data: Iterable[HotpotQAData]) -> Iterable[HotpotQAResult]:
160
+ """Ingest contexts, perform multi-hop retrieval, and answer HotpotQA questions."""
161
+ data_list = list(data)
162
+ await self._ingest_documents(data_list)
163
+
164
+ results: list[HotpotQAResult] = []
165
+ for row in data_list:
166
+ try:
167
+ predicted_text, usage, tool_calls, hop_logs, metadata = await self._answer_with_retrieval(row)
168
+ except Exception:
169
+ predicted_text = ""
170
+ usage = Usage()
171
+ tool_calls = []
172
+ hop_logs = []
173
+ metadata = {}
174
+ predicted_extracted = self._parse_answer(predicted_text)
175
+ ref_norm = self._normalize(row.reference_answer)
176
+
177
+ # Compute normalized fields and sample metrics once
178
+ em = 1.0 if self._normalize(predicted_extracted) == ref_norm else 0.0
179
+ f1 = self._f1(self._normalize(predicted_extracted), ref_norm)
180
+
181
+ result = HotpotQAResult(
182
+ id=row.id,
183
+ predicted_result=predicted_text,
184
+ reference_answer=row.reference_answer,
185
+ question=row.question,
186
+ qtype=row.qtype,
187
+ level=row.level,
188
+ predicted_parsed=self._normalize(predicted_extracted),
189
+ reference_normalized=ref_norm,
190
+ em_value=float(em),
191
+ f1_value=float(f1),
192
+ )
193
+ results.append(result)
194
+
195
+ ext_log_str = None
196
+ if self.extended_logs:
197
+ ext_log_str = json.dumps(
198
+ [
199
+ {
200
+ "usage": usage,
201
+ "tool_calls": tool_calls,
202
+ "hops": hop_logs,
203
+ "metadata": metadata,
204
+ }
205
+ ],
206
+ ensure_ascii=False,
207
+ default=str,
208
+ )
209
+ self._log_example(
210
+ row=row,
211
+ predicted_text=predicted_text,
212
+ predicted_extracted=predicted_extracted,
213
+ em=result.em_value,
214
+ f1=result.f1_value,
215
+ extended_log=ext_log_str,
216
+ )
217
+
218
+ return results
219
+
220
+ def _init_log_file(self) -> None:
221
+ """Ensure the per-example log file exists if logging is enabled."""
222
+ if self.per_example_log_file is None:
223
+ return
224
+ self.per_example_log_file.parent.mkdir(parents=True, exist_ok=True)
225
+ with open(self.per_example_log_file, "w", encoding="utf-8") as _:
226
+ pass
227
+
228
+ def _log_example(
229
+ self,
230
+ *,
231
+ row: HotpotQAData,
232
+ predicted_text: str,
233
+ predicted_extracted: str,
234
+ em: float,
235
+ f1: float,
236
+ extended_log: str | None = None,
237
+ ) -> None:
238
+ """Append a single NDJSON record for debugging if enabled."""
239
+ if self.per_example_log_file is None:
240
+ return
241
+ record: dict[str, object] = {
242
+ "id": row.id,
243
+ "question": row.question,
244
+ "reference": row.reference_answer,
245
+ "predicted": predicted_text,
246
+ "predicted_extracted": predicted_extracted,
247
+ "type": row.qtype,
248
+ "level": row.level,
249
+ "em": float(em),
250
+ "f1": float(f1),
251
+ }
252
+ record["extended_debug_logging"] = extended_log or "[]"
253
+ with open(self.per_example_log_file, "a", encoding="utf-8") as file:
254
+ file.write(json.dumps(record, ensure_ascii=False) + "\n")
255
+
256
+ def _parse_answer(self, text: str) -> str:
257
+ """Optionally parse final answer from text using provided function.
258
+ If no parser provided, returns the original text.
259
+ """
260
+ if self.parse_answer_fn is None:
261
+ return text
262
+ try:
263
+ return self.parse_answer_fn(text)
264
+ except Exception as exc:
265
+ import logging as _logging
266
+
267
+ _logging.getLogger(__name__).debug("Answer parse error: %s", exc)
268
+ return text
269
+
270
+ async def _generate_answer(self, prompt: str) -> tuple[str, Usage, list]:
271
+ """Generate final answer from Agent or raw LLM and capture usage and tool calls."""
272
+ target = self.evaluation_target
273
+ if isinstance(target, Agent):
274
+ res = await target.run(prompt)
275
+ return str(res.content), res.usage, (res.tool_calls or [])
276
+
277
+ resp = cast(LLMResponseWithMetadata[str], await target.generate_with_metadata(prompt))
278
+ return str(resp.content), (resp.usage or Usage()), []
279
+
280
+ async def _generate_with_debug(self, prompt: str) -> tuple[str, dict | None]:
281
+ """Generate answer and capture tool/history/usage for logging (as raw content)."""
282
+ target = self.evaluation_target
283
+ if isinstance(target, Agent):
284
+ res = await target.run(prompt)
285
+ dbg = {
286
+ "history": res.history,
287
+ "tool_calls": res.tool_calls,
288
+ "usage": res.usage,
289
+ "metadata": res.metadata,
290
+ }
291
+ return str(res.content), dbg
292
+ resp = await target.generate(prompt)
293
+ return str(resp), None
294
+
295
+ async def _generate_next_question(self, original_question: str, accumulated_context: str) -> str:
296
+ """Generate a new follow-up question based on the original question and accumulated context."""
297
+ if self.question_generation_prompt_fn is None:
298
+ # default: simple concatenation
299
+ return f"{original_question}\n\nContext so far:\n{accumulated_context}"
300
+
301
+ question_generation_prompt = self.question_generation_prompt_fn(original_question, accumulated_context)
302
+
303
+ target = self.evaluation_target
304
+ if isinstance(target, Agent):
305
+ resp = await target.llm.generate(question_generation_prompt)
306
+ return str(resp).strip()
307
+ resp = await target.generate(question_generation_prompt)
308
+ return str(resp).strip()
309
+
310
+ @staticmethod
311
+ def _normalize(text: str) -> str:
312
+ """Basic normalization for answer equality checks: lowercase, strip spaces."""
313
+ return "".join(ch.lower() for ch in (text or "").strip() if not ch.isspace())
314
+
315
+ @staticmethod
316
+ def _f1(prediction: str, ground_truth: str) -> float:
317
+ import re as _re
318
+ from collections import Counter as _Counter
319
+
320
+ def tokens(value: str) -> list[str]:
321
+ value = (value or "").lower()
322
+ value = _re.sub(r"[^a-z0-9\s]", " ", value)
323
+ value = _re.sub(r"\b(a|an|the)\b", " ", value)
324
+ value = _re.sub(r"\s+", " ", value).strip()
325
+ return value.split()
326
+
327
+ pred_tokens = tokens(prediction)
328
+ gt_tokens = tokens(ground_truth)
329
+ if not pred_tokens and not gt_tokens:
330
+ return 1.0
331
+ if not pred_tokens or not gt_tokens:
332
+ return 0.0
333
+
334
+ pred_counts = _Counter(pred_tokens)
335
+ gt_counts = _Counter(gt_tokens)
336
+ common = sum((pred_counts & gt_counts).values())
337
+ if common == 0:
338
+ return 0.0
339
+
340
+ precision = common / len(pred_tokens)
341
+ recall = common / len(gt_tokens)
342
+ return 2 * precision * recall / (precision + recall)