rnsr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rnsr/__init__.py +118 -0
  2. rnsr/__main__.py +242 -0
  3. rnsr/agent/__init__.py +218 -0
  4. rnsr/agent/cross_doc_navigator.py +767 -0
  5. rnsr/agent/graph.py +1557 -0
  6. rnsr/agent/llm_cache.py +575 -0
  7. rnsr/agent/navigator_api.py +497 -0
  8. rnsr/agent/provenance.py +772 -0
  9. rnsr/agent/query_clarifier.py +617 -0
  10. rnsr/agent/reasoning_memory.py +736 -0
  11. rnsr/agent/repl_env.py +709 -0
  12. rnsr/agent/rlm_navigator.py +2108 -0
  13. rnsr/agent/self_reflection.py +602 -0
  14. rnsr/agent/variable_store.py +308 -0
  15. rnsr/benchmarks/__init__.py +118 -0
  16. rnsr/benchmarks/comprehensive_benchmark.py +733 -0
  17. rnsr/benchmarks/evaluation_suite.py +1210 -0
  18. rnsr/benchmarks/finance_bench.py +147 -0
  19. rnsr/benchmarks/pdf_merger.py +178 -0
  20. rnsr/benchmarks/performance.py +321 -0
  21. rnsr/benchmarks/quality.py +321 -0
  22. rnsr/benchmarks/runner.py +298 -0
  23. rnsr/benchmarks/standard_benchmarks.py +995 -0
  24. rnsr/client.py +560 -0
  25. rnsr/document_store.py +394 -0
  26. rnsr/exceptions.py +74 -0
  27. rnsr/extraction/__init__.py +172 -0
  28. rnsr/extraction/candidate_extractor.py +357 -0
  29. rnsr/extraction/entity_extractor.py +581 -0
  30. rnsr/extraction/entity_linker.py +825 -0
  31. rnsr/extraction/grounded_extractor.py +722 -0
  32. rnsr/extraction/learned_types.py +599 -0
  33. rnsr/extraction/models.py +232 -0
  34. rnsr/extraction/relationship_extractor.py +600 -0
  35. rnsr/extraction/relationship_patterns.py +511 -0
  36. rnsr/extraction/relationship_validator.py +392 -0
  37. rnsr/extraction/rlm_extractor.py +589 -0
  38. rnsr/extraction/rlm_unified_extractor.py +990 -0
  39. rnsr/extraction/tot_validator.py +610 -0
  40. rnsr/extraction/unified_extractor.py +342 -0
  41. rnsr/indexing/__init__.py +60 -0
  42. rnsr/indexing/knowledge_graph.py +1128 -0
  43. rnsr/indexing/kv_store.py +313 -0
  44. rnsr/indexing/persistence.py +323 -0
  45. rnsr/indexing/semantic_retriever.py +237 -0
  46. rnsr/indexing/semantic_search.py +320 -0
  47. rnsr/indexing/skeleton_index.py +395 -0
  48. rnsr/ingestion/__init__.py +161 -0
  49. rnsr/ingestion/chart_parser.py +569 -0
  50. rnsr/ingestion/document_boundary.py +662 -0
  51. rnsr/ingestion/font_histogram.py +334 -0
  52. rnsr/ingestion/header_classifier.py +595 -0
  53. rnsr/ingestion/hierarchical_cluster.py +515 -0
  54. rnsr/ingestion/layout_detector.py +356 -0
  55. rnsr/ingestion/layout_model.py +379 -0
  56. rnsr/ingestion/ocr_fallback.py +177 -0
  57. rnsr/ingestion/pipeline.py +936 -0
  58. rnsr/ingestion/semantic_fallback.py +417 -0
  59. rnsr/ingestion/table_parser.py +799 -0
  60. rnsr/ingestion/text_builder.py +460 -0
  61. rnsr/ingestion/tree_builder.py +402 -0
  62. rnsr/ingestion/vision_retrieval.py +965 -0
  63. rnsr/ingestion/xy_cut.py +555 -0
  64. rnsr/llm.py +733 -0
  65. rnsr/models.py +167 -0
  66. rnsr/py.typed +2 -0
  67. rnsr-0.1.0.dist-info/METADATA +592 -0
  68. rnsr-0.1.0.dist-info/RECORD +72 -0
  69. rnsr-0.1.0.dist-info/WHEEL +5 -0
  70. rnsr-0.1.0.dist-info/entry_points.txt +2 -0
  71. rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
  72. rnsr-0.1.0.dist-info/top_level.txt +1 -0
rnsr/agent/graph.py ADDED
@@ -0,0 +1,1557 @@
1
+ """
2
+ Agent Graph - LangGraph State Machine for Document Navigation
3
+
4
+ Implements the Navigator Agent with full RLM (Recursive Language Model) support:
5
+
6
+ 1. Decomposes queries into sub-questions (Section 2.2 - Recursive Loop)
7
+ 2. Navigates the document tree (expand/traverse decisions)
8
+ 3. Stores findings as pointers (Variable Stitching - Section 2.2)
9
+ 4. Synthesizes final answer from stored pointers
10
+ 5. Tree of Thoughts prompting (Section 7.2)
11
+ 6. Recursive sub-LLM invocation for complex queries
12
+
13
+ Agent State follows Appendix C specification:
14
+ - question: Current question being answered
15
+ - sub_questions: Decomposed sub-questions (via LLM)
16
+ - current_node_id: Where we are in the document tree
17
+ - visited_nodes: Navigation history
18
+ - variables: Stored findings as $POINTER -> content
19
+ - pending_questions: Sub-questions not yet answered
20
+ - context: Accumulated context (pointers only!)
21
+ - answer: Final synthesized answer
22
+ - trace: Full retrieval trace for transparency
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import re
28
+ from datetime import datetime, timezone
29
+ from typing import Any, Literal, TypedDict, cast
30
+
31
+ import structlog
32
+
33
+ from rnsr.agent.variable_store import VariableStore, generate_pointer_name
34
+ from rnsr.indexing.kv_store import KVStore
35
+ from rnsr.indexing.semantic_search import SemanticSearcher
36
+ from rnsr.models import SkeletonNode, TraceEntry
37
+
38
+ logger = structlog.get_logger(__name__)
39
+
40
+
41
+ # =============================================================================
42
+ # Agent State Definition (Appendix C)
43
+ # =============================================================================
44
+
45
+
46
+ class AgentState(TypedDict):
47
+ """
48
+ State for the RNSR Navigator Agent.
49
+
50
+ All fields use pointer-based Variable Stitching:
51
+ - Full content stored in VariableStore
52
+ - Only $POINTER names in state fields
53
+ """
54
+
55
+ # Query processing
56
+ question: str
57
+ sub_questions: list[str]
58
+ pending_questions: list[str]
59
+ current_sub_question: str | None
60
+
61
+ # Navigation state
62
+ current_node_id: str | None
63
+ visited_nodes: list[str]
64
+ navigation_path: list[str]
65
+
66
+ # Tree of Thoughts (ToT) state - Section 7.2
67
+ nodes_to_visit: list[str] # Queue for parallel exploration
68
+ scored_candidates: list[dict[str, Any]] # [{node_id, score, reasoning}]
69
+ backtrack_stack: list[str] # Stack of parent node IDs for backtracking
70
+ dead_ends: list[str] # Nodes marked as dead ends
71
+ top_k: int # Number of top candidates to explore
72
+ tot_selection_threshold: float # Minimum probability for selection
73
+ tot_dead_end_threshold: float # Probability threshold for dead end
74
+
75
+ # Variable stitching (pointers only!)
76
+ variables: list[str] # List of $POINTER names
77
+ context: str # Contains pointers, not full content
78
+
79
+ # Output
80
+ answer: str | None
81
+ confidence: float
82
+
83
+ # Question metadata (e.g., multiple-choice options)
84
+ metadata: dict[str, Any]
85
+
86
+ # Traceability
87
+ trace: list[dict[str, Any]]
88
+ iteration: int
89
+ max_iterations: int
90
+
91
+
92
+ # =============================================================================
93
+ # Navigator Tools
94
+ # =============================================================================
95
+
96
+
97
+ def create_navigator_tools(
98
+ skeleton: dict[str, SkeletonNode],
99
+ kv_store: KVStore,
100
+ variable_store: VariableStore,
101
+ ) -> dict[str, Any]:
102
+ """
103
+ Create tools for the Navigator Agent.
104
+
105
+ Args:
106
+ skeleton: Skeleton index (node_id -> SkeletonNode).
107
+ kv_store: KV store with full content.
108
+ variable_store: Variable store for findings.
109
+
110
+ Returns:
111
+ Dictionary of tool functions.
112
+ """
113
+
114
+ def get_node_summary(node_id: str) -> str:
115
+ """Get the summary of a skeleton node."""
116
+ node = skeleton.get(node_id)
117
+ if node is None:
118
+ return f"Node {node_id} not found"
119
+
120
+ children_info = ""
121
+ if node.child_ids:
122
+ children = [skeleton.get(cid) for cid in node.child_ids]
123
+ children_info = "\nChildren:\n" + "\n".join(
124
+ f" - {c.node_id}: {c.header}" for c in children if c
125
+ )
126
+
127
+ return f"""
128
+ Node: {node.node_id}
129
+ Header: {node.header}
130
+ Level: {node.level}
131
+ Summary: {node.summary}
132
+ {children_info}
133
+ """
134
+
135
+ def navigate_to_child(node_id: str, child_id: str) -> str:
136
+ """Navigate to a child node."""
137
+ node = skeleton.get(node_id)
138
+ if node is None:
139
+ return f"Node {node_id} not found"
140
+
141
+ if child_id not in node.child_ids:
142
+ return f"{child_id} is not a child of {node_id}"
143
+
144
+ child = skeleton.get(child_id)
145
+ if child is None:
146
+ return f"Child {child_id} not found"
147
+
148
+ return get_node_summary(child_id)
149
+
150
+ def expand_node(node_id: str) -> str:
151
+ """
152
+ EXPAND: Fetch full content and store as variable.
153
+
154
+ Use when the summary answers the question.
155
+ """
156
+ node = skeleton.get(node_id)
157
+ if node is None:
158
+ return f"Node {node_id} not found"
159
+
160
+ # Fetch full content from KV store
161
+ content = kv_store.get(node_id)
162
+ if content is None:
163
+ return f"No content found for {node_id}"
164
+
165
+ # Generate pointer name
166
+ pointer = generate_pointer_name(node.header)
167
+
168
+ # Store as variable
169
+ variable_store.assign(pointer, content, node_id)
170
+
171
+ logger.info(
172
+ "node_expanded",
173
+ node_id=node_id,
174
+ pointer=pointer,
175
+ chars=len(content),
176
+ )
177
+
178
+ return f"Stored content as {pointer} ({len(content)} chars)"
179
+
180
+ def store_finding(
181
+ pointer_name: str,
182
+ content: str,
183
+ source_node_id: str = "",
184
+ ) -> str:
185
+ """
186
+ Store a finding as a pointer variable.
187
+
188
+ Args:
189
+ pointer_name: Name like $LIABILITY_CLAUSE (must start with $)
190
+ content: Full text content to store
191
+ source_node_id: Source node for traceability
192
+ """
193
+ if not pointer_name.startswith("$"):
194
+ pointer_name = "$" + pointer_name.upper()
195
+
196
+ try:
197
+ meta = variable_store.assign(pointer_name, content, source_node_id)
198
+ return f"Stored as {pointer_name} ({meta.char_count} chars)"
199
+ except Exception as e:
200
+ return f"Error storing: {e}"
201
+
202
+ def compare_variables(*pointers: str) -> str:
203
+ """
204
+ Compare multiple stored variables.
205
+
206
+ Resolves pointers and returns content for comparison.
207
+ Use during synthesis phase only.
208
+ """
209
+ results = []
210
+ for pointer in pointers:
211
+ content = variable_store.resolve(pointer)
212
+ if content:
213
+ results.append(f"=== {pointer} ===\n{content}")
214
+ else:
215
+ results.append(f"=== {pointer} ===\n[Not found]")
216
+
217
+ return "\n\n".join(results)
218
+
219
+ def list_stored_variables() -> str:
220
+ """List all stored variables with metadata."""
221
+ variables = variable_store.list_variables()
222
+ if not variables:
223
+ return "No variables stored yet."
224
+
225
+ lines = ["Stored Variables:"]
226
+ for v in variables:
227
+ lines.append(f" {v.pointer}: {v.char_count} chars from {v.source_node_id}")
228
+
229
+ return "\n".join(lines)
230
+
231
+ def synthesize_from_variables() -> str:
232
+ """
233
+ Get all stored content for final synthesis.
234
+
235
+ Call this at the end to get full content for answer generation.
236
+ """
237
+ variables = variable_store.list_variables()
238
+ if not variables:
239
+ return "No variables to synthesize from."
240
+
241
+ parts = []
242
+ for v in variables:
243
+ content = variable_store.resolve(v.pointer)
244
+ if content:
245
+ parts.append(f"=== {v.pointer} ===\n{content}")
246
+
247
+ return "\n\n".join(parts)
248
+
249
+ return {
250
+ "get_node_summary": get_node_summary,
251
+ "navigate_to_child": navigate_to_child,
252
+ "expand_node": expand_node,
253
+ "store_finding": store_finding,
254
+ "compare_variables": compare_variables,
255
+ "list_stored_variables": list_stored_variables,
256
+ "synthesize_from_variables": synthesize_from_variables,
257
+ }
258
+
259
+
260
+ # =============================================================================
261
+ # State Management Functions
262
+ # =============================================================================
263
+
264
+
265
+ def create_initial_state(
266
+ question: str,
267
+ root_node_id: str,
268
+ max_iterations: int = 20,
269
+ top_k: int = 3,
270
+ metadata: dict[str, Any] | None = None,
271
+ tot_selection_threshold: float = 0.4,
272
+ tot_dead_end_threshold: float = 0.1,
273
+ ) -> AgentState:
274
+ """Create the initial agent state with ToT support."""
275
+ return AgentState(
276
+ question=question,
277
+ sub_questions=[],
278
+ pending_questions=[],
279
+ current_sub_question=None,
280
+ current_node_id=root_node_id,
281
+ visited_nodes=[],
282
+ navigation_path=[root_node_id],
283
+ # Tree of Thoughts state
284
+ nodes_to_visit=[],
285
+ scored_candidates=[],
286
+ backtrack_stack=[],
287
+ dead_ends=[],
288
+ top_k=top_k,
289
+ tot_selection_threshold=tot_selection_threshold,
290
+ tot_dead_end_threshold=tot_dead_end_threshold,
291
+ # Variable stitching
292
+ variables=[],
293
+ context="",
294
+ answer=None,
295
+ confidence=0.0,
296
+ metadata=metadata or {},
297
+ trace=[],
298
+ iteration=0,
299
+ max_iterations=max_iterations,
300
+ )
301
+
302
+
303
+ def add_trace_entry(
304
+ state: AgentState,
305
+ node_type: Literal["decomposition", "navigation", "variable_stitching", "synthesis"],
306
+ action: str,
307
+ details: dict | None = None,
308
+ ) -> None:
309
+ """Add a trace entry to the state."""
310
+ entry = TraceEntry(
311
+ timestamp=datetime.now(timezone.utc).isoformat(),
312
+ node_type=node_type,
313
+ action=action,
314
+ details=details or {},
315
+ )
316
+ state["trace"].append(entry.model_dump())
317
+
318
+
319
+ # =============================================================================
320
+ # Tree of Thoughts (ToT) Prompting - Section 7.2
321
+ # =============================================================================
322
+
323
+ # ToT System Prompt as specified in the research paper
324
+ TOT_SYSTEM_PROMPT = """You are a Deep Research Agent navigating a document tree.
325
+
326
+ You are currently at Node: {current_node_summary}
327
+ Children Nodes: {children_summaries}
328
+ Your Goal: {query}
329
+
330
+ EVALUATION TASK:
331
+ For each child node, estimate the probability (0.0 to 1.0) that it contains relevant information for the goal.
332
+
333
+ INSTRUCTIONS:
334
+ 1. Evaluate: For each child node, analyze its summary and estimate relevance probability.
335
+ 2. Be OPEN-MINDED: Select nodes with probability > {selection_threshold} (moderate evidence of relevance).
336
+ 3. Look for matches: Prefer nodes with facts/entities mentioned in the query, but also consider broad thematic matches.
337
+ 4. Balance PRECISION and RECALL: Do not prune branches too early. If unsure, include the node.
338
+ 5. Plan: Select the top-{top_k} most promising nodes.
339
+ 6. Reasoning: Explain briefly what SPECIFIC content in the summary makes it relevant.
340
+ 7. Backtrack Signal: If NO child seems relevant (all probabilities < {dead_end_threshold}), report "DEAD_END".
341
+
342
+ OUTPUT FORMAT (JSON):
343
+ {{
344
+ "evaluations": [
345
+ {{"node_id": "...", "probability": 0.85, "reasoning": "Summary mentions X which directly relates to query about Y"}},
346
+ {{"node_id": "...", "probability": 0.60, "reasoning": "Contains information about Z"}},
347
+ ...
348
+ ],
349
+ "selected_nodes": ["node_id_1", "node_id_2", ...],
350
+ "is_dead_end": false,
351
+ "backtrack_reason": null
352
+ }}
353
+
354
+ If this is a dead end:
355
+ {{
356
+ "evaluations": [...],
357
+ "selected_nodes": [],
358
+ "is_dead_end": true,
359
+ "backtrack_reason": "None of the children appear to contain information about X."
360
+ }}
361
+
362
+ Respond ONLY with the JSON, no other text."""
363
+
364
+
365
+ def _format_node_summary(node: SkeletonNode) -> str:
366
+ """Format a node's summary for the ToT prompt."""
367
+ return f"[{node.node_id}] {node.header}: {node.summary or '(no summary)'}"
368
+
369
+
370
+ def _format_children_summaries(
371
+ skeleton: dict[str, SkeletonNode],
372
+ child_ids: list[str],
373
+ ) -> str:
374
+ """Format all children summaries for the ToT prompt."""
375
+ if not child_ids:
376
+ return "(no children - this is a leaf node)"
377
+
378
+ lines = []
379
+ for child_id in child_ids:
380
+ child = skeleton.get(child_id)
381
+ if child:
382
+ lines.append(f" - {_format_node_summary(child)}")
383
+
384
+ return "\n".join(lines) if lines else "(no children found)"
385
+
386
+
387
+ def evaluate_children_with_tot(
388
+ state: AgentState,
389
+ skeleton: dict[str, SkeletonNode],
390
+ top_k_override: int | None = None,
391
+ ) -> dict[str, Any]:
392
+ """
393
+ Use Tree of Thoughts prompting to evaluate child nodes.
394
+
395
+ This implements Section 7.2 of the research paper:
396
+ "We use Tree of Thoughts (ToT) which explicitly encourages the model
397
+ to explore multiple branches of the index before committing to a path."
398
+
399
+ Args:
400
+ state: Current agent state.
401
+ skeleton: Skeleton index.
402
+ top_k_override: Optional override for adaptive exploration.
403
+
404
+ Returns:
405
+ Dictionary with evaluations, selected_nodes, is_dead_end, and backtrack_reason.
406
+ """
407
+ node_id = state.get("current_node_id")
408
+ if node_id is None:
409
+ return {"evaluations": [], "selected_nodes": [], "is_dead_end": True, "backtrack_reason": "No current node"}
410
+
411
+ node = skeleton.get(node_id)
412
+ if node is None:
413
+ return {"evaluations": [], "selected_nodes": [], "is_dead_end": True, "backtrack_reason": "Node not found"}
414
+
415
+ # If no children, this is a leaf - expand instead of traverse
416
+ if not node.child_ids:
417
+ return {"evaluations": [], "selected_nodes": [], "is_dead_end": False, "backtrack_reason": None, "is_leaf": True}
418
+
419
+ # Format the ToT prompt
420
+ current_summary = _format_node_summary(node)
421
+ children_summaries = _format_children_summaries(skeleton, node.child_ids)
422
+ query = state.get("current_sub_question") or state.get("question", "")
423
+ top_k = top_k_override if top_k_override is not None else state.get("top_k", 3)
424
+ selection_threshold = state.get("tot_selection_threshold", 0.4)
425
+ dead_end_threshold = state.get("tot_dead_end_threshold", 0.1)
426
+
427
+ prompt = TOT_SYSTEM_PROMPT.format(
428
+ current_node_summary=current_summary,
429
+ children_summaries=children_summaries,
430
+ query=query,
431
+ top_k=top_k,
432
+ selection_threshold=selection_threshold,
433
+ dead_end_threshold=dead_end_threshold,
434
+ )
435
+
436
+ # Call the LLM
437
+ try:
438
+ from rnsr.llm import get_llm
439
+ import json
440
+
441
+ llm = get_llm()
442
+ response = llm.complete(prompt)
443
+ response_text = str(response).strip()
444
+
445
+ # Parse JSON response with robust error handling
446
+ try:
447
+ # Handle potential markdown code blocks
448
+ if "```" in response_text:
449
+ match = re.search(r'\{[\s\S]*\}', response_text)
450
+ if match:
451
+ response_text = match.group(0)
452
+
453
+ result = json.loads(response_text)
454
+ except json.JSONDecodeError:
455
+ logger.warning("tot_json_repair_attempt", original_response=response_text)
456
+ # Fallback to asking the LLM to fix the JSON
457
+ repair_prompt = f"""The following text is a malformed JSON object. Please fix it and return ONLY the corrected JSON. Do not add any commentary.
458
+
459
+ Malformed JSON:
460
+ {response_text}"""
461
+ from rnsr.llm import get_llm
462
+
463
+ llm = get_llm()
464
+ repaired_response_text = str(llm.complete(repair_prompt)).strip()
465
+
466
+ # Final attempt to parse the repaired JSON
467
+ if "```" in repaired_response_text:
468
+ match = re.search(r'\{[\s\S]*\}', repaired_response_text)
469
+ if match:
470
+ repaired_response_text = match.group(0)
471
+
472
+ result = json.loads(repaired_response_text)
473
+
474
+ logger.debug(
475
+ "tot_evaluation_complete",
476
+ node_id=node_id,
477
+ evaluations=len(result.get("evaluations", [])),
478
+ selected=result.get("selected_nodes", []),
479
+ is_dead_end=result.get("is_dead_end", False),
480
+ )
481
+
482
+ return result
483
+
484
+ except Exception as e:
485
+ logger.warning("tot_evaluation_failed", error=str(e), node_id=node_id)
486
+
487
+ # Fallback: use simple heuristic (select first unvisited children)
488
+ visited = state.get("visited_nodes", [])
489
+ dead_ends = state.get("dead_ends", [])
490
+
491
+ unvisited = [
492
+ cid for cid in node.child_ids
493
+ if cid not in visited and cid not in dead_ends
494
+ ]
495
+
496
+ return {
497
+ "evaluations": [{"node_id": cid, "probability": 0.5, "reasoning": "Fallback selection"} for cid in unvisited[:top_k]],
498
+ "selected_nodes": unvisited[:top_k],
499
+ "is_dead_end": len(unvisited) == 0,
500
+ "backtrack_reason": "No unvisited children" if len(unvisited) == 0 else None,
501
+ }
502
+
503
+
504
+ def backtrack_to_parent(
505
+ state: AgentState,
506
+ skeleton: dict[str, SkeletonNode],
507
+ ) -> AgentState:
508
+ """
509
+ Backtrack to parent node when current path is a dead end.
510
+
511
+ Implements the backtracking logic from Section 7.2:
512
+ "Backtrack: If a node yields no useful info, report 'Dead End' and return to parent."
513
+ """
514
+ new_state = cast(AgentState, dict(state))
515
+
516
+ current_id = new_state["current_node_id"]
517
+
518
+ # Mark current node as dead end
519
+ if current_id and current_id not in new_state["dead_ends"]:
520
+ new_state["dead_ends"].append(current_id)
521
+
522
+ # Try to backtrack using the backtrack stack
523
+ if new_state["backtrack_stack"]:
524
+ parent_id = new_state["backtrack_stack"].pop()
525
+ new_state["current_node_id"] = parent_id
526
+ new_state["navigation_path"].append(parent_id)
527
+
528
+ add_trace_entry(
529
+ new_state,
530
+ "navigation",
531
+ f"Backtracked to {parent_id} (dead end at {current_id})",
532
+ {"from": current_id, "to": parent_id, "reason": "dead_end"},
533
+ )
534
+
535
+ logger.debug("backtrack_success", from_node=current_id, to_node=parent_id)
536
+ else:
537
+ # No parent to backtrack to - we're done exploring
538
+ add_trace_entry(
539
+ new_state,
540
+ "navigation",
541
+ "Cannot backtrack - at root or stack empty",
542
+ {"current": current_id},
543
+ )
544
+ logger.debug("backtrack_failed", reason="empty_stack")
545
+
546
+ new_state["iteration"] = new_state["iteration"] + 1
547
+ return new_state
548
+
549
+
550
+ # =============================================================================
551
+ # LangGraph Node Functions
552
+ # =============================================================================
553
+
554
+
555
+ # Decomposition prompt for RLM recursive sub-task generation
556
+ DECOMPOSITION_PROMPT = """You are analyzing a complex query to decompose it into sub-tasks.
557
+
558
+ Query: {query}
559
+
560
+ Document Structure (top-level sections):
561
+ {structure}
562
+
563
+ TASK: Decompose this query into specific sub-tasks that can be executed independently.
564
+
565
+ RULES:
566
+ 1. Each sub-task should target a specific section or piece of information
567
+ 2. Sub-tasks should be answerable by reading specific document sections
568
+ 3. For comparison queries, create one sub-task per item being compared
569
+ 4. For multi-hop queries, create sequential sub-tasks (find A, then use A to find B)
570
+ 5. Maximum 5 sub-tasks to maintain efficiency
571
+
572
+ OUTPUT FORMAT (JSON):
573
+ {{
574
+ "sub_tasks": [
575
+ {{"id": 1, "task": "Find X in section Y", "target_section": "section_hint"}},
576
+ {{"id": 2, "task": "Extract Z from W", "target_section": "section_hint"}}
577
+ ],
578
+ "synthesis_plan": "How to combine sub-task results into final answer"
579
+ }}
580
+
581
+ Return ONLY valid JSON."""
582
+
583
+
584
+ def decompose_query(state: AgentState) -> AgentState:
585
+ """
586
+ Decompose the main question into sub-questions using LLM.
587
+
588
+ Implements Section 2.2 "The Recursive Loop":
589
+ "The model's capability to divide a complex reasoning task into
590
+ smaller, manageable sub-tasks and invoke instances of itself to solve them."
591
+ """
592
+ new_state = cast(AgentState, dict(state)) # Copy
593
+ question = new_state["question"]
594
+
595
+ # Try LLM-based decomposition
596
+ try:
597
+ from rnsr.llm import get_llm
598
+ import json
599
+
600
+ llm = get_llm()
601
+
602
+ # Get document structure for context
603
+ # Note: skeleton is not available here, so we'll do basic decomposition
604
+ structure_hint = "Document sections available for navigation"
605
+
606
+ prompt = DECOMPOSITION_PROMPT.format(
607
+ query=question,
608
+ structure=structure_hint,
609
+ )
610
+
611
+ response = llm.complete(prompt)
612
+ response_text = str(response).strip()
613
+
614
+ # Extract JSON from response
615
+ json_match = re.search(r'\{[\s\S]*\}', response_text)
616
+ if json_match:
617
+ decomposition = json.loads(json_match.group())
618
+ sub_tasks = decomposition.get("sub_tasks", [])
619
+
620
+ if sub_tasks:
621
+ sub_questions = [t["task"] for t in sub_tasks]
622
+ new_state["sub_questions"] = sub_questions
623
+ new_state["pending_questions"] = sub_questions.copy()
624
+ new_state["current_sub_question"] = sub_questions[0]
625
+
626
+ add_trace_entry(
627
+ new_state,
628
+ "decomposition",
629
+ f"LLM decomposed into {len(sub_questions)} sub-tasks",
630
+ {
631
+ "sub_questions": sub_questions,
632
+ "synthesis_plan": decomposition.get("synthesis_plan", ""),
633
+ },
634
+ )
635
+
636
+ new_state["iteration"] = new_state["iteration"] + 1
637
+ return new_state
638
+
639
+ except Exception as e:
640
+ logger.warning("llm_decomposition_failed", error=str(e))
641
+
642
+ # Fallback: simple decomposition patterns
643
+ sub_questions = _simple_decompose(question)
644
+ new_state["sub_questions"] = sub_questions
645
+ new_state["pending_questions"] = sub_questions.copy()
646
+ new_state["current_sub_question"] = sub_questions[0] if sub_questions else question
647
+
648
+ add_trace_entry(
649
+ new_state,
650
+ "decomposition",
651
+ f"Decomposed into {len(sub_questions)} sub-questions (fallback)",
652
+ {"sub_questions": sub_questions},
653
+ )
654
+
655
+ new_state["iteration"] = new_state["iteration"] + 1
656
+ return new_state
657
+
658
+
659
+ def _simple_decompose(question: str) -> list[str]:
660
+ """
661
+ Simple pattern-based decomposition fallback.
662
+
663
+ Handles common query patterns without LLM.
664
+ """
665
+ question_lower = question.lower()
666
+
667
+ # Pattern: "Compare X and Y" or "X vs Y"
668
+ compare_patterns = [
669
+ r"compare\s+(.+?)\s+(?:and|with|to|vs\.?)\s+(.+)",
670
+ r"difference(?:s)?\s+between\s+(.+?)\s+and\s+(.+)",
671
+ r"(.+?)\s+vs\.?\s+(.+)",
672
+ ]
673
+
674
+ for pattern in compare_patterns:
675
+ match = re.search(pattern, question_lower)
676
+ if match:
677
+ item1, item2 = match.groups()
678
+ return [
679
+ f"Find information about {item1.strip()}",
680
+ f"Find information about {item2.strip()}",
681
+ ]
682
+
683
+ # Pattern: "What are the X in sections A, B, and C"
684
+ multi_section = re.search(
685
+ r"in\s+(?:sections?\s+)?(.+?,\s*.+?(?:,\s*.+)*)",
686
+ question_lower,
687
+ )
688
+ if multi_section:
689
+ sections = [s.strip() for s in multi_section.group(1).split(",")]
690
+ return [f"Find relevant information in {s}" for s in sections[:5]]
691
+
692
+ # Pattern: "List all X" or "Find all X" - may need iteration
693
+ if re.search(r"(list|find|show)\s+all", question_lower):
694
+ return [
695
+ f"Search for all relevant sections",
696
+ question, # Original as synthesis query
697
+ ]
698
+
699
+ # Default: single question
700
+ return [question]
701
+
702
+
703
+ def navigate_tree(
704
+ state: AgentState,
705
+ skeleton: dict[str, SkeletonNode],
706
+ ) -> AgentState:
707
+ """
708
+ Navigate the document tree based on current sub-question.
709
+
710
+ This function now handles the node visitation queue for multi-path exploration.
711
+ """
712
+ new_state = cast(AgentState, dict(state))
713
+
714
+ # Add detailed logging to debug navigation loops
715
+ logger.debug(
716
+ "navigate_step_start",
717
+ iteration=new_state["iteration"],
718
+ current_node=new_state["current_node_id"],
719
+ queue=new_state["nodes_to_visit"],
720
+ visited=new_state["visited_nodes"],
721
+ )
722
+
723
+ # If current node is None, try to pop from the visit queue
724
+ if new_state["current_node_id"] is None and new_state["nodes_to_visit"]:
725
+ next_node_id = new_state["nodes_to_visit"].pop(0)
726
+ new_state["current_node_id"] = next_node_id
727
+
728
+ # Also add to navigation path
729
+ if next_node_id not in new_state["navigation_path"]:
730
+ new_state["navigation_path"].append(next_node_id)
731
+
732
+ add_trace_entry(
733
+ new_state,
734
+ "navigation",
735
+ f"Visiting next node from queue: {next_node_id}",
736
+ {"queue_size": len(new_state["nodes_to_visit"])},
737
+ )
738
+
739
+ node_id = new_state["current_node_id"]
740
+ if node_id is None:
741
+ add_trace_entry(new_state, "navigation", "No current node to navigate to and queue is empty")
742
+ return new_state
743
+
744
+ node = skeleton.get(node_id)
745
+
746
+ if node is None:
747
+ add_trace_entry(new_state, "navigation", f"Node {node_id} not found")
748
+ return new_state
749
+
750
+ # Don't add to visited here - let expand/traverse handle that
751
+ # to avoid preventing expansion of newly queued nodes
752
+
753
+ add_trace_entry(
754
+ new_state,
755
+ "navigation",
756
+ f"At node {node_id}: {node.header}",
757
+ {"summary": node.summary, "children": node.child_ids},
758
+ )
759
+
760
+ new_state["iteration"] = new_state["iteration"] + 1
761
+ return new_state
762
+
763
+
764
+ def should_expand(
765
+ state: AgentState,
766
+ skeleton: dict[str, SkeletonNode],
767
+ ) -> Literal["expand", "traverse", "backtrack", "done"]:
768
+ """
769
+ Decide whether to expand (fetch content), traverse (go to children),
770
+ backtrack (dead end), or finish.
771
+
772
+ Uses Tree of Thoughts evaluation to make intelligent decisions.
773
+ """
774
+ node_id = state.get("current_node_id")
775
+ if node_id is None:
776
+ return "done"
777
+
778
+ node = skeleton.get(node_id)
779
+ if node is None:
780
+ return "done"
781
+
782
+ # Check iteration limit
783
+ if state.get("iteration", 0) >= state.get("max_iterations", 20):
784
+ return "done"
785
+
786
+ # Removed variable count limit - iteration limit is sufficient
787
+
788
+ # If no children, must expand (leaf node)
789
+ # But only if not already visited
790
+ visited = state.get("visited_nodes", [])
791
+ if not node.child_ids:
792
+ if node_id in visited:
793
+ # Already expanded this leaf, done with it
794
+ return "done"
795
+ return "expand"
796
+
797
+ # Check if all children are visited or dead ends
798
+ visited = state.get("visited_nodes", [])
799
+ dead_ends = state.get("dead_ends", [])
800
+ unvisited_children = [
801
+ cid for cid in node.child_ids
802
+ if cid not in visited and cid not in dead_ends
803
+ ]
804
+
805
+ if not unvisited_children:
806
+ # All children explored - backtrack or done
807
+ if state.get("backtrack_stack"):
808
+ return "backtrack"
809
+ return "done"
810
+
811
+ # Adaptive exploration: increase top_k if we haven't found much information yet
812
+ base_top_k = state.get("top_k", 3)
813
+ variables_found = len(state.get("variables", []))
814
+ iteration = state.get("iteration", 0)
815
+
816
+ if iteration > 5 and variables_found == 0:
817
+ adaptive_top_k = min(len(node.child_ids), base_top_k * 3)
818
+ elif iteration > 3 and variables_found < 2:
819
+ adaptive_top_k = min(len(node.child_ids), base_top_k * 2)
820
+ else:
821
+ adaptive_top_k = base_top_k
822
+
823
+ # Use ToT evaluation to decide
824
+ tot_result = evaluate_children_with_tot(state, skeleton, top_k_override=adaptive_top_k)
825
+
826
+ # Check for dead end signal from ToT
827
+ if tot_result.get("is_dead_end", False):
828
+ if state.get("backtrack_stack"):
829
+ return "backtrack"
830
+ return "done"
831
+
832
+ # Check if ToT says this is a leaf (should expand)
833
+ if tot_result.get("is_leaf", False):
834
+ return "expand"
835
+
836
+ # If ToT selected nodes, traverse
837
+ if tot_result.get("selected_nodes"):
838
+ return "traverse"
839
+
840
+ # Fallback: expand if at deep level (H3+)
841
+ if node.level >= 3:
842
+ return "expand"
843
+
844
+ # Otherwise traverse to children
845
+ return "traverse"
846
+
847
+
848
+ def expand_current_node(
849
+ state: AgentState,
850
+ skeleton: dict[str, SkeletonNode],
851
+ kv_store: KVStore,
852
+ variable_store: VariableStore,
853
+ ) -> AgentState:
854
+ """
855
+ EXPAND: Fetch full content and store as variable.
856
+ """
857
+ new_state = cast(AgentState, dict(state))
858
+
859
+ node_id = new_state["current_node_id"]
860
+ if node_id is None:
861
+ return new_state
862
+
863
+ # Check if already visited to prevent infinite loops
864
+ if node_id in new_state.get("visited_nodes", []):
865
+ add_trace_entry(
866
+ new_state,
867
+ "navigation",
868
+ f"Skipping already-visited node {node_id}",
869
+ {"node_id": node_id},
870
+ )
871
+ new_state["current_node_id"] = None
872
+ new_state["iteration"] = new_state["iteration"] + 1
873
+ return new_state
874
+
875
+ node = skeleton.get(node_id)
876
+
877
+ if node is None:
878
+ new_state["current_node_id"] = None
879
+ new_state["iteration"] = new_state["iteration"] + 1
880
+ return new_state
881
+
882
+ # Fetch full content
883
+ content = kv_store.get(node_id)
884
+ if content is None:
885
+ add_trace_entry(new_state, "variable_stitching", f"No content for {node_id}")
886
+ new_state["current_node_id"] = None
887
+ new_state["iteration"] = new_state["iteration"] + 1
888
+ return new_state
889
+
890
+ # Generate and store as variable
891
+ pointer = generate_pointer_name(node.header)
892
+ variable_store.assign(pointer, content, node_id)
893
+
894
+ new_state["variables"].append(pointer)
895
+ new_state["context"] += f"\nFound: {pointer} (from {node.header})"
896
+
897
+ # Mark node as visited and clear current node
898
+ if node_id not in new_state["visited_nodes"]:
899
+ new_state["visited_nodes"].append(node_id)
900
+ new_state["current_node_id"] = None # Process queue next
901
+
902
+ add_trace_entry(
903
+ new_state,
904
+ "variable_stitching",
905
+ f"Stored {pointer}",
906
+ {"node_id": node_id, "header": node.header, "chars": len(content)},
907
+ )
908
+
909
+ new_state["iteration"] = new_state["iteration"] + 1
910
+ return new_state
911
+
912
+
913
+ def traverse_to_children(
914
+ state: AgentState,
915
+ skeleton: dict[str, SkeletonNode],
916
+ semantic_searcher: SemanticSearcher | None = None,
917
+ ) -> AgentState:
918
+ """
919
+ TRAVERSE: Navigate to child nodes using Tree of Thoughts reasoning.
920
+
921
+ Primary Strategy: Tree of Thoughts (ToT) with adaptive exploration.
922
+ - Uses LLM reasoning to evaluate child nodes based on summaries
923
+ - Dynamically increases exploration when insufficient variables found
924
+ - Preserves document structure and context (per research paper Section 7.2)
925
+
926
+ Optional: Semantic search as shortcut for finding entry points (Section 9.1)
927
+ """
928
+ new_state = cast(AgentState, dict(state))
929
+
930
+ node_id = new_state["current_node_id"]
931
+ if node_id is None:
932
+ return new_state
933
+
934
+ node = skeleton.get(node_id)
935
+
936
+ if node is None or not node.child_ids:
937
+ return new_state
938
+
939
+ question = new_state["question"]
940
+ base_top_k = new_state.get("top_k", 3)
941
+
942
+ # Adaptive exploration: increase top_k if we haven't found much information yet
943
+ # This addresses the original issue: "agent should be able to explore all nodes if it needs to"
944
+ variables_found = len(new_state.get("variables", []))
945
+ iteration = new_state.get("iteration", 0)
946
+
947
+ # Dynamic top_k based on progress
948
+ if iteration > 5 and variables_found == 0:
949
+ # Not finding anything - expand search significantly
950
+ adaptive_top_k = min(len(node.child_ids), base_top_k * 3)
951
+ add_trace_entry(
952
+ new_state,
953
+ "navigation",
954
+ f"Expanding search: {variables_found} variables after {iteration} iterations",
955
+ {"base_top_k": base_top_k, "adaptive_top_k": adaptive_top_k},
956
+ )
957
+ elif iteration > 3 and variables_found < 2:
958
+ # Finding very little - expand moderately
959
+ adaptive_top_k = min(len(node.child_ids), base_top_k * 2)
960
+ else:
961
+ # Normal exploration
962
+ adaptive_top_k = base_top_k
963
+
964
+ # Use Tree of Thoughts as primary navigation method (per research paper)
965
+ tot_result = evaluate_children_with_tot(new_state, skeleton, top_k_override=adaptive_top_k)
966
+
967
+ # Check for dead end
968
+ if tot_result.get("is_dead_end", False):
969
+ add_trace_entry(
970
+ new_state,
971
+ "navigation",
972
+ f"Dead end detected at {node_id}: {tot_result.get('backtrack_reason', 'Unknown')}",
973
+ {"node_id": node_id, "backtrack_reason": tot_result.get("backtrack_reason")},
974
+ )
975
+ # Mark as dead end and don't change current node
976
+ if node_id not in new_state["dead_ends"]:
977
+ new_state["dead_ends"].append(node_id)
978
+ new_state["iteration"] = new_state["iteration"] + 1
979
+ return new_state
980
+
981
+ selected_nodes = tot_result.get("selected_nodes", [])
982
+ evaluations = tot_result.get("evaluations", [])
983
+ new_state["scored_candidates"] = evaluations
984
+
985
+ # Queue selected children for visitation
986
+ if selected_nodes:
987
+ new_state["nodes_to_visit"].extend(selected_nodes)
988
+ # Mark parent as visited since we've traversed its children
989
+ if node_id not in new_state["visited_nodes"]:
990
+ new_state["visited_nodes"].append(node_id)
991
+ # Unset current node to force the next navigate step to pull from the queue
992
+ new_state["current_node_id"] = None
993
+ add_trace_entry(
994
+ new_state,
995
+ "navigation",
996
+ f"Queued {len(selected_nodes)} children from {node_id}",
997
+ {"nodes": selected_nodes, "parent": node_id},
998
+ )
999
+ else:
1000
+ # No candidates available - this should trigger backtracking
1001
+ add_trace_entry(
1002
+ new_state,
1003
+ "navigation",
1004
+ f"No unvisited candidates at {node_id}",
1005
+ {"node_id": node_id},
1006
+ )
1007
+
1008
+ new_state["iteration"] = new_state["iteration"] + 1
1009
+ return new_state
1010
+
1011
+
1012
+ def synthesize_answer(
1013
+ state: AgentState,
1014
+ variable_store: VariableStore,
1015
+ ) -> AgentState:
1016
+ """
1017
+ Synthesize final answer from stored variables using an LLM.
1018
+
1019
+ Uses the configured LLM to generate a concise answer
1020
+ from the resolved variable content.
1021
+ """
1022
+ new_state = cast(AgentState, dict(state))
1023
+
1024
+ # Resolve all variables
1025
+ pointers = new_state["variables"]
1026
+
1027
+ if not pointers:
1028
+ new_state["answer"] = "No relevant content found."
1029
+ new_state["confidence"] = 0.0
1030
+ else:
1031
+ # Collect all content
1032
+ contents = []
1033
+ for pointer in pointers:
1034
+ content = variable_store.resolve(pointer)
1035
+ if content:
1036
+ contents.append(content)
1037
+
1038
+ context_text = "\n\n---\n\n".join(contents)
1039
+ question = new_state["question"]
1040
+
1041
+ # Use LLM to synthesize answer
1042
+ try:
1043
+ from rnsr.llm import get_llm
1044
+
1045
+ llm = get_llm()
1046
+
1047
+ metadata = new_state.get("metadata", {})
1048
+ options = metadata.get("options")
1049
+
1050
+ if options and isinstance(options, list) and len(options) > 0:
1051
+ # QuALITY multiple-choice question
1052
+ # Ensure each option is a string and format with letters
1053
+ options_text = "\n".join([f"{chr(65+i)}. {str(opt)}" for i, opt in enumerate(options)])
1054
+ prompt = f"""Based on the provided context, answer this multiple-choice question.
1055
+
1056
+ Question: {question}
1057
+
1058
+ Options:
1059
+ {options_text}
1060
+
1061
+ Context:
1062
+ {context_text}
1063
+
1064
+ Instructions:
1065
+ 1. Read the context carefully
1066
+ 2. Determine which option is best supported by the evidence in the context
1067
+ 3. Respond with ONLY the letter and full text of the correct option
1068
+ 4. Format: "X. [complete option text]" where X is A, B, C, or D
1069
+
1070
+ Your answer:"""
1071
+ else:
1072
+ # Standard open-ended question
1073
+ prompt = f"""Based on the following context, answer the question concisely.
1074
+
1075
+ IMPORTANT INSTRUCTIONS:
1076
+ - If the context contains information that DIRECTLY answers the question, provide that answer
1077
+ - If the context contains information that allows you to INFER the answer, provide your best inference
1078
+ - Use evidence from the context to support your answer
1079
+ - Only say "Cannot determine from available context" if the context is completely unrelated or missing critical information
1080
+ - It's better to provide a reasonable answer based on available evidence than to say "cannot determine"
1081
+
1082
+ Question: {question}
1083
+
1084
+ Context:
1085
+ {context_text}
1086
+
1087
+ Answer (be concise, direct, and confident):"""
1088
+
1089
+ response = llm.complete(prompt)
1090
+ answer = str(response).strip()
1091
+
1092
+ # Normalize multiple-choice answers
1093
+ if options:
1094
+ answer_lower = answer.lower().strip()
1095
+ matched = False
1096
+
1097
+ for i, opt in enumerate(options):
1098
+ letter = chr(65 + i) # A, B, C, D
1099
+ opt_lower = opt.lower()
1100
+
1101
+ # Match patterns: "A.", "A)", "(A)", "A. answer text"
1102
+ if (answer_lower.startswith(f"{letter.lower()}.") or
1103
+ answer_lower.startswith(f"{letter.lower()})") or
1104
+ answer_lower.startswith(f"({letter.lower()})")):
1105
+ answer = opt
1106
+ matched = True
1107
+ break
1108
+
1109
+ # Exact match (case insensitive)
1110
+ if answer_lower == opt_lower:
1111
+ answer = opt
1112
+ matched = True
1113
+ break
1114
+
1115
+ # Check if option text is contained in answer
1116
+ if opt_lower in answer_lower or answer_lower in opt_lower:
1117
+ # Use similarity to pick best match if multiple partial matches
1118
+ answer = opt
1119
+ matched = True
1120
+ break
1121
+
1122
+ new_state["confidence"] = 0.8 if matched else 0.1
1123
+ else:
1124
+ new_state["confidence"] = min(1.0, len(pointers) * 0.3)
1125
+
1126
+ new_state["answer"] = answer
1127
+
1128
+ except Exception as e:
1129
+ logger.warning("llm_synthesis_failed", error=str(e))
1130
+ # Fallback to context concatenation
1131
+ new_state["answer"] = f"Context found:\n{context_text}"
1132
+ new_state["confidence"] = min(1.0, len(pointers) * 0.2)
1133
+
1134
+ add_trace_entry(
1135
+ new_state,
1136
+ "synthesis",
1137
+ "Generated answer from variables",
1138
+ {"variables_used": pointers, "confidence": new_state["confidence"]},
1139
+ )
1140
+
1141
+ return new_state
1142
+
1143
+
1144
+ # =============================================================================
1145
+ # RLM Recursive Execution Functions (Section 2.2)
1146
+ # =============================================================================
1147
+
1148
+
1149
+ def execute_sub_task_with_llm(
1150
+ sub_task: str,
1151
+ context: str,
1152
+ llm_fn: Any = None,
1153
+ ) -> str:
1154
+ """
1155
+ Execute a single sub-task using an LLM.
1156
+
1157
+ Implements the recursive LLM pattern from Section 2.2:
1158
+ "For each contract, call the LLM API (invoking itself) with a
1159
+ specific sub-prompt: 'Extract liability clause from this text.'"
1160
+
1161
+ Args:
1162
+ sub_task: The sub-task description/prompt.
1163
+ context: The context to process.
1164
+ llm_fn: Optional LLM function. If None, uses default.
1165
+
1166
+ Returns:
1167
+ LLM response as string.
1168
+ """
1169
+ if llm_fn is None:
1170
+ try:
1171
+ from rnsr.llm import get_llm
1172
+ llm = get_llm()
1173
+ llm_fn = lambda p: str(llm.complete(p))
1174
+ except Exception as e:
1175
+ logger.warning("llm_not_available", error=str(e))
1176
+ return f"[Error: LLM not available - {str(e)}]"
1177
+
1178
+ prompt = f"""{sub_task}
1179
+
1180
+ Context:
1181
+ {context}
1182
+
1183
+ Response:"""
1184
+
1185
+ try:
1186
+ return llm_fn(prompt)
1187
+ except Exception as e:
1188
+ logger.error("sub_task_execution_failed", error=str(e))
1189
+ return f"[Error: {str(e)}]"
1190
+
1191
+
1192
+ def batch_execute_sub_tasks(
1193
+ sub_tasks: list[str],
1194
+ contexts: list[str],
1195
+ batch_size: int = 5,
1196
+ max_parallel: int = 4,
1197
+ ) -> list[str]:
1198
+ """
1199
+ Execute multiple sub-tasks in parallel batches.
1200
+
1201
+ Implements Section 2.3 "Optimization via Batching":
1202
+ "Instead of making 1,000 individual calls to summarize 1,000 paragraphs,
1203
+ the RLM writes code to group paragraphs into chunks of 5 and processes
1204
+ them in parallel threads."
1205
+
1206
+ Args:
1207
+ sub_tasks: List of sub-task prompts.
1208
+ contexts: List of contexts for each sub-task.
1209
+ batch_size: Items per batch (default 5).
1210
+ max_parallel: Max parallel threads (default 4).
1211
+
1212
+ Returns:
1213
+ List of results for each sub-task.
1214
+ """
1215
+ from concurrent.futures import ThreadPoolExecutor, as_completed
1216
+
1217
+ if len(sub_tasks) != len(contexts):
1218
+ raise ValueError("sub_tasks and contexts must have same length")
1219
+
1220
+ if not sub_tasks:
1221
+ return []
1222
+
1223
+ logger.info(
1224
+ "batch_execution_start",
1225
+ num_tasks=len(sub_tasks),
1226
+ batch_size=batch_size,
1227
+ )
1228
+
1229
+ # Get LLM function
1230
+ try:
1231
+ from rnsr.llm import get_llm
1232
+ llm = get_llm()
1233
+ llm_fn = lambda p: str(llm.complete(p))
1234
+ except Exception as e:
1235
+ logger.error("batch_llm_failed", error=str(e))
1236
+ return [f"[Error: {str(e)}]"] * len(sub_tasks)
1237
+
1238
+ results: list[str] = [""] * len(sub_tasks)
1239
+
1240
+ with ThreadPoolExecutor(max_workers=max_parallel) as executor:
1241
+ future_to_idx: dict[Any, int] = {}
1242
+
1243
+ for idx, (task, ctx) in enumerate(zip(sub_tasks, contexts)):
1244
+ future = executor.submit(execute_sub_task_with_llm, task, ctx, llm_fn)
1245
+ future_to_idx[future] = idx
1246
+
1247
+ for future in as_completed(future_to_idx):
1248
+ idx = future_to_idx[future]
1249
+ try:
1250
+ results[idx] = future.result(timeout=120)
1251
+ except Exception as e:
1252
+ results[idx] = f"[Error: {str(e)}]"
1253
+
1254
+ logger.info("batch_execution_complete", num_results=len(results))
1255
+ return results
1256
+
1257
+
1258
+ def process_pending_questions(
1259
+ state: AgentState,
1260
+ skeleton: dict[str, SkeletonNode],
1261
+ kv_store: KVStore,
1262
+ variable_store: VariableStore,
1263
+ ) -> AgentState:
1264
+ """
1265
+ Process all pending sub-questions using recursive LLM calls.
1266
+
1267
+ This implements the full RLM recursive loop:
1268
+ 1. For each pending question, find relevant context
1269
+ 2. Invoke sub-LLM to extract answer
1270
+ 3. Store result as variable
1271
+ 4. Move to next question
1272
+ """
1273
+ new_state = cast(AgentState, dict(state))
1274
+ pending = new_state.get("pending_questions", [])
1275
+
1276
+ if not pending:
1277
+ return new_state
1278
+
1279
+ current_question = pending[0]
1280
+
1281
+ # Find relevant content for this question
1282
+ # Use current node and its children
1283
+ node_id = new_state.get("current_node_id") or "root"
1284
+ node = skeleton.get(node_id)
1285
+
1286
+ if node is None:
1287
+ # Pop and continue
1288
+ new_state["pending_questions"] = pending[1:]
1289
+ return new_state
1290
+
1291
+ # Get context from current node and children
1292
+ context_parts = []
1293
+
1294
+ # Add current node content
1295
+ current_content = kv_store.get(node_id) if node_id else None
1296
+ if current_content:
1297
+ context_parts.append(f"[{node.header}]\n{current_content}")
1298
+
1299
+ # Add children summaries
1300
+ for child_id in node.child_ids[:5]: # Limit to 5 children
1301
+ child = skeleton.get(child_id)
1302
+ if child:
1303
+ child_content = kv_store.get(child_id)
1304
+ if child_content:
1305
+ context_parts.append(f"[{child.header}]\n{child_content[:2000]}")
1306
+
1307
+ context = "\n\n---\n\n".join(context_parts)
1308
+
1309
+ # Execute sub-task with LLM
1310
+ result = execute_sub_task_with_llm(
1311
+ sub_task=f"Answer this question: {current_question}",
1312
+ context=context,
1313
+ )
1314
+
1315
+ # Store as variable
1316
+ pointer = generate_pointer_name(current_question[:30])
1317
+ variable_store.assign(pointer, result, node_id or "root")
1318
+ new_state["variables"].append(pointer)
1319
+
1320
+ add_trace_entry(
1321
+ new_state,
1322
+ "decomposition",
1323
+ f"Processed sub-question via recursive LLM",
1324
+ {
1325
+ "question": current_question,
1326
+ "pointer": pointer,
1327
+ "context_length": len(context),
1328
+ },
1329
+ )
1330
+
1331
+ # Pop processed question
1332
+ new_state["pending_questions"] = pending[1:]
1333
+ if pending[1:]:
1334
+ new_state["current_sub_question"] = pending[1]
1335
+
1336
+ new_state["iteration"] = new_state["iteration"] + 1
1337
+ return new_state
1338
+
1339
+
1340
+ # =============================================================================
1341
+ # Graph Builder (LangGraph)
1342
+ # =============================================================================
1343
+
1344
+
1345
+ def build_navigator_graph(
1346
+ skeleton: dict[str, SkeletonNode],
1347
+ kv_store: KVStore,
1348
+ semantic_searcher: SemanticSearcher | None = None,
1349
+ ) -> Any:
1350
+ """
1351
+ Build the LangGraph state machine for document navigation.
1352
+
1353
+ Implements Tree of Thoughts (ToT) navigation with backtracking support.
1354
+
1355
+ Returns a compiled graph that can be invoked with a question.
1356
+
1357
+ Usage:
1358
+ graph = build_navigator_graph(skeleton, kv_store)
1359
+ result = graph.invoke({"question": "What are the payment terms?"})
1360
+ """
1361
+ try:
1362
+ from langgraph.graph import END, StateGraph
1363
+ except ImportError:
1364
+ raise ImportError(
1365
+ "LangGraph not installed. Install with: pip install langgraph"
1366
+ )
1367
+
1368
+ variable_store = VariableStore()
1369
+
1370
+ # Create graph
1371
+ graph = StateGraph(AgentState)
1372
+
1373
+ # Add nodes
1374
+ graph.add_node("decompose", decompose_query)
1375
+
1376
+ graph.add_node(
1377
+ "navigate",
1378
+ lambda state: navigate_tree(cast(AgentState, state), skeleton),
1379
+ )
1380
+
1381
+ graph.add_node(
1382
+ "expand",
1383
+ lambda state: expand_current_node(cast(AgentState, state), skeleton, kv_store, variable_store),
1384
+ )
1385
+
1386
+ graph.add_node(
1387
+ "traverse",
1388
+ lambda state: traverse_to_children(
1389
+ cast(AgentState, state),
1390
+ skeleton,
1391
+ semantic_searcher,
1392
+ ),
1393
+ )
1394
+
1395
+ # Add backtrack node for ToT dead-end handling
1396
+ graph.add_node(
1397
+ "backtrack",
1398
+ lambda state: backtrack_to_parent(cast(AgentState, state), skeleton),
1399
+ )
1400
+
1401
+ graph.add_node(
1402
+ "synthesize",
1403
+ lambda state: synthesize_answer(cast(AgentState, state), variable_store),
1404
+ )
1405
+
1406
+ # Add edges
1407
+ graph.add_edge("decompose", "navigate")
1408
+
1409
+ # Conditional edge based on expand/traverse/backtrack decision
1410
+ graph.add_conditional_edges(
1411
+ "navigate",
1412
+ lambda s: should_expand(s, skeleton),
1413
+ {
1414
+ "expand": "expand",
1415
+ "traverse": "traverse",
1416
+ "backtrack": "backtrack",
1417
+ "done": "synthesize",
1418
+ },
1419
+ )
1420
+
1421
+ # After expand, traverse, or backtrack, always return to the main navigation
1422
+ # handler, which will decide what to do next (e.g., pull from queue)
1423
+ graph.add_edge("expand", "navigate")
1424
+ graph.add_edge("traverse", "navigate")
1425
+ graph.add_edge("backtrack", "navigate")
1426
+
1427
+ graph.add_edge("synthesize", END)
1428
+
1429
+ # Set entry point
1430
+ graph.set_entry_point("decompose")
1431
+
1432
+ logger.info("navigator_graph_built", features=["ToT", "backtracking"])
1433
+
1434
+ return graph.compile()
1435
+
1436
+
1437
+ # =============================================================================
1438
+ # High-Level API
1439
+ # =============================================================================
1440
+
1441
+
1442
+ def run_navigator(
1443
+ question: str,
1444
+ skeleton: dict[str, SkeletonNode],
1445
+ kv_store: KVStore,
1446
+ max_iterations: int = 20,
1447
+ top_k: int | None = None,
1448
+ use_semantic_search: bool = True,
1449
+ semantic_searcher: SemanticSearcher | None = None,
1450
+ metadata: dict[str, Any] | None = None,
1451
+ tot_selection_threshold: float = 0.4,
1452
+ tot_dead_end_threshold: float = 0.1,
1453
+ ) -> dict[str, Any]:
1454
+ """
1455
+ Run the navigator agent on a question.
1456
+
1457
+ Args:
1458
+ question: User's question.
1459
+ skeleton: Skeleton index.
1460
+ kv_store: KV store with full content.
1461
+ max_iterations: Maximum navigation iterations.
1462
+ top_k: Number of top children to explore (default: auto-detect based on tree depth).
1463
+ use_semantic_search: Use semantic search (O(log N)) instead of ToT evaluation (O(N)).
1464
+ Allows exploring ALL leaf nodes ranked by relevance, preventing missed data.
1465
+ semantic_searcher: Optional pre-built semantic searcher. If None and use_semantic_search=True, creates one.
1466
+ tot_selection_threshold: Minimum probability for ToT node selection (0.0-1.0).
1467
+ tot_dead_end_threshold: Probability threshold for declaring a dead end (0.0-1.0).
1468
+
1469
+ Returns:
1470
+ Dictionary with answer, confidence, trace.
1471
+
1472
+ Example:
1473
+ result = run_navigator(
1474
+ "What are the liability terms?",
1475
+ skeleton,
1476
+ kv_store,
1477
+ use_semantic_search=True, # Enable semantic search
1478
+ )
1479
+ print(result["answer"])
1480
+ """
1481
+ # Get root node
1482
+ root_id = None
1483
+ root_node = None
1484
+ for node in skeleton.values():
1485
+ if node.level == 0:
1486
+ root_id = node.node_id
1487
+ root_node = node
1488
+ break
1489
+
1490
+ if root_id is None:
1491
+ return {
1492
+ "answer": "Error: No root node found in skeleton index.",
1493
+ "confidence": 0.0,
1494
+ "trace": [],
1495
+ }
1496
+
1497
+ # Auto-detect top_k based on tree structure
1498
+ if top_k is None:
1499
+ num_root_children = len(root_node.child_ids) if root_node else 0
1500
+ if num_root_children > 10:
1501
+ # Flat hierarchy (e.g., QuALITY): explore more children
1502
+ top_k = min(10, num_root_children)
1503
+ else:
1504
+ # Deep hierarchy (e.g., PDFs): explore fewer
1505
+ top_k = 3
1506
+
1507
+ # Semantic search disabled by default per research paper Section 9.1:
1508
+ # "Hybridize Search: Give the agent a vector_search tool as a SHORTCUT.
1509
+ # The agent can use vector search to find a starting node and then
1510
+ # SWITCH TO TREE TRAVERSAL for local exploration."
1511
+ #
1512
+ # Primary navigation uses ToT reasoning-based retrieval.
1513
+ if use_semantic_search and semantic_searcher is None:
1514
+ try:
1515
+ from rnsr.indexing.semantic_search import create_semantic_searcher
1516
+ semantic_searcher = create_semantic_searcher(skeleton, kv_store)
1517
+ logger.info(
1518
+ "semantic_search_optional_tool",
1519
+ nodes=len(skeleton),
1520
+ note="Available as shortcut for entry points only",
1521
+ )
1522
+ except Exception as e:
1523
+ logger.warning(
1524
+ "semantic_search_unavailable",
1525
+ error=str(e),
1526
+ )
1527
+ semantic_searcher = None
1528
+
1529
+ logger.info(
1530
+ "using_tot_reasoning_navigation",
1531
+ method="Tree of Thoughts",
1532
+ adaptive_exploration=True,
1533
+ note="LLM reasons about document structure to navigate",
1534
+ )
1535
+
1536
+ # Build and run graph
1537
+ graph = build_navigator_graph(skeleton, kv_store, semantic_searcher)
1538
+
1539
+ initial_state = create_initial_state(
1540
+ question=question,
1541
+ root_node_id=root_id,
1542
+ max_iterations=max_iterations,
1543
+ top_k=top_k,
1544
+ metadata=metadata,
1545
+ tot_selection_threshold=tot_selection_threshold,
1546
+ tot_dead_end_threshold=tot_dead_end_threshold,
1547
+ )
1548
+
1549
+ final_state = graph.invoke(initial_state)
1550
+
1551
+ return {
1552
+ "answer": final_state.get("answer", ""),
1553
+ "confidence": final_state.get("confidence", 0.0),
1554
+ "trace": final_state.get("trace", []),
1555
+ "variables_used": final_state.get("variables", []),
1556
+ "nodes_visited": final_state.get("visited_nodes", []),
1557
+ }