rnsr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnsr/__init__.py +118 -0
- rnsr/__main__.py +242 -0
- rnsr/agent/__init__.py +218 -0
- rnsr/agent/cross_doc_navigator.py +767 -0
- rnsr/agent/graph.py +1557 -0
- rnsr/agent/llm_cache.py +575 -0
- rnsr/agent/navigator_api.py +497 -0
- rnsr/agent/provenance.py +772 -0
- rnsr/agent/query_clarifier.py +617 -0
- rnsr/agent/reasoning_memory.py +736 -0
- rnsr/agent/repl_env.py +709 -0
- rnsr/agent/rlm_navigator.py +2108 -0
- rnsr/agent/self_reflection.py +602 -0
- rnsr/agent/variable_store.py +308 -0
- rnsr/benchmarks/__init__.py +118 -0
- rnsr/benchmarks/comprehensive_benchmark.py +733 -0
- rnsr/benchmarks/evaluation_suite.py +1210 -0
- rnsr/benchmarks/finance_bench.py +147 -0
- rnsr/benchmarks/pdf_merger.py +178 -0
- rnsr/benchmarks/performance.py +321 -0
- rnsr/benchmarks/quality.py +321 -0
- rnsr/benchmarks/runner.py +298 -0
- rnsr/benchmarks/standard_benchmarks.py +995 -0
- rnsr/client.py +560 -0
- rnsr/document_store.py +394 -0
- rnsr/exceptions.py +74 -0
- rnsr/extraction/__init__.py +172 -0
- rnsr/extraction/candidate_extractor.py +357 -0
- rnsr/extraction/entity_extractor.py +581 -0
- rnsr/extraction/entity_linker.py +825 -0
- rnsr/extraction/grounded_extractor.py +722 -0
- rnsr/extraction/learned_types.py +599 -0
- rnsr/extraction/models.py +232 -0
- rnsr/extraction/relationship_extractor.py +600 -0
- rnsr/extraction/relationship_patterns.py +511 -0
- rnsr/extraction/relationship_validator.py +392 -0
- rnsr/extraction/rlm_extractor.py +589 -0
- rnsr/extraction/rlm_unified_extractor.py +990 -0
- rnsr/extraction/tot_validator.py +610 -0
- rnsr/extraction/unified_extractor.py +342 -0
- rnsr/indexing/__init__.py +60 -0
- rnsr/indexing/knowledge_graph.py +1128 -0
- rnsr/indexing/kv_store.py +313 -0
- rnsr/indexing/persistence.py +323 -0
- rnsr/indexing/semantic_retriever.py +237 -0
- rnsr/indexing/semantic_search.py +320 -0
- rnsr/indexing/skeleton_index.py +395 -0
- rnsr/ingestion/__init__.py +161 -0
- rnsr/ingestion/chart_parser.py +569 -0
- rnsr/ingestion/document_boundary.py +662 -0
- rnsr/ingestion/font_histogram.py +334 -0
- rnsr/ingestion/header_classifier.py +595 -0
- rnsr/ingestion/hierarchical_cluster.py +515 -0
- rnsr/ingestion/layout_detector.py +356 -0
- rnsr/ingestion/layout_model.py +379 -0
- rnsr/ingestion/ocr_fallback.py +177 -0
- rnsr/ingestion/pipeline.py +936 -0
- rnsr/ingestion/semantic_fallback.py +417 -0
- rnsr/ingestion/table_parser.py +799 -0
- rnsr/ingestion/text_builder.py +460 -0
- rnsr/ingestion/tree_builder.py +402 -0
- rnsr/ingestion/vision_retrieval.py +965 -0
- rnsr/ingestion/xy_cut.py +555 -0
- rnsr/llm.py +733 -0
- rnsr/models.py +167 -0
- rnsr/py.typed +2 -0
- rnsr-0.1.0.dist-info/METADATA +592 -0
- rnsr-0.1.0.dist-info/RECORD +72 -0
- rnsr-0.1.0.dist-info/WHEEL +5 -0
- rnsr-0.1.0.dist-info/entry_points.txt +2 -0
- rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
- rnsr-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1210 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RNSR Benchmark Suite - Comprehensive Evaluation Against Standard Baselines
|
|
3
|
+
|
|
4
|
+
This module runs RNSR against standard RAG benchmarks to validate
|
|
5
|
+
the claims in the research paper:
|
|
6
|
+
|
|
7
|
+
1. Tree traversal is more efficient than flat chunk retrieval
|
|
8
|
+
2. Hierarchical indexing preserves context better
|
|
9
|
+
3. Multi-hop reasoning benefits from structural navigation
|
|
10
|
+
4. Latent TOC extraction improves document understanding
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
python -m rnsr.benchmarks.evaluation_suite --dataset hotpotqa --samples 100
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import json
|
|
20
|
+
import re
|
|
21
|
+
import time
|
|
22
|
+
import threading
|
|
23
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
24
|
+
from dataclasses import dataclass, field, asdict
|
|
25
|
+
from datetime import datetime, timezone
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any, Literal
|
|
28
|
+
|
|
29
|
+
import structlog
|
|
30
|
+
|
|
31
|
+
from rnsr.benchmarks.standard_benchmarks import (
|
|
32
|
+
BenchmarkLoader,
|
|
33
|
+
BenchmarkDataset,
|
|
34
|
+
BenchmarkQuestion,
|
|
35
|
+
NaiveChunkRAG,
|
|
36
|
+
SemanticChunkRAG,
|
|
37
|
+
RAGASEvaluator,
|
|
38
|
+
RAGASMetrics,
|
|
39
|
+
MultiHopMetrics,
|
|
40
|
+
evaluate_multihop,
|
|
41
|
+
compare_rnsr_vs_baseline,
|
|
42
|
+
ComparisonResult,
|
|
43
|
+
)
|
|
44
|
+
from rnsr.llm import LLMProvider
|
|
45
|
+
|
|
46
|
+
logger = structlog.get_logger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# =============================================================================
|
|
50
|
+
# RNSR Wrapper for Benchmark Evaluation
|
|
51
|
+
# =============================================================================
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class RNSRResult:
|
|
55
|
+
"""Result from RNSR system."""
|
|
56
|
+
|
|
57
|
+
answer: str
|
|
58
|
+
supporting_facts: list[str]
|
|
59
|
+
nodes_visited: list[str]
|
|
60
|
+
traversal_path: list[str]
|
|
61
|
+
retrieval_time_s: float
|
|
62
|
+
generation_time_s: float
|
|
63
|
+
total_time_s: float
|
|
64
|
+
tree_depth_reached: int
|
|
65
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
66
|
+
|
|
67
|
+
# RLM-specific metrics (Section 2)
|
|
68
|
+
rlm_metrics: dict[str, Any] = field(default_factory=dict)
|
|
69
|
+
|
|
70
|
+
# Full execution trace
|
|
71
|
+
trace: list[dict[str, Any]] = field(default_factory=list)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class RLMMetrics:
|
|
76
|
+
"""
|
|
77
|
+
Metrics specific to RLM (Recursive Language Model) execution.
|
|
78
|
+
|
|
79
|
+
Tracks Section 2 implementation effectiveness:
|
|
80
|
+
- Decomposition quality (Section 2.2)
|
|
81
|
+
- Variable stitching usage (Section 2.2)
|
|
82
|
+
- Batch processing efficiency (Section 2.3)
|
|
83
|
+
- REPL interaction patterns
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# Query decomposition
|
|
87
|
+
sub_questions_generated: int = 0
|
|
88
|
+
sub_questions_answered: int = 0
|
|
89
|
+
decomposition_method: str = "none" # "llm", "pattern", "none"
|
|
90
|
+
|
|
91
|
+
# Variable stitching
|
|
92
|
+
variables_stored: int = 0
|
|
93
|
+
variables_resolved: int = 0
|
|
94
|
+
total_variable_chars: int = 0
|
|
95
|
+
|
|
96
|
+
# Recursive execution
|
|
97
|
+
sub_llm_calls: int = 0
|
|
98
|
+
batch_calls: int = 0
|
|
99
|
+
batch_efficiency: float = 0.0 # items_processed / api_calls
|
|
100
|
+
|
|
101
|
+
# REPL execution
|
|
102
|
+
repl_commands_executed: int = 0
|
|
103
|
+
repl_errors: int = 0
|
|
104
|
+
|
|
105
|
+
# Timing breakdown
|
|
106
|
+
decomposition_time_s: float = 0.0
|
|
107
|
+
navigation_time_s: float = 0.0
|
|
108
|
+
synthesis_time_s: float = 0.0
|
|
109
|
+
|
|
110
|
+
def to_dict(self) -> dict[str, Any]:
|
|
111
|
+
"""Convert to dictionary."""
|
|
112
|
+
return {
|
|
113
|
+
"sub_questions_generated": self.sub_questions_generated,
|
|
114
|
+
"sub_questions_answered": self.sub_questions_answered,
|
|
115
|
+
"decomposition_method": self.decomposition_method,
|
|
116
|
+
"variables_stored": self.variables_stored,
|
|
117
|
+
"variables_resolved": self.variables_resolved,
|
|
118
|
+
"total_variable_chars": self.total_variable_chars,
|
|
119
|
+
"sub_llm_calls": self.sub_llm_calls,
|
|
120
|
+
"batch_calls": self.batch_calls,
|
|
121
|
+
"batch_efficiency": self.batch_efficiency,
|
|
122
|
+
"repl_commands_executed": self.repl_commands_executed,
|
|
123
|
+
"repl_errors": self.repl_errors,
|
|
124
|
+
"decomposition_time_s": self.decomposition_time_s,
|
|
125
|
+
"navigation_time_s": self.navigation_time_s,
|
|
126
|
+
"synthesis_time_s": self.synthesis_time_s,
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class RNSRBenchmarkAdapter:
|
|
131
|
+
"""
|
|
132
|
+
Adapter to run RNSR on benchmark datasets.
|
|
133
|
+
|
|
134
|
+
IMPORTANT: This adapter ALWAYS uses the full RLM (Recursive Language Model)
|
|
135
|
+
pipeline. RLM is not optional - it IS RNSR. The key principles:
|
|
136
|
+
|
|
137
|
+
1. Document stored as variable (DOC_VAR), not stuffed into prompt
|
|
138
|
+
2. LLM navigates via summaries (skeleton index)
|
|
139
|
+
3. Query decomposition for complex questions
|
|
140
|
+
4. Variable stitching prevents context pollution
|
|
141
|
+
|
|
142
|
+
For benchmarks that provide raw text (not PDFs), we build an ephemeral
|
|
143
|
+
tree structure to enable full RLM processing.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
llm_provider: str = "gemini",
|
|
149
|
+
llm_model: str = "gemini-2.5-flash",
|
|
150
|
+
max_iterations: int = 20,
|
|
151
|
+
tot_selection_threshold: float = 0.4,
|
|
152
|
+
tot_dead_end_threshold: float = 0.1,
|
|
153
|
+
):
|
|
154
|
+
self.llm_provider = llm_provider
|
|
155
|
+
self.llm_model = llm_model
|
|
156
|
+
self.max_iterations = max_iterations
|
|
157
|
+
self.tot_selection_threshold = tot_selection_threshold
|
|
158
|
+
self.tot_dead_end_threshold = tot_dead_end_threshold
|
|
159
|
+
# Ingestion cache to avoid re-processing PDFs (thread-safe)
|
|
160
|
+
self._ingestion_cache: dict[Path, tuple] = {}
|
|
161
|
+
self._cache_lock = threading.Lock()
|
|
162
|
+
|
|
163
|
+
def answer_from_pdf(
|
|
164
|
+
self,
|
|
165
|
+
question: str,
|
|
166
|
+
pdf_path: Path,
|
|
167
|
+
metadata: dict | None = None,
|
|
168
|
+
) -> RNSRResult:
|
|
169
|
+
"""
|
|
170
|
+
Answer a question using RNSR's full RLM pipeline on a PDF.
|
|
171
|
+
|
|
172
|
+
Uses:
|
|
173
|
+
1. Font histogram ingestion → Document tree
|
|
174
|
+
2. Skeleton index (summaries + KV store)
|
|
175
|
+
3. Navigator agent with decomposition + variable stitching
|
|
176
|
+
"""
|
|
177
|
+
from rnsr.ingestion import ingest_document
|
|
178
|
+
from rnsr.indexing import build_skeleton_index
|
|
179
|
+
from rnsr.agent import run_navigator
|
|
180
|
+
|
|
181
|
+
start_total = time.perf_counter()
|
|
182
|
+
|
|
183
|
+
# Check cache first (thread-safe)
|
|
184
|
+
start_index = time.perf_counter()
|
|
185
|
+
with self._cache_lock:
|
|
186
|
+
cached = self._ingestion_cache.get(pdf_path)
|
|
187
|
+
|
|
188
|
+
if cached:
|
|
189
|
+
skeleton, kv_store = cached
|
|
190
|
+
index_time = time.perf_counter() - start_index
|
|
191
|
+
logger.debug("using_cached_ingestion", pdf=str(pdf_path))
|
|
192
|
+
else:
|
|
193
|
+
# Ingest and index
|
|
194
|
+
result = ingest_document(pdf_path)
|
|
195
|
+
skeleton, kv_store = build_skeleton_index(result.tree)
|
|
196
|
+
index_time = time.perf_counter() - start_index
|
|
197
|
+
# Store in cache
|
|
198
|
+
with self._cache_lock:
|
|
199
|
+
self._ingestion_cache[pdf_path] = (skeleton, kv_store)
|
|
200
|
+
logger.debug("cached_ingestion", pdf=str(pdf_path))
|
|
201
|
+
|
|
202
|
+
# Full RLM: Query with navigator (always uses decomposition + variable stitching)
|
|
203
|
+
start_query = time.perf_counter()
|
|
204
|
+
answer_result = run_navigator(
|
|
205
|
+
question=question,
|
|
206
|
+
skeleton=skeleton,
|
|
207
|
+
kv_store=kv_store,
|
|
208
|
+
max_iterations=self.max_iterations,
|
|
209
|
+
metadata=metadata,
|
|
210
|
+
tot_selection_threshold=self.tot_selection_threshold,
|
|
211
|
+
tot_dead_end_threshold=self.tot_dead_end_threshold,
|
|
212
|
+
)
|
|
213
|
+
query_time = time.perf_counter() - start_query
|
|
214
|
+
|
|
215
|
+
total_time = time.perf_counter() - start_total
|
|
216
|
+
|
|
217
|
+
# Extract supporting facts from trace
|
|
218
|
+
supporting_facts = []
|
|
219
|
+
traversal_path = []
|
|
220
|
+
max_depth = 0
|
|
221
|
+
|
|
222
|
+
for entry in answer_result.get("trace", []):
|
|
223
|
+
if entry.get("action") == "read_node":
|
|
224
|
+
node_id = entry.get("node_id", "")
|
|
225
|
+
supporting_facts.append(node_id)
|
|
226
|
+
traversal_path.append(entry.get("node_type", "unknown"))
|
|
227
|
+
max_depth = max(max_depth, entry.get("depth", 0))
|
|
228
|
+
|
|
229
|
+
return RNSRResult(
|
|
230
|
+
answer=answer_result.get("answer", ""),
|
|
231
|
+
supporting_facts=supporting_facts,
|
|
232
|
+
nodes_visited=answer_result.get("nodes_visited", []),
|
|
233
|
+
traversal_path=traversal_path,
|
|
234
|
+
retrieval_time_s=index_time,
|
|
235
|
+
generation_time_s=query_time,
|
|
236
|
+
total_time_s=total_time,
|
|
237
|
+
tree_depth_reached=max_depth,
|
|
238
|
+
metadata={
|
|
239
|
+
"confidence": answer_result.get("confidence", 0),
|
|
240
|
+
"variables_used": len(answer_result.get("variables_used", [])),
|
|
241
|
+
}
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def answer_from_context(
|
|
245
|
+
self,
|
|
246
|
+
question: str,
|
|
247
|
+
contexts: list[str],
|
|
248
|
+
metadata: dict | None = None,
|
|
249
|
+
) -> RNSRResult:
|
|
250
|
+
"""
|
|
251
|
+
Answer using pre-provided contexts (for benchmark datasets).
|
|
252
|
+
|
|
253
|
+
This uses the FULL RLM pipeline:
|
|
254
|
+
1. Build ephemeral tree from text contexts
|
|
255
|
+
2. Create skeleton index with summaries
|
|
256
|
+
3. Run navigator with decomposition + variable stitching
|
|
257
|
+
|
|
258
|
+
This is NOT traditional RAG (stuffing context into prompt).
|
|
259
|
+
The document is stored as DOC_VAR and navigated structurally.
|
|
260
|
+
"""
|
|
261
|
+
from rnsr.ingestion.text_builder import build_tree_from_contexts
|
|
262
|
+
from rnsr.indexing import build_skeleton_index
|
|
263
|
+
from rnsr.agent import run_navigator
|
|
264
|
+
|
|
265
|
+
start_total = time.perf_counter()
|
|
266
|
+
metadata = metadata or {}
|
|
267
|
+
|
|
268
|
+
# Check if this is a PDF-based benchmark (e.g. FinanceBench)
|
|
269
|
+
if metadata and "pdf_path" in metadata and metadata["pdf_path"]:
|
|
270
|
+
pdf_path_str = metadata["pdf_path"]
|
|
271
|
+
if Path(pdf_path_str).exists():
|
|
272
|
+
return self.answer_from_pdf(question, Path(pdf_path_str), metadata)
|
|
273
|
+
# Fallback if PDF missing but context provided (rare for FB)
|
|
274
|
+
|
|
275
|
+
# Step 1: Build ephemeral tree from benchmark contexts
|
|
276
|
+
start_index = time.perf_counter()
|
|
277
|
+
tree = build_tree_from_contexts(contexts, question)
|
|
278
|
+
skeleton, kv_store = build_skeleton_index(tree)
|
|
279
|
+
index_time = time.perf_counter() - start_index
|
|
280
|
+
|
|
281
|
+
# Step 2: Run full RLM navigator (decomposition + variable stitching)
|
|
282
|
+
start_query = time.perf_counter()
|
|
283
|
+
try:
|
|
284
|
+
# Use Tree of Thoughts reasoning-based navigation (research paper Section 7.2)
|
|
285
|
+
# Semantic search disabled by default per Section 9.1 (optional shortcut only)
|
|
286
|
+
answer_result = run_navigator(
|
|
287
|
+
question=question,
|
|
288
|
+
skeleton=skeleton,
|
|
289
|
+
kv_store=kv_store,
|
|
290
|
+
max_iterations=self.max_iterations,
|
|
291
|
+
use_semantic_search=False, # ToT primary, embeddings optional
|
|
292
|
+
metadata=metadata, # Pass options for multiple-choice questions
|
|
293
|
+
tot_selection_threshold=self.tot_selection_threshold,
|
|
294
|
+
tot_dead_end_threshold=self.tot_dead_end_threshold,
|
|
295
|
+
)
|
|
296
|
+
answer = answer_result.get("answer", "")
|
|
297
|
+
|
|
298
|
+
except Exception as e:
|
|
299
|
+
logger.warning("rnsr_navigation_failed", error=str(e))
|
|
300
|
+
answer = f"Error: {str(e)}"
|
|
301
|
+
answer_result = {"trace": [], "nodes_visited": [], "confidence": 0}
|
|
302
|
+
|
|
303
|
+
query_time = time.perf_counter() - start_query
|
|
304
|
+
total_time = time.perf_counter() - start_total
|
|
305
|
+
|
|
306
|
+
# Extract trace information
|
|
307
|
+
supporting_facts = []
|
|
308
|
+
traversal_path = []
|
|
309
|
+
max_depth = 0
|
|
310
|
+
|
|
311
|
+
for entry in answer_result.get("trace", []):
|
|
312
|
+
if entry.get("action") == "read_node":
|
|
313
|
+
supporting_facts.append(entry.get("node_id", ""))
|
|
314
|
+
traversal_path.append(entry.get("node_type", "unknown"))
|
|
315
|
+
max_depth = max(max_depth, entry.get("depth", 0))
|
|
316
|
+
|
|
317
|
+
return RNSRResult(
|
|
318
|
+
answer=answer,
|
|
319
|
+
supporting_facts=supporting_facts,
|
|
320
|
+
nodes_visited=answer_result.get("nodes_visited", []),
|
|
321
|
+
traversal_path=traversal_path,
|
|
322
|
+
retrieval_time_s=index_time,
|
|
323
|
+
generation_time_s=query_time,
|
|
324
|
+
total_time_s=total_time,
|
|
325
|
+
tree_depth_reached=max_depth,
|
|
326
|
+
trace=answer_result.get("trace", []),
|
|
327
|
+
metadata={
|
|
328
|
+
"num_contexts": len(contexts),
|
|
329
|
+
"confidence": answer_result.get("confidence", 0),
|
|
330
|
+
"variables_used": len(answer_result.get("variables_used", [])),
|
|
331
|
+
"rlm_mode": True, # Always true now
|
|
332
|
+
},
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
def _normalize_multiple_choice(self, answer: str, options: list[str]) -> str:
|
|
336
|
+
"""Normalize multiple choice answer to match option text."""
|
|
337
|
+
answer_lower = answer.lower().strip()
|
|
338
|
+
|
|
339
|
+
for i, opt in enumerate(options):
|
|
340
|
+
letter = chr(65 + i) # A, B, C, D
|
|
341
|
+
|
|
342
|
+
# Match patterns like "A.", "A)", "(A)", "A. answer text"
|
|
343
|
+
if (answer_lower.startswith(f"{letter.lower()}.") or
|
|
344
|
+
answer_lower.startswith(f"{letter.lower()})") or
|
|
345
|
+
answer_lower.startswith(f"({letter.lower()})")):
|
|
346
|
+
return opt
|
|
347
|
+
|
|
348
|
+
# Check if option number format
|
|
349
|
+
if answer_lower.startswith(f"{i+1}.") or answer_lower.startswith(f"({i+1})"):
|
|
350
|
+
return opt
|
|
351
|
+
|
|
352
|
+
# Exact match (case insensitive)
|
|
353
|
+
if answer_lower == opt.lower():
|
|
354
|
+
return opt
|
|
355
|
+
|
|
356
|
+
# Check if answer contains the option text
|
|
357
|
+
if opt.lower() in answer_lower:
|
|
358
|
+
return opt
|
|
359
|
+
|
|
360
|
+
return answer # Return original if no match
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
# =============================================================================
|
|
364
|
+
# Baseline RAG Implementations (for fair comparison)
|
|
365
|
+
# =============================================================================
|
|
366
|
+
|
|
367
|
+
@dataclass
|
|
368
|
+
class BaselineResult:
|
|
369
|
+
"""Result from a baseline RAG approach."""
|
|
370
|
+
answer: str
|
|
371
|
+
retrieved_chunks: list[str]
|
|
372
|
+
total_time_s: float
|
|
373
|
+
method: str
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class NaiveChunkBaseline:
|
|
377
|
+
"""
|
|
378
|
+
Naive chunking baseline using the SAME LLM as RNSR for fair comparison.
|
|
379
|
+
|
|
380
|
+
This chunks the context into fixed-size segments, retrieves top-k by
|
|
381
|
+
simple keyword overlap, and generates an answer.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
def __init__(
|
|
385
|
+
self,
|
|
386
|
+
chunk_size: int = 512,
|
|
387
|
+
chunk_overlap: int = 50,
|
|
388
|
+
top_k: int = 5,
|
|
389
|
+
llm_provider: str = "gemini",
|
|
390
|
+
llm_model: str = "gemini-2.5-flash",
|
|
391
|
+
):
|
|
392
|
+
self.chunk_size = chunk_size
|
|
393
|
+
self.chunk_overlap = chunk_overlap
|
|
394
|
+
self.top_k = top_k
|
|
395
|
+
self.llm_provider = llm_provider
|
|
396
|
+
self.llm_model = llm_model
|
|
397
|
+
|
|
398
|
+
def name(self) -> str:
|
|
399
|
+
return f"naive_chunk_{self.chunk_size}"
|
|
400
|
+
|
|
401
|
+
def answer_from_context(
|
|
402
|
+
self,
|
|
403
|
+
question: str,
|
|
404
|
+
contexts: list[str],
|
|
405
|
+
metadata: dict[str, Any] | None = None,
|
|
406
|
+
) -> BaselineResult:
|
|
407
|
+
"""Answer using naive chunking on the provided context."""
|
|
408
|
+
from rnsr.llm import get_llm
|
|
409
|
+
|
|
410
|
+
start_total = time.perf_counter()
|
|
411
|
+
|
|
412
|
+
# Combine all context
|
|
413
|
+
full_text = "\n\n".join(contexts)
|
|
414
|
+
|
|
415
|
+
# Chunk the text naively
|
|
416
|
+
chunks = []
|
|
417
|
+
for i in range(0, len(full_text), self.chunk_size - self.chunk_overlap):
|
|
418
|
+
chunk = full_text[i:i + self.chunk_size]
|
|
419
|
+
if chunk.strip():
|
|
420
|
+
chunks.append(chunk)
|
|
421
|
+
|
|
422
|
+
# Simple retrieval by keyword overlap
|
|
423
|
+
question_words = set(question.lower().split())
|
|
424
|
+
scored_chunks = []
|
|
425
|
+
for chunk in chunks:
|
|
426
|
+
chunk_words = set(chunk.lower().split())
|
|
427
|
+
score = len(question_words & chunk_words) / max(len(question_words), 1)
|
|
428
|
+
scored_chunks.append((score, chunk))
|
|
429
|
+
|
|
430
|
+
scored_chunks.sort(reverse=True, key=lambda x: x[0])
|
|
431
|
+
retrieved = [c for _, c in scored_chunks[:self.top_k]]
|
|
432
|
+
|
|
433
|
+
# Generate answer using the SAME LLM as RNSR
|
|
434
|
+
logger.info("provider_detected", provider=self.llm_provider)
|
|
435
|
+
logger.debug("initializing_llm", provider=self.llm_provider, model=self.llm_model)
|
|
436
|
+
llm = get_llm(provider=LLMProvider(self.llm_provider), model=self.llm_model)
|
|
437
|
+
|
|
438
|
+
# Check if this is multiple choice
|
|
439
|
+
options = metadata.get("options", []) if metadata else []
|
|
440
|
+
if options:
|
|
441
|
+
options_text = "\n".join([f" {i+1}. {opt}" for i, opt in enumerate(options)])
|
|
442
|
+
prompt = f"""Answer this multiple-choice question based on the provided text.
|
|
443
|
+
|
|
444
|
+
Retrieved chunks:
|
|
445
|
+
{chr(10).join(retrieved)}
|
|
446
|
+
|
|
447
|
+
Question: {question}
|
|
448
|
+
|
|
449
|
+
Options:
|
|
450
|
+
{options_text}
|
|
451
|
+
|
|
452
|
+
Respond with ONLY the text of the correct option, nothing else."""
|
|
453
|
+
else:
|
|
454
|
+
prompt = f"""Answer this question based on the provided text. Be concise.
|
|
455
|
+
|
|
456
|
+
Retrieved chunks:
|
|
457
|
+
{chr(10).join(retrieved)}
|
|
458
|
+
|
|
459
|
+
Question: {question}
|
|
460
|
+
|
|
461
|
+
Answer:"""
|
|
462
|
+
|
|
463
|
+
try:
|
|
464
|
+
response = llm.complete(prompt)
|
|
465
|
+
answer = str(response).strip()
|
|
466
|
+
|
|
467
|
+
# Match to options if needed
|
|
468
|
+
if options:
|
|
469
|
+
answer_lower = answer.lower()
|
|
470
|
+
for opt in options:
|
|
471
|
+
if opt.lower() in answer_lower:
|
|
472
|
+
answer = opt
|
|
473
|
+
break
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.warning("baseline_llm_failed", error=str(e))
|
|
476
|
+
answer = f"Error: {str(e)}"
|
|
477
|
+
|
|
478
|
+
total_time = time.perf_counter() - start_total
|
|
479
|
+
|
|
480
|
+
return BaselineResult(
|
|
481
|
+
answer=answer,
|
|
482
|
+
retrieved_chunks=retrieved,
|
|
483
|
+
total_time_s=total_time,
|
|
484
|
+
method=self.name(),
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
class SemanticChunkBaseline:
|
|
489
|
+
"""Semantic chunking baseline - splits on paragraph boundaries."""
|
|
490
|
+
|
|
491
|
+
def __init__(
|
|
492
|
+
self,
|
|
493
|
+
top_k: int = 5,
|
|
494
|
+
llm_provider: str = "gemini",
|
|
495
|
+
llm_model: str = "gemini-2.5-flash",
|
|
496
|
+
):
|
|
497
|
+
self.top_k = top_k
|
|
498
|
+
self.llm_provider = llm_provider
|
|
499
|
+
self.llm_model = llm_model
|
|
500
|
+
|
|
501
|
+
def name(self) -> str:
|
|
502
|
+
return "semantic_chunk"
|
|
503
|
+
|
|
504
|
+
def answer_from_context(
|
|
505
|
+
self,
|
|
506
|
+
question: str,
|
|
507
|
+
contexts: list[str],
|
|
508
|
+
metadata: dict[str, Any] | None = None,
|
|
509
|
+
) -> BaselineResult:
|
|
510
|
+
"""Answer using paragraph-based chunking."""
|
|
511
|
+
from rnsr.llm import get_llm
|
|
512
|
+
|
|
513
|
+
start_total = time.perf_counter()
|
|
514
|
+
|
|
515
|
+
# Split by paragraphs
|
|
516
|
+
chunks = []
|
|
517
|
+
for ctx in contexts:
|
|
518
|
+
paragraphs = ctx.split("\n\n")
|
|
519
|
+
for para in paragraphs:
|
|
520
|
+
if para.strip() and len(para.strip()) > 50:
|
|
521
|
+
chunks.append(para.strip())
|
|
522
|
+
|
|
523
|
+
# Retrieve by keyword overlap
|
|
524
|
+
question_words = set(question.lower().split())
|
|
525
|
+
scored_chunks = []
|
|
526
|
+
for chunk in chunks:
|
|
527
|
+
chunk_words = set(chunk.lower().split())
|
|
528
|
+
score = len(question_words & chunk_words) / max(len(question_words), 1)
|
|
529
|
+
scored_chunks.append((score, chunk))
|
|
530
|
+
|
|
531
|
+
scored_chunks.sort(reverse=True, key=lambda x: x[0])
|
|
532
|
+
retrieved = [c for _, c in scored_chunks[:self.top_k]]
|
|
533
|
+
|
|
534
|
+
# Generate answer
|
|
535
|
+
logger.info("provider_detected", provider=self.llm_provider)
|
|
536
|
+
logger.debug("initializing_llm", provider=self.llm_provider, model=self.llm_model)
|
|
537
|
+
llm = get_llm(provider=LLMProvider(self.llm_provider), model=self.llm_model)
|
|
538
|
+
|
|
539
|
+
options = metadata.get("options", []) if metadata else []
|
|
540
|
+
if options:
|
|
541
|
+
options_text = "\n".join([f" {i+1}. {opt}" for i, opt in enumerate(options)])
|
|
542
|
+
prompt = f"""Answer this multiple-choice question based on the provided text.
|
|
543
|
+
|
|
544
|
+
Retrieved paragraphs:
|
|
545
|
+
{chr(10).join(retrieved)}
|
|
546
|
+
|
|
547
|
+
Question: {question}
|
|
548
|
+
|
|
549
|
+
Options:
|
|
550
|
+
{options_text}
|
|
551
|
+
|
|
552
|
+
Respond with ONLY the text of the correct option, nothing else."""
|
|
553
|
+
else:
|
|
554
|
+
prompt = f"""Answer this question based on the provided text. Be concise.
|
|
555
|
+
|
|
556
|
+
Retrieved paragraphs:
|
|
557
|
+
{chr(10).join(retrieved)}
|
|
558
|
+
|
|
559
|
+
Question: {question}
|
|
560
|
+
|
|
561
|
+
Answer:"""
|
|
562
|
+
|
|
563
|
+
try:
|
|
564
|
+
response = llm.complete(prompt)
|
|
565
|
+
answer = str(response).strip()
|
|
566
|
+
|
|
567
|
+
if options:
|
|
568
|
+
answer_lower = answer.lower()
|
|
569
|
+
for opt in options:
|
|
570
|
+
if opt.lower() in answer_lower:
|
|
571
|
+
answer = opt
|
|
572
|
+
break
|
|
573
|
+
except Exception as e:
|
|
574
|
+
logger.warning("baseline_llm_failed", error=str(e))
|
|
575
|
+
answer = f"Error: {str(e)}"
|
|
576
|
+
|
|
577
|
+
total_time = time.perf_counter() - start_total
|
|
578
|
+
|
|
579
|
+
return BaselineResult(
|
|
580
|
+
answer=answer,
|
|
581
|
+
retrieved_chunks=retrieved,
|
|
582
|
+
total_time_s=total_time,
|
|
583
|
+
method=self.name(),
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
# =============================================================================
|
|
588
|
+
# Evaluation Suite
|
|
589
|
+
# =============================================================================
|
|
590
|
+
|
|
591
|
+
@dataclass
|
|
592
|
+
class EvaluationConfig:
|
|
593
|
+
"""Configuration for benchmark evaluation."""
|
|
594
|
+
|
|
595
|
+
datasets: list[str] = field(default_factory=lambda: ["hotpotqa"])
|
|
596
|
+
max_samples: int = field(default=100)
|
|
597
|
+
baselines: list[str] = field(default_factory=lambda: ["naive_chunk_512"])
|
|
598
|
+
output_dir: Path = field(default_factory=lambda: Path("benchmark_results"))
|
|
599
|
+
run_ragas: bool = field(default=True)
|
|
600
|
+
save_predictions: bool = field(default=True)
|
|
601
|
+
llm_provider: str = field(default="gemini")
|
|
602
|
+
llm_model: str = field(default="gemini-2.5-flash")
|
|
603
|
+
chaos_mode: bool = field(default=False)
|
|
604
|
+
chaos_distractors: int = field(default=3)
|
|
605
|
+
tot_selection_threshold: float = field(default=0.4)
|
|
606
|
+
tot_dead_end_threshold: float = field(default=0.1)
|
|
607
|
+
parallel_workers: int = field(default=1) # Number of parallel workers for question processing
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
@dataclass
|
|
611
|
+
class EvaluationReport:
|
|
612
|
+
"""Complete evaluation report."""
|
|
613
|
+
|
|
614
|
+
timestamp: str
|
|
615
|
+
config: dict[str, Any]
|
|
616
|
+
dataset_results: dict[str, dict[str, Any]]
|
|
617
|
+
comparisons: list[dict[str, Any]]
|
|
618
|
+
summary: dict[str, Any]
|
|
619
|
+
|
|
620
|
+
@property
|
|
621
|
+
def overall_accuracy(self) -> float:
|
|
622
|
+
"""Calculate overall accuracy across all datasets."""
|
|
623
|
+
total_acc = 0.0
|
|
624
|
+
count = 0
|
|
625
|
+
for results in self.dataset_results.values():
|
|
626
|
+
metrics = results.get("rnsr_metrics", {})
|
|
627
|
+
if "accuracy" in metrics:
|
|
628
|
+
total_acc += metrics["accuracy"]
|
|
629
|
+
count += 1
|
|
630
|
+
elif "exact_match" in metrics:
|
|
631
|
+
total_acc += metrics["exact_match"]
|
|
632
|
+
count += 1
|
|
633
|
+
return total_acc / max(count, 1)
|
|
634
|
+
|
|
635
|
+
@property
|
|
636
|
+
def avg_latency_s(self) -> float:
|
|
637
|
+
"""Calculate average latency across all datasets."""
|
|
638
|
+
total_latency = 0.0
|
|
639
|
+
count = 0
|
|
640
|
+
for results in self.dataset_results.values():
|
|
641
|
+
metrics = results.get("rnsr_metrics", {})
|
|
642
|
+
if "avg_time_s" in metrics:
|
|
643
|
+
total_latency += metrics["avg_time_s"]
|
|
644
|
+
count += 1
|
|
645
|
+
return total_latency / max(count, 1)
|
|
646
|
+
|
|
647
|
+
def save(self, path: Path) -> None:
|
|
648
|
+
"""Save report to JSON."""
|
|
649
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
650
|
+
with open(path, "w") as f:
|
|
651
|
+
json.dump(asdict(self), f, indent=2, default=str)
|
|
652
|
+
|
|
653
|
+
def print_summary(self) -> None:
|
|
654
|
+
"""Print human-readable summary."""
|
|
655
|
+
print("\n" + "=" * 70)
|
|
656
|
+
print("RNSR BENCHMARK EVALUATION REPORT")
|
|
657
|
+
print("=" * 70)
|
|
658
|
+
print(f"Timestamp: {self.timestamp}")
|
|
659
|
+
print(f"Datasets evaluated: {list(self.dataset_results.keys())}")
|
|
660
|
+
|
|
661
|
+
for dataset, results in self.dataset_results.items():
|
|
662
|
+
print(f"\n--- {dataset} ---")
|
|
663
|
+
|
|
664
|
+
rnsr_metrics = results.get("rnsr_metrics", {})
|
|
665
|
+
for metric, value in rnsr_metrics.items():
|
|
666
|
+
if isinstance(value, float):
|
|
667
|
+
print(f" RNSR {metric}: {value:.3f}")
|
|
668
|
+
else:
|
|
669
|
+
print(f" RNSR {metric}: {value}")
|
|
670
|
+
|
|
671
|
+
# Print RLM-specific metrics if available
|
|
672
|
+
rlm_metrics = results.get("rlm_metrics", {})
|
|
673
|
+
if rlm_metrics:
|
|
674
|
+
print("\n 📊 RLM Metrics:")
|
|
675
|
+
print(f" Sub-questions generated: {rlm_metrics.get('total_sub_questions', 0)}")
|
|
676
|
+
print(f" Variables stored: {rlm_metrics.get('total_variables', 0)}")
|
|
677
|
+
print(f" Batch calls: {rlm_metrics.get('total_batch_calls', 0)}")
|
|
678
|
+
print(f" Avg decomposition time: {rlm_metrics.get('avg_decomposition_time', 0):.3f}s")
|
|
679
|
+
print(f" Avg sub-task time: {rlm_metrics.get('avg_sub_task_time', 0):.3f}s")
|
|
680
|
+
|
|
681
|
+
for baseline, metrics in results.get("baseline_metrics", {}).items():
|
|
682
|
+
print(f"\n {baseline}:")
|
|
683
|
+
for metric, value in metrics.items():
|
|
684
|
+
if isinstance(value, float):
|
|
685
|
+
print(f" {metric}: {value:.3f}")
|
|
686
|
+
else:
|
|
687
|
+
print(f" {metric}: {value}")
|
|
688
|
+
|
|
689
|
+
print("\n" + "-" * 70)
|
|
690
|
+
print("IMPROVEMENTS OVER BASELINES")
|
|
691
|
+
print("-" * 70)
|
|
692
|
+
|
|
693
|
+
for comp in self.comparisons:
|
|
694
|
+
print(f"\nvs {comp['baseline_name']} on {comp['dataset_name']}:")
|
|
695
|
+
for metric, delta in comp.get("improvement", {}).items():
|
|
696
|
+
rel = comp.get("relative_improvement", {}).get(metric, 0) * 100
|
|
697
|
+
print(f" {metric}: {delta:+.3f} ({rel:+.1f}%)")
|
|
698
|
+
|
|
699
|
+
print("\n" + "=" * 70)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class EvaluationSuite:
|
|
703
|
+
"""
|
|
704
|
+
Run comprehensive RNSR evaluation against standard benchmarks.
|
|
705
|
+
"""
|
|
706
|
+
|
|
707
|
+
def __init__(self, config: EvaluationConfig):
|
|
708
|
+
self.config = config
|
|
709
|
+
self.rnsr = RNSRBenchmarkAdapter(
|
|
710
|
+
llm_provider=config.llm_provider,
|
|
711
|
+
llm_model=config.llm_model,
|
|
712
|
+
tot_selection_threshold=config.tot_selection_threshold,
|
|
713
|
+
tot_dead_end_threshold=config.tot_dead_end_threshold,
|
|
714
|
+
)
|
|
715
|
+
self.ragas_evaluator = RAGASEvaluator(
|
|
716
|
+
llm_provider=config.llm_provider,
|
|
717
|
+
llm_model=config.llm_model,
|
|
718
|
+
) if config.run_ragas else None
|
|
719
|
+
|
|
720
|
+
# Initialize baselines
|
|
721
|
+
self.baselines = {}
|
|
722
|
+
for baseline_name in config.baselines:
|
|
723
|
+
if baseline_name.startswith("naive_chunk"):
|
|
724
|
+
chunk_size = int(baseline_name.split("_")[-1])
|
|
725
|
+
self.baselines[baseline_name] = NaiveChunkBaseline(
|
|
726
|
+
chunk_size=chunk_size,
|
|
727
|
+
llm_provider=config.llm_provider,
|
|
728
|
+
llm_model=config.llm_model,
|
|
729
|
+
)
|
|
730
|
+
elif baseline_name == "semantic_chunk":
|
|
731
|
+
self.baselines[baseline_name] = SemanticChunkBaseline(
|
|
732
|
+
llm_provider=config.llm_provider,
|
|
733
|
+
llm_model=config.llm_model,
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
def load_dataset(self, name: str) -> BenchmarkDataset:
|
|
737
|
+
"""Load a benchmark dataset by name."""
|
|
738
|
+
if name == "hotpotqa":
|
|
739
|
+
return BenchmarkLoader.load_hotpotqa(max_samples=self.config.max_samples)
|
|
740
|
+
elif name.startswith("musique"):
|
|
741
|
+
variant = "full" if "full" in name else "ans"
|
|
742
|
+
return BenchmarkLoader.load_musique(variant=variant, max_samples=self.config.max_samples)
|
|
743
|
+
elif name.startswith("beir_"):
|
|
744
|
+
dataset_name = name.replace("beir_", "")
|
|
745
|
+
return BenchmarkLoader.load_beir_dataset(dataset_name, max_samples=self.config.max_samples)
|
|
746
|
+
elif name == "qasper":
|
|
747
|
+
return BenchmarkLoader.load_qasper(max_samples=self.config.max_samples)
|
|
748
|
+
elif name == "quality":
|
|
749
|
+
return BenchmarkLoader.load_quality(max_samples=self.config.max_samples)
|
|
750
|
+
elif name == "financebench":
|
|
751
|
+
return BenchmarkLoader.load_financebench(split="train", max_samples=self.config.max_samples)
|
|
752
|
+
elif name == "narrativeqa":
|
|
753
|
+
return BenchmarkLoader.load_narrative_qa(max_samples=self.config.max_samples)
|
|
754
|
+
else:
|
|
755
|
+
raise ValueError(f"Unknown dataset: {name}. Available: hotpotqa, musique_ans, musique_full, qasper, quality, narrativeqa, beir_*")
|
|
756
|
+
|
|
757
|
+
def evaluate_rnsr_on_dataset(
|
|
758
|
+
self,
|
|
759
|
+
dataset: BenchmarkDataset,
|
|
760
|
+
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
761
|
+
"""
|
|
762
|
+
Evaluate RNSR on a benchmark dataset.
|
|
763
|
+
|
|
764
|
+
Returns:
|
|
765
|
+
predictions: List of prediction dicts
|
|
766
|
+
metrics: Aggregated metrics dict (includes RLM metrics if enabled)
|
|
767
|
+
"""
|
|
768
|
+
predictions = []
|
|
769
|
+
|
|
770
|
+
logger.info(
|
|
771
|
+
"evaluating_rnsr",
|
|
772
|
+
dataset=dataset.name,
|
|
773
|
+
questions=len(dataset.questions),
|
|
774
|
+
workers=self.config.parallel_workers,
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
def process_question(idx_question: tuple[int, "BenchmarkQuestion"]) -> dict[str, Any]:
|
|
778
|
+
"""Process a single question (helper for parallel execution)."""
|
|
779
|
+
i, question = idx_question
|
|
780
|
+
logger.debug("processing_question", index=i, question=question.question[:50])
|
|
781
|
+
|
|
782
|
+
try:
|
|
783
|
+
result = self.rnsr.answer_from_context(
|
|
784
|
+
question=question.question,
|
|
785
|
+
contexts=question.context,
|
|
786
|
+
metadata=question.metadata,
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
pred_entry = {
|
|
790
|
+
"id": question.id,
|
|
791
|
+
"question": question.question,
|
|
792
|
+
"answer": result.answer,
|
|
793
|
+
"supporting_facts": result.supporting_facts,
|
|
794
|
+
"nodes_visited": result.nodes_visited,
|
|
795
|
+
"time_s": result.total_time_s,
|
|
796
|
+
"tree_depth": result.tree_depth_reached,
|
|
797
|
+
"trace": result.trace,
|
|
798
|
+
"metadata": result.metadata,
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
if result.metadata and result.metadata.get("rlm_mode"):
|
|
802
|
+
pred_entry["rlm_metrics"] = result.metadata.get("rlm_metrics", {})
|
|
803
|
+
|
|
804
|
+
return pred_entry
|
|
805
|
+
|
|
806
|
+
except Exception as e:
|
|
807
|
+
logger.error("question_failed", error=str(e), question_id=question.id)
|
|
808
|
+
return {
|
|
809
|
+
"id": question.id,
|
|
810
|
+
"question": question.question,
|
|
811
|
+
"answer": "",
|
|
812
|
+
"supporting_facts": [],
|
|
813
|
+
"error": str(e),
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
# Parallel or sequential execution based on config
|
|
817
|
+
if self.config.parallel_workers > 1:
|
|
818
|
+
# Parallel execution with ThreadPoolExecutor
|
|
819
|
+
with ThreadPoolExecutor(max_workers=self.config.parallel_workers) as executor:
|
|
820
|
+
futures = {
|
|
821
|
+
executor.submit(process_question, (i, q)): i
|
|
822
|
+
for i, q in enumerate(dataset.questions)
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
completed = 0
|
|
826
|
+
for future in as_completed(futures):
|
|
827
|
+
pred = future.result()
|
|
828
|
+
predictions.append(pred)
|
|
829
|
+
completed += 1
|
|
830
|
+
if completed % 10 == 0:
|
|
831
|
+
logger.info("progress", completed=completed, total=len(dataset.questions))
|
|
832
|
+
|
|
833
|
+
# Sort by original order (id may not be sequential)
|
|
834
|
+
predictions.sort(key=lambda p: str(p.get("id", "")))
|
|
835
|
+
else:
|
|
836
|
+
# Sequential execution (original behavior)
|
|
837
|
+
for i, question in enumerate(dataset.questions):
|
|
838
|
+
pred = process_question((i, question))
|
|
839
|
+
predictions.append(pred)
|
|
840
|
+
|
|
841
|
+
# Compute metrics
|
|
842
|
+
metrics: dict[str, Any] = {}
|
|
843
|
+
if "answer_f1" in dataset.metrics or "answer_em" in dataset.metrics or "accuracy" in dataset.metrics:
|
|
844
|
+
multi_hop = evaluate_multihop(predictions, dataset.questions)
|
|
845
|
+
metrics = multi_hop.to_dict()
|
|
846
|
+
else:
|
|
847
|
+
# Retrieval metrics (for BEIR)
|
|
848
|
+
metrics = self._compute_retrieval_metrics(predictions, dataset.questions)
|
|
849
|
+
|
|
850
|
+
# Add timing metrics
|
|
851
|
+
times = [p.get("time_s", 0) for p in predictions if "time_s" in p]
|
|
852
|
+
if times:
|
|
853
|
+
metrics["mean_time_s"] = sum(times) / len(times)
|
|
854
|
+
metrics["total_time_s"] = sum(times)
|
|
855
|
+
|
|
856
|
+
# Aggregate RLM metrics if in RLM mode
|
|
857
|
+
rlm_preds = [p for p in predictions if p.get("rlm_metrics")]
|
|
858
|
+
if rlm_preds:
|
|
859
|
+
rlm_agg = {
|
|
860
|
+
"total_sub_questions": sum(p["rlm_metrics"].get("sub_questions_generated", 0) for p in rlm_preds),
|
|
861
|
+
"total_variables": sum(p["rlm_metrics"].get("variables_stored", 0) for p in rlm_preds),
|
|
862
|
+
"total_batch_calls": sum(p["rlm_metrics"].get("batch_calls_made", 0) for p in rlm_preds),
|
|
863
|
+
"total_llm_calls": sum(p["rlm_metrics"].get("total_llm_calls", 0) for p in rlm_preds),
|
|
864
|
+
"avg_decomposition_time": sum(p["rlm_metrics"].get("decomposition_time_s", 0) for p in rlm_preds) / len(rlm_preds),
|
|
865
|
+
"avg_sub_task_time": sum(p["rlm_metrics"].get("sub_task_time_s", 0) for p in rlm_preds) / len(rlm_preds),
|
|
866
|
+
"avg_stitching_time": sum(p["rlm_metrics"].get("stitching_time_s", 0) for p in rlm_preds) / len(rlm_preds),
|
|
867
|
+
}
|
|
868
|
+
metrics["rlm_metrics"] = rlm_agg
|
|
869
|
+
|
|
870
|
+
return predictions, metrics
|
|
871
|
+
|
|
872
|
+
def _compute_retrieval_metrics(
|
|
873
|
+
self,
|
|
874
|
+
predictions: list[dict],
|
|
875
|
+
questions: list[BenchmarkQuestion],
|
|
876
|
+
) -> dict[str, float]:
|
|
877
|
+
"""Compute retrieval metrics (precision, recall, etc.)."""
|
|
878
|
+
# Simplified retrieval metrics
|
|
879
|
+
precisions = []
|
|
880
|
+
recalls = []
|
|
881
|
+
|
|
882
|
+
for pred, q in zip(predictions, questions):
|
|
883
|
+
retrieved = set(pred.get("nodes_visited", []))
|
|
884
|
+
relevant = set(q.supporting_facts) if q.supporting_facts else set()
|
|
885
|
+
|
|
886
|
+
if retrieved and relevant:
|
|
887
|
+
prec = len(retrieved & relevant) / len(retrieved)
|
|
888
|
+
rec = len(retrieved & relevant) / len(relevant)
|
|
889
|
+
precisions.append(prec)
|
|
890
|
+
recalls.append(rec)
|
|
891
|
+
|
|
892
|
+
n = len(precisions) or 1
|
|
893
|
+
precision = sum(precisions) / n
|
|
894
|
+
recall = sum(recalls) / n
|
|
895
|
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
896
|
+
|
|
897
|
+
return {
|
|
898
|
+
"precision": precision,
|
|
899
|
+
"recall": recall,
|
|
900
|
+
"f1": f1,
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
def evaluate_baseline_on_dataset(
|
|
904
|
+
self,
|
|
905
|
+
baseline: NaiveChunkBaseline | SemanticChunkBaseline,
|
|
906
|
+
dataset: BenchmarkDataset,
|
|
907
|
+
) -> tuple[list[dict[str, Any]], dict[str, float]]:
|
|
908
|
+
"""
|
|
909
|
+
Evaluate a baseline RAG approach on a benchmark dataset.
|
|
910
|
+
|
|
911
|
+
Uses the same LLM as RNSR for fair comparison.
|
|
912
|
+
"""
|
|
913
|
+
predictions = []
|
|
914
|
+
|
|
915
|
+
logger.info(
|
|
916
|
+
"evaluating_baseline_on_dataset",
|
|
917
|
+
baseline=baseline.name(),
|
|
918
|
+
dataset=dataset.name,
|
|
919
|
+
questions=len(dataset.questions),
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
for i, question in enumerate(dataset.questions):
|
|
923
|
+
logger.debug(
|
|
924
|
+
"baseline_processing_question",
|
|
925
|
+
baseline=baseline.name(),
|
|
926
|
+
index=i,
|
|
927
|
+
question=question.question[:50],
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
try:
|
|
931
|
+
result = baseline.answer_from_context(
|
|
932
|
+
question=question.question,
|
|
933
|
+
contexts=question.context,
|
|
934
|
+
metadata=question.metadata,
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
predictions.append({
|
|
938
|
+
"id": question.id,
|
|
939
|
+
"question": question.question,
|
|
940
|
+
"answer": result.answer,
|
|
941
|
+
"time_s": result.total_time_s,
|
|
942
|
+
"method": result.method,
|
|
943
|
+
})
|
|
944
|
+
|
|
945
|
+
except Exception as e:
|
|
946
|
+
logger.error(
|
|
947
|
+
"baseline_question_failed",
|
|
948
|
+
error=str(e),
|
|
949
|
+
question_id=question.id,
|
|
950
|
+
baseline=baseline.name(),
|
|
951
|
+
)
|
|
952
|
+
predictions.append({
|
|
953
|
+
"id": question.id,
|
|
954
|
+
"question": question.question,
|
|
955
|
+
"answer": "",
|
|
956
|
+
"error": str(e),
|
|
957
|
+
"method": baseline.name(),
|
|
958
|
+
})
|
|
959
|
+
|
|
960
|
+
# Compute metrics (same as RNSR)
|
|
961
|
+
if "answer_f1" in dataset.metrics or "answer_em" in dataset.metrics or "accuracy" in dataset.metrics:
|
|
962
|
+
multi_hop = evaluate_multihop(predictions, dataset.questions)
|
|
963
|
+
metrics = multi_hop.to_dict()
|
|
964
|
+
else:
|
|
965
|
+
metrics = {"f1": 0.0} # Retrieval metrics not applicable
|
|
966
|
+
|
|
967
|
+
# Add timing metrics
|
|
968
|
+
times = [p.get("time_s", 0) for p in predictions if "time_s" in p]
|
|
969
|
+
if times:
|
|
970
|
+
metrics["mean_time_s"] = sum(times) / len(times)
|
|
971
|
+
metrics["total_time_s"] = sum(times)
|
|
972
|
+
|
|
973
|
+
return predictions, metrics
|
|
974
|
+
|
|
975
|
+
def run(self) -> EvaluationReport:
|
|
976
|
+
"""Run full evaluation suite."""
|
|
977
|
+
logger.info("starting_evaluation_suite", datasets=self.config.datasets)
|
|
978
|
+
|
|
979
|
+
dataset_results = {}
|
|
980
|
+
all_comparisons = []
|
|
981
|
+
|
|
982
|
+
for dataset_name in self.config.datasets:
|
|
983
|
+
logger.info("loading_dataset", name=dataset_name)
|
|
984
|
+
dataset = self.load_dataset(dataset_name)
|
|
985
|
+
|
|
986
|
+
if not dataset.questions:
|
|
987
|
+
logger.warning("empty_dataset", name=dataset_name)
|
|
988
|
+
continue
|
|
989
|
+
|
|
990
|
+
# Application of Chaos Mode
|
|
991
|
+
if self.config.chaos_mode and "financebench" in dataset_name.lower():
|
|
992
|
+
from rnsr.benchmarks.pdf_merger import PDFMerger
|
|
993
|
+
logger.info("applying_chaos_mode", dataset=dataset_name)
|
|
994
|
+
|
|
995
|
+
# Collect all available PDFs to use as distractors
|
|
996
|
+
all_pdfs = set()
|
|
997
|
+
for q in dataset.questions:
|
|
998
|
+
if q.metadata and q.metadata.get("pdf_path"):
|
|
999
|
+
p = Path(q.metadata["pdf_path"])
|
|
1000
|
+
if p.exists():
|
|
1001
|
+
all_pdfs.add(p)
|
|
1002
|
+
|
|
1003
|
+
pool_of_pdfs = list(all_pdfs)
|
|
1004
|
+
if len(pool_of_pdfs) < 2:
|
|
1005
|
+
logger.warning("not_enough_pdfs_for_chaos", count=len(pool_of_pdfs))
|
|
1006
|
+
else:
|
|
1007
|
+
chaos_dir = self.config.output_dir / "chaos_data"
|
|
1008
|
+
dataset.questions = PDFMerger.create_chaos_dataset(
|
|
1009
|
+
dataset.questions,
|
|
1010
|
+
pool_of_pdfs,
|
|
1011
|
+
chaos_dir,
|
|
1012
|
+
num_distractors=self.config.chaos_distractors
|
|
1013
|
+
)
|
|
1014
|
+
dataset.name = f"{dataset.name}-CHAOS"
|
|
1015
|
+
logger.info("chaos_mode_applied", new_size=len(dataset.questions))
|
|
1016
|
+
|
|
1017
|
+
# Evaluate RNSR
|
|
1018
|
+
predictions, rnsr_metrics = self.evaluate_rnsr_on_dataset(dataset)
|
|
1019
|
+
|
|
1020
|
+
# Evaluate baselines using the same LLM for fair comparison
|
|
1021
|
+
baseline_metrics = {}
|
|
1022
|
+
for baseline_name, baseline in self.baselines.items():
|
|
1023
|
+
logger.info("evaluating_baseline", baseline=baseline_name, dataset=dataset_name)
|
|
1024
|
+
baseline_preds, base_metrics = self.evaluate_baseline_on_dataset(
|
|
1025
|
+
baseline, dataset
|
|
1026
|
+
)
|
|
1027
|
+
baseline_metrics[baseline_name] = base_metrics
|
|
1028
|
+
|
|
1029
|
+
dataset_results[dataset_name] = {
|
|
1030
|
+
"rnsr_metrics": rnsr_metrics,
|
|
1031
|
+
"baseline_metrics": baseline_metrics,
|
|
1032
|
+
"num_questions": len(dataset.questions),
|
|
1033
|
+
"predictions": predictions if self.config.save_predictions else [],
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
# Generate comparisons
|
|
1037
|
+
for baseline_name, base_metrics in baseline_metrics.items():
|
|
1038
|
+
comparison = compare_rnsr_vs_baseline(
|
|
1039
|
+
rnsr_metrics,
|
|
1040
|
+
base_metrics,
|
|
1041
|
+
dataset_name,
|
|
1042
|
+
baseline_name,
|
|
1043
|
+
)
|
|
1044
|
+
all_comparisons.append(asdict(comparison))
|
|
1045
|
+
|
|
1046
|
+
# Build summary
|
|
1047
|
+
summary = self._build_summary(dataset_results, all_comparisons)
|
|
1048
|
+
|
|
1049
|
+
report = EvaluationReport(
|
|
1050
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
1051
|
+
config=asdict(self.config),
|
|
1052
|
+
dataset_results=dataset_results,
|
|
1053
|
+
comparisons=all_comparisons,
|
|
1054
|
+
summary=summary,
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
# Save report
|
|
1058
|
+
report_path = self.config.output_dir / f"eval_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
1059
|
+
report.save(report_path)
|
|
1060
|
+
logger.info("report_saved", path=str(report_path))
|
|
1061
|
+
|
|
1062
|
+
return report
|
|
1063
|
+
|
|
1064
|
+
def _build_summary(
|
|
1065
|
+
self,
|
|
1066
|
+
dataset_results: dict,
|
|
1067
|
+
comparisons: list[dict],
|
|
1068
|
+
) -> dict[str, Any]:
|
|
1069
|
+
"""Build summary statistics."""
|
|
1070
|
+
# Aggregate metrics across datasets
|
|
1071
|
+
all_f1s = []
|
|
1072
|
+
all_times = []
|
|
1073
|
+
|
|
1074
|
+
for results in dataset_results.values():
|
|
1075
|
+
rnsr = results.get("rnsr_metrics", {})
|
|
1076
|
+
if "answer_f1" in rnsr:
|
|
1077
|
+
all_f1s.append(rnsr["answer_f1"])
|
|
1078
|
+
if "mean_time_s" in rnsr:
|
|
1079
|
+
all_times.append(rnsr["mean_time_s"])
|
|
1080
|
+
|
|
1081
|
+
# Average improvements over baselines
|
|
1082
|
+
avg_improvements = {}
|
|
1083
|
+
for comp in comparisons:
|
|
1084
|
+
for metric, delta in comp.get("improvement", {}).items():
|
|
1085
|
+
if metric not in avg_improvements:
|
|
1086
|
+
avg_improvements[metric] = []
|
|
1087
|
+
avg_improvements[metric].append(delta)
|
|
1088
|
+
|
|
1089
|
+
for metric in avg_improvements:
|
|
1090
|
+
vals = avg_improvements[metric]
|
|
1091
|
+
avg_improvements[metric] = sum(vals) / len(vals) if vals else 0
|
|
1092
|
+
|
|
1093
|
+
return {
|
|
1094
|
+
"datasets_evaluated": len(dataset_results),
|
|
1095
|
+
"total_questions": sum(r.get("num_questions", 0) for r in dataset_results.values()),
|
|
1096
|
+
"mean_answer_f1": sum(all_f1s) / len(all_f1s) if all_f1s else 0,
|
|
1097
|
+
"mean_time_s": sum(all_times) / len(all_times) if all_times else 0,
|
|
1098
|
+
"avg_improvement_over_baselines": avg_improvements,
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
|
|
1102
|
+
# =============================================================================
|
|
1103
|
+
# CLI
|
|
1104
|
+
# =============================================================================
|
|
1105
|
+
|
|
1106
|
+
def main():
|
|
1107
|
+
parser = argparse.ArgumentParser(
|
|
1108
|
+
description="Run RNSR against standard RAG benchmarks"
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
parser.add_argument(
|
|
1112
|
+
"--datasets", "-d",
|
|
1113
|
+
nargs="+",
|
|
1114
|
+
default=["hotpotqa"],
|
|
1115
|
+
help="Datasets to evaluate (hotpotqa, musique_ans, beir_nfcorpus, etc.)"
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
parser.add_argument(
|
|
1119
|
+
"--samples", "-n",
|
|
1120
|
+
type=int,
|
|
1121
|
+
default=100,
|
|
1122
|
+
help="Max samples per dataset"
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
parser.add_argument(
|
|
1126
|
+
"--baselines", "-b",
|
|
1127
|
+
nargs="+",
|
|
1128
|
+
default=["naive_chunk_512"],
|
|
1129
|
+
help="Baselines to compare against"
|
|
1130
|
+
)
|
|
1131
|
+
|
|
1132
|
+
parser.add_argument(
|
|
1133
|
+
"--output", "-o",
|
|
1134
|
+
type=Path,
|
|
1135
|
+
default=Path("benchmark_results"),
|
|
1136
|
+
help="Output directory"
|
|
1137
|
+
)
|
|
1138
|
+
|
|
1139
|
+
parser.add_argument(
|
|
1140
|
+
"--no-ragas",
|
|
1141
|
+
action="store_true",
|
|
1142
|
+
help="Skip RAGAS evaluation"
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
parser.add_argument(
|
|
1146
|
+
"--llm-provider", "-p",
|
|
1147
|
+
type=str,
|
|
1148
|
+
default="gemini",
|
|
1149
|
+
choices=["openai", "anthropic", "gemini"],
|
|
1150
|
+
help="LLM provider to use (default: gemini)"
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
parser.add_argument(
|
|
1154
|
+
"--llm-model", "-m",
|
|
1155
|
+
type=str,
|
|
1156
|
+
default="gemini-2.5-flash",
|
|
1157
|
+
help="LLM model name (default: gemini-2.5-flash)"
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
parser.add_argument(
|
|
1161
|
+
"--chaos",
|
|
1162
|
+
action="store_true",
|
|
1163
|
+
help="Enable chaos mode (merge PDFs with random distractors)"
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
parser.add_argument(
|
|
1167
|
+
"--tot-threshold",
|
|
1168
|
+
type=float,
|
|
1169
|
+
default=0.4,
|
|
1170
|
+
help="ToT selection threshold (default: 0.4)"
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
parser.add_argument(
|
|
1174
|
+
"--tot-dead-end",
|
|
1175
|
+
type=float,
|
|
1176
|
+
default=0.1,
|
|
1177
|
+
help="ToT dead end threshold (default: 0.1)"
|
|
1178
|
+
)
|
|
1179
|
+
|
|
1180
|
+
parser.add_argument(
|
|
1181
|
+
"--workers", "-w",
|
|
1182
|
+
type=int,
|
|
1183
|
+
default=1,
|
|
1184
|
+
help="Number of parallel workers for processing questions (default: 1, sequential)"
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
args = parser.parse_args()
|
|
1188
|
+
|
|
1189
|
+
# RNSR always uses the full RLM flow - no mode switching needed
|
|
1190
|
+
config = EvaluationConfig(
|
|
1191
|
+
datasets=args.datasets,
|
|
1192
|
+
max_samples=args.samples,
|
|
1193
|
+
baselines=args.baselines,
|
|
1194
|
+
output_dir=args.output,
|
|
1195
|
+
run_ragas=not args.no_ragas,
|
|
1196
|
+
llm_provider=args.llm_provider,
|
|
1197
|
+
llm_model=args.llm_model,
|
|
1198
|
+
chaos_mode=args.chaos,
|
|
1199
|
+
tot_selection_threshold=args.tot_threshold,
|
|
1200
|
+
tot_dead_end_threshold=args.tot_dead_end,
|
|
1201
|
+
parallel_workers=args.workers,
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
suite = EvaluationSuite(config)
|
|
1205
|
+
report = suite.run()
|
|
1206
|
+
report.print_summary()
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
if __name__ == "__main__":
|
|
1210
|
+
main()
|