rnsr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rnsr/__init__.py +118 -0
  2. rnsr/__main__.py +242 -0
  3. rnsr/agent/__init__.py +218 -0
  4. rnsr/agent/cross_doc_navigator.py +767 -0
  5. rnsr/agent/graph.py +1557 -0
  6. rnsr/agent/llm_cache.py +575 -0
  7. rnsr/agent/navigator_api.py +497 -0
  8. rnsr/agent/provenance.py +772 -0
  9. rnsr/agent/query_clarifier.py +617 -0
  10. rnsr/agent/reasoning_memory.py +736 -0
  11. rnsr/agent/repl_env.py +709 -0
  12. rnsr/agent/rlm_navigator.py +2108 -0
  13. rnsr/agent/self_reflection.py +602 -0
  14. rnsr/agent/variable_store.py +308 -0
  15. rnsr/benchmarks/__init__.py +118 -0
  16. rnsr/benchmarks/comprehensive_benchmark.py +733 -0
  17. rnsr/benchmarks/evaluation_suite.py +1210 -0
  18. rnsr/benchmarks/finance_bench.py +147 -0
  19. rnsr/benchmarks/pdf_merger.py +178 -0
  20. rnsr/benchmarks/performance.py +321 -0
  21. rnsr/benchmarks/quality.py +321 -0
  22. rnsr/benchmarks/runner.py +298 -0
  23. rnsr/benchmarks/standard_benchmarks.py +995 -0
  24. rnsr/client.py +560 -0
  25. rnsr/document_store.py +394 -0
  26. rnsr/exceptions.py +74 -0
  27. rnsr/extraction/__init__.py +172 -0
  28. rnsr/extraction/candidate_extractor.py +357 -0
  29. rnsr/extraction/entity_extractor.py +581 -0
  30. rnsr/extraction/entity_linker.py +825 -0
  31. rnsr/extraction/grounded_extractor.py +722 -0
  32. rnsr/extraction/learned_types.py +599 -0
  33. rnsr/extraction/models.py +232 -0
  34. rnsr/extraction/relationship_extractor.py +600 -0
  35. rnsr/extraction/relationship_patterns.py +511 -0
  36. rnsr/extraction/relationship_validator.py +392 -0
  37. rnsr/extraction/rlm_extractor.py +589 -0
  38. rnsr/extraction/rlm_unified_extractor.py +990 -0
  39. rnsr/extraction/tot_validator.py +610 -0
  40. rnsr/extraction/unified_extractor.py +342 -0
  41. rnsr/indexing/__init__.py +60 -0
  42. rnsr/indexing/knowledge_graph.py +1128 -0
  43. rnsr/indexing/kv_store.py +313 -0
  44. rnsr/indexing/persistence.py +323 -0
  45. rnsr/indexing/semantic_retriever.py +237 -0
  46. rnsr/indexing/semantic_search.py +320 -0
  47. rnsr/indexing/skeleton_index.py +395 -0
  48. rnsr/ingestion/__init__.py +161 -0
  49. rnsr/ingestion/chart_parser.py +569 -0
  50. rnsr/ingestion/document_boundary.py +662 -0
  51. rnsr/ingestion/font_histogram.py +334 -0
  52. rnsr/ingestion/header_classifier.py +595 -0
  53. rnsr/ingestion/hierarchical_cluster.py +515 -0
  54. rnsr/ingestion/layout_detector.py +356 -0
  55. rnsr/ingestion/layout_model.py +379 -0
  56. rnsr/ingestion/ocr_fallback.py +177 -0
  57. rnsr/ingestion/pipeline.py +936 -0
  58. rnsr/ingestion/semantic_fallback.py +417 -0
  59. rnsr/ingestion/table_parser.py +799 -0
  60. rnsr/ingestion/text_builder.py +460 -0
  61. rnsr/ingestion/tree_builder.py +402 -0
  62. rnsr/ingestion/vision_retrieval.py +965 -0
  63. rnsr/ingestion/xy_cut.py +555 -0
  64. rnsr/llm.py +733 -0
  65. rnsr/models.py +167 -0
  66. rnsr/py.typed +2 -0
  67. rnsr-0.1.0.dist-info/METADATA +592 -0
  68. rnsr-0.1.0.dist-info/RECORD +72 -0
  69. rnsr-0.1.0.dist-info/WHEEL +5 -0
  70. rnsr-0.1.0.dist-info/entry_points.txt +2 -0
  71. rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
  72. rnsr-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2108 @@
1
+ """
2
+ RLM Navigator - Recursive Language Model Navigator with Full REPL Integration
3
+
4
+ This module implements the full RLM (Recursive Language Model) pattern from the
5
+ arxiv paper "Recursive Language Models" combined with RNSR's tree-based retrieval.
6
+
7
+ Key Features:
8
+ 1. Full REPL environment with code execution for document filtering
9
+ 2. Pre-LLM filtering using regex/keyword search before ToT evaluation
10
+ 3. Deep recursive sub-LLM calls (configurable depth)
11
+ 4. Answer verification loops
12
+ 5. Async parallel sub-LLM processing
13
+ 6. Adaptive learning for stop words and query patterns
14
+
15
+ This is the state-of-the-art combination of:
16
+ - PageIndex: Vectorless, reasoning-based tree search
17
+ - RLMs: REPL environment with recursive sub-LLM calls
18
+ - RNSR: Latent hierarchy reconstruction + variable stitching
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import asyncio
24
+ import json
25
+ import os
26
+ import re
27
+ from concurrent.futures import ThreadPoolExecutor
28
+ from dataclasses import dataclass, field
29
+ from datetime import datetime, timezone
30
+ from pathlib import Path
31
+ from threading import Lock
32
+ from typing import Any, Callable, Literal
33
+
34
+ import structlog
35
+
36
+ from rnsr.agent.variable_store import VariableStore, generate_pointer_name
37
+ from rnsr.indexing.kv_store import KVStore
38
+ from rnsr.models import SkeletonNode, TraceEntry
39
+
40
+ logger = structlog.get_logger(__name__)
41
+
42
+
43
+ # =============================================================================
44
+ # Learned Stop Words Registry
45
+ # =============================================================================
46
+
47
+ DEFAULT_STOP_WORDS_PATH = Path.home() / ".rnsr" / "learned_stop_words.json"
48
+
49
+
50
+ class LearnedStopWords:
51
+ """
52
+ Registry for learning domain-specific stop words.
53
+
54
+ Learns:
55
+ - Words that are generic in your domain (should be filtered)
56
+ - Words that seem generic but are important in your domain (should be kept)
57
+
58
+ Examples:
59
+ - Legal: "hereby", "whereas" are filler (add to stop)
60
+ - Legal: "party" is important (remove from stop)
61
+ """
62
+
63
+ # Base stop words (always included unless explicitly removed)
64
+ BASE_STOP_WORDS = {
65
+ "what", "is", "the", "a", "an", "are", "was", "were", "be", "been",
66
+ "being", "have", "has", "had", "do", "does", "did", "will", "would",
67
+ "could", "should", "may", "might", "must", "shall", "can", "need",
68
+ "dare", "ought", "used", "to", "of", "in", "for", "on", "with", "at",
69
+ "by", "from", "about", "into", "through", "during", "before", "after",
70
+ "above", "below", "between", "under", "again", "further", "then",
71
+ "once", "here", "there", "when", "where", "why", "how", "all", "each",
72
+ "few", "more", "most", "other", "some", "such", "no", "nor", "not",
73
+ "only", "own", "same", "so", "than", "too", "very", "just", "and",
74
+ "but", "if", "or", "because", "as", "until", "while", "this", "that",
75
+ "these", "those", "find", "show", "list", "describe", "explain", "tell",
76
+ }
77
+
78
+ def __init__(
79
+ self,
80
+ storage_path: Path | str | None = None,
81
+ auto_save: bool = True,
82
+ ):
83
+ """
84
+ Initialize the learned stop words registry.
85
+
86
+ Args:
87
+ storage_path: Path to JSON file for persistence.
88
+ auto_save: Whether to save after changes.
89
+ """
90
+ self.storage_path = Path(storage_path) if storage_path else DEFAULT_STOP_WORDS_PATH
91
+ self.auto_save = auto_save
92
+
93
+ self._lock = Lock()
94
+ self._added_stop_words: dict[str, dict[str, Any]] = {} # Domain-specific additions
95
+ self._removed_stop_words: dict[str, dict[str, Any]] = {} # Words to keep despite being in base
96
+ self._dirty = False
97
+
98
+ self._load()
99
+
100
+ def _load(self) -> None:
101
+ """Load learned stop words from storage."""
102
+ if not self.storage_path.exists():
103
+ return
104
+
105
+ try:
106
+ with open(self.storage_path, "r") as f:
107
+ data = json.load(f)
108
+
109
+ self._added_stop_words = data.get("added", {})
110
+ self._removed_stop_words = data.get("removed", {})
111
+
112
+ logger.info(
113
+ "learned_stop_words_loaded",
114
+ added=len(self._added_stop_words),
115
+ removed=len(self._removed_stop_words),
116
+ )
117
+
118
+ except Exception as e:
119
+ logger.warning("failed_to_load_stop_words", error=str(e))
120
+
121
+ def _save(self) -> None:
122
+ """Save to storage."""
123
+ if not self._dirty:
124
+ return
125
+
126
+ try:
127
+ self.storage_path.parent.mkdir(parents=True, exist_ok=True)
128
+
129
+ data = {
130
+ "version": "1.0",
131
+ "updated_at": datetime.utcnow().isoformat(),
132
+ "added": self._added_stop_words,
133
+ "removed": self._removed_stop_words,
134
+ }
135
+
136
+ with open(self.storage_path, "w") as f:
137
+ json.dump(data, f, indent=2)
138
+
139
+ self._dirty = False
140
+
141
+ except Exception as e:
142
+ logger.warning("failed_to_save_stop_words", error=str(e))
143
+
144
+ def add_stop_word(
145
+ self,
146
+ word: str,
147
+ domain: str = "general",
148
+ reason: str = "",
149
+ ) -> None:
150
+ """
151
+ Add a word to the stop word list.
152
+
153
+ Args:
154
+ word: Word to add.
155
+ domain: Domain category.
156
+ reason: Why this should be a stop word.
157
+ """
158
+ word = word.lower().strip()
159
+
160
+ if not word or word in self.BASE_STOP_WORDS:
161
+ return
162
+
163
+ with self._lock:
164
+ now = datetime.utcnow().isoformat()
165
+
166
+ if word not in self._added_stop_words:
167
+ self._added_stop_words[word] = {
168
+ "count": 0,
169
+ "domain": domain,
170
+ "reason": reason,
171
+ "first_seen": now,
172
+ "last_seen": now,
173
+ }
174
+ logger.info("stop_word_added", word=word)
175
+
176
+ self._added_stop_words[word]["count"] += 1
177
+ self._added_stop_words[word]["last_seen"] = now
178
+
179
+ self._dirty = True
180
+
181
+ if self.auto_save:
182
+ self._save()
183
+
184
+ def remove_stop_word(
185
+ self,
186
+ word: str,
187
+ domain: str = "general",
188
+ reason: str = "",
189
+ ) -> None:
190
+ """
191
+ Mark a base stop word as important (should not be filtered).
192
+
193
+ Args:
194
+ word: Word to keep.
195
+ domain: Domain where this is important.
196
+ reason: Why this should be kept.
197
+ """
198
+ word = word.lower().strip()
199
+
200
+ if not word or word not in self.BASE_STOP_WORDS:
201
+ return
202
+
203
+ with self._lock:
204
+ now = datetime.utcnow().isoformat()
205
+
206
+ if word not in self._removed_stop_words:
207
+ self._removed_stop_words[word] = {
208
+ "count": 0,
209
+ "domain": domain,
210
+ "reason": reason,
211
+ "first_seen": now,
212
+ "last_seen": now,
213
+ }
214
+ logger.info("stop_word_marked_important", word=word)
215
+
216
+ self._removed_stop_words[word]["count"] += 1
217
+ self._removed_stop_words[word]["last_seen"] = now
218
+
219
+ self._dirty = True
220
+
221
+ if self.auto_save:
222
+ self._save()
223
+
224
+ def get_stop_words(self, min_count: int = 1) -> set[str]:
225
+ """
226
+ Get the effective stop word set.
227
+
228
+ Returns:
229
+ Set of words to filter (base + added - removed).
230
+ """
231
+ with self._lock:
232
+ # Start with base
233
+ result = set(self.BASE_STOP_WORDS)
234
+
235
+ # Add learned additions
236
+ for word, data in self._added_stop_words.items():
237
+ if data["count"] >= min_count:
238
+ result.add(word)
239
+
240
+ # Remove marked-important words
241
+ for word, data in self._removed_stop_words.items():
242
+ if data["count"] >= min_count:
243
+ result.discard(word)
244
+
245
+ return result
246
+
247
+ def get_stats(self) -> dict[str, Any]:
248
+ """Get statistics about stop words."""
249
+ return {
250
+ "base_count": len(self.BASE_STOP_WORDS),
251
+ "added_count": len(self._added_stop_words),
252
+ "removed_count": len(self._removed_stop_words),
253
+ "effective_count": len(self.get_stop_words()),
254
+ }
255
+
256
+
257
+ # Global stop words registry
258
+ _global_stop_words: LearnedStopWords | None = None
259
+
260
+
261
+ def get_learned_stop_words() -> LearnedStopWords:
262
+ """Get the global learned stop words registry."""
263
+ global _global_stop_words
264
+
265
+ if _global_stop_words is None:
266
+ custom_path = os.getenv("RNSR_STOP_WORDS_PATH")
267
+ _global_stop_words = LearnedStopWords(
268
+ storage_path=custom_path if custom_path else None
269
+ )
270
+
271
+ return _global_stop_words
272
+
273
+
274
+ # =============================================================================
275
+ # Learned Query Patterns Registry
276
+ # =============================================================================
277
+
278
+ DEFAULT_QUERY_PATTERNS_PATH = Path.home() / ".rnsr" / "learned_query_patterns.json"
279
+
280
+
281
+ class LearnedQueryPatterns:
282
+ """
283
+ Registry for learning successful query patterns.
284
+
285
+ Tracks:
286
+ - Query patterns that lead to high-confidence answers
287
+ - Patterns that need decomposition vs. direct retrieval
288
+ - Entity-focused vs. section-focused queries
289
+
290
+ Used to:
291
+ - Inform decomposition strategy
292
+ - Adjust confidence thresholds
293
+ - Route to specialized handlers
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ storage_path: Path | str | None = None,
299
+ auto_save: bool = True,
300
+ ):
301
+ """
302
+ Initialize the query patterns registry.
303
+
304
+ Args:
305
+ storage_path: Path to JSON file for persistence.
306
+ auto_save: Whether to save after changes.
307
+ """
308
+ self.storage_path = Path(storage_path) if storage_path else DEFAULT_QUERY_PATTERNS_PATH
309
+ self.auto_save = auto_save
310
+
311
+ self._lock = Lock()
312
+ self._patterns: dict[str, dict[str, Any]] = {}
313
+ self._dirty = False
314
+
315
+ self._load()
316
+
317
+ def _load(self) -> None:
318
+ """Load learned patterns from storage."""
319
+ if not self.storage_path.exists():
320
+ return
321
+
322
+ try:
323
+ with open(self.storage_path, "r") as f:
324
+ data = json.load(f)
325
+
326
+ self._patterns = data.get("patterns", {})
327
+
328
+ logger.info(
329
+ "query_patterns_loaded",
330
+ patterns=len(self._patterns),
331
+ )
332
+
333
+ except Exception as e:
334
+ logger.warning("failed_to_load_query_patterns", error=str(e))
335
+
336
+ def _save(self) -> None:
337
+ """Save to storage."""
338
+ if not self._dirty:
339
+ return
340
+
341
+ try:
342
+ self.storage_path.parent.mkdir(parents=True, exist_ok=True)
343
+
344
+ data = {
345
+ "version": "1.0",
346
+ "updated_at": datetime.utcnow().isoformat(),
347
+ "patterns": self._patterns,
348
+ }
349
+
350
+ with open(self.storage_path, "w") as f:
351
+ json.dump(data, f, indent=2)
352
+
353
+ self._dirty = False
354
+
355
+ except Exception as e:
356
+ logger.warning("failed_to_save_query_patterns", error=str(e))
357
+
358
+ def record_query(
359
+ self,
360
+ query: str,
361
+ pattern_type: str,
362
+ success: bool,
363
+ confidence: float,
364
+ needed_decomposition: bool,
365
+ sub_questions_count: int = 0,
366
+ entities_involved: list[str] | None = None,
367
+ ) -> None:
368
+ """
369
+ Record a query and its outcome.
370
+
371
+ Args:
372
+ query: The original query.
373
+ pattern_type: Detected pattern type (entity_lookup, comparison, etc.)
374
+ success: Whether the query was answered successfully.
375
+ confidence: Answer confidence score.
376
+ needed_decomposition: Whether decomposition was required.
377
+ sub_questions_count: Number of sub-questions generated.
378
+ entities_involved: Entity types involved in the query.
379
+ """
380
+ pattern_type = pattern_type.lower().strip()
381
+
382
+ with self._lock:
383
+ now = datetime.utcnow().isoformat()
384
+
385
+ if pattern_type not in self._patterns:
386
+ self._patterns[pattern_type] = {
387
+ "total_queries": 0,
388
+ "successful_queries": 0,
389
+ "total_confidence": 0.0,
390
+ "decomposition_count": 0,
391
+ "total_sub_questions": 0,
392
+ "entity_types": {},
393
+ "first_seen": now,
394
+ "last_seen": now,
395
+ "example_queries": [],
396
+ }
397
+ logger.info("new_query_pattern_discovered", pattern_type=pattern_type)
398
+
399
+ pt = self._patterns[pattern_type]
400
+ pt["total_queries"] += 1
401
+ pt["total_confidence"] += confidence
402
+ pt["last_seen"] = now
403
+
404
+ if success:
405
+ pt["successful_queries"] += 1
406
+
407
+ if needed_decomposition:
408
+ pt["decomposition_count"] += 1
409
+ pt["total_sub_questions"] += sub_questions_count
410
+
411
+ if entities_involved:
412
+ for entity_type in entities_involved:
413
+ pt["entity_types"][entity_type] = pt["entity_types"].get(entity_type, 0) + 1
414
+
415
+ if len(pt["example_queries"]) < 5:
416
+ pt["example_queries"].append({
417
+ "query": query[:200],
418
+ "success": success,
419
+ "confidence": confidence,
420
+ "timestamp": now,
421
+ })
422
+
423
+ self._dirty = True
424
+
425
+ if self.auto_save:
426
+ self._save()
427
+
428
+ def detect_pattern_type(self, query: str) -> str:
429
+ """
430
+ Detect the pattern type of a query.
431
+
432
+ Args:
433
+ query: The query to analyze.
434
+
435
+ Returns:
436
+ Detected pattern type.
437
+ """
438
+ query_lower = query.lower()
439
+
440
+ # Pattern detection heuristics
441
+ if any(word in query_lower for word in ["compare", "difference", "versus", "vs"]):
442
+ return "comparison"
443
+
444
+ if any(word in query_lower for word in ["list", "all", "every", "enumerate"]):
445
+ return "enumeration"
446
+
447
+ if any(word in query_lower for word in ["when", "date", "time", "timeline"]):
448
+ return "temporal"
449
+
450
+ if any(word in query_lower for word in ["who", "person", "name"]):
451
+ return "entity_person"
452
+
453
+ if any(word in query_lower for word in ["company", "organization", "entity"]):
454
+ return "entity_organization"
455
+
456
+ if any(word in query_lower for word in ["how much", "amount", "price", "cost", "$"]):
457
+ return "monetary"
458
+
459
+ if any(word in query_lower for word in ["section", "clause", "paragraph", "article"]):
460
+ return "section_lookup"
461
+
462
+ if any(word in query_lower for word in ["what is", "define", "explain", "describe"]):
463
+ return "definition"
464
+
465
+ if any(word in query_lower for word in ["why", "reason", "cause"]):
466
+ return "causal"
467
+
468
+ return "general"
469
+
470
+ def get_pattern_stats(self, pattern_type: str) -> dict[str, Any] | None:
471
+ """
472
+ Get statistics for a pattern type.
473
+
474
+ Args:
475
+ pattern_type: The pattern type to look up.
476
+
477
+ Returns:
478
+ Pattern statistics or None if not found.
479
+ """
480
+ pattern_type = pattern_type.lower().strip()
481
+
482
+ with self._lock:
483
+ if pattern_type not in self._patterns:
484
+ return None
485
+
486
+ pt = self._patterns[pattern_type]
487
+ total = pt["total_queries"]
488
+
489
+ return {
490
+ "pattern_type": pattern_type,
491
+ "total_queries": total,
492
+ "success_rate": pt["successful_queries"] / total if total > 0 else 0,
493
+ "avg_confidence": pt["total_confidence"] / total if total > 0 else 0,
494
+ "decomposition_rate": pt["decomposition_count"] / total if total > 0 else 0,
495
+ "avg_sub_questions": pt["total_sub_questions"] / pt["decomposition_count"] if pt["decomposition_count"] > 0 else 0,
496
+ "top_entity_types": sorted(
497
+ pt["entity_types"].items(),
498
+ key=lambda x: -x[1]
499
+ )[:5],
500
+ }
501
+
502
+ def should_decompose(self, pattern_type: str) -> bool:
503
+ """
504
+ Determine if a pattern type typically needs decomposition.
505
+
506
+ Args:
507
+ pattern_type: The pattern type.
508
+
509
+ Returns:
510
+ True if decomposition is recommended.
511
+ """
512
+ stats = self.get_pattern_stats(pattern_type)
513
+
514
+ if not stats:
515
+ # Default recommendations for unknown patterns
516
+ always_decompose = {"comparison", "enumeration", "temporal"}
517
+ return pattern_type.lower() in always_decompose
518
+
519
+ # Recommend decomposition if historically needed > 50% of the time
520
+ return stats["decomposition_rate"] > 0.5
521
+
522
+ def get_confidence_threshold(self, pattern_type: str) -> float:
523
+ """
524
+ Get recommended confidence threshold for a pattern type.
525
+
526
+ Args:
527
+ pattern_type: The pattern type.
528
+
529
+ Returns:
530
+ Recommended confidence threshold.
531
+ """
532
+ stats = self.get_pattern_stats(pattern_type)
533
+
534
+ if not stats or stats["total_queries"] < 5:
535
+ return 0.7 # Default threshold
536
+
537
+ # Use average confidence minus one standard deviation as threshold
538
+ avg_conf = stats["avg_confidence"]
539
+ return max(0.5, min(0.9, avg_conf - 0.1))
540
+
541
+ def get_all_patterns(self) -> list[dict[str, Any]]:
542
+ """Get statistics for all known patterns."""
543
+ results = []
544
+
545
+ with self._lock:
546
+ for pattern_type in self._patterns:
547
+ stats = self.get_pattern_stats(pattern_type)
548
+ if stats:
549
+ results.append(stats)
550
+
551
+ return sorted(results, key=lambda x: -x["total_queries"])
552
+
553
+
554
+ # Global query patterns registry
555
+ _global_query_patterns: LearnedQueryPatterns | None = None
556
+
557
+
558
+ def get_learned_query_patterns() -> LearnedQueryPatterns:
559
+ """Get the global learned query patterns registry."""
560
+ global _global_query_patterns
561
+
562
+ if _global_query_patterns is None:
563
+ custom_path = os.getenv("RNSR_QUERY_PATTERNS_PATH")
564
+ _global_query_patterns = LearnedQueryPatterns(
565
+ storage_path=custom_path if custom_path else None
566
+ )
567
+
568
+ return _global_query_patterns
569
+
570
+
571
+ # =============================================================================
572
+ # RLM Configuration
573
+ # =============================================================================
574
+
575
+
576
+ @dataclass
577
+ class RLMConfig:
578
+ """Configuration for the RLM Navigator."""
579
+
580
+ # Recursion control
581
+ max_recursion_depth: int = 3 # Max depth for recursive sub-LLM calls
582
+ max_iterations: int = 30 # Max navigation iterations
583
+
584
+ # Tree of Thoughts parameters
585
+ top_k: int = 3 # Base children to explore
586
+ selection_threshold: float = 0.4 # Min probability for selection
587
+ dead_end_threshold: float = 0.1 # Threshold for dead end
588
+
589
+ # Pre-filtering
590
+ enable_pre_filtering: bool = True # Use regex/keyword filtering before ToT
591
+ pre_filter_min_matches: int = 1 # Min keyword matches to include node
592
+
593
+ # REPL execution
594
+ enable_code_execution: bool = True # Allow LLM to write/execute code
595
+ max_code_execution_time: int = 30 # Seconds
596
+
597
+ # Answer verification
598
+ enable_verification: bool = True # Verify answers with sub-LLM
599
+ verification_retries: int = 2 # Max verification attempts
600
+
601
+ # Async processing
602
+ enable_async: bool = True # Use async for parallel sub-LLM calls
603
+ max_concurrent_calls: int = 5 # Max parallel LLM calls
604
+
605
+ # Vision mode
606
+ enable_vision: bool = False # Use vision LLM for page images
607
+ vision_model: str = "gemini-2.5-flash" # Vision model to use
608
+
609
+
610
+ # =============================================================================
611
+ # Pre-Filtering Engine (Before ToT Evaluation)
612
+ # =============================================================================
613
+
614
+
615
+ class PreFilterEngine:
616
+ """
617
+ Pre-filters nodes before expensive ToT LLM evaluation.
618
+
619
+ Implements the key RLM insight: use code (regex, keywords) to filter
620
+ before sending to LLM. This dramatically reduces LLM calls.
621
+
622
+ Uses adaptive stop words that learn from domain-specific usage.
623
+
624
+ Example:
625
+ # Query: "What is the liability clause?"
626
+ # Instead of evaluating all 50 children with LLM:
627
+ # 1. Extract keywords: ["liability", "clause", "indemnification"]
628
+ # 2. Regex search children summaries
629
+ # 3. Only send matching children to ToT evaluation
630
+ """
631
+
632
+ def __init__(self, config: RLMConfig, enable_stop_word_learning: bool = True):
633
+ self.config = config
634
+ self._keyword_cache: dict[str, list[str]] = {}
635
+ self._stop_word_registry = get_learned_stop_words() if enable_stop_word_learning else None
636
+
637
+ def extract_keywords(self, query: str) -> list[str]:
638
+ """Extract searchable keywords from a query."""
639
+ if query in self._keyword_cache:
640
+ return self._keyword_cache[query]
641
+
642
+ # Get stop words (base + learned)
643
+ if self._stop_word_registry:
644
+ stop_words = self._stop_word_registry.get_stop_words()
645
+ else:
646
+ stop_words = LearnedStopWords.BASE_STOP_WORDS
647
+
648
+ # Tokenize and filter
649
+ words = re.findall(r'\b[a-zA-Z]{3,}\b', query.lower())
650
+ keywords = [w for w in words if w not in stop_words]
651
+
652
+ # Add quoted phrases as single keywords
653
+ quoted = re.findall(r'"([^"]+)"', query)
654
+ keywords.extend(quoted)
655
+
656
+ # Add capitalized words (likely proper nouns)
657
+ proper_nouns = re.findall(r'\b[A-Z][a-z]+\b', query)
658
+ keywords.extend([pn.lower() for pn in proper_nouns])
659
+
660
+ # Deduplicate while preserving order
661
+ seen = set()
662
+ unique_keywords = []
663
+ for kw in keywords:
664
+ if kw not in seen:
665
+ seen.add(kw)
666
+ unique_keywords.append(kw)
667
+
668
+ self._keyword_cache[query] = unique_keywords
669
+ logger.debug("keywords_extracted", query=query[:50], keywords=unique_keywords)
670
+ return unique_keywords
671
+
672
+ def filter_nodes_by_keywords(
673
+ self,
674
+ nodes: list[SkeletonNode],
675
+ keywords: list[str],
676
+ min_matches: int | None = None,
677
+ ) -> tuple[list[SkeletonNode], list[SkeletonNode]]:
678
+ """
679
+ Filter nodes by keyword matching.
680
+
681
+ Returns:
682
+ Tuple of (matching_nodes, remaining_nodes)
683
+ """
684
+ if not self.config.enable_pre_filtering:
685
+ return nodes, []
686
+
687
+ if not keywords:
688
+ return nodes, []
689
+
690
+ min_matches = min_matches or self.config.pre_filter_min_matches
691
+
692
+ matching = []
693
+ remaining = []
694
+
695
+ for node in nodes:
696
+ # Search in header and summary
697
+ search_text = f"{node.header} {node.summary}".lower()
698
+
699
+ matches = sum(1 for kw in keywords if kw in search_text)
700
+
701
+ if matches >= min_matches:
702
+ matching.append(node)
703
+ else:
704
+ remaining.append(node)
705
+
706
+ logger.debug(
707
+ "pre_filter_complete",
708
+ total=len(nodes),
709
+ matching=len(matching),
710
+ remaining=len(remaining),
711
+ keywords=keywords[:5],
712
+ )
713
+
714
+ return matching, remaining
715
+
716
+ def regex_search_nodes(
717
+ self,
718
+ nodes: list[SkeletonNode],
719
+ pattern: str,
720
+ ) -> list[tuple[SkeletonNode, list[str]]]:
721
+ """
722
+ Search nodes using regex pattern.
723
+
724
+ Returns:
725
+ List of (node, matches) tuples.
726
+ """
727
+ results = []
728
+
729
+ try:
730
+ regex = re.compile(pattern, re.IGNORECASE)
731
+ except re.error as e:
732
+ logger.warning("invalid_regex_pattern", pattern=pattern, error=str(e))
733
+ return results
734
+
735
+ for node in nodes:
736
+ search_text = f"{node.header}\n{node.summary}"
737
+ matches = regex.findall(search_text)
738
+ if matches:
739
+ results.append((node, matches))
740
+
741
+ return results
742
+
743
+
744
+ # =============================================================================
745
+ # Deep Recursive Sub-LLM Engine
746
+ # =============================================================================
747
+
748
+
749
+ class RecursiveSubLLMEngine:
750
+ """
751
+ Enables true multi-level recursive sub-LLM calls.
752
+
753
+ Unlike single-level decomposition, this allows sub-LLMs to spawn
754
+ their own sub-LLMs up to a configurable depth.
755
+
756
+ Example:
757
+ Query: "Compare the liability clauses in 2023 vs 2024 contracts"
758
+
759
+ Depth 0 (Root): Decompose into sub-tasks
760
+ ├── Depth 1: "Find 2023 liability clause"
761
+ │ └── Depth 2: "Extract specific terms"
762
+ └── Depth 1: "Find 2024 liability clause"
763
+ └── Depth 2: "Extract specific terms"
764
+ """
765
+
766
+ def __init__(
767
+ self,
768
+ config: RLMConfig,
769
+ llm_fn: Callable[[str], str] | None = None,
770
+ ):
771
+ self.config = config
772
+ self._llm_fn = llm_fn
773
+ self._call_count = 0
774
+ self._depth_stats: dict[int, int] = {}
775
+
776
+ def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
777
+ """Set the LLM function for sub-calls."""
778
+ self._llm_fn = llm_fn
779
+
780
+ def recursive_call(
781
+ self,
782
+ prompt: str,
783
+ context: str,
784
+ depth: int = 0,
785
+ allow_sub_calls: bool = True,
786
+ ) -> str:
787
+ """
788
+ Execute a recursive LLM call.
789
+
790
+ Args:
791
+ prompt: The task/question for the LLM.
792
+ context: Context to process.
793
+ depth: Current recursion depth.
794
+ allow_sub_calls: Whether this call can spawn sub-calls.
795
+
796
+ Returns:
797
+ LLM response.
798
+ """
799
+ if self._llm_fn is None:
800
+ return "[ERROR: LLM function not configured]"
801
+
802
+ if depth >= self.config.max_recursion_depth:
803
+ allow_sub_calls = False
804
+ logger.debug("max_recursion_depth_reached", depth=depth)
805
+
806
+ # Track stats
807
+ self._call_count += 1
808
+ self._depth_stats[depth] = self._depth_stats.get(depth, 0) + 1
809
+
810
+ # Build the prompt with recursion capability
811
+ if allow_sub_calls:
812
+ system_instruction = f"""You are a sub-LLM at recursion depth {depth}.
813
+ You can decompose complex tasks into sub-tasks.
814
+ If you need to process multiple items independently, list them as:
815
+ SUB_TASK[1]: <task description>
816
+ SUB_TASK[2]: <task description>
817
+ ...
818
+ These will be processed by sub-LLMs and results aggregated.
819
+ """
820
+ else:
821
+ system_instruction = f"""You are a sub-LLM at max recursion depth {depth}.
822
+ Provide a direct answer without further decomposition."""
823
+
824
+ full_prompt = f"""{system_instruction}
825
+
826
+ Task: {prompt}
827
+
828
+ Context:
829
+ {context}
830
+
831
+ Response:"""
832
+
833
+ try:
834
+ response = self._llm_fn(full_prompt)
835
+
836
+ # Check for sub-task declarations and process them
837
+ if allow_sub_calls and "SUB_TASK[" in response:
838
+ response = self._process_sub_tasks(response, depth + 1)
839
+
840
+ return response
841
+
842
+ except Exception as e:
843
+ logger.error("recursive_call_failed", depth=depth, error=str(e))
844
+ return f"[ERROR: {str(e)}]"
845
+
846
+ def _process_sub_tasks(self, response: str, depth: int) -> str:
847
+ """Process SUB_TASK declarations in the response."""
848
+ # Extract sub-tasks
849
+ sub_tasks = re.findall(r'SUB_TASK\[(\d+)\]:\s*(.+?)(?=SUB_TASK\[|$)', response, re.DOTALL)
850
+
851
+ if not sub_tasks:
852
+ return response
853
+
854
+ logger.debug("processing_sub_tasks", count=len(sub_tasks), depth=depth)
855
+
856
+ # Process each sub-task recursively
857
+ results = []
858
+ for idx, (task_num, task_desc) in enumerate(sub_tasks):
859
+ result = self.recursive_call(
860
+ prompt=task_desc.strip(),
861
+ context="(inherited from parent)",
862
+ depth=depth,
863
+ allow_sub_calls=(depth < self.config.max_recursion_depth),
864
+ )
865
+ results.append(f"Result[{task_num}]: {result}")
866
+
867
+ # Synthesize results
868
+ synthesis_prompt = f"""Synthesize the following sub-task results into a coherent answer:
869
+
870
+ {chr(10).join(results)}
871
+
872
+ Original task: {response.split('SUB_TASK[')[0].strip()}
873
+
874
+ Synthesized answer:"""
875
+
876
+ return self._llm_fn(synthesis_prompt) if self._llm_fn else "\n".join(results)
877
+
878
+ async def async_recursive_call(
879
+ self,
880
+ prompt: str,
881
+ context: str,
882
+ depth: int = 0,
883
+ ) -> str:
884
+ """Async version of recursive_call for parallel processing."""
885
+ # Run in thread pool to not block
886
+ loop = asyncio.get_event_loop()
887
+ return await loop.run_in_executor(
888
+ None,
889
+ lambda: self.recursive_call(prompt, context, depth),
890
+ )
891
+
892
+ def batch_recursive_calls(
893
+ self,
894
+ prompts: list[str],
895
+ contexts: list[str],
896
+ depth: int = 0,
897
+ ) -> list[str]:
898
+ """
899
+ Execute multiple recursive calls in parallel.
900
+
901
+ Uses ThreadPoolExecutor for parallel processing.
902
+ """
903
+ if len(prompts) != len(contexts):
904
+ raise ValueError("prompts and contexts must have same length")
905
+
906
+ if not prompts:
907
+ return []
908
+
909
+ results: list[str] = [""] * len(prompts)
910
+
911
+ with ThreadPoolExecutor(max_workers=self.config.max_concurrent_calls) as executor:
912
+ futures = {}
913
+ for idx, (prompt, context) in enumerate(zip(prompts, contexts)):
914
+ future = executor.submit(
915
+ self.recursive_call,
916
+ prompt,
917
+ context,
918
+ depth,
919
+ )
920
+ futures[future] = idx
921
+
922
+ for future in futures:
923
+ idx = futures[future]
924
+ try:
925
+ results[idx] = future.result(timeout=60)
926
+ except Exception as e:
927
+ results[idx] = f"[ERROR: {str(e)}]"
928
+
929
+ return results
930
+
931
+ def get_stats(self) -> dict[str, Any]:
932
+ """Get call statistics."""
933
+ return {
934
+ "total_calls": self._call_count,
935
+ "calls_by_depth": dict(self._depth_stats),
936
+ }
937
+
938
+
939
+ # =============================================================================
940
+ # Answer Verification Engine
941
+ # =============================================================================
942
+
943
+
944
+ class AnswerVerificationEngine:
945
+ """
946
+ Verifies answers using sub-LLM calls.
947
+
948
+ Implements the RLM pattern of using sub-LLMs to verify answers
949
+ before returning, ensuring higher accuracy.
950
+ """
951
+
952
+ def __init__(
953
+ self,
954
+ config: RLMConfig,
955
+ llm_fn: Callable[[str], str] | None = None,
956
+ ):
957
+ self.config = config
958
+ self._llm_fn = llm_fn
959
+
960
+ def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
961
+ """Set the LLM function."""
962
+ self._llm_fn = llm_fn
963
+
964
+ def verify_answer(
965
+ self,
966
+ question: str,
967
+ proposed_answer: str,
968
+ evidence: list[str],
969
+ attempt: int = 0,
970
+ ) -> dict[str, Any]:
971
+ """
972
+ Verify an answer using sub-LLM evaluation.
973
+
974
+ Returns:
975
+ Dict with 'is_valid', 'confidence', 'issues', 'improved_answer'.
976
+ """
977
+ if not self.config.enable_verification:
978
+ return {
979
+ "is_valid": True,
980
+ "confidence": 0.7,
981
+ "issues": [],
982
+ "improved_answer": proposed_answer,
983
+ }
984
+
985
+ if self._llm_fn is None:
986
+ return {
987
+ "is_valid": True,
988
+ "confidence": 0.5,
989
+ "issues": ["LLM not configured for verification"],
990
+ "improved_answer": proposed_answer,
991
+ }
992
+
993
+ evidence_text = "\n---\n".join(evidence) if evidence else "(no evidence provided)"
994
+
995
+ verification_prompt = f"""You are a fact-checker verifying an answer.
996
+
997
+ Question: {question}
998
+
999
+ Proposed Answer: {proposed_answer}
1000
+
1001
+ Evidence:
1002
+ {evidence_text}
1003
+
1004
+ VERIFICATION TASK:
1005
+ 1. Check if the answer is supported by the evidence
1006
+ 2. Check if the answer directly addresses the question
1007
+ 3. Check for any factual errors or unsupported claims
1008
+
1009
+ OUTPUT FORMAT (JSON):
1010
+ {{
1011
+ "is_valid": true/false,
1012
+ "confidence": 0.0-1.0,
1013
+ "issues": ["list of issues if any"],
1014
+ "improved_answer": "corrected answer if needed, or null if valid"
1015
+ }}
1016
+
1017
+ Respond ONLY with JSON:"""
1018
+
1019
+ try:
1020
+ import json
1021
+
1022
+ response = self._llm_fn(verification_prompt)
1023
+
1024
+ # Parse JSON response
1025
+ json_match = re.search(r'\{[\s\S]*\}', response)
1026
+ if json_match:
1027
+ result = json.loads(json_match.group())
1028
+
1029
+ # If not valid and we have retries left, try to improve
1030
+ if not result.get("is_valid", True) and attempt < self.config.verification_retries:
1031
+ logger.debug(
1032
+ "answer_verification_failed",
1033
+ attempt=attempt,
1034
+ issues=result.get("issues", []),
1035
+ )
1036
+
1037
+ # Try with improved answer
1038
+ if result.get("improved_answer"):
1039
+ return self.verify_answer(
1040
+ question,
1041
+ result["improved_answer"],
1042
+ evidence,
1043
+ attempt + 1,
1044
+ )
1045
+
1046
+ return result
1047
+ else:
1048
+ logger.warning("verification_json_parse_failed", response=response[:200])
1049
+ return {
1050
+ "is_valid": True,
1051
+ "confidence": 0.5,
1052
+ "issues": ["Could not parse verification response"],
1053
+ "improved_answer": proposed_answer,
1054
+ }
1055
+
1056
+ except Exception as e:
1057
+ logger.error("verification_failed", error=str(e))
1058
+ return {
1059
+ "is_valid": True,
1060
+ "confidence": 0.3,
1061
+ "issues": [str(e)],
1062
+ "improved_answer": proposed_answer,
1063
+ }
1064
+
1065
+
1066
+ # =============================================================================
1067
+ # Enhanced RLM Navigator Agent State
1068
+ # =============================================================================
1069
+
1070
+
1071
+ class RLMAgentState:
1072
+ """
1073
+ State for the RLM Navigator Agent.
1074
+
1075
+ Extends the base AgentState with RLM-specific fields.
1076
+ """
1077
+
1078
+ def __init__(
1079
+ self,
1080
+ question: str,
1081
+ root_node_id: str,
1082
+ config: RLMConfig | None = None,
1083
+ metadata: dict[str, Any] | None = None,
1084
+ ):
1085
+ self.question = question
1086
+ self.config = config or RLMConfig()
1087
+ self.metadata = metadata or {}
1088
+
1089
+ # Navigation state
1090
+ self.current_node_id: str | None = root_node_id
1091
+ self.visited_nodes: list[str] = []
1092
+ self.navigation_path: list[str] = [root_node_id]
1093
+ self.nodes_to_visit: list[str] = []
1094
+ self.dead_ends: list[str] = []
1095
+ self.backtrack_stack: list[str] = []
1096
+
1097
+ # Variable stitching
1098
+ self.variables: list[str] = []
1099
+ self.context: str = ""
1100
+
1101
+ # Sub-questions (RLM decomposition)
1102
+ self.sub_questions: list[str] = []
1103
+ self.pending_questions: list[str] = []
1104
+ self.current_sub_question: str | None = None
1105
+
1106
+ # Pre-filtering state
1107
+ self.extracted_keywords: list[str] = []
1108
+ self.pre_filtered_nodes: dict[str, list[str]] = {} # node_id -> matched keywords
1109
+
1110
+ # Recursion tracking
1111
+ self.current_recursion_depth: int = 0
1112
+ self.recursion_call_count: int = 0
1113
+
1114
+ # Output
1115
+ self.answer: str | None = None
1116
+ self.confidence: float = 0.0
1117
+ self.verification_result: dict[str, Any] | None = None
1118
+
1119
+ # Traceability
1120
+ self.trace: list[dict[str, Any]] = []
1121
+ self.iteration: int = 0
1122
+
1123
+ def add_trace(
1124
+ self,
1125
+ node_type: str,
1126
+ action: str,
1127
+ details: dict | None = None,
1128
+ ) -> None:
1129
+ """Add a trace entry."""
1130
+ entry = {
1131
+ "timestamp": datetime.now(timezone.utc).isoformat(),
1132
+ "node_type": node_type,
1133
+ "action": action,
1134
+ "details": details or {},
1135
+ "iteration": self.iteration,
1136
+ }
1137
+ self.trace.append(entry)
1138
+
1139
+ def to_dict(self) -> dict[str, Any]:
1140
+ """Convert state to dictionary."""
1141
+ return {
1142
+ "question": self.question,
1143
+ "answer": self.answer,
1144
+ "confidence": self.confidence,
1145
+ "variables": self.variables,
1146
+ "visited_nodes": self.visited_nodes,
1147
+ "iteration": self.iteration,
1148
+ "recursion_call_count": self.recursion_call_count,
1149
+ "verification_result": self.verification_result,
1150
+ "trace": self.trace,
1151
+ }
1152
+
1153
+
1154
+ # =============================================================================
1155
+ # RLM Navigator - Main Class
1156
+ # =============================================================================
1157
+
1158
+
1159
+ class RLMNavigator:
1160
+ """
1161
+ The RLM Navigator combines:
1162
+ 1. PageIndex-style tree search with reasoning
1163
+ 2. RLM-style REPL environment with code execution
1164
+ 3. RNSR-style variable stitching and skeleton indexing
1165
+ 4. Entity-aware query decomposition (when knowledge graph available)
1166
+
1167
+ This is the unified, state-of-the-art document retrieval agent.
1168
+ """
1169
+
1170
+ def __init__(
1171
+ self,
1172
+ skeleton: dict[str, SkeletonNode],
1173
+ kv_store: KVStore,
1174
+ config: RLMConfig | None = None,
1175
+ knowledge_graph=None,
1176
+ ):
1177
+ self.skeleton = skeleton
1178
+ self.kv_store = kv_store
1179
+ self.config = config or RLMConfig()
1180
+ self.knowledge_graph = knowledge_graph
1181
+
1182
+ # Initialize components
1183
+ self.variable_store = VariableStore()
1184
+ self.pre_filter = PreFilterEngine(self.config)
1185
+ self.recursive_engine = RecursiveSubLLMEngine(self.config)
1186
+ self.verification_engine = AnswerVerificationEngine(self.config)
1187
+ self.entity_decomposer = EntityAwareDecomposer(knowledge_graph)
1188
+
1189
+ # LLM function
1190
+ self._llm_fn: Callable[[str], str] | None = None
1191
+
1192
+ # Find root node
1193
+ self.root_id = self._find_root_id()
1194
+
1195
+ def _find_root_id(self) -> str:
1196
+ """Find the root node ID."""
1197
+ for node in self.skeleton.values():
1198
+ if node.level == 0:
1199
+ return node.node_id
1200
+ raise ValueError("No root node found in skeleton")
1201
+
1202
+ def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
1203
+ """Configure the LLM function for all components."""
1204
+ self._llm_fn = llm_fn
1205
+ self.recursive_engine.set_llm_function(llm_fn)
1206
+ self.verification_engine.set_llm_function(llm_fn)
1207
+ self.entity_decomposer.set_llm_function(llm_fn)
1208
+
1209
+ def set_knowledge_graph(self, kg) -> None:
1210
+ """Set the knowledge graph for entity-aware decomposition."""
1211
+ self.knowledge_graph = kg
1212
+ self.entity_decomposer.set_knowledge_graph(kg)
1213
+
1214
+ def navigate(
1215
+ self,
1216
+ question: str,
1217
+ metadata: dict[str, Any] | None = None,
1218
+ ) -> dict[str, Any]:
1219
+ """
1220
+ Navigate the document tree to answer a question.
1221
+
1222
+ This is the main entry point for the RLM Navigator.
1223
+
1224
+ Args:
1225
+ question: The user's question.
1226
+ metadata: Optional metadata (e.g., multiple choice options).
1227
+
1228
+ Returns:
1229
+ Dict with answer, confidence, trace, etc.
1230
+ """
1231
+ # Initialize state
1232
+ state = RLMAgentState(
1233
+ question=question,
1234
+ root_node_id=self.root_id,
1235
+ config=self.config,
1236
+ metadata=metadata,
1237
+ )
1238
+
1239
+ # Ensure LLM is configured
1240
+ if self._llm_fn is None:
1241
+ self._configure_default_llm()
1242
+
1243
+ logger.info("rlm_navigation_started", question=question[:100])
1244
+
1245
+ try:
1246
+ # Phase 1: Pre-filtering with keyword extraction
1247
+ state = self._phase_pre_filter(state)
1248
+
1249
+ # Phase 2: Query decomposition
1250
+ state = self._phase_decompose(state)
1251
+
1252
+ # Phase 3: Tree navigation with ToT
1253
+ state = self._phase_navigate(state)
1254
+
1255
+ # Phase 4: Synthesis
1256
+ state = self._phase_synthesize(state)
1257
+
1258
+ # Phase 5: Verification (if enabled)
1259
+ if self.config.enable_verification:
1260
+ state = self._phase_verify(state)
1261
+
1262
+ logger.info(
1263
+ "rlm_navigation_complete",
1264
+ confidence=state.confidence,
1265
+ variables=len(state.variables),
1266
+ iterations=state.iteration,
1267
+ )
1268
+
1269
+ return state.to_dict()
1270
+
1271
+ except Exception as e:
1272
+ logger.error("rlm_navigation_failed", error=str(e))
1273
+ state.answer = f"Error during navigation: {str(e)}"
1274
+ state.confidence = 0.0
1275
+ return state.to_dict()
1276
+
1277
+ def _configure_default_llm(self) -> None:
1278
+ """Configure the default LLM if none set."""
1279
+ try:
1280
+ from rnsr.llm import get_llm
1281
+ llm = get_llm()
1282
+ self.set_llm_function(lambda p: str(llm.complete(p)))
1283
+ except Exception as e:
1284
+ logger.warning("default_llm_config_failed", error=str(e))
1285
+
1286
+ def _phase_pre_filter(self, state: RLMAgentState) -> RLMAgentState:
1287
+ """Phase 1: Extract keywords and pre-filter nodes."""
1288
+ state.add_trace("pre_filter", "Extracting keywords from query")
1289
+
1290
+ # Extract keywords
1291
+ keywords = self.pre_filter.extract_keywords(state.question)
1292
+ state.extracted_keywords = keywords
1293
+
1294
+ if not keywords:
1295
+ state.add_trace("pre_filter", "No keywords extracted, skipping pre-filter")
1296
+ return state
1297
+
1298
+ # Pre-filter all leaf nodes
1299
+ all_nodes = list(self.skeleton.values())
1300
+ matching, remaining = self.pre_filter.filter_nodes_by_keywords(all_nodes, keywords)
1301
+
1302
+ # Store which nodes matched which keywords
1303
+ for node in matching:
1304
+ search_text = f"{node.header} {node.summary}".lower()
1305
+ matched_keywords = [kw for kw in keywords if kw in search_text]
1306
+ state.pre_filtered_nodes[node.node_id] = matched_keywords
1307
+
1308
+ state.add_trace(
1309
+ "pre_filter",
1310
+ f"Pre-filtered {len(matching)}/{len(all_nodes)} nodes",
1311
+ {"keywords": keywords, "matching_nodes": len(matching)},
1312
+ )
1313
+
1314
+ return state
1315
+
1316
+ def _phase_decompose(self, state: RLMAgentState) -> RLMAgentState:
1317
+ """Phase 2: Decompose query into sub-questions with entity awareness."""
1318
+ state.add_trace("decomposition", "Analyzing query for decomposition")
1319
+
1320
+ if self._llm_fn is None:
1321
+ state.sub_questions = [state.question]
1322
+ state.pending_questions = [state.question]
1323
+ return state
1324
+
1325
+ # Try entity-aware decomposition first if knowledge graph is available
1326
+ if self.knowledge_graph:
1327
+ try:
1328
+ entity_result = self.entity_decomposer.decompose_with_entities(
1329
+ state.question
1330
+ )
1331
+
1332
+ if entity_result.get("entities_found"):
1333
+ # Store entity information in state
1334
+ state.metadata["entities_found"] = entity_result.get("entities_found", [])
1335
+ state.metadata["entity_nodes"] = entity_result.get("entity_nodes", {})
1336
+ state.metadata["retrieval_plan"] = entity_result.get("retrieval_plan", [])
1337
+ state.metadata["relationships"] = entity_result.get("relationships", [])
1338
+
1339
+ sub_tasks = entity_result.get("sub_queries", [state.question])
1340
+ state.sub_questions = sub_tasks
1341
+ state.pending_questions = sub_tasks.copy()
1342
+ state.current_sub_question = sub_tasks[0] if sub_tasks else state.question
1343
+
1344
+ # Prioritize nodes from retrieval plan in pre-filtering
1345
+ for item in entity_result.get("retrieval_plan", []):
1346
+ node_id = item.get("node_id")
1347
+ if node_id and node_id not in state.pre_filtered_nodes:
1348
+ state.pre_filtered_nodes[node_id] = ["entity_match"]
1349
+
1350
+ state.add_trace(
1351
+ "decomposition",
1352
+ f"Entity-aware decomposition: {len(sub_tasks)} sub-tasks, {len(entity_result.get('entities_found', []))} entities",
1353
+ {
1354
+ "sub_tasks": sub_tasks,
1355
+ "entities": [e.canonical_name for e in entity_result.get("entities_found", [])],
1356
+ },
1357
+ )
1358
+
1359
+ return state
1360
+
1361
+ except Exception as e:
1362
+ logger.debug("entity_aware_decomposition_failed", error=str(e))
1363
+ # Fall through to standard decomposition
1364
+
1365
+ # Standard LLM decomposition
1366
+ decomposition_prompt = f"""Analyze this query and decompose it into specific sub-tasks.
1367
+
1368
+ Query: {state.question}
1369
+
1370
+ Available document sections (pre-filtered matches):
1371
+ {chr(10).join(f"- {self.skeleton[nid].header}" for nid in list(state.pre_filtered_nodes.keys())[:10])}
1372
+
1373
+ RULES:
1374
+ 1. Each sub-task should target a specific piece of information
1375
+ 2. For comparison queries, create one sub-task per item
1376
+ 3. Maximum 5 sub-tasks
1377
+ 4. If the query is simple, return just one sub-task
1378
+
1379
+ OUTPUT FORMAT (JSON):
1380
+ {{
1381
+ "sub_tasks": ["task1", "task2", ...],
1382
+ "synthesis_plan": "how to combine results"
1383
+ }}
1384
+
1385
+ Respond with JSON only:"""
1386
+
1387
+ try:
1388
+ import json
1389
+
1390
+ response = self._llm_fn(decomposition_prompt)
1391
+ json_match = re.search(r'\{[\s\S]*\}', response)
1392
+
1393
+ if json_match:
1394
+ result = json.loads(json_match.group())
1395
+ sub_tasks = result.get("sub_tasks", [state.question])
1396
+ state.sub_questions = sub_tasks
1397
+ state.pending_questions = sub_tasks.copy()
1398
+ state.current_sub_question = sub_tasks[0] if sub_tasks else state.question
1399
+
1400
+ state.add_trace(
1401
+ "decomposition",
1402
+ f"Decomposed into {len(sub_tasks)} sub-tasks",
1403
+ {"sub_tasks": sub_tasks},
1404
+ )
1405
+ else:
1406
+ state.sub_questions = [state.question]
1407
+ state.pending_questions = [state.question]
1408
+
1409
+ except Exception as e:
1410
+ logger.warning("decomposition_failed", error=str(e))
1411
+ state.sub_questions = [state.question]
1412
+ state.pending_questions = [state.question]
1413
+
1414
+ return state
1415
+
1416
+ def _phase_navigate(self, state: RLMAgentState) -> RLMAgentState:
1417
+ """Phase 3: Navigate the tree using ToT with pre-filtering."""
1418
+ state.add_trace("navigation", "Starting tree navigation")
1419
+
1420
+ # Initialize navigation at root
1421
+ state.current_node_id = self.root_id
1422
+
1423
+ while state.iteration < self.config.max_iterations:
1424
+ state.iteration += 1
1425
+
1426
+ # Check termination conditions
1427
+ if state.current_node_id is None and not state.nodes_to_visit:
1428
+ break
1429
+
1430
+ # Pop from queue if needed
1431
+ if state.current_node_id is None and state.nodes_to_visit:
1432
+ state.current_node_id = state.nodes_to_visit.pop(0)
1433
+
1434
+ if state.current_node_id is None:
1435
+ break
1436
+
1437
+ node = self.skeleton.get(state.current_node_id)
1438
+ if node is None:
1439
+ state.current_node_id = None
1440
+ continue
1441
+
1442
+ # Already visited?
1443
+ if state.current_node_id in state.visited_nodes:
1444
+ state.current_node_id = None
1445
+ continue
1446
+
1447
+ # Decide: expand or traverse
1448
+ action = self._decide_action(state, node)
1449
+
1450
+ if action == "expand":
1451
+ state = self._do_expand(state, node)
1452
+ elif action == "traverse":
1453
+ state = self._do_traverse(state, node)
1454
+ elif action == "backtrack":
1455
+ state = self._do_backtrack(state)
1456
+ else:
1457
+ break
1458
+
1459
+ state.add_trace(
1460
+ "navigation",
1461
+ f"Navigation complete after {state.iteration} iterations",
1462
+ {"variables_found": len(state.variables)},
1463
+ )
1464
+
1465
+ return state
1466
+
1467
+ def _decide_action(
1468
+ self,
1469
+ state: RLMAgentState,
1470
+ node: SkeletonNode,
1471
+ ) -> Literal["expand", "traverse", "backtrack", "done"]:
1472
+ """Decide what action to take at current node."""
1473
+ # Leaf node -> expand
1474
+ if not node.child_ids:
1475
+ if node.node_id in state.visited_nodes:
1476
+ return "done"
1477
+ return "expand"
1478
+
1479
+ # Check unvisited children
1480
+ unvisited = [
1481
+ cid for cid in node.child_ids
1482
+ if cid not in state.visited_nodes and cid not in state.dead_ends
1483
+ ]
1484
+
1485
+ if not unvisited:
1486
+ if state.backtrack_stack:
1487
+ return "backtrack"
1488
+ return "done"
1489
+
1490
+ # Has unvisited children -> traverse
1491
+ return "traverse"
1492
+
1493
+ def _do_expand(self, state: RLMAgentState, node: SkeletonNode) -> RLMAgentState:
1494
+ """Expand current node: fetch content and store as variable."""
1495
+ content = self.kv_store.get(node.node_id)
1496
+
1497
+ if content:
1498
+ pointer = generate_pointer_name(node.header)
1499
+ self.variable_store.assign(pointer, content, node.node_id)
1500
+ state.variables.append(pointer)
1501
+ state.context += f"\nFound: {pointer} (from {node.header})"
1502
+
1503
+ state.add_trace(
1504
+ "variable_stitching",
1505
+ f"Stored {pointer}",
1506
+ {"node": node.node_id, "chars": len(content)},
1507
+ )
1508
+
1509
+ state.visited_nodes.append(node.node_id)
1510
+ state.current_node_id = None
1511
+ return state
1512
+
1513
+ def _do_traverse(self, state: RLMAgentState, node: SkeletonNode) -> RLMAgentState:
1514
+ """Traverse to children using ToT with pre-filtering."""
1515
+ # Get children
1516
+ children = [self.skeleton.get(cid) for cid in node.child_ids]
1517
+ children = [c for c in children if c is not None]
1518
+
1519
+ # Apply pre-filtering
1520
+ if state.extracted_keywords and self.config.enable_pre_filtering:
1521
+ matching, remaining = self.pre_filter.filter_nodes_by_keywords(
1522
+ children,
1523
+ state.extracted_keywords,
1524
+ )
1525
+
1526
+ # If we have matching nodes, prioritize them
1527
+ if matching:
1528
+ selected = matching[:self.config.top_k]
1529
+ state.add_trace(
1530
+ "navigation",
1531
+ f"Pre-filter selected {len(selected)}/{len(children)} children",
1532
+ {"selected": [n.node_id for n in selected]},
1533
+ )
1534
+ else:
1535
+ # Fall back to ToT evaluation
1536
+ selected = self._tot_evaluate_children(state, children)
1537
+ else:
1538
+ # Use ToT evaluation
1539
+ selected = self._tot_evaluate_children(state, children)
1540
+
1541
+ # Queue selected children
1542
+ if selected:
1543
+ for child in selected:
1544
+ if child.node_id not in state.nodes_to_visit:
1545
+ state.nodes_to_visit.append(child.node_id)
1546
+
1547
+ # Push current node to backtrack stack
1548
+ state.backtrack_stack.append(node.node_id)
1549
+ else:
1550
+ # Dead end
1551
+ state.dead_ends.append(node.node_id)
1552
+
1553
+ state.visited_nodes.append(node.node_id)
1554
+ state.current_node_id = None
1555
+ return state
1556
+
1557
+ def _tot_evaluate_children(
1558
+ self,
1559
+ state: RLMAgentState,
1560
+ children: list[SkeletonNode],
1561
+ ) -> list[SkeletonNode]:
1562
+ """Use Tree of Thoughts to evaluate children."""
1563
+ if not self._llm_fn or not children:
1564
+ return children[:self.config.top_k]
1565
+
1566
+ # Format children for evaluation
1567
+ children_text = "\n".join(
1568
+ f" - [{c.node_id}] {c.header}: {c.summary[:150]}"
1569
+ for c in children
1570
+ )
1571
+
1572
+ current_node = self.skeleton.get(state.current_node_id or self.root_id)
1573
+ current_summary = f"{current_node.header}: {current_node.summary}" if current_node else ""
1574
+
1575
+ tot_prompt = f"""You are evaluating document sections for relevance.
1576
+
1577
+ Current location: {current_summary}
1578
+
1579
+ Children sections:
1580
+ {children_text}
1581
+
1582
+ Query: {state.current_sub_question or state.question}
1583
+
1584
+ TASK: Evaluate each child's probability (0.0-1.0) of containing relevant information.
1585
+
1586
+ OUTPUT FORMAT (JSON):
1587
+ {{
1588
+ "evaluations": [
1589
+ {{"node_id": "...", "probability": 0.85, "reasoning": "..."}}
1590
+ ],
1591
+ "selected_nodes": ["node_id_1", "node_id_2"],
1592
+ "is_dead_end": false
1593
+ }}
1594
+
1595
+ JSON only:"""
1596
+
1597
+ try:
1598
+ import json
1599
+
1600
+ response = self._llm_fn(tot_prompt)
1601
+ json_match = re.search(r'\{[\s\S]*\}', response)
1602
+
1603
+ if json_match:
1604
+ result = json.loads(json_match.group())
1605
+ selected_ids = result.get("selected_nodes", [])
1606
+
1607
+ # Map back to nodes
1608
+ selected = [c for c in children if c.node_id in selected_ids]
1609
+
1610
+ if not selected and not result.get("is_dead_end", False):
1611
+ # Fallback: take top-k by probability
1612
+ evaluations = result.get("evaluations", [])
1613
+ sorted_evals = sorted(
1614
+ evaluations,
1615
+ key=lambda x: x.get("probability", 0),
1616
+ reverse=True,
1617
+ )
1618
+ top_ids = [e["node_id"] for e in sorted_evals[:self.config.top_k]]
1619
+ selected = [c for c in children if c.node_id in top_ids]
1620
+
1621
+ return selected
1622
+
1623
+ except Exception as e:
1624
+ logger.warning("tot_evaluation_failed", error=str(e))
1625
+
1626
+ # Fallback: return first top_k children
1627
+ return children[:self.config.top_k]
1628
+
1629
+ def _do_backtrack(self, state: RLMAgentState) -> RLMAgentState:
1630
+ """Backtrack to previous node."""
1631
+ if state.backtrack_stack:
1632
+ parent_id = state.backtrack_stack.pop()
1633
+ state.dead_ends.append(state.current_node_id or "")
1634
+ state.current_node_id = parent_id
1635
+
1636
+ state.add_trace(
1637
+ "navigation",
1638
+ f"Backtracked to {parent_id}",
1639
+ {"from": state.current_node_id},
1640
+ )
1641
+ else:
1642
+ state.current_node_id = None
1643
+
1644
+ return state
1645
+
1646
+ def _phase_synthesize(self, state: RLMAgentState) -> RLMAgentState:
1647
+ """Phase 4: Synthesize answer from variables."""
1648
+ state.add_trace("synthesis", "Synthesizing answer from variables")
1649
+
1650
+ if not state.variables:
1651
+ state.answer = "No relevant content found in the document."
1652
+ state.confidence = 0.0
1653
+ return state
1654
+
1655
+ # Collect all variable content
1656
+ contents = []
1657
+ for pointer in state.variables:
1658
+ content = self.variable_store.resolve(pointer)
1659
+ if content:
1660
+ contents.append(f"=== {pointer} ===\n{content}")
1661
+
1662
+ context_text = "\n\n".join(contents)
1663
+
1664
+ if not self._llm_fn:
1665
+ state.answer = context_text
1666
+ state.confidence = 0.5
1667
+ return state
1668
+
1669
+ # Handle multiple choice
1670
+ options = state.metadata.get("options")
1671
+ if options:
1672
+ options_text = "\n".join(f"{chr(65+i)}. {opt}" for i, opt in enumerate(options))
1673
+ synthesis_prompt = f"""Based on the context, answer this multiple-choice question.
1674
+
1675
+ Question: {state.question}
1676
+
1677
+ Options:
1678
+ {options_text}
1679
+
1680
+ Context:
1681
+ {context_text}
1682
+
1683
+ Respond with ONLY the letter and full option text (e.g., "A. [option text]"):"""
1684
+ else:
1685
+ synthesis_prompt = f"""Based on the context, answer the question concisely.
1686
+
1687
+ Question: {state.question}
1688
+
1689
+ Context:
1690
+ {context_text}
1691
+
1692
+ Answer:"""
1693
+
1694
+ try:
1695
+ answer = self._llm_fn(synthesis_prompt)
1696
+ state.answer = answer.strip()
1697
+ state.confidence = min(1.0, len(state.variables) * 0.25)
1698
+
1699
+ # Normalize multiple choice answer
1700
+ if options:
1701
+ state.answer = self._normalize_mc_answer(state.answer, options)
1702
+
1703
+ except Exception as e:
1704
+ logger.error("synthesis_failed", error=str(e))
1705
+ state.answer = f"Error during synthesis: {str(e)}"
1706
+ state.confidence = 0.0
1707
+
1708
+ return state
1709
+
1710
+ def _normalize_mc_answer(self, answer: str, options: list) -> str:
1711
+ """Normalize multiple choice answer to match option text."""
1712
+ answer_lower = answer.lower().strip()
1713
+
1714
+ for i, opt in enumerate(options):
1715
+ letter = chr(65 + i)
1716
+ opt_lower = opt.lower()
1717
+
1718
+ if (answer_lower.startswith(f"{letter.lower()}.") or
1719
+ answer_lower.startswith(f"{letter.lower()})") or
1720
+ opt_lower in answer_lower):
1721
+ return opt
1722
+
1723
+ return answer
1724
+
1725
+ def _phase_verify(self, state: RLMAgentState) -> RLMAgentState:
1726
+ """Phase 5: Verify the answer."""
1727
+ state.add_trace("verification", "Verifying answer")
1728
+
1729
+ # Collect evidence
1730
+ evidence = [
1731
+ self.variable_store.resolve(p) or ""
1732
+ for p in state.variables
1733
+ ]
1734
+
1735
+ result = self.verification_engine.verify_answer(
1736
+ state.question,
1737
+ state.answer or "",
1738
+ evidence,
1739
+ )
1740
+
1741
+ state.verification_result = result
1742
+
1743
+ if result.get("improved_answer"):
1744
+ state.answer = result["improved_answer"]
1745
+
1746
+ state.confidence = result.get("confidence", state.confidence)
1747
+
1748
+ state.add_trace(
1749
+ "verification",
1750
+ f"Verification complete: valid={result.get('is_valid', True)}",
1751
+ {"issues": result.get("issues", [])},
1752
+ )
1753
+
1754
+ return state
1755
+
1756
+
1757
+ # =============================================================================
1758
+ # Entity-Aware Query Decomposition
1759
+ # =============================================================================
1760
+
1761
+
1762
+ class EntityAwareDecomposer:
1763
+ """
1764
+ Enhances query decomposition by leveraging entity relationships
1765
+ from the knowledge graph.
1766
+
1767
+ This allows the navigator to:
1768
+ 1. Identify entities mentioned in the query
1769
+ 2. Look up related entities via the knowledge graph
1770
+ 3. Plan retrieval based on entity relationships
1771
+ 4. Generate entity-focused sub-queries
1772
+ """
1773
+
1774
+ def __init__(
1775
+ self,
1776
+ knowledge_graph=None,
1777
+ llm_fn: Callable[[str], str] | None = None,
1778
+ ):
1779
+ """
1780
+ Initialize the entity-aware decomposer.
1781
+
1782
+ Args:
1783
+ knowledge_graph: Optional knowledge graph for entity lookup.
1784
+ llm_fn: LLM function for query analysis.
1785
+ """
1786
+ self.kg = knowledge_graph
1787
+ self._llm_fn = llm_fn
1788
+
1789
+ def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
1790
+ """Set the LLM function."""
1791
+ self._llm_fn = llm_fn
1792
+
1793
+ def set_knowledge_graph(self, kg) -> None:
1794
+ """Set the knowledge graph."""
1795
+ self.kg = kg
1796
+
1797
+ def decompose_with_entities(
1798
+ self,
1799
+ query: str,
1800
+ doc_id: str | None = None,
1801
+ ) -> dict[str, Any]:
1802
+ """
1803
+ Decompose a query using entity awareness.
1804
+
1805
+ Args:
1806
+ query: The user's query.
1807
+ doc_id: Optional document ID to scope entity lookup.
1808
+
1809
+ Returns:
1810
+ Dict with sub_queries, entities_found, and retrieval_plan.
1811
+ """
1812
+ result = {
1813
+ "original_query": query,
1814
+ "sub_queries": [query],
1815
+ "entities_found": [],
1816
+ "entity_nodes": {},
1817
+ "retrieval_plan": [],
1818
+ }
1819
+
1820
+ if not self.kg:
1821
+ return result
1822
+
1823
+ # Step 1: Extract entity names from query
1824
+ entity_names = self._extract_entity_names(query)
1825
+
1826
+ if not entity_names:
1827
+ return result
1828
+
1829
+ # Step 2: Look up entities in knowledge graph
1830
+ entities_found = []
1831
+ entity_nodes: dict[str, list[str]] = {}
1832
+
1833
+ for name in entity_names:
1834
+ matches = self.kg.find_entities_by_name(name, fuzzy=True)
1835
+
1836
+ # Filter by document if specified
1837
+ if doc_id:
1838
+ matches = [e for e in matches if doc_id in e.document_ids]
1839
+
1840
+ for entity in matches:
1841
+ if entity not in entities_found:
1842
+ entities_found.append(entity)
1843
+ # Get nodes where this entity is mentioned
1844
+ entity_nodes[entity.id] = list(entity.node_ids)
1845
+
1846
+ result["entities_found"] = entities_found
1847
+ result["entity_nodes"] = entity_nodes
1848
+
1849
+ if not entities_found:
1850
+ return result
1851
+
1852
+ # Step 3: Get related entities and relationships
1853
+ related_entities = []
1854
+ relationships = []
1855
+
1856
+ for entity in entities_found:
1857
+ # Get entities co-mentioned with this one
1858
+ co_mentions = self.kg.get_entities_mentioned_together(entity.id)
1859
+ for related, count in co_mentions[:5]: # Top 5 co-mentions
1860
+ if related not in related_entities:
1861
+ related_entities.append(related)
1862
+
1863
+ # Get relationships
1864
+ rels = self.kg.get_entity_relationships(entity.id)
1865
+ relationships.extend(rels)
1866
+
1867
+ # Step 4: Generate entity-focused sub-queries
1868
+ sub_queries = self._generate_entity_sub_queries(
1869
+ query, entities_found, related_entities, relationships
1870
+ )
1871
+
1872
+ result["sub_queries"] = sub_queries
1873
+ result["related_entities"] = related_entities
1874
+ result["relationships"] = relationships
1875
+
1876
+ # Step 5: Create retrieval plan
1877
+ result["retrieval_plan"] = self._create_retrieval_plan(
1878
+ query, entities_found, entity_nodes, relationships
1879
+ )
1880
+
1881
+ logger.debug(
1882
+ "entity_aware_decomposition",
1883
+ entities=len(entities_found),
1884
+ sub_queries=len(sub_queries),
1885
+ relationships=len(relationships),
1886
+ )
1887
+
1888
+ return result
1889
+
1890
+ def _extract_entity_names(self, query: str) -> list[str]:
1891
+ """Extract potential entity names from a query."""
1892
+ entity_names = []
1893
+
1894
+ # Use LLM if available
1895
+ if self._llm_fn:
1896
+ try:
1897
+ prompt = f"""Extract entity names (people, organizations, places, documents) from this query.
1898
+
1899
+ Query: {query}
1900
+
1901
+ Return as JSON array of names:
1902
+ ["Name 1", "Name 2"]
1903
+
1904
+ JSON only:"""
1905
+
1906
+ response = self._llm_fn(prompt)
1907
+ json_match = re.search(r'\[[\s\S]*?\]', response)
1908
+ if json_match:
1909
+ import json
1910
+ entity_names = json.loads(json_match.group())
1911
+
1912
+ except Exception as e:
1913
+ logger.debug("entity_extraction_llm_failed", error=str(e))
1914
+
1915
+ # Fallback: extract capitalized phrases
1916
+ if not entity_names:
1917
+ # Find capitalized words (likely proper nouns)
1918
+ proper_nouns = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', query)
1919
+ entity_names = proper_nouns
1920
+
1921
+ return entity_names
1922
+
1923
+ def _generate_entity_sub_queries(
1924
+ self,
1925
+ query: str,
1926
+ entities: list,
1927
+ related: list,
1928
+ relationships: list,
1929
+ ) -> list[str]:
1930
+ """Generate sub-queries focused on entities."""
1931
+ sub_queries = []
1932
+
1933
+ if not self._llm_fn:
1934
+ # Simple decomposition: one query per entity
1935
+ for entity in entities[:3]:
1936
+ sub_queries.append(
1937
+ f"Find information about {entity.canonical_name}: {query}"
1938
+ )
1939
+ return sub_queries if sub_queries else [query]
1940
+
1941
+ # Use LLM for intelligent decomposition
1942
+ try:
1943
+ entity_names = [e.canonical_name for e in entities]
1944
+ related_names = [e.canonical_name for e in related[:5]]
1945
+
1946
+ rel_descriptions = []
1947
+ for rel in relationships[:10]:
1948
+ rel_descriptions.append(
1949
+ f"- {rel.source_id} {rel.type.value} {rel.target_id}"
1950
+ )
1951
+
1952
+ prompt = f"""Decompose this query into focused sub-queries based on the entities.
1953
+
1954
+ Query: {query}
1955
+
1956
+ Key entities found: {', '.join(entity_names)}
1957
+ Related entities: {', '.join(related_names)}
1958
+
1959
+ Known relationships:
1960
+ {chr(10).join(rel_descriptions) if rel_descriptions else '(none)'}
1961
+
1962
+ Generate 1-5 focused sub-queries. Each should target specific entities or relationships.
1963
+
1964
+ Return as JSON:
1965
+ {{"sub_queries": ["query 1", "query 2"]}}
1966
+
1967
+ JSON only:"""
1968
+
1969
+ response = self._llm_fn(prompt)
1970
+ json_match = re.search(r'\{[\s\S]*?\}', response)
1971
+ if json_match:
1972
+ import json
1973
+ result = json.loads(json_match.group())
1974
+ sub_queries = result.get("sub_queries", [])
1975
+
1976
+ except Exception as e:
1977
+ logger.debug("sub_query_generation_failed", error=str(e))
1978
+
1979
+ return sub_queries if sub_queries else [query]
1980
+
1981
+ def _create_retrieval_plan(
1982
+ self,
1983
+ query: str,
1984
+ entities: list,
1985
+ entity_nodes: dict[str, list[str]],
1986
+ relationships: list,
1987
+ ) -> list[dict[str, Any]]:
1988
+ """Create a retrieval plan based on entities."""
1989
+ plan = []
1990
+
1991
+ # Priority 1: Nodes with direct entity mentions
1992
+ priority_nodes = set()
1993
+ for entity in entities:
1994
+ nodes = entity_nodes.get(entity.id, [])
1995
+ for node_id in nodes:
1996
+ priority_nodes.add(node_id)
1997
+ plan.append({
1998
+ "node_id": node_id,
1999
+ "priority": 1,
2000
+ "reason": f"Contains {entity.canonical_name}",
2001
+ "entity_id": entity.id,
2002
+ })
2003
+
2004
+ # Priority 2: Nodes involved in relationships
2005
+ for rel in relationships:
2006
+ if rel.source_type == "node" and rel.source_id not in priority_nodes:
2007
+ plan.append({
2008
+ "node_id": rel.source_id,
2009
+ "priority": 2,
2010
+ "reason": f"Related via {rel.type.value}",
2011
+ "relationship_id": rel.id,
2012
+ })
2013
+ if rel.target_type == "node" and rel.target_id not in priority_nodes:
2014
+ plan.append({
2015
+ "node_id": rel.target_id,
2016
+ "priority": 2,
2017
+ "reason": f"Related via {rel.type.value}",
2018
+ "relationship_id": rel.id,
2019
+ })
2020
+
2021
+ # Sort by priority
2022
+ plan.sort(key=lambda x: x["priority"])
2023
+
2024
+ return plan
2025
+
2026
+
2027
+ # =============================================================================
2028
+ # Factory Function
2029
+ # =============================================================================
2030
+
2031
+
2032
+ def create_rlm_navigator(
2033
+ skeleton: dict[str, SkeletonNode],
2034
+ kv_store: KVStore,
2035
+ config: RLMConfig | None = None,
2036
+ knowledge_graph=None,
2037
+ ) -> RLMNavigator:
2038
+ """
2039
+ Create an RLM Navigator instance.
2040
+
2041
+ Args:
2042
+ skeleton: Skeleton index.
2043
+ kv_store: KV store with full content.
2044
+ config: Optional configuration.
2045
+ knowledge_graph: Optional knowledge graph for entity-aware queries.
2046
+
2047
+ Returns:
2048
+ Configured RLMNavigator.
2049
+
2050
+ Example:
2051
+ from rnsr import ingest_document, build_skeleton_index
2052
+ from rnsr.agent.rlm_navigator import create_rlm_navigator, RLMConfig
2053
+ from rnsr.indexing.knowledge_graph import KnowledgeGraph
2054
+
2055
+ result = ingest_document("contract.pdf")
2056
+ skeleton, kv_store = build_skeleton_index(result.tree)
2057
+
2058
+ # With knowledge graph for entity-aware queries
2059
+ kg = KnowledgeGraph("./data/kg.db")
2060
+
2061
+ # With custom config
2062
+ config = RLMConfig(
2063
+ max_recursion_depth=3,
2064
+ enable_pre_filtering=True,
2065
+ enable_verification=True,
2066
+ )
2067
+
2068
+ navigator = create_rlm_navigator(skeleton, kv_store, config, kg)
2069
+ result = navigator.navigate("What are the liability terms?")
2070
+ print(result["answer"])
2071
+ """
2072
+ nav = RLMNavigator(skeleton, kv_store, config, knowledge_graph)
2073
+
2074
+ # Configure LLM
2075
+ try:
2076
+ from rnsr.llm import get_llm
2077
+ llm = get_llm()
2078
+ nav.set_llm_function(lambda p: str(llm.complete(p)))
2079
+ except Exception as e:
2080
+ logger.warning("llm_config_failed", error=str(e))
2081
+
2082
+ return nav
2083
+
2084
+
2085
+ def run_rlm_navigator(
2086
+ question: str,
2087
+ skeleton: dict[str, SkeletonNode],
2088
+ kv_store: KVStore,
2089
+ config: RLMConfig | None = None,
2090
+ metadata: dict[str, Any] | None = None,
2091
+ ) -> dict[str, Any]:
2092
+ """
2093
+ Run the RLM Navigator on a question.
2094
+
2095
+ Convenience function that creates and runs the navigator.
2096
+
2097
+ Args:
2098
+ question: The user's question.
2099
+ skeleton: Skeleton index.
2100
+ kv_store: KV store.
2101
+ config: Optional configuration.
2102
+ metadata: Optional metadata.
2103
+
2104
+ Returns:
2105
+ Dict with answer, confidence, trace.
2106
+ """
2107
+ navigator = create_rlm_navigator(skeleton, kv_store, config)
2108
+ return navigator.navigate(question, metadata)