rnsr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rnsr/__init__.py +118 -0
  2. rnsr/__main__.py +242 -0
  3. rnsr/agent/__init__.py +218 -0
  4. rnsr/agent/cross_doc_navigator.py +767 -0
  5. rnsr/agent/graph.py +1557 -0
  6. rnsr/agent/llm_cache.py +575 -0
  7. rnsr/agent/navigator_api.py +497 -0
  8. rnsr/agent/provenance.py +772 -0
  9. rnsr/agent/query_clarifier.py +617 -0
  10. rnsr/agent/reasoning_memory.py +736 -0
  11. rnsr/agent/repl_env.py +709 -0
  12. rnsr/agent/rlm_navigator.py +2108 -0
  13. rnsr/agent/self_reflection.py +602 -0
  14. rnsr/agent/variable_store.py +308 -0
  15. rnsr/benchmarks/__init__.py +118 -0
  16. rnsr/benchmarks/comprehensive_benchmark.py +733 -0
  17. rnsr/benchmarks/evaluation_suite.py +1210 -0
  18. rnsr/benchmarks/finance_bench.py +147 -0
  19. rnsr/benchmarks/pdf_merger.py +178 -0
  20. rnsr/benchmarks/performance.py +321 -0
  21. rnsr/benchmarks/quality.py +321 -0
  22. rnsr/benchmarks/runner.py +298 -0
  23. rnsr/benchmarks/standard_benchmarks.py +995 -0
  24. rnsr/client.py +560 -0
  25. rnsr/document_store.py +394 -0
  26. rnsr/exceptions.py +74 -0
  27. rnsr/extraction/__init__.py +172 -0
  28. rnsr/extraction/candidate_extractor.py +357 -0
  29. rnsr/extraction/entity_extractor.py +581 -0
  30. rnsr/extraction/entity_linker.py +825 -0
  31. rnsr/extraction/grounded_extractor.py +722 -0
  32. rnsr/extraction/learned_types.py +599 -0
  33. rnsr/extraction/models.py +232 -0
  34. rnsr/extraction/relationship_extractor.py +600 -0
  35. rnsr/extraction/relationship_patterns.py +511 -0
  36. rnsr/extraction/relationship_validator.py +392 -0
  37. rnsr/extraction/rlm_extractor.py +589 -0
  38. rnsr/extraction/rlm_unified_extractor.py +990 -0
  39. rnsr/extraction/tot_validator.py +610 -0
  40. rnsr/extraction/unified_extractor.py +342 -0
  41. rnsr/indexing/__init__.py +60 -0
  42. rnsr/indexing/knowledge_graph.py +1128 -0
  43. rnsr/indexing/kv_store.py +313 -0
  44. rnsr/indexing/persistence.py +323 -0
  45. rnsr/indexing/semantic_retriever.py +237 -0
  46. rnsr/indexing/semantic_search.py +320 -0
  47. rnsr/indexing/skeleton_index.py +395 -0
  48. rnsr/ingestion/__init__.py +161 -0
  49. rnsr/ingestion/chart_parser.py +569 -0
  50. rnsr/ingestion/document_boundary.py +662 -0
  51. rnsr/ingestion/font_histogram.py +334 -0
  52. rnsr/ingestion/header_classifier.py +595 -0
  53. rnsr/ingestion/hierarchical_cluster.py +515 -0
  54. rnsr/ingestion/layout_detector.py +356 -0
  55. rnsr/ingestion/layout_model.py +379 -0
  56. rnsr/ingestion/ocr_fallback.py +177 -0
  57. rnsr/ingestion/pipeline.py +936 -0
  58. rnsr/ingestion/semantic_fallback.py +417 -0
  59. rnsr/ingestion/table_parser.py +799 -0
  60. rnsr/ingestion/text_builder.py +460 -0
  61. rnsr/ingestion/tree_builder.py +402 -0
  62. rnsr/ingestion/vision_retrieval.py +965 -0
  63. rnsr/ingestion/xy_cut.py +555 -0
  64. rnsr/llm.py +733 -0
  65. rnsr/models.py +167 -0
  66. rnsr/py.typed +2 -0
  67. rnsr-0.1.0.dist-info/METADATA +592 -0
  68. rnsr-0.1.0.dist-info/RECORD +72 -0
  69. rnsr-0.1.0.dist-info/WHEEL +5 -0
  70. rnsr-0.1.0.dist-info/entry_points.txt +2 -0
  71. rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
  72. rnsr-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1210 @@
1
+ """
2
+ RNSR Benchmark Suite - Comprehensive Evaluation Against Standard Baselines
3
+
4
+ This module runs RNSR against standard RAG benchmarks to validate
5
+ the claims in the research paper:
6
+
7
+ 1. Tree traversal is more efficient than flat chunk retrieval
8
+ 2. Hierarchical indexing preserves context better
9
+ 3. Multi-hop reasoning benefits from structural navigation
10
+ 4. Latent TOC extraction improves document understanding
11
+
12
+ Usage:
13
+ python -m rnsr.benchmarks.evaluation_suite --dataset hotpotqa --samples 100
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import re
21
+ import time
22
+ import threading
23
+ from concurrent.futures import ThreadPoolExecutor, as_completed
24
+ from dataclasses import dataclass, field, asdict
25
+ from datetime import datetime, timezone
26
+ from pathlib import Path
27
+ from typing import Any, Literal
28
+
29
+ import structlog
30
+
31
+ from rnsr.benchmarks.standard_benchmarks import (
32
+ BenchmarkLoader,
33
+ BenchmarkDataset,
34
+ BenchmarkQuestion,
35
+ NaiveChunkRAG,
36
+ SemanticChunkRAG,
37
+ RAGASEvaluator,
38
+ RAGASMetrics,
39
+ MultiHopMetrics,
40
+ evaluate_multihop,
41
+ compare_rnsr_vs_baseline,
42
+ ComparisonResult,
43
+ )
44
+ from rnsr.llm import LLMProvider
45
+
46
+ logger = structlog.get_logger(__name__)
47
+
48
+
49
+ # =============================================================================
50
+ # RNSR Wrapper for Benchmark Evaluation
51
+ # =============================================================================
52
+
53
+ @dataclass
54
+ class RNSRResult:
55
+ """Result from RNSR system."""
56
+
57
+ answer: str
58
+ supporting_facts: list[str]
59
+ nodes_visited: list[str]
60
+ traversal_path: list[str]
61
+ retrieval_time_s: float
62
+ generation_time_s: float
63
+ total_time_s: float
64
+ tree_depth_reached: int
65
+ metadata: dict[str, Any] = field(default_factory=dict)
66
+
67
+ # RLM-specific metrics (Section 2)
68
+ rlm_metrics: dict[str, Any] = field(default_factory=dict)
69
+
70
+ # Full execution trace
71
+ trace: list[dict[str, Any]] = field(default_factory=list)
72
+
73
+
74
+ @dataclass
75
+ class RLMMetrics:
76
+ """
77
+ Metrics specific to RLM (Recursive Language Model) execution.
78
+
79
+ Tracks Section 2 implementation effectiveness:
80
+ - Decomposition quality (Section 2.2)
81
+ - Variable stitching usage (Section 2.2)
82
+ - Batch processing efficiency (Section 2.3)
83
+ - REPL interaction patterns
84
+ """
85
+
86
+ # Query decomposition
87
+ sub_questions_generated: int = 0
88
+ sub_questions_answered: int = 0
89
+ decomposition_method: str = "none" # "llm", "pattern", "none"
90
+
91
+ # Variable stitching
92
+ variables_stored: int = 0
93
+ variables_resolved: int = 0
94
+ total_variable_chars: int = 0
95
+
96
+ # Recursive execution
97
+ sub_llm_calls: int = 0
98
+ batch_calls: int = 0
99
+ batch_efficiency: float = 0.0 # items_processed / api_calls
100
+
101
+ # REPL execution
102
+ repl_commands_executed: int = 0
103
+ repl_errors: int = 0
104
+
105
+ # Timing breakdown
106
+ decomposition_time_s: float = 0.0
107
+ navigation_time_s: float = 0.0
108
+ synthesis_time_s: float = 0.0
109
+
110
+ def to_dict(self) -> dict[str, Any]:
111
+ """Convert to dictionary."""
112
+ return {
113
+ "sub_questions_generated": self.sub_questions_generated,
114
+ "sub_questions_answered": self.sub_questions_answered,
115
+ "decomposition_method": self.decomposition_method,
116
+ "variables_stored": self.variables_stored,
117
+ "variables_resolved": self.variables_resolved,
118
+ "total_variable_chars": self.total_variable_chars,
119
+ "sub_llm_calls": self.sub_llm_calls,
120
+ "batch_calls": self.batch_calls,
121
+ "batch_efficiency": self.batch_efficiency,
122
+ "repl_commands_executed": self.repl_commands_executed,
123
+ "repl_errors": self.repl_errors,
124
+ "decomposition_time_s": self.decomposition_time_s,
125
+ "navigation_time_s": self.navigation_time_s,
126
+ "synthesis_time_s": self.synthesis_time_s,
127
+ }
128
+
129
+
130
+ class RNSRBenchmarkAdapter:
131
+ """
132
+ Adapter to run RNSR on benchmark datasets.
133
+
134
+ IMPORTANT: This adapter ALWAYS uses the full RLM (Recursive Language Model)
135
+ pipeline. RLM is not optional - it IS RNSR. The key principles:
136
+
137
+ 1. Document stored as variable (DOC_VAR), not stuffed into prompt
138
+ 2. LLM navigates via summaries (skeleton index)
139
+ 3. Query decomposition for complex questions
140
+ 4. Variable stitching prevents context pollution
141
+
142
+ For benchmarks that provide raw text (not PDFs), we build an ephemeral
143
+ tree structure to enable full RLM processing.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ llm_provider: str = "gemini",
149
+ llm_model: str = "gemini-2.5-flash",
150
+ max_iterations: int = 20,
151
+ tot_selection_threshold: float = 0.4,
152
+ tot_dead_end_threshold: float = 0.1,
153
+ ):
154
+ self.llm_provider = llm_provider
155
+ self.llm_model = llm_model
156
+ self.max_iterations = max_iterations
157
+ self.tot_selection_threshold = tot_selection_threshold
158
+ self.tot_dead_end_threshold = tot_dead_end_threshold
159
+ # Ingestion cache to avoid re-processing PDFs (thread-safe)
160
+ self._ingestion_cache: dict[Path, tuple] = {}
161
+ self._cache_lock = threading.Lock()
162
+
163
+ def answer_from_pdf(
164
+ self,
165
+ question: str,
166
+ pdf_path: Path,
167
+ metadata: dict | None = None,
168
+ ) -> RNSRResult:
169
+ """
170
+ Answer a question using RNSR's full RLM pipeline on a PDF.
171
+
172
+ Uses:
173
+ 1. Font histogram ingestion → Document tree
174
+ 2. Skeleton index (summaries + KV store)
175
+ 3. Navigator agent with decomposition + variable stitching
176
+ """
177
+ from rnsr.ingestion import ingest_document
178
+ from rnsr.indexing import build_skeleton_index
179
+ from rnsr.agent import run_navigator
180
+
181
+ start_total = time.perf_counter()
182
+
183
+ # Check cache first (thread-safe)
184
+ start_index = time.perf_counter()
185
+ with self._cache_lock:
186
+ cached = self._ingestion_cache.get(pdf_path)
187
+
188
+ if cached:
189
+ skeleton, kv_store = cached
190
+ index_time = time.perf_counter() - start_index
191
+ logger.debug("using_cached_ingestion", pdf=str(pdf_path))
192
+ else:
193
+ # Ingest and index
194
+ result = ingest_document(pdf_path)
195
+ skeleton, kv_store = build_skeleton_index(result.tree)
196
+ index_time = time.perf_counter() - start_index
197
+ # Store in cache
198
+ with self._cache_lock:
199
+ self._ingestion_cache[pdf_path] = (skeleton, kv_store)
200
+ logger.debug("cached_ingestion", pdf=str(pdf_path))
201
+
202
+ # Full RLM: Query with navigator (always uses decomposition + variable stitching)
203
+ start_query = time.perf_counter()
204
+ answer_result = run_navigator(
205
+ question=question,
206
+ skeleton=skeleton,
207
+ kv_store=kv_store,
208
+ max_iterations=self.max_iterations,
209
+ metadata=metadata,
210
+ tot_selection_threshold=self.tot_selection_threshold,
211
+ tot_dead_end_threshold=self.tot_dead_end_threshold,
212
+ )
213
+ query_time = time.perf_counter() - start_query
214
+
215
+ total_time = time.perf_counter() - start_total
216
+
217
+ # Extract supporting facts from trace
218
+ supporting_facts = []
219
+ traversal_path = []
220
+ max_depth = 0
221
+
222
+ for entry in answer_result.get("trace", []):
223
+ if entry.get("action") == "read_node":
224
+ node_id = entry.get("node_id", "")
225
+ supporting_facts.append(node_id)
226
+ traversal_path.append(entry.get("node_type", "unknown"))
227
+ max_depth = max(max_depth, entry.get("depth", 0))
228
+
229
+ return RNSRResult(
230
+ answer=answer_result.get("answer", ""),
231
+ supporting_facts=supporting_facts,
232
+ nodes_visited=answer_result.get("nodes_visited", []),
233
+ traversal_path=traversal_path,
234
+ retrieval_time_s=index_time,
235
+ generation_time_s=query_time,
236
+ total_time_s=total_time,
237
+ tree_depth_reached=max_depth,
238
+ metadata={
239
+ "confidence": answer_result.get("confidence", 0),
240
+ "variables_used": len(answer_result.get("variables_used", [])),
241
+ }
242
+ )
243
+
244
+ def answer_from_context(
245
+ self,
246
+ question: str,
247
+ contexts: list[str],
248
+ metadata: dict | None = None,
249
+ ) -> RNSRResult:
250
+ """
251
+ Answer using pre-provided contexts (for benchmark datasets).
252
+
253
+ This uses the FULL RLM pipeline:
254
+ 1. Build ephemeral tree from text contexts
255
+ 2. Create skeleton index with summaries
256
+ 3. Run navigator with decomposition + variable stitching
257
+
258
+ This is NOT traditional RAG (stuffing context into prompt).
259
+ The document is stored as DOC_VAR and navigated structurally.
260
+ """
261
+ from rnsr.ingestion.text_builder import build_tree_from_contexts
262
+ from rnsr.indexing import build_skeleton_index
263
+ from rnsr.agent import run_navigator
264
+
265
+ start_total = time.perf_counter()
266
+ metadata = metadata or {}
267
+
268
+ # Check if this is a PDF-based benchmark (e.g. FinanceBench)
269
+ if metadata and "pdf_path" in metadata and metadata["pdf_path"]:
270
+ pdf_path_str = metadata["pdf_path"]
271
+ if Path(pdf_path_str).exists():
272
+ return self.answer_from_pdf(question, Path(pdf_path_str), metadata)
273
+ # Fallback if PDF missing but context provided (rare for FB)
274
+
275
+ # Step 1: Build ephemeral tree from benchmark contexts
276
+ start_index = time.perf_counter()
277
+ tree = build_tree_from_contexts(contexts, question)
278
+ skeleton, kv_store = build_skeleton_index(tree)
279
+ index_time = time.perf_counter() - start_index
280
+
281
+ # Step 2: Run full RLM navigator (decomposition + variable stitching)
282
+ start_query = time.perf_counter()
283
+ try:
284
+ # Use Tree of Thoughts reasoning-based navigation (research paper Section 7.2)
285
+ # Semantic search disabled by default per Section 9.1 (optional shortcut only)
286
+ answer_result = run_navigator(
287
+ question=question,
288
+ skeleton=skeleton,
289
+ kv_store=kv_store,
290
+ max_iterations=self.max_iterations,
291
+ use_semantic_search=False, # ToT primary, embeddings optional
292
+ metadata=metadata, # Pass options for multiple-choice questions
293
+ tot_selection_threshold=self.tot_selection_threshold,
294
+ tot_dead_end_threshold=self.tot_dead_end_threshold,
295
+ )
296
+ answer = answer_result.get("answer", "")
297
+
298
+ except Exception as e:
299
+ logger.warning("rnsr_navigation_failed", error=str(e))
300
+ answer = f"Error: {str(e)}"
301
+ answer_result = {"trace": [], "nodes_visited": [], "confidence": 0}
302
+
303
+ query_time = time.perf_counter() - start_query
304
+ total_time = time.perf_counter() - start_total
305
+
306
+ # Extract trace information
307
+ supporting_facts = []
308
+ traversal_path = []
309
+ max_depth = 0
310
+
311
+ for entry in answer_result.get("trace", []):
312
+ if entry.get("action") == "read_node":
313
+ supporting_facts.append(entry.get("node_id", ""))
314
+ traversal_path.append(entry.get("node_type", "unknown"))
315
+ max_depth = max(max_depth, entry.get("depth", 0))
316
+
317
+ return RNSRResult(
318
+ answer=answer,
319
+ supporting_facts=supporting_facts,
320
+ nodes_visited=answer_result.get("nodes_visited", []),
321
+ traversal_path=traversal_path,
322
+ retrieval_time_s=index_time,
323
+ generation_time_s=query_time,
324
+ total_time_s=total_time,
325
+ tree_depth_reached=max_depth,
326
+ trace=answer_result.get("trace", []),
327
+ metadata={
328
+ "num_contexts": len(contexts),
329
+ "confidence": answer_result.get("confidence", 0),
330
+ "variables_used": len(answer_result.get("variables_used", [])),
331
+ "rlm_mode": True, # Always true now
332
+ },
333
+ )
334
+
335
+ def _normalize_multiple_choice(self, answer: str, options: list[str]) -> str:
336
+ """Normalize multiple choice answer to match option text."""
337
+ answer_lower = answer.lower().strip()
338
+
339
+ for i, opt in enumerate(options):
340
+ letter = chr(65 + i) # A, B, C, D
341
+
342
+ # Match patterns like "A.", "A)", "(A)", "A. answer text"
343
+ if (answer_lower.startswith(f"{letter.lower()}.") or
344
+ answer_lower.startswith(f"{letter.lower()})") or
345
+ answer_lower.startswith(f"({letter.lower()})")):
346
+ return opt
347
+
348
+ # Check if option number format
349
+ if answer_lower.startswith(f"{i+1}.") or answer_lower.startswith(f"({i+1})"):
350
+ return opt
351
+
352
+ # Exact match (case insensitive)
353
+ if answer_lower == opt.lower():
354
+ return opt
355
+
356
+ # Check if answer contains the option text
357
+ if opt.lower() in answer_lower:
358
+ return opt
359
+
360
+ return answer # Return original if no match
361
+
362
+
363
+ # =============================================================================
364
+ # Baseline RAG Implementations (for fair comparison)
365
+ # =============================================================================
366
+
367
+ @dataclass
368
+ class BaselineResult:
369
+ """Result from a baseline RAG approach."""
370
+ answer: str
371
+ retrieved_chunks: list[str]
372
+ total_time_s: float
373
+ method: str
374
+
375
+
376
+ class NaiveChunkBaseline:
377
+ """
378
+ Naive chunking baseline using the SAME LLM as RNSR for fair comparison.
379
+
380
+ This chunks the context into fixed-size segments, retrieves top-k by
381
+ simple keyword overlap, and generates an answer.
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ chunk_size: int = 512,
387
+ chunk_overlap: int = 50,
388
+ top_k: int = 5,
389
+ llm_provider: str = "gemini",
390
+ llm_model: str = "gemini-2.5-flash",
391
+ ):
392
+ self.chunk_size = chunk_size
393
+ self.chunk_overlap = chunk_overlap
394
+ self.top_k = top_k
395
+ self.llm_provider = llm_provider
396
+ self.llm_model = llm_model
397
+
398
+ def name(self) -> str:
399
+ return f"naive_chunk_{self.chunk_size}"
400
+
401
+ def answer_from_context(
402
+ self,
403
+ question: str,
404
+ contexts: list[str],
405
+ metadata: dict[str, Any] | None = None,
406
+ ) -> BaselineResult:
407
+ """Answer using naive chunking on the provided context."""
408
+ from rnsr.llm import get_llm
409
+
410
+ start_total = time.perf_counter()
411
+
412
+ # Combine all context
413
+ full_text = "\n\n".join(contexts)
414
+
415
+ # Chunk the text naively
416
+ chunks = []
417
+ for i in range(0, len(full_text), self.chunk_size - self.chunk_overlap):
418
+ chunk = full_text[i:i + self.chunk_size]
419
+ if chunk.strip():
420
+ chunks.append(chunk)
421
+
422
+ # Simple retrieval by keyword overlap
423
+ question_words = set(question.lower().split())
424
+ scored_chunks = []
425
+ for chunk in chunks:
426
+ chunk_words = set(chunk.lower().split())
427
+ score = len(question_words & chunk_words) / max(len(question_words), 1)
428
+ scored_chunks.append((score, chunk))
429
+
430
+ scored_chunks.sort(reverse=True, key=lambda x: x[0])
431
+ retrieved = [c for _, c in scored_chunks[:self.top_k]]
432
+
433
+ # Generate answer using the SAME LLM as RNSR
434
+ logger.info("provider_detected", provider=self.llm_provider)
435
+ logger.debug("initializing_llm", provider=self.llm_provider, model=self.llm_model)
436
+ llm = get_llm(provider=LLMProvider(self.llm_provider), model=self.llm_model)
437
+
438
+ # Check if this is multiple choice
439
+ options = metadata.get("options", []) if metadata else []
440
+ if options:
441
+ options_text = "\n".join([f" {i+1}. {opt}" for i, opt in enumerate(options)])
442
+ prompt = f"""Answer this multiple-choice question based on the provided text.
443
+
444
+ Retrieved chunks:
445
+ {chr(10).join(retrieved)}
446
+
447
+ Question: {question}
448
+
449
+ Options:
450
+ {options_text}
451
+
452
+ Respond with ONLY the text of the correct option, nothing else."""
453
+ else:
454
+ prompt = f"""Answer this question based on the provided text. Be concise.
455
+
456
+ Retrieved chunks:
457
+ {chr(10).join(retrieved)}
458
+
459
+ Question: {question}
460
+
461
+ Answer:"""
462
+
463
+ try:
464
+ response = llm.complete(prompt)
465
+ answer = str(response).strip()
466
+
467
+ # Match to options if needed
468
+ if options:
469
+ answer_lower = answer.lower()
470
+ for opt in options:
471
+ if opt.lower() in answer_lower:
472
+ answer = opt
473
+ break
474
+ except Exception as e:
475
+ logger.warning("baseline_llm_failed", error=str(e))
476
+ answer = f"Error: {str(e)}"
477
+
478
+ total_time = time.perf_counter() - start_total
479
+
480
+ return BaselineResult(
481
+ answer=answer,
482
+ retrieved_chunks=retrieved,
483
+ total_time_s=total_time,
484
+ method=self.name(),
485
+ )
486
+
487
+
488
+ class SemanticChunkBaseline:
489
+ """Semantic chunking baseline - splits on paragraph boundaries."""
490
+
491
+ def __init__(
492
+ self,
493
+ top_k: int = 5,
494
+ llm_provider: str = "gemini",
495
+ llm_model: str = "gemini-2.5-flash",
496
+ ):
497
+ self.top_k = top_k
498
+ self.llm_provider = llm_provider
499
+ self.llm_model = llm_model
500
+
501
+ def name(self) -> str:
502
+ return "semantic_chunk"
503
+
504
+ def answer_from_context(
505
+ self,
506
+ question: str,
507
+ contexts: list[str],
508
+ metadata: dict[str, Any] | None = None,
509
+ ) -> BaselineResult:
510
+ """Answer using paragraph-based chunking."""
511
+ from rnsr.llm import get_llm
512
+
513
+ start_total = time.perf_counter()
514
+
515
+ # Split by paragraphs
516
+ chunks = []
517
+ for ctx in contexts:
518
+ paragraphs = ctx.split("\n\n")
519
+ for para in paragraphs:
520
+ if para.strip() and len(para.strip()) > 50:
521
+ chunks.append(para.strip())
522
+
523
+ # Retrieve by keyword overlap
524
+ question_words = set(question.lower().split())
525
+ scored_chunks = []
526
+ for chunk in chunks:
527
+ chunk_words = set(chunk.lower().split())
528
+ score = len(question_words & chunk_words) / max(len(question_words), 1)
529
+ scored_chunks.append((score, chunk))
530
+
531
+ scored_chunks.sort(reverse=True, key=lambda x: x[0])
532
+ retrieved = [c for _, c in scored_chunks[:self.top_k]]
533
+
534
+ # Generate answer
535
+ logger.info("provider_detected", provider=self.llm_provider)
536
+ logger.debug("initializing_llm", provider=self.llm_provider, model=self.llm_model)
537
+ llm = get_llm(provider=LLMProvider(self.llm_provider), model=self.llm_model)
538
+
539
+ options = metadata.get("options", []) if metadata else []
540
+ if options:
541
+ options_text = "\n".join([f" {i+1}. {opt}" for i, opt in enumerate(options)])
542
+ prompt = f"""Answer this multiple-choice question based on the provided text.
543
+
544
+ Retrieved paragraphs:
545
+ {chr(10).join(retrieved)}
546
+
547
+ Question: {question}
548
+
549
+ Options:
550
+ {options_text}
551
+
552
+ Respond with ONLY the text of the correct option, nothing else."""
553
+ else:
554
+ prompt = f"""Answer this question based on the provided text. Be concise.
555
+
556
+ Retrieved paragraphs:
557
+ {chr(10).join(retrieved)}
558
+
559
+ Question: {question}
560
+
561
+ Answer:"""
562
+
563
+ try:
564
+ response = llm.complete(prompt)
565
+ answer = str(response).strip()
566
+
567
+ if options:
568
+ answer_lower = answer.lower()
569
+ for opt in options:
570
+ if opt.lower() in answer_lower:
571
+ answer = opt
572
+ break
573
+ except Exception as e:
574
+ logger.warning("baseline_llm_failed", error=str(e))
575
+ answer = f"Error: {str(e)}"
576
+
577
+ total_time = time.perf_counter() - start_total
578
+
579
+ return BaselineResult(
580
+ answer=answer,
581
+ retrieved_chunks=retrieved,
582
+ total_time_s=total_time,
583
+ method=self.name(),
584
+ )
585
+
586
+
587
+ # =============================================================================
588
+ # Evaluation Suite
589
+ # =============================================================================
590
+
591
+ @dataclass
592
+ class EvaluationConfig:
593
+ """Configuration for benchmark evaluation."""
594
+
595
+ datasets: list[str] = field(default_factory=lambda: ["hotpotqa"])
596
+ max_samples: int = field(default=100)
597
+ baselines: list[str] = field(default_factory=lambda: ["naive_chunk_512"])
598
+ output_dir: Path = field(default_factory=lambda: Path("benchmark_results"))
599
+ run_ragas: bool = field(default=True)
600
+ save_predictions: bool = field(default=True)
601
+ llm_provider: str = field(default="gemini")
602
+ llm_model: str = field(default="gemini-2.5-flash")
603
+ chaos_mode: bool = field(default=False)
604
+ chaos_distractors: int = field(default=3)
605
+ tot_selection_threshold: float = field(default=0.4)
606
+ tot_dead_end_threshold: float = field(default=0.1)
607
+ parallel_workers: int = field(default=1) # Number of parallel workers for question processing
608
+
609
+
610
+ @dataclass
611
+ class EvaluationReport:
612
+ """Complete evaluation report."""
613
+
614
+ timestamp: str
615
+ config: dict[str, Any]
616
+ dataset_results: dict[str, dict[str, Any]]
617
+ comparisons: list[dict[str, Any]]
618
+ summary: dict[str, Any]
619
+
620
+ @property
621
+ def overall_accuracy(self) -> float:
622
+ """Calculate overall accuracy across all datasets."""
623
+ total_acc = 0.0
624
+ count = 0
625
+ for results in self.dataset_results.values():
626
+ metrics = results.get("rnsr_metrics", {})
627
+ if "accuracy" in metrics:
628
+ total_acc += metrics["accuracy"]
629
+ count += 1
630
+ elif "exact_match" in metrics:
631
+ total_acc += metrics["exact_match"]
632
+ count += 1
633
+ return total_acc / max(count, 1)
634
+
635
+ @property
636
+ def avg_latency_s(self) -> float:
637
+ """Calculate average latency across all datasets."""
638
+ total_latency = 0.0
639
+ count = 0
640
+ for results in self.dataset_results.values():
641
+ metrics = results.get("rnsr_metrics", {})
642
+ if "avg_time_s" in metrics:
643
+ total_latency += metrics["avg_time_s"]
644
+ count += 1
645
+ return total_latency / max(count, 1)
646
+
647
+ def save(self, path: Path) -> None:
648
+ """Save report to JSON."""
649
+ path.parent.mkdir(parents=True, exist_ok=True)
650
+ with open(path, "w") as f:
651
+ json.dump(asdict(self), f, indent=2, default=str)
652
+
653
+ def print_summary(self) -> None:
654
+ """Print human-readable summary."""
655
+ print("\n" + "=" * 70)
656
+ print("RNSR BENCHMARK EVALUATION REPORT")
657
+ print("=" * 70)
658
+ print(f"Timestamp: {self.timestamp}")
659
+ print(f"Datasets evaluated: {list(self.dataset_results.keys())}")
660
+
661
+ for dataset, results in self.dataset_results.items():
662
+ print(f"\n--- {dataset} ---")
663
+
664
+ rnsr_metrics = results.get("rnsr_metrics", {})
665
+ for metric, value in rnsr_metrics.items():
666
+ if isinstance(value, float):
667
+ print(f" RNSR {metric}: {value:.3f}")
668
+ else:
669
+ print(f" RNSR {metric}: {value}")
670
+
671
+ # Print RLM-specific metrics if available
672
+ rlm_metrics = results.get("rlm_metrics", {})
673
+ if rlm_metrics:
674
+ print("\n 📊 RLM Metrics:")
675
+ print(f" Sub-questions generated: {rlm_metrics.get('total_sub_questions', 0)}")
676
+ print(f" Variables stored: {rlm_metrics.get('total_variables', 0)}")
677
+ print(f" Batch calls: {rlm_metrics.get('total_batch_calls', 0)}")
678
+ print(f" Avg decomposition time: {rlm_metrics.get('avg_decomposition_time', 0):.3f}s")
679
+ print(f" Avg sub-task time: {rlm_metrics.get('avg_sub_task_time', 0):.3f}s")
680
+
681
+ for baseline, metrics in results.get("baseline_metrics", {}).items():
682
+ print(f"\n {baseline}:")
683
+ for metric, value in metrics.items():
684
+ if isinstance(value, float):
685
+ print(f" {metric}: {value:.3f}")
686
+ else:
687
+ print(f" {metric}: {value}")
688
+
689
+ print("\n" + "-" * 70)
690
+ print("IMPROVEMENTS OVER BASELINES")
691
+ print("-" * 70)
692
+
693
+ for comp in self.comparisons:
694
+ print(f"\nvs {comp['baseline_name']} on {comp['dataset_name']}:")
695
+ for metric, delta in comp.get("improvement", {}).items():
696
+ rel = comp.get("relative_improvement", {}).get(metric, 0) * 100
697
+ print(f" {metric}: {delta:+.3f} ({rel:+.1f}%)")
698
+
699
+ print("\n" + "=" * 70)
700
+
701
+
702
+ class EvaluationSuite:
703
+ """
704
+ Run comprehensive RNSR evaluation against standard benchmarks.
705
+ """
706
+
707
+ def __init__(self, config: EvaluationConfig):
708
+ self.config = config
709
+ self.rnsr = RNSRBenchmarkAdapter(
710
+ llm_provider=config.llm_provider,
711
+ llm_model=config.llm_model,
712
+ tot_selection_threshold=config.tot_selection_threshold,
713
+ tot_dead_end_threshold=config.tot_dead_end_threshold,
714
+ )
715
+ self.ragas_evaluator = RAGASEvaluator(
716
+ llm_provider=config.llm_provider,
717
+ llm_model=config.llm_model,
718
+ ) if config.run_ragas else None
719
+
720
+ # Initialize baselines
721
+ self.baselines = {}
722
+ for baseline_name in config.baselines:
723
+ if baseline_name.startswith("naive_chunk"):
724
+ chunk_size = int(baseline_name.split("_")[-1])
725
+ self.baselines[baseline_name] = NaiveChunkBaseline(
726
+ chunk_size=chunk_size,
727
+ llm_provider=config.llm_provider,
728
+ llm_model=config.llm_model,
729
+ )
730
+ elif baseline_name == "semantic_chunk":
731
+ self.baselines[baseline_name] = SemanticChunkBaseline(
732
+ llm_provider=config.llm_provider,
733
+ llm_model=config.llm_model,
734
+ )
735
+
736
+ def load_dataset(self, name: str) -> BenchmarkDataset:
737
+ """Load a benchmark dataset by name."""
738
+ if name == "hotpotqa":
739
+ return BenchmarkLoader.load_hotpotqa(max_samples=self.config.max_samples)
740
+ elif name.startswith("musique"):
741
+ variant = "full" if "full" in name else "ans"
742
+ return BenchmarkLoader.load_musique(variant=variant, max_samples=self.config.max_samples)
743
+ elif name.startswith("beir_"):
744
+ dataset_name = name.replace("beir_", "")
745
+ return BenchmarkLoader.load_beir_dataset(dataset_name, max_samples=self.config.max_samples)
746
+ elif name == "qasper":
747
+ return BenchmarkLoader.load_qasper(max_samples=self.config.max_samples)
748
+ elif name == "quality":
749
+ return BenchmarkLoader.load_quality(max_samples=self.config.max_samples)
750
+ elif name == "financebench":
751
+ return BenchmarkLoader.load_financebench(split="train", max_samples=self.config.max_samples)
752
+ elif name == "narrativeqa":
753
+ return BenchmarkLoader.load_narrative_qa(max_samples=self.config.max_samples)
754
+ else:
755
+ raise ValueError(f"Unknown dataset: {name}. Available: hotpotqa, musique_ans, musique_full, qasper, quality, narrativeqa, beir_*")
756
+
757
+ def evaluate_rnsr_on_dataset(
758
+ self,
759
+ dataset: BenchmarkDataset,
760
+ ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
761
+ """
762
+ Evaluate RNSR on a benchmark dataset.
763
+
764
+ Returns:
765
+ predictions: List of prediction dicts
766
+ metrics: Aggregated metrics dict (includes RLM metrics if enabled)
767
+ """
768
+ predictions = []
769
+
770
+ logger.info(
771
+ "evaluating_rnsr",
772
+ dataset=dataset.name,
773
+ questions=len(dataset.questions),
774
+ workers=self.config.parallel_workers,
775
+ )
776
+
777
+ def process_question(idx_question: tuple[int, "BenchmarkQuestion"]) -> dict[str, Any]:
778
+ """Process a single question (helper for parallel execution)."""
779
+ i, question = idx_question
780
+ logger.debug("processing_question", index=i, question=question.question[:50])
781
+
782
+ try:
783
+ result = self.rnsr.answer_from_context(
784
+ question=question.question,
785
+ contexts=question.context,
786
+ metadata=question.metadata,
787
+ )
788
+
789
+ pred_entry = {
790
+ "id": question.id,
791
+ "question": question.question,
792
+ "answer": result.answer,
793
+ "supporting_facts": result.supporting_facts,
794
+ "nodes_visited": result.nodes_visited,
795
+ "time_s": result.total_time_s,
796
+ "tree_depth": result.tree_depth_reached,
797
+ "trace": result.trace,
798
+ "metadata": result.metadata,
799
+ }
800
+
801
+ if result.metadata and result.metadata.get("rlm_mode"):
802
+ pred_entry["rlm_metrics"] = result.metadata.get("rlm_metrics", {})
803
+
804
+ return pred_entry
805
+
806
+ except Exception as e:
807
+ logger.error("question_failed", error=str(e), question_id=question.id)
808
+ return {
809
+ "id": question.id,
810
+ "question": question.question,
811
+ "answer": "",
812
+ "supporting_facts": [],
813
+ "error": str(e),
814
+ }
815
+
816
+ # Parallel or sequential execution based on config
817
+ if self.config.parallel_workers > 1:
818
+ # Parallel execution with ThreadPoolExecutor
819
+ with ThreadPoolExecutor(max_workers=self.config.parallel_workers) as executor:
820
+ futures = {
821
+ executor.submit(process_question, (i, q)): i
822
+ for i, q in enumerate(dataset.questions)
823
+ }
824
+
825
+ completed = 0
826
+ for future in as_completed(futures):
827
+ pred = future.result()
828
+ predictions.append(pred)
829
+ completed += 1
830
+ if completed % 10 == 0:
831
+ logger.info("progress", completed=completed, total=len(dataset.questions))
832
+
833
+ # Sort by original order (id may not be sequential)
834
+ predictions.sort(key=lambda p: str(p.get("id", "")))
835
+ else:
836
+ # Sequential execution (original behavior)
837
+ for i, question in enumerate(dataset.questions):
838
+ pred = process_question((i, question))
839
+ predictions.append(pred)
840
+
841
+ # Compute metrics
842
+ metrics: dict[str, Any] = {}
843
+ if "answer_f1" in dataset.metrics or "answer_em" in dataset.metrics or "accuracy" in dataset.metrics:
844
+ multi_hop = evaluate_multihop(predictions, dataset.questions)
845
+ metrics = multi_hop.to_dict()
846
+ else:
847
+ # Retrieval metrics (for BEIR)
848
+ metrics = self._compute_retrieval_metrics(predictions, dataset.questions)
849
+
850
+ # Add timing metrics
851
+ times = [p.get("time_s", 0) for p in predictions if "time_s" in p]
852
+ if times:
853
+ metrics["mean_time_s"] = sum(times) / len(times)
854
+ metrics["total_time_s"] = sum(times)
855
+
856
+ # Aggregate RLM metrics if in RLM mode
857
+ rlm_preds = [p for p in predictions if p.get("rlm_metrics")]
858
+ if rlm_preds:
859
+ rlm_agg = {
860
+ "total_sub_questions": sum(p["rlm_metrics"].get("sub_questions_generated", 0) for p in rlm_preds),
861
+ "total_variables": sum(p["rlm_metrics"].get("variables_stored", 0) for p in rlm_preds),
862
+ "total_batch_calls": sum(p["rlm_metrics"].get("batch_calls_made", 0) for p in rlm_preds),
863
+ "total_llm_calls": sum(p["rlm_metrics"].get("total_llm_calls", 0) for p in rlm_preds),
864
+ "avg_decomposition_time": sum(p["rlm_metrics"].get("decomposition_time_s", 0) for p in rlm_preds) / len(rlm_preds),
865
+ "avg_sub_task_time": sum(p["rlm_metrics"].get("sub_task_time_s", 0) for p in rlm_preds) / len(rlm_preds),
866
+ "avg_stitching_time": sum(p["rlm_metrics"].get("stitching_time_s", 0) for p in rlm_preds) / len(rlm_preds),
867
+ }
868
+ metrics["rlm_metrics"] = rlm_agg
869
+
870
+ return predictions, metrics
871
+
872
+ def _compute_retrieval_metrics(
873
+ self,
874
+ predictions: list[dict],
875
+ questions: list[BenchmarkQuestion],
876
+ ) -> dict[str, float]:
877
+ """Compute retrieval metrics (precision, recall, etc.)."""
878
+ # Simplified retrieval metrics
879
+ precisions = []
880
+ recalls = []
881
+
882
+ for pred, q in zip(predictions, questions):
883
+ retrieved = set(pred.get("nodes_visited", []))
884
+ relevant = set(q.supporting_facts) if q.supporting_facts else set()
885
+
886
+ if retrieved and relevant:
887
+ prec = len(retrieved & relevant) / len(retrieved)
888
+ rec = len(retrieved & relevant) / len(relevant)
889
+ precisions.append(prec)
890
+ recalls.append(rec)
891
+
892
+ n = len(precisions) or 1
893
+ precision = sum(precisions) / n
894
+ recall = sum(recalls) / n
895
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
896
+
897
+ return {
898
+ "precision": precision,
899
+ "recall": recall,
900
+ "f1": f1,
901
+ }
902
+
903
+ def evaluate_baseline_on_dataset(
904
+ self,
905
+ baseline: NaiveChunkBaseline | SemanticChunkBaseline,
906
+ dataset: BenchmarkDataset,
907
+ ) -> tuple[list[dict[str, Any]], dict[str, float]]:
908
+ """
909
+ Evaluate a baseline RAG approach on a benchmark dataset.
910
+
911
+ Uses the same LLM as RNSR for fair comparison.
912
+ """
913
+ predictions = []
914
+
915
+ logger.info(
916
+ "evaluating_baseline_on_dataset",
917
+ baseline=baseline.name(),
918
+ dataset=dataset.name,
919
+ questions=len(dataset.questions),
920
+ )
921
+
922
+ for i, question in enumerate(dataset.questions):
923
+ logger.debug(
924
+ "baseline_processing_question",
925
+ baseline=baseline.name(),
926
+ index=i,
927
+ question=question.question[:50],
928
+ )
929
+
930
+ try:
931
+ result = baseline.answer_from_context(
932
+ question=question.question,
933
+ contexts=question.context,
934
+ metadata=question.metadata,
935
+ )
936
+
937
+ predictions.append({
938
+ "id": question.id,
939
+ "question": question.question,
940
+ "answer": result.answer,
941
+ "time_s": result.total_time_s,
942
+ "method": result.method,
943
+ })
944
+
945
+ except Exception as e:
946
+ logger.error(
947
+ "baseline_question_failed",
948
+ error=str(e),
949
+ question_id=question.id,
950
+ baseline=baseline.name(),
951
+ )
952
+ predictions.append({
953
+ "id": question.id,
954
+ "question": question.question,
955
+ "answer": "",
956
+ "error": str(e),
957
+ "method": baseline.name(),
958
+ })
959
+
960
+ # Compute metrics (same as RNSR)
961
+ if "answer_f1" in dataset.metrics or "answer_em" in dataset.metrics or "accuracy" in dataset.metrics:
962
+ multi_hop = evaluate_multihop(predictions, dataset.questions)
963
+ metrics = multi_hop.to_dict()
964
+ else:
965
+ metrics = {"f1": 0.0} # Retrieval metrics not applicable
966
+
967
+ # Add timing metrics
968
+ times = [p.get("time_s", 0) for p in predictions if "time_s" in p]
969
+ if times:
970
+ metrics["mean_time_s"] = sum(times) / len(times)
971
+ metrics["total_time_s"] = sum(times)
972
+
973
+ return predictions, metrics
974
+
975
+ def run(self) -> EvaluationReport:
976
+ """Run full evaluation suite."""
977
+ logger.info("starting_evaluation_suite", datasets=self.config.datasets)
978
+
979
+ dataset_results = {}
980
+ all_comparisons = []
981
+
982
+ for dataset_name in self.config.datasets:
983
+ logger.info("loading_dataset", name=dataset_name)
984
+ dataset = self.load_dataset(dataset_name)
985
+
986
+ if not dataset.questions:
987
+ logger.warning("empty_dataset", name=dataset_name)
988
+ continue
989
+
990
+ # Application of Chaos Mode
991
+ if self.config.chaos_mode and "financebench" in dataset_name.lower():
992
+ from rnsr.benchmarks.pdf_merger import PDFMerger
993
+ logger.info("applying_chaos_mode", dataset=dataset_name)
994
+
995
+ # Collect all available PDFs to use as distractors
996
+ all_pdfs = set()
997
+ for q in dataset.questions:
998
+ if q.metadata and q.metadata.get("pdf_path"):
999
+ p = Path(q.metadata["pdf_path"])
1000
+ if p.exists():
1001
+ all_pdfs.add(p)
1002
+
1003
+ pool_of_pdfs = list(all_pdfs)
1004
+ if len(pool_of_pdfs) < 2:
1005
+ logger.warning("not_enough_pdfs_for_chaos", count=len(pool_of_pdfs))
1006
+ else:
1007
+ chaos_dir = self.config.output_dir / "chaos_data"
1008
+ dataset.questions = PDFMerger.create_chaos_dataset(
1009
+ dataset.questions,
1010
+ pool_of_pdfs,
1011
+ chaos_dir,
1012
+ num_distractors=self.config.chaos_distractors
1013
+ )
1014
+ dataset.name = f"{dataset.name}-CHAOS"
1015
+ logger.info("chaos_mode_applied", new_size=len(dataset.questions))
1016
+
1017
+ # Evaluate RNSR
1018
+ predictions, rnsr_metrics = self.evaluate_rnsr_on_dataset(dataset)
1019
+
1020
+ # Evaluate baselines using the same LLM for fair comparison
1021
+ baseline_metrics = {}
1022
+ for baseline_name, baseline in self.baselines.items():
1023
+ logger.info("evaluating_baseline", baseline=baseline_name, dataset=dataset_name)
1024
+ baseline_preds, base_metrics = self.evaluate_baseline_on_dataset(
1025
+ baseline, dataset
1026
+ )
1027
+ baseline_metrics[baseline_name] = base_metrics
1028
+
1029
+ dataset_results[dataset_name] = {
1030
+ "rnsr_metrics": rnsr_metrics,
1031
+ "baseline_metrics": baseline_metrics,
1032
+ "num_questions": len(dataset.questions),
1033
+ "predictions": predictions if self.config.save_predictions else [],
1034
+ }
1035
+
1036
+ # Generate comparisons
1037
+ for baseline_name, base_metrics in baseline_metrics.items():
1038
+ comparison = compare_rnsr_vs_baseline(
1039
+ rnsr_metrics,
1040
+ base_metrics,
1041
+ dataset_name,
1042
+ baseline_name,
1043
+ )
1044
+ all_comparisons.append(asdict(comparison))
1045
+
1046
+ # Build summary
1047
+ summary = self._build_summary(dataset_results, all_comparisons)
1048
+
1049
+ report = EvaluationReport(
1050
+ timestamp=datetime.now(timezone.utc).isoformat(),
1051
+ config=asdict(self.config),
1052
+ dataset_results=dataset_results,
1053
+ comparisons=all_comparisons,
1054
+ summary=summary,
1055
+ )
1056
+
1057
+ # Save report
1058
+ report_path = self.config.output_dir / f"eval_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
1059
+ report.save(report_path)
1060
+ logger.info("report_saved", path=str(report_path))
1061
+
1062
+ return report
1063
+
1064
+ def _build_summary(
1065
+ self,
1066
+ dataset_results: dict,
1067
+ comparisons: list[dict],
1068
+ ) -> dict[str, Any]:
1069
+ """Build summary statistics."""
1070
+ # Aggregate metrics across datasets
1071
+ all_f1s = []
1072
+ all_times = []
1073
+
1074
+ for results in dataset_results.values():
1075
+ rnsr = results.get("rnsr_metrics", {})
1076
+ if "answer_f1" in rnsr:
1077
+ all_f1s.append(rnsr["answer_f1"])
1078
+ if "mean_time_s" in rnsr:
1079
+ all_times.append(rnsr["mean_time_s"])
1080
+
1081
+ # Average improvements over baselines
1082
+ avg_improvements = {}
1083
+ for comp in comparisons:
1084
+ for metric, delta in comp.get("improvement", {}).items():
1085
+ if metric not in avg_improvements:
1086
+ avg_improvements[metric] = []
1087
+ avg_improvements[metric].append(delta)
1088
+
1089
+ for metric in avg_improvements:
1090
+ vals = avg_improvements[metric]
1091
+ avg_improvements[metric] = sum(vals) / len(vals) if vals else 0
1092
+
1093
+ return {
1094
+ "datasets_evaluated": len(dataset_results),
1095
+ "total_questions": sum(r.get("num_questions", 0) for r in dataset_results.values()),
1096
+ "mean_answer_f1": sum(all_f1s) / len(all_f1s) if all_f1s else 0,
1097
+ "mean_time_s": sum(all_times) / len(all_times) if all_times else 0,
1098
+ "avg_improvement_over_baselines": avg_improvements,
1099
+ }
1100
+
1101
+
1102
+ # =============================================================================
1103
+ # CLI
1104
+ # =============================================================================
1105
+
1106
+ def main():
1107
+ parser = argparse.ArgumentParser(
1108
+ description="Run RNSR against standard RAG benchmarks"
1109
+ )
1110
+
1111
+ parser.add_argument(
1112
+ "--datasets", "-d",
1113
+ nargs="+",
1114
+ default=["hotpotqa"],
1115
+ help="Datasets to evaluate (hotpotqa, musique_ans, beir_nfcorpus, etc.)"
1116
+ )
1117
+
1118
+ parser.add_argument(
1119
+ "--samples", "-n",
1120
+ type=int,
1121
+ default=100,
1122
+ help="Max samples per dataset"
1123
+ )
1124
+
1125
+ parser.add_argument(
1126
+ "--baselines", "-b",
1127
+ nargs="+",
1128
+ default=["naive_chunk_512"],
1129
+ help="Baselines to compare against"
1130
+ )
1131
+
1132
+ parser.add_argument(
1133
+ "--output", "-o",
1134
+ type=Path,
1135
+ default=Path("benchmark_results"),
1136
+ help="Output directory"
1137
+ )
1138
+
1139
+ parser.add_argument(
1140
+ "--no-ragas",
1141
+ action="store_true",
1142
+ help="Skip RAGAS evaluation"
1143
+ )
1144
+
1145
+ parser.add_argument(
1146
+ "--llm-provider", "-p",
1147
+ type=str,
1148
+ default="gemini",
1149
+ choices=["openai", "anthropic", "gemini"],
1150
+ help="LLM provider to use (default: gemini)"
1151
+ )
1152
+
1153
+ parser.add_argument(
1154
+ "--llm-model", "-m",
1155
+ type=str,
1156
+ default="gemini-2.5-flash",
1157
+ help="LLM model name (default: gemini-2.5-flash)"
1158
+ )
1159
+
1160
+ parser.add_argument(
1161
+ "--chaos",
1162
+ action="store_true",
1163
+ help="Enable chaos mode (merge PDFs with random distractors)"
1164
+ )
1165
+
1166
+ parser.add_argument(
1167
+ "--tot-threshold",
1168
+ type=float,
1169
+ default=0.4,
1170
+ help="ToT selection threshold (default: 0.4)"
1171
+ )
1172
+
1173
+ parser.add_argument(
1174
+ "--tot-dead-end",
1175
+ type=float,
1176
+ default=0.1,
1177
+ help="ToT dead end threshold (default: 0.1)"
1178
+ )
1179
+
1180
+ parser.add_argument(
1181
+ "--workers", "-w",
1182
+ type=int,
1183
+ default=1,
1184
+ help="Number of parallel workers for processing questions (default: 1, sequential)"
1185
+ )
1186
+
1187
+ args = parser.parse_args()
1188
+
1189
+ # RNSR always uses the full RLM flow - no mode switching needed
1190
+ config = EvaluationConfig(
1191
+ datasets=args.datasets,
1192
+ max_samples=args.samples,
1193
+ baselines=args.baselines,
1194
+ output_dir=args.output,
1195
+ run_ragas=not args.no_ragas,
1196
+ llm_provider=args.llm_provider,
1197
+ llm_model=args.llm_model,
1198
+ chaos_mode=args.chaos,
1199
+ tot_selection_threshold=args.tot_threshold,
1200
+ tot_dead_end_threshold=args.tot_dead_end,
1201
+ parallel_workers=args.workers,
1202
+ )
1203
+
1204
+ suite = EvaluationSuite(config)
1205
+ report = suite.run()
1206
+ report.print_summary()
1207
+
1208
+
1209
+ if __name__ == "__main__":
1210
+ main()