rnsr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnsr/__init__.py +118 -0
- rnsr/__main__.py +242 -0
- rnsr/agent/__init__.py +218 -0
- rnsr/agent/cross_doc_navigator.py +767 -0
- rnsr/agent/graph.py +1557 -0
- rnsr/agent/llm_cache.py +575 -0
- rnsr/agent/navigator_api.py +497 -0
- rnsr/agent/provenance.py +772 -0
- rnsr/agent/query_clarifier.py +617 -0
- rnsr/agent/reasoning_memory.py +736 -0
- rnsr/agent/repl_env.py +709 -0
- rnsr/agent/rlm_navigator.py +2108 -0
- rnsr/agent/self_reflection.py +602 -0
- rnsr/agent/variable_store.py +308 -0
- rnsr/benchmarks/__init__.py +118 -0
- rnsr/benchmarks/comprehensive_benchmark.py +733 -0
- rnsr/benchmarks/evaluation_suite.py +1210 -0
- rnsr/benchmarks/finance_bench.py +147 -0
- rnsr/benchmarks/pdf_merger.py +178 -0
- rnsr/benchmarks/performance.py +321 -0
- rnsr/benchmarks/quality.py +321 -0
- rnsr/benchmarks/runner.py +298 -0
- rnsr/benchmarks/standard_benchmarks.py +995 -0
- rnsr/client.py +560 -0
- rnsr/document_store.py +394 -0
- rnsr/exceptions.py +74 -0
- rnsr/extraction/__init__.py +172 -0
- rnsr/extraction/candidate_extractor.py +357 -0
- rnsr/extraction/entity_extractor.py +581 -0
- rnsr/extraction/entity_linker.py +825 -0
- rnsr/extraction/grounded_extractor.py +722 -0
- rnsr/extraction/learned_types.py +599 -0
- rnsr/extraction/models.py +232 -0
- rnsr/extraction/relationship_extractor.py +600 -0
- rnsr/extraction/relationship_patterns.py +511 -0
- rnsr/extraction/relationship_validator.py +392 -0
- rnsr/extraction/rlm_extractor.py +589 -0
- rnsr/extraction/rlm_unified_extractor.py +990 -0
- rnsr/extraction/tot_validator.py +610 -0
- rnsr/extraction/unified_extractor.py +342 -0
- rnsr/indexing/__init__.py +60 -0
- rnsr/indexing/knowledge_graph.py +1128 -0
- rnsr/indexing/kv_store.py +313 -0
- rnsr/indexing/persistence.py +323 -0
- rnsr/indexing/semantic_retriever.py +237 -0
- rnsr/indexing/semantic_search.py +320 -0
- rnsr/indexing/skeleton_index.py +395 -0
- rnsr/ingestion/__init__.py +161 -0
- rnsr/ingestion/chart_parser.py +569 -0
- rnsr/ingestion/document_boundary.py +662 -0
- rnsr/ingestion/font_histogram.py +334 -0
- rnsr/ingestion/header_classifier.py +595 -0
- rnsr/ingestion/hierarchical_cluster.py +515 -0
- rnsr/ingestion/layout_detector.py +356 -0
- rnsr/ingestion/layout_model.py +379 -0
- rnsr/ingestion/ocr_fallback.py +177 -0
- rnsr/ingestion/pipeline.py +936 -0
- rnsr/ingestion/semantic_fallback.py +417 -0
- rnsr/ingestion/table_parser.py +799 -0
- rnsr/ingestion/text_builder.py +460 -0
- rnsr/ingestion/tree_builder.py +402 -0
- rnsr/ingestion/vision_retrieval.py +965 -0
- rnsr/ingestion/xy_cut.py +555 -0
- rnsr/llm.py +733 -0
- rnsr/models.py +167 -0
- rnsr/py.typed +2 -0
- rnsr-0.1.0.dist-info/METADATA +592 -0
- rnsr-0.1.0.dist-info/RECORD +72 -0
- rnsr-0.1.0.dist-info/WHEEL +5 -0
- rnsr-0.1.0.dist-info/entry_points.txt +2 -0
- rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
- rnsr-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,995 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Standard RAG Benchmarks for RNSR Evaluation
|
|
3
|
+
|
|
4
|
+
This module provides integration with established RAG and retrieval benchmarks
|
|
5
|
+
to validate RNSR's claims of improved document parsing and traversal.
|
|
6
|
+
|
|
7
|
+
Key Benchmarks:
|
|
8
|
+
1. RAGAS - Standard RAG evaluation metrics (faithfulness, relevance, etc.)
|
|
9
|
+
2. BEIR - Information retrieval benchmark (17+ datasets)
|
|
10
|
+
3. HotpotQA - Multi-hop question answering
|
|
11
|
+
4. MuSiQue - Multi-hop questions via single-hop composition
|
|
12
|
+
|
|
13
|
+
These benchmarks help demonstrate RNSR's advantages:
|
|
14
|
+
- Hierarchical tree traversal vs flat chunk retrieval
|
|
15
|
+
- Multi-hop reasoning capabilities
|
|
16
|
+
- Context preservation in complex documents
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import json
|
|
22
|
+
import time
|
|
23
|
+
from abc import ABC, abstractmethod
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any, Literal
|
|
27
|
+
|
|
28
|
+
import structlog
|
|
29
|
+
|
|
30
|
+
logger = structlog.get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# =============================================================================
|
|
34
|
+
# Baseline RAG Systems for Comparison
|
|
35
|
+
# =============================================================================
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class BaselineResult:
|
|
39
|
+
"""Result from a baseline RAG system."""
|
|
40
|
+
|
|
41
|
+
answer: str
|
|
42
|
+
retrieved_chunks: list[str]
|
|
43
|
+
retrieval_time_s: float
|
|
44
|
+
generation_time_s: float
|
|
45
|
+
total_time_s: float
|
|
46
|
+
method: str
|
|
47
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BaselineRAG(ABC):
|
|
51
|
+
"""Abstract base class for baseline RAG implementations."""
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def query(self, question: str, document_path: Path) -> BaselineResult:
|
|
55
|
+
"""Answer a question using the baseline method."""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def name(self) -> str:
|
|
60
|
+
"""Return the name of this baseline."""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class NaiveChunkRAG(BaselineRAG):
|
|
65
|
+
"""
|
|
66
|
+
Naive chunking baseline - the standard RAG approach.
|
|
67
|
+
|
|
68
|
+
Chunks document into fixed-size segments, embeds them,
|
|
69
|
+
retrieves top-k by similarity, and generates answer.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
chunk_size: int = 512,
|
|
75
|
+
chunk_overlap: int = 50,
|
|
76
|
+
top_k: int = 5,
|
|
77
|
+
embedding_model: str = "text-embedding-3-small",
|
|
78
|
+
):
|
|
79
|
+
self.chunk_size = chunk_size
|
|
80
|
+
self.chunk_overlap = chunk_overlap
|
|
81
|
+
self.top_k = top_k
|
|
82
|
+
self.embedding_model = embedding_model
|
|
83
|
+
|
|
84
|
+
def name(self) -> str:
|
|
85
|
+
return f"naive_chunk_{self.chunk_size}"
|
|
86
|
+
|
|
87
|
+
def query(self, question: str, document_path: Path) -> BaselineResult:
|
|
88
|
+
"""Query using naive chunking."""
|
|
89
|
+
import fitz # type: ignore[import-not-found] # PyMuPDF
|
|
90
|
+
|
|
91
|
+
start_total = time.perf_counter()
|
|
92
|
+
|
|
93
|
+
# Extract text
|
|
94
|
+
doc = fitz.open(document_path)
|
|
95
|
+
full_text = ""
|
|
96
|
+
for page in doc:
|
|
97
|
+
text = page.get_text()
|
|
98
|
+
if isinstance(text, str):
|
|
99
|
+
full_text += text
|
|
100
|
+
doc.close()
|
|
101
|
+
|
|
102
|
+
# Naive chunking
|
|
103
|
+
chunks = []
|
|
104
|
+
for i in range(0, len(full_text), self.chunk_size - self.chunk_overlap):
|
|
105
|
+
chunk = full_text[i:i + self.chunk_size]
|
|
106
|
+
if chunk.strip():
|
|
107
|
+
chunks.append(chunk)
|
|
108
|
+
|
|
109
|
+
# Embed and retrieve (simplified - would use actual embeddings)
|
|
110
|
+
start_retrieval = time.perf_counter()
|
|
111
|
+
|
|
112
|
+
# For now, use simple keyword matching as proxy
|
|
113
|
+
# In production, use actual embeddings
|
|
114
|
+
question_words = set(question.lower().split())
|
|
115
|
+
scored_chunks = []
|
|
116
|
+
for chunk in chunks:
|
|
117
|
+
chunk_words = set(chunk.lower().split())
|
|
118
|
+
score = len(question_words & chunk_words) / max(len(question_words), 1)
|
|
119
|
+
scored_chunks.append((score, chunk))
|
|
120
|
+
|
|
121
|
+
scored_chunks.sort(reverse=True, key=lambda x: x[0])
|
|
122
|
+
retrieved = [c for _, c in scored_chunks[:self.top_k]]
|
|
123
|
+
|
|
124
|
+
retrieval_time = time.perf_counter() - start_retrieval
|
|
125
|
+
|
|
126
|
+
# Generate answer (placeholder - would use LLM)
|
|
127
|
+
start_generation = time.perf_counter()
|
|
128
|
+
context = "\n\n".join(retrieved)
|
|
129
|
+
answer = f"[Baseline answer based on {len(retrieved)} chunks]"
|
|
130
|
+
generation_time = time.perf_counter() - start_generation
|
|
131
|
+
|
|
132
|
+
total_time = time.perf_counter() - start_total
|
|
133
|
+
|
|
134
|
+
return BaselineResult(
|
|
135
|
+
answer=answer,
|
|
136
|
+
retrieved_chunks=retrieved,
|
|
137
|
+
retrieval_time_s=retrieval_time,
|
|
138
|
+
generation_time_s=generation_time,
|
|
139
|
+
total_time_s=total_time,
|
|
140
|
+
method=self.name(),
|
|
141
|
+
metadata={
|
|
142
|
+
"total_chunks": len(chunks),
|
|
143
|
+
"chunk_size": self.chunk_size,
|
|
144
|
+
}
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class SemanticChunkRAG(BaselineRAG):
|
|
149
|
+
"""
|
|
150
|
+
Semantic chunking baseline - splits on semantic boundaries.
|
|
151
|
+
|
|
152
|
+
Uses sentence embeddings to detect topic shifts and
|
|
153
|
+
creates more coherent chunks than naive splitting.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
similarity_threshold: float = 0.7,
|
|
159
|
+
top_k: int = 5,
|
|
160
|
+
):
|
|
161
|
+
self.similarity_threshold = similarity_threshold
|
|
162
|
+
self.top_k = top_k
|
|
163
|
+
|
|
164
|
+
def name(self) -> str:
|
|
165
|
+
return "semantic_chunk"
|
|
166
|
+
|
|
167
|
+
def query(self, question: str, document_path: Path) -> BaselineResult:
|
|
168
|
+
"""Query using semantic chunking."""
|
|
169
|
+
# Placeholder implementation
|
|
170
|
+
start_total = time.perf_counter()
|
|
171
|
+
|
|
172
|
+
# Would implement semantic boundary detection here
|
|
173
|
+
# For now, return placeholder result
|
|
174
|
+
|
|
175
|
+
return BaselineResult(
|
|
176
|
+
answer="[Semantic baseline placeholder]",
|
|
177
|
+
retrieved_chunks=[],
|
|
178
|
+
retrieval_time_s=0.0,
|
|
179
|
+
generation_time_s=0.0,
|
|
180
|
+
total_time_s=time.perf_counter() - start_total,
|
|
181
|
+
method=self.name(),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# =============================================================================
|
|
186
|
+
# Standard Benchmark Datasets
|
|
187
|
+
# =============================================================================
|
|
188
|
+
|
|
189
|
+
@dataclass
|
|
190
|
+
class BenchmarkQuestion:
|
|
191
|
+
"""A question from a standard benchmark."""
|
|
192
|
+
|
|
193
|
+
id: str
|
|
194
|
+
question: str
|
|
195
|
+
answer: str
|
|
196
|
+
supporting_facts: list[str] = field(default_factory=list)
|
|
197
|
+
context: list[str] = field(default_factory=list)
|
|
198
|
+
reasoning_type: str = "single-hop"
|
|
199
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass
|
|
203
|
+
class BenchmarkDataset:
|
|
204
|
+
"""A standard benchmark dataset."""
|
|
205
|
+
|
|
206
|
+
name: str
|
|
207
|
+
description: str
|
|
208
|
+
questions: list[BenchmarkQuestion]
|
|
209
|
+
metrics: list[str]
|
|
210
|
+
source_url: str
|
|
211
|
+
|
|
212
|
+
def __len__(self) -> int:
|
|
213
|
+
return len(self.questions)
|
|
214
|
+
|
|
215
|
+
def sample(self, n: int) -> list[BenchmarkQuestion]:
|
|
216
|
+
"""Get a random sample of questions."""
|
|
217
|
+
import random
|
|
218
|
+
return random.sample(self.questions, min(n, len(self.questions)))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class BenchmarkLoader:
|
|
222
|
+
"""Load standard benchmark datasets."""
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def load_hotpotqa(
|
|
226
|
+
split: Literal["train", "dev_distractor", "dev_fullwiki"] = "dev_distractor",
|
|
227
|
+
max_samples: int | None = None,
|
|
228
|
+
) -> BenchmarkDataset:
|
|
229
|
+
"""
|
|
230
|
+
Load HotpotQA dataset for multi-hop QA evaluation.
|
|
231
|
+
|
|
232
|
+
HotpotQA features:
|
|
233
|
+
- Natural multi-hop questions
|
|
234
|
+
- Strong supervision for supporting facts
|
|
235
|
+
- Explainable reasoning chains
|
|
236
|
+
|
|
237
|
+
Download: http://curtis.ml.cmu.edu/datasets/hotpot/
|
|
238
|
+
"""
|
|
239
|
+
try:
|
|
240
|
+
from datasets import load_dataset # type: ignore[import-not-found]
|
|
241
|
+
|
|
242
|
+
dataset = load_dataset("hotpot_qa", "distractor", split="validation")
|
|
243
|
+
|
|
244
|
+
questions = []
|
|
245
|
+
for i, item in enumerate(dataset):
|
|
246
|
+
if max_samples and i >= max_samples:
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
questions.append(BenchmarkQuestion(
|
|
250
|
+
id=item["id"],
|
|
251
|
+
question=item["question"],
|
|
252
|
+
answer=item["answer"],
|
|
253
|
+
supporting_facts=item.get("supporting_facts", {}).get("title", []),
|
|
254
|
+
context=[
|
|
255
|
+
" ".join(sentences)
|
|
256
|
+
for sentences in item.get("context", {}).get("sentences", [])
|
|
257
|
+
],
|
|
258
|
+
reasoning_type="multi-hop",
|
|
259
|
+
metadata={
|
|
260
|
+
"type": item.get("type", "unknown"),
|
|
261
|
+
"level": item.get("level", "unknown"),
|
|
262
|
+
}
|
|
263
|
+
))
|
|
264
|
+
|
|
265
|
+
return BenchmarkDataset(
|
|
266
|
+
name="HotpotQA",
|
|
267
|
+
description="Multi-hop question answering with supporting facts",
|
|
268
|
+
questions=questions,
|
|
269
|
+
metrics=["answer_em", "answer_f1", "support_em", "support_f1"],
|
|
270
|
+
source_url="https://hotpotqa.github.io/",
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
except ImportError:
|
|
274
|
+
logger.warning("datasets library not installed, returning empty dataset")
|
|
275
|
+
return BenchmarkDataset(
|
|
276
|
+
name="HotpotQA",
|
|
277
|
+
description="Multi-hop QA (not loaded - install 'datasets')",
|
|
278
|
+
questions=[],
|
|
279
|
+
metrics=["answer_em", "answer_f1", "support_em", "support_f1"],
|
|
280
|
+
source_url="https://hotpotqa.github.io/",
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
@staticmethod
|
|
284
|
+
def load_musique(
|
|
285
|
+
variant: Literal["ans", "full"] = "ans",
|
|
286
|
+
max_samples: int | None = None,
|
|
287
|
+
) -> BenchmarkDataset:
|
|
288
|
+
"""
|
|
289
|
+
Load MuSiQue dataset for compositional multi-hop QA.
|
|
290
|
+
|
|
291
|
+
MuSiQue features:
|
|
292
|
+
- Questions composed from single-hop questions
|
|
293
|
+
- Harder disconnected reasoning required
|
|
294
|
+
- 2-4 hop questions
|
|
295
|
+
|
|
296
|
+
Download: https://github.com/StonyBrookNLP/musique
|
|
297
|
+
"""
|
|
298
|
+
try:
|
|
299
|
+
from datasets import load_dataset # type: ignore[import-not-found]
|
|
300
|
+
|
|
301
|
+
dataset = load_dataset(
|
|
302
|
+
"dgslibiern/musique_ans" if variant == "ans" else "dgslibiern/musique_full",
|
|
303
|
+
split="validation"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
questions = []
|
|
307
|
+
for i, item in enumerate(dataset):
|
|
308
|
+
if max_samples and i >= max_samples:
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
questions.append(BenchmarkQuestion(
|
|
312
|
+
id=item.get("id", str(i)),
|
|
313
|
+
question=item["question"],
|
|
314
|
+
answer=item.get("answer", ""),
|
|
315
|
+
supporting_facts=[],
|
|
316
|
+
context=item.get("paragraphs", []),
|
|
317
|
+
reasoning_type="multi-hop-compositional",
|
|
318
|
+
metadata={
|
|
319
|
+
"answerable": item.get("answerable", True),
|
|
320
|
+
}
|
|
321
|
+
))
|
|
322
|
+
|
|
323
|
+
return BenchmarkDataset(
|
|
324
|
+
name=f"MuSiQue-{variant.upper()}",
|
|
325
|
+
description="Compositional multi-hop questions",
|
|
326
|
+
questions=questions,
|
|
327
|
+
metrics=["answer_f1", "support_f1"],
|
|
328
|
+
source_url="https://github.com/StonyBrookNLP/musique",
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
except ImportError:
|
|
332
|
+
logger.warning("datasets library not installed")
|
|
333
|
+
return BenchmarkDataset(
|
|
334
|
+
name=f"MuSiQue-{variant.upper()}",
|
|
335
|
+
description="MuSiQue (not loaded - install 'datasets')",
|
|
336
|
+
questions=[],
|
|
337
|
+
metrics=["answer_f1", "support_f1"],
|
|
338
|
+
source_url="https://github.com/StonyBrookNLP/musique",
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
@staticmethod
|
|
342
|
+
def load_beir_dataset(
|
|
343
|
+
dataset_name: str = "nfcorpus",
|
|
344
|
+
max_samples: int | None = None,
|
|
345
|
+
) -> BenchmarkDataset:
|
|
346
|
+
"""
|
|
347
|
+
Load a BEIR benchmark dataset for retrieval evaluation.
|
|
348
|
+
|
|
349
|
+
Available datasets:
|
|
350
|
+
- msmarco, trec-covid, nfcorpus, bioasq, nq, hotpotqa
|
|
351
|
+
- fiqa, arguana, webis-touche2020, cqadupstack, quora
|
|
352
|
+
- dbpedia-entity, scidocs, fever, climate-fever, scifact
|
|
353
|
+
|
|
354
|
+
See: https://github.com/beir-cellar/beir
|
|
355
|
+
"""
|
|
356
|
+
try:
|
|
357
|
+
from beir import util # type: ignore[import-not-found]
|
|
358
|
+
from beir.datasets.data_loader import GenericDataLoader # type: ignore[import-not-found]
|
|
359
|
+
|
|
360
|
+
data_path = util.download_and_unzip(
|
|
361
|
+
f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip",
|
|
362
|
+
"benchmark_data"
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
|
|
366
|
+
|
|
367
|
+
questions = []
|
|
368
|
+
for i, (qid, query) in enumerate(queries.items()):
|
|
369
|
+
if max_samples and i >= max_samples:
|
|
370
|
+
break
|
|
371
|
+
|
|
372
|
+
relevant_docs = qrels.get(qid, {})
|
|
373
|
+
context = [corpus[doc_id]["text"] for doc_id in relevant_docs if doc_id in corpus]
|
|
374
|
+
|
|
375
|
+
questions.append(BenchmarkQuestion(
|
|
376
|
+
id=qid,
|
|
377
|
+
question=query,
|
|
378
|
+
answer="", # BEIR is retrieval-focused, not QA
|
|
379
|
+
context=context[:5],
|
|
380
|
+
reasoning_type="retrieval",
|
|
381
|
+
metadata={"relevance_scores": relevant_docs}
|
|
382
|
+
))
|
|
383
|
+
|
|
384
|
+
return BenchmarkDataset(
|
|
385
|
+
name=f"BEIR-{dataset_name}",
|
|
386
|
+
description=f"BEIR retrieval benchmark: {dataset_name}",
|
|
387
|
+
questions=questions,
|
|
388
|
+
metrics=["ndcg@10", "map", "recall@100", "precision@10"],
|
|
389
|
+
source_url="https://github.com/beir-cellar/beir",
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
except ImportError:
|
|
393
|
+
logger.warning("beir library not installed")
|
|
394
|
+
return BenchmarkDataset(
|
|
395
|
+
name=f"BEIR-{dataset_name}",
|
|
396
|
+
description=f"BEIR (not loaded - install 'beir')",
|
|
397
|
+
questions=[],
|
|
398
|
+
metrics=["ndcg@10", "map", "recall@100"],
|
|
399
|
+
source_url="https://github.com/beir-cellar/beir",
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
@staticmethod
|
|
403
|
+
def load_qasper(
|
|
404
|
+
max_samples: int | None = None,
|
|
405
|
+
) -> BenchmarkDataset:
|
|
406
|
+
"""
|
|
407
|
+
Load QASPER-style scientific paper QA.
|
|
408
|
+
|
|
409
|
+
Note: Original QASPER dataset uses deprecated format.
|
|
410
|
+
Using SciQ as a scientific reasoning alternative.
|
|
411
|
+
|
|
412
|
+
For true QASPER testing, download from:
|
|
413
|
+
https://allenai.org/data/qasper
|
|
414
|
+
"""
|
|
415
|
+
try:
|
|
416
|
+
from datasets import load_dataset # type: ignore[import-not-found]
|
|
417
|
+
|
|
418
|
+
# Use SciQ as scientific QA alternative (QASPER is deprecated)
|
|
419
|
+
dataset = load_dataset("allenai/sciq", split="validation")
|
|
420
|
+
|
|
421
|
+
questions = []
|
|
422
|
+
for i, item in enumerate(dataset):
|
|
423
|
+
if max_samples and i >= max_samples:
|
|
424
|
+
break
|
|
425
|
+
|
|
426
|
+
# SciQ has question, correct_answer, support (context)
|
|
427
|
+
support = item.get("support", "")
|
|
428
|
+
question = item.get("question", "")
|
|
429
|
+
answer = item.get("correct_answer", "")
|
|
430
|
+
|
|
431
|
+
# Skip if no support context
|
|
432
|
+
if not support:
|
|
433
|
+
continue
|
|
434
|
+
|
|
435
|
+
questions.append(BenchmarkQuestion(
|
|
436
|
+
id=str(i),
|
|
437
|
+
question=question,
|
|
438
|
+
answer=answer,
|
|
439
|
+
supporting_facts=[],
|
|
440
|
+
context=[support],
|
|
441
|
+
reasoning_type="scientific",
|
|
442
|
+
metadata={
|
|
443
|
+
"distractor1": item.get("distractor1", ""),
|
|
444
|
+
"distractor2": item.get("distractor2", ""),
|
|
445
|
+
"distractor3": item.get("distractor3", ""),
|
|
446
|
+
}
|
|
447
|
+
))
|
|
448
|
+
|
|
449
|
+
return BenchmarkDataset(
|
|
450
|
+
name="SciQ",
|
|
451
|
+
description="Scientific reasoning QA with supporting context",
|
|
452
|
+
questions=questions,
|
|
453
|
+
metrics=["answer_f1", "answer_em"],
|
|
454
|
+
source_url="https://allenai.org/data/sciq",
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
except Exception as e:
|
|
458
|
+
logger.warning("sciq_load_failed", error=str(e))
|
|
459
|
+
return BenchmarkDataset(
|
|
460
|
+
name="SciQ",
|
|
461
|
+
description=f"SciQ (load failed: {str(e)[:50]})",
|
|
462
|
+
questions=[],
|
|
463
|
+
metrics=["answer_f1"],
|
|
464
|
+
source_url="https://allenai.org/data/sciq",
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
@staticmethod
|
|
468
|
+
def load_quality(
|
|
469
|
+
max_samples: int | None = None,
|
|
470
|
+
) -> BenchmarkDataset:
|
|
471
|
+
"""
|
|
472
|
+
Load QuALITY dataset for long document QA.
|
|
473
|
+
|
|
474
|
+
QuALITY features (ideal for RNSR):
|
|
475
|
+
- Long articles (2,000-8,000 words)
|
|
476
|
+
- Multiple-choice questions
|
|
477
|
+
- Requires reading entire document
|
|
478
|
+
- Tests long-range comprehension
|
|
479
|
+
|
|
480
|
+
Paper: Pang et al., NAACL 2022
|
|
481
|
+
URL: https://github.com/nyu-mll/quality
|
|
482
|
+
"""
|
|
483
|
+
try:
|
|
484
|
+
from datasets import load_dataset # type: ignore[import-not-found]
|
|
485
|
+
|
|
486
|
+
# Use emozilla/quality which is available on HuggingFace
|
|
487
|
+
dataset = load_dataset("emozilla/quality", split="validation")
|
|
488
|
+
|
|
489
|
+
questions = []
|
|
490
|
+
for item in dataset:
|
|
491
|
+
if max_samples and len(questions) >= max_samples:
|
|
492
|
+
break
|
|
493
|
+
|
|
494
|
+
article = item.get("article", "")
|
|
495
|
+
question = item.get("question", "")
|
|
496
|
+
options = item.get("options", [])
|
|
497
|
+
gold_label = item.get("answer", 0)
|
|
498
|
+
is_hard = item.get("hard", False)
|
|
499
|
+
|
|
500
|
+
# Format answer as the correct option
|
|
501
|
+
answer = options[gold_label] if gold_label < len(options) else ""
|
|
502
|
+
|
|
503
|
+
questions.append(BenchmarkQuestion(
|
|
504
|
+
id=str(len(questions)),
|
|
505
|
+
question=question,
|
|
506
|
+
answer=answer,
|
|
507
|
+
supporting_facts=[],
|
|
508
|
+
context=[article], # Full article as context
|
|
509
|
+
reasoning_type="long-document",
|
|
510
|
+
metadata={
|
|
511
|
+
"options": options,
|
|
512
|
+
"gold_label": gold_label,
|
|
513
|
+
"is_hard": is_hard,
|
|
514
|
+
"article_length": len(article.split()),
|
|
515
|
+
}
|
|
516
|
+
))
|
|
517
|
+
|
|
518
|
+
return BenchmarkDataset(
|
|
519
|
+
name="QuALITY",
|
|
520
|
+
description="Long document multiple-choice QA",
|
|
521
|
+
questions=questions,
|
|
522
|
+
metrics=["accuracy", "answer_em"],
|
|
523
|
+
source_url="https://github.com/nyu-mll/quality",
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
except Exception as e:
|
|
527
|
+
logger.warning("quality_load_failed", error=str(e))
|
|
528
|
+
return BenchmarkDataset(
|
|
529
|
+
name="QuALITY",
|
|
530
|
+
description=f"QuALITY (load failed: {str(e)[:50]})",
|
|
531
|
+
questions=[],
|
|
532
|
+
metrics=["accuracy", "answer_em"],
|
|
533
|
+
source_url="https://github.com/nyu-mll/quality",
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
@staticmethod
|
|
537
|
+
def load_financebench(
|
|
538
|
+
split: str = "train",
|
|
539
|
+
max_samples: int | None = None,
|
|
540
|
+
) -> BenchmarkDataset:
|
|
541
|
+
"""
|
|
542
|
+
Load FinanceBench dataset.
|
|
543
|
+
|
|
544
|
+
FinanceBench features:
|
|
545
|
+
- Financial QA over complex PDFs
|
|
546
|
+
- Requires table/chart understanding
|
|
547
|
+
- Document-level retrieval
|
|
548
|
+
"""
|
|
549
|
+
try:
|
|
550
|
+
from rnsr.benchmarks.finance_bench import FinanceBenchLoader
|
|
551
|
+
return FinanceBenchLoader.load(split=split, max_samples=max_samples)
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.error("Failed to load FinanceBench", error=str(e))
|
|
554
|
+
return BenchmarkDataset(
|
|
555
|
+
name="FinanceBench",
|
|
556
|
+
description="Financial QA (Failed to load)",
|
|
557
|
+
questions=[],
|
|
558
|
+
metrics=[],
|
|
559
|
+
source_url=""
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
@staticmethod
|
|
564
|
+
def load_narrative_qa(
|
|
565
|
+
max_samples: int | None = None,
|
|
566
|
+
) -> BenchmarkDataset:
|
|
567
|
+
"""
|
|
568
|
+
Load NarrativeQA dataset for very long document QA.
|
|
569
|
+
|
|
570
|
+
NarrativeQA features (stress test for RNSR):
|
|
571
|
+
- Full books and movie scripts
|
|
572
|
+
- Very long context (10k-100k+ words)
|
|
573
|
+
- Tests extreme long-range comprehension
|
|
574
|
+
|
|
575
|
+
Paper: Kočiský et al., TACL 2018
|
|
576
|
+
URL: https://github.com/deepmind/narrativeqa
|
|
577
|
+
"""
|
|
578
|
+
try:
|
|
579
|
+
from datasets import load_dataset # type: ignore[import-not-found]
|
|
580
|
+
|
|
581
|
+
dataset = load_dataset("narrativeqa", split="validation")
|
|
582
|
+
|
|
583
|
+
questions = []
|
|
584
|
+
for item in dataset:
|
|
585
|
+
if max_samples and len(questions) >= max_samples:
|
|
586
|
+
break
|
|
587
|
+
|
|
588
|
+
# NarrativeQA has summaries as proxy for full documents
|
|
589
|
+
document = item.get("document", {})
|
|
590
|
+
summary = document.get("summary", {}).get("text", "")
|
|
591
|
+
|
|
592
|
+
question = item.get("question", {}).get("text", "")
|
|
593
|
+
answers = item.get("answers", [])
|
|
594
|
+
answer = answers[0].get("text", "") if answers else ""
|
|
595
|
+
|
|
596
|
+
questions.append(BenchmarkQuestion(
|
|
597
|
+
id=item.get("document", {}).get("id", str(len(questions))),
|
|
598
|
+
question=question,
|
|
599
|
+
answer=answer,
|
|
600
|
+
supporting_facts=[],
|
|
601
|
+
context=[summary], # Using summary as proxy
|
|
602
|
+
reasoning_type="narrative",
|
|
603
|
+
metadata={
|
|
604
|
+
"kind": document.get("kind", ""),
|
|
605
|
+
"all_answers": [a.get("text", "") for a in answers],
|
|
606
|
+
}
|
|
607
|
+
))
|
|
608
|
+
|
|
609
|
+
return BenchmarkDataset(
|
|
610
|
+
name="NarrativeQA",
|
|
611
|
+
description="Very long document QA (books/scripts)",
|
|
612
|
+
questions=questions,
|
|
613
|
+
metrics=["answer_f1", "rouge_l"],
|
|
614
|
+
source_url="https://github.com/deepmind/narrativeqa",
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
except ImportError:
|
|
618
|
+
logger.warning("datasets library not installed")
|
|
619
|
+
return BenchmarkDataset(
|
|
620
|
+
name="NarrativeQA",
|
|
621
|
+
description="NarrativeQA (not loaded - install 'datasets')",
|
|
622
|
+
questions=[],
|
|
623
|
+
metrics=["answer_f1", "rouge_l"],
|
|
624
|
+
source_url="https://github.com/deepmind/narrativeqa",
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
# =============================================================================
|
|
629
|
+
# RAGAS Metrics Integration
|
|
630
|
+
# =============================================================================
|
|
631
|
+
|
|
632
|
+
@dataclass
|
|
633
|
+
class RAGASMetrics:
|
|
634
|
+
"""Standard RAGAS evaluation metrics."""
|
|
635
|
+
|
|
636
|
+
faithfulness: float = 0.0
|
|
637
|
+
answer_relevancy: float = 0.0
|
|
638
|
+
context_precision: float = 0.0
|
|
639
|
+
context_recall: float = 0.0
|
|
640
|
+
context_relevancy: float = 0.0
|
|
641
|
+
answer_correctness: float = 0.0
|
|
642
|
+
|
|
643
|
+
def overall_score(self) -> float:
|
|
644
|
+
"""Compute weighted overall score."""
|
|
645
|
+
weights = {
|
|
646
|
+
"faithfulness": 0.2,
|
|
647
|
+
"answer_relevancy": 0.2,
|
|
648
|
+
"context_precision": 0.15,
|
|
649
|
+
"context_recall": 0.15,
|
|
650
|
+
"context_relevancy": 0.15,
|
|
651
|
+
"answer_correctness": 0.15,
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
total = 0.0
|
|
655
|
+
for metric, weight in weights.items():
|
|
656
|
+
total += getattr(self, metric) * weight
|
|
657
|
+
|
|
658
|
+
return total
|
|
659
|
+
|
|
660
|
+
def to_dict(self) -> dict[str, float]:
|
|
661
|
+
return {
|
|
662
|
+
"faithfulness": self.faithfulness,
|
|
663
|
+
"answer_relevancy": self.answer_relevancy,
|
|
664
|
+
"context_precision": self.context_precision,
|
|
665
|
+
"context_recall": self.context_recall,
|
|
666
|
+
"context_relevancy": self.context_relevancy,
|
|
667
|
+
"answer_correctness": self.answer_correctness,
|
|
668
|
+
"overall": self.overall_score(),
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
class RAGASEvaluator:
|
|
673
|
+
"""
|
|
674
|
+
Evaluate RAG systems using RAGAS metrics.
|
|
675
|
+
|
|
676
|
+
RAGAS (Retrieval Augmented Generation Assessment) provides
|
|
677
|
+
standard metrics for evaluating RAG pipelines:
|
|
678
|
+
|
|
679
|
+
- Faithfulness: Is the answer grounded in the context?
|
|
680
|
+
- Answer Relevancy: Does the answer address the question?
|
|
681
|
+
- Context Precision: Are retrieved contexts relevant?
|
|
682
|
+
- Context Recall: Are all relevant contexts retrieved?
|
|
683
|
+
|
|
684
|
+
See: https://github.com/explodinggradients/ragas
|
|
685
|
+
"""
|
|
686
|
+
|
|
687
|
+
def __init__(
|
|
688
|
+
self,
|
|
689
|
+
llm_provider: str = "gemini",
|
|
690
|
+
llm_model: str = "gemini-2.5-flash",
|
|
691
|
+
):
|
|
692
|
+
self.llm_provider = llm_provider
|
|
693
|
+
self.llm_model = llm_model
|
|
694
|
+
|
|
695
|
+
def evaluate(
|
|
696
|
+
self,
|
|
697
|
+
question: str,
|
|
698
|
+
answer: str,
|
|
699
|
+
contexts: list[str],
|
|
700
|
+
ground_truth: str | None = None,
|
|
701
|
+
) -> RAGASMetrics:
|
|
702
|
+
"""
|
|
703
|
+
Evaluate a single RAG response using RAGAS metrics.
|
|
704
|
+
"""
|
|
705
|
+
try:
|
|
706
|
+
from ragas import evaluate # type: ignore[import-not-found]
|
|
707
|
+
from ragas.metrics import ( # type: ignore[import-not-found]
|
|
708
|
+
faithfulness,
|
|
709
|
+
answer_relevancy,
|
|
710
|
+
context_precision,
|
|
711
|
+
context_recall,
|
|
712
|
+
)
|
|
713
|
+
from datasets import Dataset # type: ignore[import-not-found]
|
|
714
|
+
|
|
715
|
+
# Prepare data
|
|
716
|
+
data = {
|
|
717
|
+
"question": [question],
|
|
718
|
+
"answer": [answer],
|
|
719
|
+
"contexts": [contexts],
|
|
720
|
+
}
|
|
721
|
+
if ground_truth:
|
|
722
|
+
data["ground_truth"] = [ground_truth]
|
|
723
|
+
|
|
724
|
+
dataset = Dataset.from_dict(data)
|
|
725
|
+
|
|
726
|
+
# Run evaluation
|
|
727
|
+
metrics = [faithfulness, answer_relevancy, context_precision]
|
|
728
|
+
if ground_truth:
|
|
729
|
+
metrics.append(context_recall)
|
|
730
|
+
|
|
731
|
+
result = evaluate(dataset, metrics=metrics)
|
|
732
|
+
|
|
733
|
+
return RAGASMetrics(
|
|
734
|
+
faithfulness=result.get("faithfulness", 0.0),
|
|
735
|
+
answer_relevancy=result.get("answer_relevancy", 0.0),
|
|
736
|
+
context_precision=result.get("context_precision", 0.0),
|
|
737
|
+
context_recall=result.get("context_recall", 0.0) if ground_truth else 0.0,
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
except ImportError:
|
|
741
|
+
logger.warning("ragas library not installed, returning zero metrics")
|
|
742
|
+
return RAGASMetrics()
|
|
743
|
+
|
|
744
|
+
def evaluate_batch(
|
|
745
|
+
self,
|
|
746
|
+
questions: list[str],
|
|
747
|
+
answers: list[str],
|
|
748
|
+
contexts: list[list[str]],
|
|
749
|
+
ground_truths: list[str] | None = None,
|
|
750
|
+
) -> RAGASMetrics:
|
|
751
|
+
"""Evaluate a batch of responses and return aggregated metrics."""
|
|
752
|
+
all_metrics = []
|
|
753
|
+
|
|
754
|
+
for i in range(len(questions)):
|
|
755
|
+
gt = ground_truths[i] if ground_truths else None
|
|
756
|
+
metrics = self.evaluate(
|
|
757
|
+
questions[i],
|
|
758
|
+
answers[i],
|
|
759
|
+
contexts[i],
|
|
760
|
+
gt,
|
|
761
|
+
)
|
|
762
|
+
all_metrics.append(metrics)
|
|
763
|
+
|
|
764
|
+
# Aggregate
|
|
765
|
+
if not all_metrics:
|
|
766
|
+
return RAGASMetrics()
|
|
767
|
+
|
|
768
|
+
return RAGASMetrics(
|
|
769
|
+
faithfulness=sum(m.faithfulness for m in all_metrics) / len(all_metrics),
|
|
770
|
+
answer_relevancy=sum(m.answer_relevancy for m in all_metrics) / len(all_metrics),
|
|
771
|
+
context_precision=sum(m.context_precision for m in all_metrics) / len(all_metrics),
|
|
772
|
+
context_recall=sum(m.context_recall for m in all_metrics) / len(all_metrics),
|
|
773
|
+
context_relevancy=sum(m.context_relevancy for m in all_metrics) / len(all_metrics),
|
|
774
|
+
answer_correctness=sum(m.answer_correctness for m in all_metrics) / len(all_metrics),
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
# =============================================================================
|
|
779
|
+
# Multi-Hop Reasoning Metrics (for HotpotQA/MuSiQue)
|
|
780
|
+
# =============================================================================
|
|
781
|
+
|
|
782
|
+
@dataclass
|
|
783
|
+
class MultiHopMetrics:
|
|
784
|
+
"""Metrics for multi-hop reasoning evaluation."""
|
|
785
|
+
|
|
786
|
+
answer_em: float = 0.0 # Exact match
|
|
787
|
+
answer_f1: float = 0.0 # Token-level F1
|
|
788
|
+
support_em: float = 0.0 # Supporting fact EM
|
|
789
|
+
support_f1: float = 0.0 # Supporting fact F1
|
|
790
|
+
joint_em: float = 0.0 # Joint answer + support EM
|
|
791
|
+
joint_f1: float = 0.0 # Joint answer + support F1
|
|
792
|
+
|
|
793
|
+
def to_dict(self) -> dict[str, float]:
|
|
794
|
+
return {
|
|
795
|
+
"answer_em": self.answer_em,
|
|
796
|
+
"answer_f1": self.answer_f1,
|
|
797
|
+
"support_em": self.support_em,
|
|
798
|
+
"support_f1": self.support_f1,
|
|
799
|
+
"joint_em": self.joint_em,
|
|
800
|
+
"joint_f1": self.joint_f1,
|
|
801
|
+
}
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def normalize_answer(s: str) -> str:
|
|
805
|
+
"""Normalize answer for comparison."""
|
|
806
|
+
import re
|
|
807
|
+
import string
|
|
808
|
+
|
|
809
|
+
def remove_articles(text):
|
|
810
|
+
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
|
811
|
+
|
|
812
|
+
def white_space_fix(text):
|
|
813
|
+
return ' '.join(text.split())
|
|
814
|
+
|
|
815
|
+
def remove_punc(text):
|
|
816
|
+
exclude = set(string.punctuation)
|
|
817
|
+
return ''.join(ch for ch in text if ch not in exclude)
|
|
818
|
+
|
|
819
|
+
def lower(text):
|
|
820
|
+
return text.lower()
|
|
821
|
+
|
|
822
|
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def compute_em(prediction: str, ground_truth: str) -> float:
|
|
826
|
+
"""Compute exact match score."""
|
|
827
|
+
return float(normalize_answer(prediction) == normalize_answer(ground_truth))
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def compute_f1(prediction: str, ground_truth: str) -> float:
|
|
831
|
+
"""Compute token-level F1 score."""
|
|
832
|
+
pred_tokens = normalize_answer(prediction).split()
|
|
833
|
+
gold_tokens = normalize_answer(ground_truth).split()
|
|
834
|
+
|
|
835
|
+
common = set(pred_tokens) & set(gold_tokens)
|
|
836
|
+
|
|
837
|
+
if len(common) == 0:
|
|
838
|
+
return 0.0
|
|
839
|
+
|
|
840
|
+
precision = len(common) / len(pred_tokens) if pred_tokens else 0
|
|
841
|
+
recall = len(common) / len(gold_tokens) if gold_tokens else 0
|
|
842
|
+
|
|
843
|
+
if precision + recall == 0:
|
|
844
|
+
return 0.0
|
|
845
|
+
|
|
846
|
+
return 2 * precision * recall / (precision + recall)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def evaluate_multihop(
|
|
850
|
+
predictions: list[dict[str, Any]],
|
|
851
|
+
ground_truths: list[BenchmarkQuestion],
|
|
852
|
+
) -> MultiHopMetrics:
|
|
853
|
+
"""
|
|
854
|
+
Evaluate multi-hop QA predictions against ground truth.
|
|
855
|
+
|
|
856
|
+
Args:
|
|
857
|
+
predictions: List of {"answer": str, "supporting_facts": list[str]}
|
|
858
|
+
ground_truths: List of BenchmarkQuestion with answers and supporting facts
|
|
859
|
+
"""
|
|
860
|
+
answer_ems = []
|
|
861
|
+
answer_f1s = []
|
|
862
|
+
support_ems = []
|
|
863
|
+
support_f1s = []
|
|
864
|
+
|
|
865
|
+
for pred, gold in zip(predictions, ground_truths):
|
|
866
|
+
# Answer metrics
|
|
867
|
+
answer_ems.append(compute_em(pred.get("answer", ""), gold.answer))
|
|
868
|
+
answer_f1s.append(compute_f1(pred.get("answer", ""), gold.answer))
|
|
869
|
+
|
|
870
|
+
# Supporting facts metrics
|
|
871
|
+
pred_facts = set(pred.get("supporting_facts", []))
|
|
872
|
+
gold_facts = set(gold.supporting_facts)
|
|
873
|
+
|
|
874
|
+
if gold_facts:
|
|
875
|
+
support_em = float(pred_facts == gold_facts)
|
|
876
|
+
|
|
877
|
+
common = pred_facts & gold_facts
|
|
878
|
+
prec = len(common) / len(pred_facts) if pred_facts else 0
|
|
879
|
+
rec = len(common) / len(gold_facts) if gold_facts else 0
|
|
880
|
+
support_f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
|
881
|
+
|
|
882
|
+
support_ems.append(support_em)
|
|
883
|
+
support_f1s.append(support_f1)
|
|
884
|
+
|
|
885
|
+
n = len(predictions)
|
|
886
|
+
|
|
887
|
+
return MultiHopMetrics(
|
|
888
|
+
answer_em=sum(answer_ems) / n if n else 0,
|
|
889
|
+
answer_f1=sum(answer_f1s) / n if n else 0,
|
|
890
|
+
support_em=sum(support_ems) / len(support_ems) if support_ems else 0,
|
|
891
|
+
support_f1=sum(support_f1s) / len(support_f1s) if support_f1s else 0,
|
|
892
|
+
joint_em=(sum(answer_ems) / n) * (sum(support_ems) / len(support_ems)) if n and support_ems else 0,
|
|
893
|
+
joint_f1=(sum(answer_f1s) / n) * (sum(support_f1s) / len(support_f1s)) if n and support_f1s else 0,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
# =============================================================================
|
|
898
|
+
# RNSR vs Baseline Comparison
|
|
899
|
+
# =============================================================================
|
|
900
|
+
|
|
901
|
+
@dataclass
|
|
902
|
+
class ComparisonResult:
|
|
903
|
+
"""Result of comparing RNSR against a baseline."""
|
|
904
|
+
|
|
905
|
+
dataset_name: str
|
|
906
|
+
rnsr_metrics: dict[str, float]
|
|
907
|
+
baseline_metrics: dict[str, float]
|
|
908
|
+
baseline_name: str
|
|
909
|
+
improvement: dict[str, float] # RNSR - baseline for each metric
|
|
910
|
+
relative_improvement: dict[str, float] # (RNSR - baseline) / baseline
|
|
911
|
+
|
|
912
|
+
def summary(self) -> str:
|
|
913
|
+
"""Generate human-readable summary."""
|
|
914
|
+
lines = [
|
|
915
|
+
f"\n{'='*60}",
|
|
916
|
+
f"Comparison: RNSR vs {self.baseline_name}",
|
|
917
|
+
f"Dataset: {self.dataset_name}",
|
|
918
|
+
f"{'='*60}",
|
|
919
|
+
"",
|
|
920
|
+
f"{'Metric':<25} {'RNSR':>10} {'Baseline':>10} {'Δ':>10} {'%':>10}",
|
|
921
|
+
"-" * 65,
|
|
922
|
+
]
|
|
923
|
+
|
|
924
|
+
for metric in self.rnsr_metrics:
|
|
925
|
+
rnsr_val = self.rnsr_metrics.get(metric, 0)
|
|
926
|
+
base_val = self.baseline_metrics.get(metric, 0)
|
|
927
|
+
delta = self.improvement.get(metric, 0)
|
|
928
|
+
rel = self.relative_improvement.get(metric, 0) * 100
|
|
929
|
+
|
|
930
|
+
lines.append(f"{metric:<25} {rnsr_val:>10.3f} {base_val:>10.3f} {delta:>+10.3f} {rel:>+9.1f}%")
|
|
931
|
+
|
|
932
|
+
lines.append("=" * 65)
|
|
933
|
+
return "\n".join(lines)
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
def compare_rnsr_vs_baseline(
|
|
937
|
+
rnsr_results: dict[str, float],
|
|
938
|
+
baseline_results: dict[str, float],
|
|
939
|
+
dataset_name: str,
|
|
940
|
+
baseline_name: str,
|
|
941
|
+
) -> ComparisonResult:
|
|
942
|
+
"""Compare RNSR results against a baseline."""
|
|
943
|
+
improvement = {}
|
|
944
|
+
relative_improvement = {}
|
|
945
|
+
|
|
946
|
+
for metric in rnsr_results:
|
|
947
|
+
rnsr_val = rnsr_results.get(metric, 0)
|
|
948
|
+
base_val = baseline_results.get(metric, 0)
|
|
949
|
+
|
|
950
|
+
improvement[metric] = rnsr_val - base_val
|
|
951
|
+
if base_val > 0:
|
|
952
|
+
relative_improvement[metric] = (rnsr_val - base_val) / base_val
|
|
953
|
+
else:
|
|
954
|
+
relative_improvement[metric] = 0.0
|
|
955
|
+
|
|
956
|
+
return ComparisonResult(
|
|
957
|
+
dataset_name=dataset_name,
|
|
958
|
+
rnsr_metrics=rnsr_results,
|
|
959
|
+
baseline_metrics=baseline_results,
|
|
960
|
+
baseline_name=baseline_name,
|
|
961
|
+
improvement=improvement,
|
|
962
|
+
relative_improvement=relative_improvement,
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
# =============================================================================
|
|
967
|
+
# Exports
|
|
968
|
+
# =============================================================================
|
|
969
|
+
|
|
970
|
+
__all__ = [
|
|
971
|
+
# Baselines
|
|
972
|
+
"BaselineRAG",
|
|
973
|
+
"BaselineResult",
|
|
974
|
+
"NaiveChunkRAG",
|
|
975
|
+
"SemanticChunkRAG",
|
|
976
|
+
|
|
977
|
+
# Benchmarks
|
|
978
|
+
"BenchmarkQuestion",
|
|
979
|
+
"BenchmarkDataset",
|
|
980
|
+
"BenchmarkLoader",
|
|
981
|
+
|
|
982
|
+
# RAGAS
|
|
983
|
+
"RAGASMetrics",
|
|
984
|
+
"RAGASEvaluator",
|
|
985
|
+
|
|
986
|
+
# Multi-hop
|
|
987
|
+
"MultiHopMetrics",
|
|
988
|
+
"evaluate_multihop",
|
|
989
|
+
"compute_em",
|
|
990
|
+
"compute_f1",
|
|
991
|
+
|
|
992
|
+
# Comparison
|
|
993
|
+
"ComparisonResult",
|
|
994
|
+
"compare_rnsr_vs_baseline",
|
|
995
|
+
]
|