rnsr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rnsr/__init__.py +118 -0
  2. rnsr/__main__.py +242 -0
  3. rnsr/agent/__init__.py +218 -0
  4. rnsr/agent/cross_doc_navigator.py +767 -0
  5. rnsr/agent/graph.py +1557 -0
  6. rnsr/agent/llm_cache.py +575 -0
  7. rnsr/agent/navigator_api.py +497 -0
  8. rnsr/agent/provenance.py +772 -0
  9. rnsr/agent/query_clarifier.py +617 -0
  10. rnsr/agent/reasoning_memory.py +736 -0
  11. rnsr/agent/repl_env.py +709 -0
  12. rnsr/agent/rlm_navigator.py +2108 -0
  13. rnsr/agent/self_reflection.py +602 -0
  14. rnsr/agent/variable_store.py +308 -0
  15. rnsr/benchmarks/__init__.py +118 -0
  16. rnsr/benchmarks/comprehensive_benchmark.py +733 -0
  17. rnsr/benchmarks/evaluation_suite.py +1210 -0
  18. rnsr/benchmarks/finance_bench.py +147 -0
  19. rnsr/benchmarks/pdf_merger.py +178 -0
  20. rnsr/benchmarks/performance.py +321 -0
  21. rnsr/benchmarks/quality.py +321 -0
  22. rnsr/benchmarks/runner.py +298 -0
  23. rnsr/benchmarks/standard_benchmarks.py +995 -0
  24. rnsr/client.py +560 -0
  25. rnsr/document_store.py +394 -0
  26. rnsr/exceptions.py +74 -0
  27. rnsr/extraction/__init__.py +172 -0
  28. rnsr/extraction/candidate_extractor.py +357 -0
  29. rnsr/extraction/entity_extractor.py +581 -0
  30. rnsr/extraction/entity_linker.py +825 -0
  31. rnsr/extraction/grounded_extractor.py +722 -0
  32. rnsr/extraction/learned_types.py +599 -0
  33. rnsr/extraction/models.py +232 -0
  34. rnsr/extraction/relationship_extractor.py +600 -0
  35. rnsr/extraction/relationship_patterns.py +511 -0
  36. rnsr/extraction/relationship_validator.py +392 -0
  37. rnsr/extraction/rlm_extractor.py +589 -0
  38. rnsr/extraction/rlm_unified_extractor.py +990 -0
  39. rnsr/extraction/tot_validator.py +610 -0
  40. rnsr/extraction/unified_extractor.py +342 -0
  41. rnsr/indexing/__init__.py +60 -0
  42. rnsr/indexing/knowledge_graph.py +1128 -0
  43. rnsr/indexing/kv_store.py +313 -0
  44. rnsr/indexing/persistence.py +323 -0
  45. rnsr/indexing/semantic_retriever.py +237 -0
  46. rnsr/indexing/semantic_search.py +320 -0
  47. rnsr/indexing/skeleton_index.py +395 -0
  48. rnsr/ingestion/__init__.py +161 -0
  49. rnsr/ingestion/chart_parser.py +569 -0
  50. rnsr/ingestion/document_boundary.py +662 -0
  51. rnsr/ingestion/font_histogram.py +334 -0
  52. rnsr/ingestion/header_classifier.py +595 -0
  53. rnsr/ingestion/hierarchical_cluster.py +515 -0
  54. rnsr/ingestion/layout_detector.py +356 -0
  55. rnsr/ingestion/layout_model.py +379 -0
  56. rnsr/ingestion/ocr_fallback.py +177 -0
  57. rnsr/ingestion/pipeline.py +936 -0
  58. rnsr/ingestion/semantic_fallback.py +417 -0
  59. rnsr/ingestion/table_parser.py +799 -0
  60. rnsr/ingestion/text_builder.py +460 -0
  61. rnsr/ingestion/tree_builder.py +402 -0
  62. rnsr/ingestion/vision_retrieval.py +965 -0
  63. rnsr/ingestion/xy_cut.py +555 -0
  64. rnsr/llm.py +733 -0
  65. rnsr/models.py +167 -0
  66. rnsr/py.typed +2 -0
  67. rnsr-0.1.0.dist-info/METADATA +592 -0
  68. rnsr-0.1.0.dist-info/RECORD +72 -0
  69. rnsr-0.1.0.dist-info/WHEEL +5 -0
  70. rnsr-0.1.0.dist-info/entry_points.txt +2 -0
  71. rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
  72. rnsr-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,995 @@
1
+ """
2
+ Standard RAG Benchmarks for RNSR Evaluation
3
+
4
+ This module provides integration with established RAG and retrieval benchmarks
5
+ to validate RNSR's claims of improved document parsing and traversal.
6
+
7
+ Key Benchmarks:
8
+ 1. RAGAS - Standard RAG evaluation metrics (faithfulness, relevance, etc.)
9
+ 2. BEIR - Information retrieval benchmark (17+ datasets)
10
+ 3. HotpotQA - Multi-hop question answering
11
+ 4. MuSiQue - Multi-hop questions via single-hop composition
12
+
13
+ These benchmarks help demonstrate RNSR's advantages:
14
+ - Hierarchical tree traversal vs flat chunk retrieval
15
+ - Multi-hop reasoning capabilities
16
+ - Context preservation in complex documents
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import json
22
+ import time
23
+ from abc import ABC, abstractmethod
24
+ from dataclasses import dataclass, field
25
+ from pathlib import Path
26
+ from typing import Any, Literal
27
+
28
+ import structlog
29
+
30
+ logger = structlog.get_logger(__name__)
31
+
32
+
33
+ # =============================================================================
34
+ # Baseline RAG Systems for Comparison
35
+ # =============================================================================
36
+
37
+ @dataclass
38
+ class BaselineResult:
39
+ """Result from a baseline RAG system."""
40
+
41
+ answer: str
42
+ retrieved_chunks: list[str]
43
+ retrieval_time_s: float
44
+ generation_time_s: float
45
+ total_time_s: float
46
+ method: str
47
+ metadata: dict[str, Any] = field(default_factory=dict)
48
+
49
+
50
+ class BaselineRAG(ABC):
51
+ """Abstract base class for baseline RAG implementations."""
52
+
53
+ @abstractmethod
54
+ def query(self, question: str, document_path: Path) -> BaselineResult:
55
+ """Answer a question using the baseline method."""
56
+ pass
57
+
58
+ @abstractmethod
59
+ def name(self) -> str:
60
+ """Return the name of this baseline."""
61
+ pass
62
+
63
+
64
+ class NaiveChunkRAG(BaselineRAG):
65
+ """
66
+ Naive chunking baseline - the standard RAG approach.
67
+
68
+ Chunks document into fixed-size segments, embeds them,
69
+ retrieves top-k by similarity, and generates answer.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ chunk_size: int = 512,
75
+ chunk_overlap: int = 50,
76
+ top_k: int = 5,
77
+ embedding_model: str = "text-embedding-3-small",
78
+ ):
79
+ self.chunk_size = chunk_size
80
+ self.chunk_overlap = chunk_overlap
81
+ self.top_k = top_k
82
+ self.embedding_model = embedding_model
83
+
84
+ def name(self) -> str:
85
+ return f"naive_chunk_{self.chunk_size}"
86
+
87
+ def query(self, question: str, document_path: Path) -> BaselineResult:
88
+ """Query using naive chunking."""
89
+ import fitz # type: ignore[import-not-found] # PyMuPDF
90
+
91
+ start_total = time.perf_counter()
92
+
93
+ # Extract text
94
+ doc = fitz.open(document_path)
95
+ full_text = ""
96
+ for page in doc:
97
+ text = page.get_text()
98
+ if isinstance(text, str):
99
+ full_text += text
100
+ doc.close()
101
+
102
+ # Naive chunking
103
+ chunks = []
104
+ for i in range(0, len(full_text), self.chunk_size - self.chunk_overlap):
105
+ chunk = full_text[i:i + self.chunk_size]
106
+ if chunk.strip():
107
+ chunks.append(chunk)
108
+
109
+ # Embed and retrieve (simplified - would use actual embeddings)
110
+ start_retrieval = time.perf_counter()
111
+
112
+ # For now, use simple keyword matching as proxy
113
+ # In production, use actual embeddings
114
+ question_words = set(question.lower().split())
115
+ scored_chunks = []
116
+ for chunk in chunks:
117
+ chunk_words = set(chunk.lower().split())
118
+ score = len(question_words & chunk_words) / max(len(question_words), 1)
119
+ scored_chunks.append((score, chunk))
120
+
121
+ scored_chunks.sort(reverse=True, key=lambda x: x[0])
122
+ retrieved = [c for _, c in scored_chunks[:self.top_k]]
123
+
124
+ retrieval_time = time.perf_counter() - start_retrieval
125
+
126
+ # Generate answer (placeholder - would use LLM)
127
+ start_generation = time.perf_counter()
128
+ context = "\n\n".join(retrieved)
129
+ answer = f"[Baseline answer based on {len(retrieved)} chunks]"
130
+ generation_time = time.perf_counter() - start_generation
131
+
132
+ total_time = time.perf_counter() - start_total
133
+
134
+ return BaselineResult(
135
+ answer=answer,
136
+ retrieved_chunks=retrieved,
137
+ retrieval_time_s=retrieval_time,
138
+ generation_time_s=generation_time,
139
+ total_time_s=total_time,
140
+ method=self.name(),
141
+ metadata={
142
+ "total_chunks": len(chunks),
143
+ "chunk_size": self.chunk_size,
144
+ }
145
+ )
146
+
147
+
148
+ class SemanticChunkRAG(BaselineRAG):
149
+ """
150
+ Semantic chunking baseline - splits on semantic boundaries.
151
+
152
+ Uses sentence embeddings to detect topic shifts and
153
+ creates more coherent chunks than naive splitting.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ similarity_threshold: float = 0.7,
159
+ top_k: int = 5,
160
+ ):
161
+ self.similarity_threshold = similarity_threshold
162
+ self.top_k = top_k
163
+
164
+ def name(self) -> str:
165
+ return "semantic_chunk"
166
+
167
+ def query(self, question: str, document_path: Path) -> BaselineResult:
168
+ """Query using semantic chunking."""
169
+ # Placeholder implementation
170
+ start_total = time.perf_counter()
171
+
172
+ # Would implement semantic boundary detection here
173
+ # For now, return placeholder result
174
+
175
+ return BaselineResult(
176
+ answer="[Semantic baseline placeholder]",
177
+ retrieved_chunks=[],
178
+ retrieval_time_s=0.0,
179
+ generation_time_s=0.0,
180
+ total_time_s=time.perf_counter() - start_total,
181
+ method=self.name(),
182
+ )
183
+
184
+
185
+ # =============================================================================
186
+ # Standard Benchmark Datasets
187
+ # =============================================================================
188
+
189
+ @dataclass
190
+ class BenchmarkQuestion:
191
+ """A question from a standard benchmark."""
192
+
193
+ id: str
194
+ question: str
195
+ answer: str
196
+ supporting_facts: list[str] = field(default_factory=list)
197
+ context: list[str] = field(default_factory=list)
198
+ reasoning_type: str = "single-hop"
199
+ metadata: dict[str, Any] = field(default_factory=dict)
200
+
201
+
202
+ @dataclass
203
+ class BenchmarkDataset:
204
+ """A standard benchmark dataset."""
205
+
206
+ name: str
207
+ description: str
208
+ questions: list[BenchmarkQuestion]
209
+ metrics: list[str]
210
+ source_url: str
211
+
212
+ def __len__(self) -> int:
213
+ return len(self.questions)
214
+
215
+ def sample(self, n: int) -> list[BenchmarkQuestion]:
216
+ """Get a random sample of questions."""
217
+ import random
218
+ return random.sample(self.questions, min(n, len(self.questions)))
219
+
220
+
221
+ class BenchmarkLoader:
222
+ """Load standard benchmark datasets."""
223
+
224
+ @staticmethod
225
+ def load_hotpotqa(
226
+ split: Literal["train", "dev_distractor", "dev_fullwiki"] = "dev_distractor",
227
+ max_samples: int | None = None,
228
+ ) -> BenchmarkDataset:
229
+ """
230
+ Load HotpotQA dataset for multi-hop QA evaluation.
231
+
232
+ HotpotQA features:
233
+ - Natural multi-hop questions
234
+ - Strong supervision for supporting facts
235
+ - Explainable reasoning chains
236
+
237
+ Download: http://curtis.ml.cmu.edu/datasets/hotpot/
238
+ """
239
+ try:
240
+ from datasets import load_dataset # type: ignore[import-not-found]
241
+
242
+ dataset = load_dataset("hotpot_qa", "distractor", split="validation")
243
+
244
+ questions = []
245
+ for i, item in enumerate(dataset):
246
+ if max_samples and i >= max_samples:
247
+ break
248
+
249
+ questions.append(BenchmarkQuestion(
250
+ id=item["id"],
251
+ question=item["question"],
252
+ answer=item["answer"],
253
+ supporting_facts=item.get("supporting_facts", {}).get("title", []),
254
+ context=[
255
+ " ".join(sentences)
256
+ for sentences in item.get("context", {}).get("sentences", [])
257
+ ],
258
+ reasoning_type="multi-hop",
259
+ metadata={
260
+ "type": item.get("type", "unknown"),
261
+ "level": item.get("level", "unknown"),
262
+ }
263
+ ))
264
+
265
+ return BenchmarkDataset(
266
+ name="HotpotQA",
267
+ description="Multi-hop question answering with supporting facts",
268
+ questions=questions,
269
+ metrics=["answer_em", "answer_f1", "support_em", "support_f1"],
270
+ source_url="https://hotpotqa.github.io/",
271
+ )
272
+
273
+ except ImportError:
274
+ logger.warning("datasets library not installed, returning empty dataset")
275
+ return BenchmarkDataset(
276
+ name="HotpotQA",
277
+ description="Multi-hop QA (not loaded - install 'datasets')",
278
+ questions=[],
279
+ metrics=["answer_em", "answer_f1", "support_em", "support_f1"],
280
+ source_url="https://hotpotqa.github.io/",
281
+ )
282
+
283
+ @staticmethod
284
+ def load_musique(
285
+ variant: Literal["ans", "full"] = "ans",
286
+ max_samples: int | None = None,
287
+ ) -> BenchmarkDataset:
288
+ """
289
+ Load MuSiQue dataset for compositional multi-hop QA.
290
+
291
+ MuSiQue features:
292
+ - Questions composed from single-hop questions
293
+ - Harder disconnected reasoning required
294
+ - 2-4 hop questions
295
+
296
+ Download: https://github.com/StonyBrookNLP/musique
297
+ """
298
+ try:
299
+ from datasets import load_dataset # type: ignore[import-not-found]
300
+
301
+ dataset = load_dataset(
302
+ "dgslibiern/musique_ans" if variant == "ans" else "dgslibiern/musique_full",
303
+ split="validation"
304
+ )
305
+
306
+ questions = []
307
+ for i, item in enumerate(dataset):
308
+ if max_samples and i >= max_samples:
309
+ break
310
+
311
+ questions.append(BenchmarkQuestion(
312
+ id=item.get("id", str(i)),
313
+ question=item["question"],
314
+ answer=item.get("answer", ""),
315
+ supporting_facts=[],
316
+ context=item.get("paragraphs", []),
317
+ reasoning_type="multi-hop-compositional",
318
+ metadata={
319
+ "answerable": item.get("answerable", True),
320
+ }
321
+ ))
322
+
323
+ return BenchmarkDataset(
324
+ name=f"MuSiQue-{variant.upper()}",
325
+ description="Compositional multi-hop questions",
326
+ questions=questions,
327
+ metrics=["answer_f1", "support_f1"],
328
+ source_url="https://github.com/StonyBrookNLP/musique",
329
+ )
330
+
331
+ except ImportError:
332
+ logger.warning("datasets library not installed")
333
+ return BenchmarkDataset(
334
+ name=f"MuSiQue-{variant.upper()}",
335
+ description="MuSiQue (not loaded - install 'datasets')",
336
+ questions=[],
337
+ metrics=["answer_f1", "support_f1"],
338
+ source_url="https://github.com/StonyBrookNLP/musique",
339
+ )
340
+
341
+ @staticmethod
342
+ def load_beir_dataset(
343
+ dataset_name: str = "nfcorpus",
344
+ max_samples: int | None = None,
345
+ ) -> BenchmarkDataset:
346
+ """
347
+ Load a BEIR benchmark dataset for retrieval evaluation.
348
+
349
+ Available datasets:
350
+ - msmarco, trec-covid, nfcorpus, bioasq, nq, hotpotqa
351
+ - fiqa, arguana, webis-touche2020, cqadupstack, quora
352
+ - dbpedia-entity, scidocs, fever, climate-fever, scifact
353
+
354
+ See: https://github.com/beir-cellar/beir
355
+ """
356
+ try:
357
+ from beir import util # type: ignore[import-not-found]
358
+ from beir.datasets.data_loader import GenericDataLoader # type: ignore[import-not-found]
359
+
360
+ data_path = util.download_and_unzip(
361
+ f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip",
362
+ "benchmark_data"
363
+ )
364
+
365
+ corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
366
+
367
+ questions = []
368
+ for i, (qid, query) in enumerate(queries.items()):
369
+ if max_samples and i >= max_samples:
370
+ break
371
+
372
+ relevant_docs = qrels.get(qid, {})
373
+ context = [corpus[doc_id]["text"] for doc_id in relevant_docs if doc_id in corpus]
374
+
375
+ questions.append(BenchmarkQuestion(
376
+ id=qid,
377
+ question=query,
378
+ answer="", # BEIR is retrieval-focused, not QA
379
+ context=context[:5],
380
+ reasoning_type="retrieval",
381
+ metadata={"relevance_scores": relevant_docs}
382
+ ))
383
+
384
+ return BenchmarkDataset(
385
+ name=f"BEIR-{dataset_name}",
386
+ description=f"BEIR retrieval benchmark: {dataset_name}",
387
+ questions=questions,
388
+ metrics=["ndcg@10", "map", "recall@100", "precision@10"],
389
+ source_url="https://github.com/beir-cellar/beir",
390
+ )
391
+
392
+ except ImportError:
393
+ logger.warning("beir library not installed")
394
+ return BenchmarkDataset(
395
+ name=f"BEIR-{dataset_name}",
396
+ description=f"BEIR (not loaded - install 'beir')",
397
+ questions=[],
398
+ metrics=["ndcg@10", "map", "recall@100"],
399
+ source_url="https://github.com/beir-cellar/beir",
400
+ )
401
+
402
+ @staticmethod
403
+ def load_qasper(
404
+ max_samples: int | None = None,
405
+ ) -> BenchmarkDataset:
406
+ """
407
+ Load QASPER-style scientific paper QA.
408
+
409
+ Note: Original QASPER dataset uses deprecated format.
410
+ Using SciQ as a scientific reasoning alternative.
411
+
412
+ For true QASPER testing, download from:
413
+ https://allenai.org/data/qasper
414
+ """
415
+ try:
416
+ from datasets import load_dataset # type: ignore[import-not-found]
417
+
418
+ # Use SciQ as scientific QA alternative (QASPER is deprecated)
419
+ dataset = load_dataset("allenai/sciq", split="validation")
420
+
421
+ questions = []
422
+ for i, item in enumerate(dataset):
423
+ if max_samples and i >= max_samples:
424
+ break
425
+
426
+ # SciQ has question, correct_answer, support (context)
427
+ support = item.get("support", "")
428
+ question = item.get("question", "")
429
+ answer = item.get("correct_answer", "")
430
+
431
+ # Skip if no support context
432
+ if not support:
433
+ continue
434
+
435
+ questions.append(BenchmarkQuestion(
436
+ id=str(i),
437
+ question=question,
438
+ answer=answer,
439
+ supporting_facts=[],
440
+ context=[support],
441
+ reasoning_type="scientific",
442
+ metadata={
443
+ "distractor1": item.get("distractor1", ""),
444
+ "distractor2": item.get("distractor2", ""),
445
+ "distractor3": item.get("distractor3", ""),
446
+ }
447
+ ))
448
+
449
+ return BenchmarkDataset(
450
+ name="SciQ",
451
+ description="Scientific reasoning QA with supporting context",
452
+ questions=questions,
453
+ metrics=["answer_f1", "answer_em"],
454
+ source_url="https://allenai.org/data/sciq",
455
+ )
456
+
457
+ except Exception as e:
458
+ logger.warning("sciq_load_failed", error=str(e))
459
+ return BenchmarkDataset(
460
+ name="SciQ",
461
+ description=f"SciQ (load failed: {str(e)[:50]})",
462
+ questions=[],
463
+ metrics=["answer_f1"],
464
+ source_url="https://allenai.org/data/sciq",
465
+ )
466
+
467
+ @staticmethod
468
+ def load_quality(
469
+ max_samples: int | None = None,
470
+ ) -> BenchmarkDataset:
471
+ """
472
+ Load QuALITY dataset for long document QA.
473
+
474
+ QuALITY features (ideal for RNSR):
475
+ - Long articles (2,000-8,000 words)
476
+ - Multiple-choice questions
477
+ - Requires reading entire document
478
+ - Tests long-range comprehension
479
+
480
+ Paper: Pang et al., NAACL 2022
481
+ URL: https://github.com/nyu-mll/quality
482
+ """
483
+ try:
484
+ from datasets import load_dataset # type: ignore[import-not-found]
485
+
486
+ # Use emozilla/quality which is available on HuggingFace
487
+ dataset = load_dataset("emozilla/quality", split="validation")
488
+
489
+ questions = []
490
+ for item in dataset:
491
+ if max_samples and len(questions) >= max_samples:
492
+ break
493
+
494
+ article = item.get("article", "")
495
+ question = item.get("question", "")
496
+ options = item.get("options", [])
497
+ gold_label = item.get("answer", 0)
498
+ is_hard = item.get("hard", False)
499
+
500
+ # Format answer as the correct option
501
+ answer = options[gold_label] if gold_label < len(options) else ""
502
+
503
+ questions.append(BenchmarkQuestion(
504
+ id=str(len(questions)),
505
+ question=question,
506
+ answer=answer,
507
+ supporting_facts=[],
508
+ context=[article], # Full article as context
509
+ reasoning_type="long-document",
510
+ metadata={
511
+ "options": options,
512
+ "gold_label": gold_label,
513
+ "is_hard": is_hard,
514
+ "article_length": len(article.split()),
515
+ }
516
+ ))
517
+
518
+ return BenchmarkDataset(
519
+ name="QuALITY",
520
+ description="Long document multiple-choice QA",
521
+ questions=questions,
522
+ metrics=["accuracy", "answer_em"],
523
+ source_url="https://github.com/nyu-mll/quality",
524
+ )
525
+
526
+ except Exception as e:
527
+ logger.warning("quality_load_failed", error=str(e))
528
+ return BenchmarkDataset(
529
+ name="QuALITY",
530
+ description=f"QuALITY (load failed: {str(e)[:50]})",
531
+ questions=[],
532
+ metrics=["accuracy", "answer_em"],
533
+ source_url="https://github.com/nyu-mll/quality",
534
+ )
535
+
536
+ @staticmethod
537
+ def load_financebench(
538
+ split: str = "train",
539
+ max_samples: int | None = None,
540
+ ) -> BenchmarkDataset:
541
+ """
542
+ Load FinanceBench dataset.
543
+
544
+ FinanceBench features:
545
+ - Financial QA over complex PDFs
546
+ - Requires table/chart understanding
547
+ - Document-level retrieval
548
+ """
549
+ try:
550
+ from rnsr.benchmarks.finance_bench import FinanceBenchLoader
551
+ return FinanceBenchLoader.load(split=split, max_samples=max_samples)
552
+ except Exception as e:
553
+ logger.error("Failed to load FinanceBench", error=str(e))
554
+ return BenchmarkDataset(
555
+ name="FinanceBench",
556
+ description="Financial QA (Failed to load)",
557
+ questions=[],
558
+ metrics=[],
559
+ source_url=""
560
+ )
561
+
562
+
563
+ @staticmethod
564
+ def load_narrative_qa(
565
+ max_samples: int | None = None,
566
+ ) -> BenchmarkDataset:
567
+ """
568
+ Load NarrativeQA dataset for very long document QA.
569
+
570
+ NarrativeQA features (stress test for RNSR):
571
+ - Full books and movie scripts
572
+ - Very long context (10k-100k+ words)
573
+ - Tests extreme long-range comprehension
574
+
575
+ Paper: Kočiský et al., TACL 2018
576
+ URL: https://github.com/deepmind/narrativeqa
577
+ """
578
+ try:
579
+ from datasets import load_dataset # type: ignore[import-not-found]
580
+
581
+ dataset = load_dataset("narrativeqa", split="validation")
582
+
583
+ questions = []
584
+ for item in dataset:
585
+ if max_samples and len(questions) >= max_samples:
586
+ break
587
+
588
+ # NarrativeQA has summaries as proxy for full documents
589
+ document = item.get("document", {})
590
+ summary = document.get("summary", {}).get("text", "")
591
+
592
+ question = item.get("question", {}).get("text", "")
593
+ answers = item.get("answers", [])
594
+ answer = answers[0].get("text", "") if answers else ""
595
+
596
+ questions.append(BenchmarkQuestion(
597
+ id=item.get("document", {}).get("id", str(len(questions))),
598
+ question=question,
599
+ answer=answer,
600
+ supporting_facts=[],
601
+ context=[summary], # Using summary as proxy
602
+ reasoning_type="narrative",
603
+ metadata={
604
+ "kind": document.get("kind", ""),
605
+ "all_answers": [a.get("text", "") for a in answers],
606
+ }
607
+ ))
608
+
609
+ return BenchmarkDataset(
610
+ name="NarrativeQA",
611
+ description="Very long document QA (books/scripts)",
612
+ questions=questions,
613
+ metrics=["answer_f1", "rouge_l"],
614
+ source_url="https://github.com/deepmind/narrativeqa",
615
+ )
616
+
617
+ except ImportError:
618
+ logger.warning("datasets library not installed")
619
+ return BenchmarkDataset(
620
+ name="NarrativeQA",
621
+ description="NarrativeQA (not loaded - install 'datasets')",
622
+ questions=[],
623
+ metrics=["answer_f1", "rouge_l"],
624
+ source_url="https://github.com/deepmind/narrativeqa",
625
+ )
626
+
627
+
628
+ # =============================================================================
629
+ # RAGAS Metrics Integration
630
+ # =============================================================================
631
+
632
+ @dataclass
633
+ class RAGASMetrics:
634
+ """Standard RAGAS evaluation metrics."""
635
+
636
+ faithfulness: float = 0.0
637
+ answer_relevancy: float = 0.0
638
+ context_precision: float = 0.0
639
+ context_recall: float = 0.0
640
+ context_relevancy: float = 0.0
641
+ answer_correctness: float = 0.0
642
+
643
+ def overall_score(self) -> float:
644
+ """Compute weighted overall score."""
645
+ weights = {
646
+ "faithfulness": 0.2,
647
+ "answer_relevancy": 0.2,
648
+ "context_precision": 0.15,
649
+ "context_recall": 0.15,
650
+ "context_relevancy": 0.15,
651
+ "answer_correctness": 0.15,
652
+ }
653
+
654
+ total = 0.0
655
+ for metric, weight in weights.items():
656
+ total += getattr(self, metric) * weight
657
+
658
+ return total
659
+
660
+ def to_dict(self) -> dict[str, float]:
661
+ return {
662
+ "faithfulness": self.faithfulness,
663
+ "answer_relevancy": self.answer_relevancy,
664
+ "context_precision": self.context_precision,
665
+ "context_recall": self.context_recall,
666
+ "context_relevancy": self.context_relevancy,
667
+ "answer_correctness": self.answer_correctness,
668
+ "overall": self.overall_score(),
669
+ }
670
+
671
+
672
+ class RAGASEvaluator:
673
+ """
674
+ Evaluate RAG systems using RAGAS metrics.
675
+
676
+ RAGAS (Retrieval Augmented Generation Assessment) provides
677
+ standard metrics for evaluating RAG pipelines:
678
+
679
+ - Faithfulness: Is the answer grounded in the context?
680
+ - Answer Relevancy: Does the answer address the question?
681
+ - Context Precision: Are retrieved contexts relevant?
682
+ - Context Recall: Are all relevant contexts retrieved?
683
+
684
+ See: https://github.com/explodinggradients/ragas
685
+ """
686
+
687
+ def __init__(
688
+ self,
689
+ llm_provider: str = "gemini",
690
+ llm_model: str = "gemini-2.5-flash",
691
+ ):
692
+ self.llm_provider = llm_provider
693
+ self.llm_model = llm_model
694
+
695
+ def evaluate(
696
+ self,
697
+ question: str,
698
+ answer: str,
699
+ contexts: list[str],
700
+ ground_truth: str | None = None,
701
+ ) -> RAGASMetrics:
702
+ """
703
+ Evaluate a single RAG response using RAGAS metrics.
704
+ """
705
+ try:
706
+ from ragas import evaluate # type: ignore[import-not-found]
707
+ from ragas.metrics import ( # type: ignore[import-not-found]
708
+ faithfulness,
709
+ answer_relevancy,
710
+ context_precision,
711
+ context_recall,
712
+ )
713
+ from datasets import Dataset # type: ignore[import-not-found]
714
+
715
+ # Prepare data
716
+ data = {
717
+ "question": [question],
718
+ "answer": [answer],
719
+ "contexts": [contexts],
720
+ }
721
+ if ground_truth:
722
+ data["ground_truth"] = [ground_truth]
723
+
724
+ dataset = Dataset.from_dict(data)
725
+
726
+ # Run evaluation
727
+ metrics = [faithfulness, answer_relevancy, context_precision]
728
+ if ground_truth:
729
+ metrics.append(context_recall)
730
+
731
+ result = evaluate(dataset, metrics=metrics)
732
+
733
+ return RAGASMetrics(
734
+ faithfulness=result.get("faithfulness", 0.0),
735
+ answer_relevancy=result.get("answer_relevancy", 0.0),
736
+ context_precision=result.get("context_precision", 0.0),
737
+ context_recall=result.get("context_recall", 0.0) if ground_truth else 0.0,
738
+ )
739
+
740
+ except ImportError:
741
+ logger.warning("ragas library not installed, returning zero metrics")
742
+ return RAGASMetrics()
743
+
744
+ def evaluate_batch(
745
+ self,
746
+ questions: list[str],
747
+ answers: list[str],
748
+ contexts: list[list[str]],
749
+ ground_truths: list[str] | None = None,
750
+ ) -> RAGASMetrics:
751
+ """Evaluate a batch of responses and return aggregated metrics."""
752
+ all_metrics = []
753
+
754
+ for i in range(len(questions)):
755
+ gt = ground_truths[i] if ground_truths else None
756
+ metrics = self.evaluate(
757
+ questions[i],
758
+ answers[i],
759
+ contexts[i],
760
+ gt,
761
+ )
762
+ all_metrics.append(metrics)
763
+
764
+ # Aggregate
765
+ if not all_metrics:
766
+ return RAGASMetrics()
767
+
768
+ return RAGASMetrics(
769
+ faithfulness=sum(m.faithfulness for m in all_metrics) / len(all_metrics),
770
+ answer_relevancy=sum(m.answer_relevancy for m in all_metrics) / len(all_metrics),
771
+ context_precision=sum(m.context_precision for m in all_metrics) / len(all_metrics),
772
+ context_recall=sum(m.context_recall for m in all_metrics) / len(all_metrics),
773
+ context_relevancy=sum(m.context_relevancy for m in all_metrics) / len(all_metrics),
774
+ answer_correctness=sum(m.answer_correctness for m in all_metrics) / len(all_metrics),
775
+ )
776
+
777
+
778
+ # =============================================================================
779
+ # Multi-Hop Reasoning Metrics (for HotpotQA/MuSiQue)
780
+ # =============================================================================
781
+
782
+ @dataclass
783
+ class MultiHopMetrics:
784
+ """Metrics for multi-hop reasoning evaluation."""
785
+
786
+ answer_em: float = 0.0 # Exact match
787
+ answer_f1: float = 0.0 # Token-level F1
788
+ support_em: float = 0.0 # Supporting fact EM
789
+ support_f1: float = 0.0 # Supporting fact F1
790
+ joint_em: float = 0.0 # Joint answer + support EM
791
+ joint_f1: float = 0.0 # Joint answer + support F1
792
+
793
+ def to_dict(self) -> dict[str, float]:
794
+ return {
795
+ "answer_em": self.answer_em,
796
+ "answer_f1": self.answer_f1,
797
+ "support_em": self.support_em,
798
+ "support_f1": self.support_f1,
799
+ "joint_em": self.joint_em,
800
+ "joint_f1": self.joint_f1,
801
+ }
802
+
803
+
804
+ def normalize_answer(s: str) -> str:
805
+ """Normalize answer for comparison."""
806
+ import re
807
+ import string
808
+
809
+ def remove_articles(text):
810
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
811
+
812
+ def white_space_fix(text):
813
+ return ' '.join(text.split())
814
+
815
+ def remove_punc(text):
816
+ exclude = set(string.punctuation)
817
+ return ''.join(ch for ch in text if ch not in exclude)
818
+
819
+ def lower(text):
820
+ return text.lower()
821
+
822
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
823
+
824
+
825
+ def compute_em(prediction: str, ground_truth: str) -> float:
826
+ """Compute exact match score."""
827
+ return float(normalize_answer(prediction) == normalize_answer(ground_truth))
828
+
829
+
830
+ def compute_f1(prediction: str, ground_truth: str) -> float:
831
+ """Compute token-level F1 score."""
832
+ pred_tokens = normalize_answer(prediction).split()
833
+ gold_tokens = normalize_answer(ground_truth).split()
834
+
835
+ common = set(pred_tokens) & set(gold_tokens)
836
+
837
+ if len(common) == 0:
838
+ return 0.0
839
+
840
+ precision = len(common) / len(pred_tokens) if pred_tokens else 0
841
+ recall = len(common) / len(gold_tokens) if gold_tokens else 0
842
+
843
+ if precision + recall == 0:
844
+ return 0.0
845
+
846
+ return 2 * precision * recall / (precision + recall)
847
+
848
+
849
+ def evaluate_multihop(
850
+ predictions: list[dict[str, Any]],
851
+ ground_truths: list[BenchmarkQuestion],
852
+ ) -> MultiHopMetrics:
853
+ """
854
+ Evaluate multi-hop QA predictions against ground truth.
855
+
856
+ Args:
857
+ predictions: List of {"answer": str, "supporting_facts": list[str]}
858
+ ground_truths: List of BenchmarkQuestion with answers and supporting facts
859
+ """
860
+ answer_ems = []
861
+ answer_f1s = []
862
+ support_ems = []
863
+ support_f1s = []
864
+
865
+ for pred, gold in zip(predictions, ground_truths):
866
+ # Answer metrics
867
+ answer_ems.append(compute_em(pred.get("answer", ""), gold.answer))
868
+ answer_f1s.append(compute_f1(pred.get("answer", ""), gold.answer))
869
+
870
+ # Supporting facts metrics
871
+ pred_facts = set(pred.get("supporting_facts", []))
872
+ gold_facts = set(gold.supporting_facts)
873
+
874
+ if gold_facts:
875
+ support_em = float(pred_facts == gold_facts)
876
+
877
+ common = pred_facts & gold_facts
878
+ prec = len(common) / len(pred_facts) if pred_facts else 0
879
+ rec = len(common) / len(gold_facts) if gold_facts else 0
880
+ support_f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
881
+
882
+ support_ems.append(support_em)
883
+ support_f1s.append(support_f1)
884
+
885
+ n = len(predictions)
886
+
887
+ return MultiHopMetrics(
888
+ answer_em=sum(answer_ems) / n if n else 0,
889
+ answer_f1=sum(answer_f1s) / n if n else 0,
890
+ support_em=sum(support_ems) / len(support_ems) if support_ems else 0,
891
+ support_f1=sum(support_f1s) / len(support_f1s) if support_f1s else 0,
892
+ joint_em=(sum(answer_ems) / n) * (sum(support_ems) / len(support_ems)) if n and support_ems else 0,
893
+ joint_f1=(sum(answer_f1s) / n) * (sum(support_f1s) / len(support_f1s)) if n and support_f1s else 0,
894
+ )
895
+
896
+
897
+ # =============================================================================
898
+ # RNSR vs Baseline Comparison
899
+ # =============================================================================
900
+
901
+ @dataclass
902
+ class ComparisonResult:
903
+ """Result of comparing RNSR against a baseline."""
904
+
905
+ dataset_name: str
906
+ rnsr_metrics: dict[str, float]
907
+ baseline_metrics: dict[str, float]
908
+ baseline_name: str
909
+ improvement: dict[str, float] # RNSR - baseline for each metric
910
+ relative_improvement: dict[str, float] # (RNSR - baseline) / baseline
911
+
912
+ def summary(self) -> str:
913
+ """Generate human-readable summary."""
914
+ lines = [
915
+ f"\n{'='*60}",
916
+ f"Comparison: RNSR vs {self.baseline_name}",
917
+ f"Dataset: {self.dataset_name}",
918
+ f"{'='*60}",
919
+ "",
920
+ f"{'Metric':<25} {'RNSR':>10} {'Baseline':>10} {'Δ':>10} {'%':>10}",
921
+ "-" * 65,
922
+ ]
923
+
924
+ for metric in self.rnsr_metrics:
925
+ rnsr_val = self.rnsr_metrics.get(metric, 0)
926
+ base_val = self.baseline_metrics.get(metric, 0)
927
+ delta = self.improvement.get(metric, 0)
928
+ rel = self.relative_improvement.get(metric, 0) * 100
929
+
930
+ lines.append(f"{metric:<25} {rnsr_val:>10.3f} {base_val:>10.3f} {delta:>+10.3f} {rel:>+9.1f}%")
931
+
932
+ lines.append("=" * 65)
933
+ return "\n".join(lines)
934
+
935
+
936
+ def compare_rnsr_vs_baseline(
937
+ rnsr_results: dict[str, float],
938
+ baseline_results: dict[str, float],
939
+ dataset_name: str,
940
+ baseline_name: str,
941
+ ) -> ComparisonResult:
942
+ """Compare RNSR results against a baseline."""
943
+ improvement = {}
944
+ relative_improvement = {}
945
+
946
+ for metric in rnsr_results:
947
+ rnsr_val = rnsr_results.get(metric, 0)
948
+ base_val = baseline_results.get(metric, 0)
949
+
950
+ improvement[metric] = rnsr_val - base_val
951
+ if base_val > 0:
952
+ relative_improvement[metric] = (rnsr_val - base_val) / base_val
953
+ else:
954
+ relative_improvement[metric] = 0.0
955
+
956
+ return ComparisonResult(
957
+ dataset_name=dataset_name,
958
+ rnsr_metrics=rnsr_results,
959
+ baseline_metrics=baseline_results,
960
+ baseline_name=baseline_name,
961
+ improvement=improvement,
962
+ relative_improvement=relative_improvement,
963
+ )
964
+
965
+
966
+ # =============================================================================
967
+ # Exports
968
+ # =============================================================================
969
+
970
+ __all__ = [
971
+ # Baselines
972
+ "BaselineRAG",
973
+ "BaselineResult",
974
+ "NaiveChunkRAG",
975
+ "SemanticChunkRAG",
976
+
977
+ # Benchmarks
978
+ "BenchmarkQuestion",
979
+ "BenchmarkDataset",
980
+ "BenchmarkLoader",
981
+
982
+ # RAGAS
983
+ "RAGASMetrics",
984
+ "RAGASEvaluator",
985
+
986
+ # Multi-hop
987
+ "MultiHopMetrics",
988
+ "evaluate_multihop",
989
+ "compute_em",
990
+ "compute_f1",
991
+
992
+ # Comparison
993
+ "ComparisonResult",
994
+ "compare_rnsr_vs_baseline",
995
+ ]