rnsr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnsr/__init__.py +118 -0
- rnsr/__main__.py +242 -0
- rnsr/agent/__init__.py +218 -0
- rnsr/agent/cross_doc_navigator.py +767 -0
- rnsr/agent/graph.py +1557 -0
- rnsr/agent/llm_cache.py +575 -0
- rnsr/agent/navigator_api.py +497 -0
- rnsr/agent/provenance.py +772 -0
- rnsr/agent/query_clarifier.py +617 -0
- rnsr/agent/reasoning_memory.py +736 -0
- rnsr/agent/repl_env.py +709 -0
- rnsr/agent/rlm_navigator.py +2108 -0
- rnsr/agent/self_reflection.py +602 -0
- rnsr/agent/variable_store.py +308 -0
- rnsr/benchmarks/__init__.py +118 -0
- rnsr/benchmarks/comprehensive_benchmark.py +733 -0
- rnsr/benchmarks/evaluation_suite.py +1210 -0
- rnsr/benchmarks/finance_bench.py +147 -0
- rnsr/benchmarks/pdf_merger.py +178 -0
- rnsr/benchmarks/performance.py +321 -0
- rnsr/benchmarks/quality.py +321 -0
- rnsr/benchmarks/runner.py +298 -0
- rnsr/benchmarks/standard_benchmarks.py +995 -0
- rnsr/client.py +560 -0
- rnsr/document_store.py +394 -0
- rnsr/exceptions.py +74 -0
- rnsr/extraction/__init__.py +172 -0
- rnsr/extraction/candidate_extractor.py +357 -0
- rnsr/extraction/entity_extractor.py +581 -0
- rnsr/extraction/entity_linker.py +825 -0
- rnsr/extraction/grounded_extractor.py +722 -0
- rnsr/extraction/learned_types.py +599 -0
- rnsr/extraction/models.py +232 -0
- rnsr/extraction/relationship_extractor.py +600 -0
- rnsr/extraction/relationship_patterns.py +511 -0
- rnsr/extraction/relationship_validator.py +392 -0
- rnsr/extraction/rlm_extractor.py +589 -0
- rnsr/extraction/rlm_unified_extractor.py +990 -0
- rnsr/extraction/tot_validator.py +610 -0
- rnsr/extraction/unified_extractor.py +342 -0
- rnsr/indexing/__init__.py +60 -0
- rnsr/indexing/knowledge_graph.py +1128 -0
- rnsr/indexing/kv_store.py +313 -0
- rnsr/indexing/persistence.py +323 -0
- rnsr/indexing/semantic_retriever.py +237 -0
- rnsr/indexing/semantic_search.py +320 -0
- rnsr/indexing/skeleton_index.py +395 -0
- rnsr/ingestion/__init__.py +161 -0
- rnsr/ingestion/chart_parser.py +569 -0
- rnsr/ingestion/document_boundary.py +662 -0
- rnsr/ingestion/font_histogram.py +334 -0
- rnsr/ingestion/header_classifier.py +595 -0
- rnsr/ingestion/hierarchical_cluster.py +515 -0
- rnsr/ingestion/layout_detector.py +356 -0
- rnsr/ingestion/layout_model.py +379 -0
- rnsr/ingestion/ocr_fallback.py +177 -0
- rnsr/ingestion/pipeline.py +936 -0
- rnsr/ingestion/semantic_fallback.py +417 -0
- rnsr/ingestion/table_parser.py +799 -0
- rnsr/ingestion/text_builder.py +460 -0
- rnsr/ingestion/tree_builder.py +402 -0
- rnsr/ingestion/vision_retrieval.py +965 -0
- rnsr/ingestion/xy_cut.py +555 -0
- rnsr/llm.py +733 -0
- rnsr/models.py +167 -0
- rnsr/py.typed +2 -0
- rnsr-0.1.0.dist-info/METADATA +592 -0
- rnsr-0.1.0.dist-info/RECORD +72 -0
- rnsr-0.1.0.dist-info/WHEEL +5 -0
- rnsr-0.1.0.dist-info/entry_points.txt +2 -0
- rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
- rnsr-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2108 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLM Navigator - Recursive Language Model Navigator with Full REPL Integration
|
|
3
|
+
|
|
4
|
+
This module implements the full RLM (Recursive Language Model) pattern from the
|
|
5
|
+
arxiv paper "Recursive Language Models" combined with RNSR's tree-based retrieval.
|
|
6
|
+
|
|
7
|
+
Key Features:
|
|
8
|
+
1. Full REPL environment with code execution for document filtering
|
|
9
|
+
2. Pre-LLM filtering using regex/keyword search before ToT evaluation
|
|
10
|
+
3. Deep recursive sub-LLM calls (configurable depth)
|
|
11
|
+
4. Answer verification loops
|
|
12
|
+
5. Async parallel sub-LLM processing
|
|
13
|
+
6. Adaptive learning for stop words and query patterns
|
|
14
|
+
|
|
15
|
+
This is the state-of-the-art combination of:
|
|
16
|
+
- PageIndex: Vectorless, reasoning-based tree search
|
|
17
|
+
- RLMs: REPL environment with recursive sub-LLM calls
|
|
18
|
+
- RNSR: Latent hierarchy reconstruction + variable stitching
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import asyncio
|
|
24
|
+
import json
|
|
25
|
+
import os
|
|
26
|
+
import re
|
|
27
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
28
|
+
from dataclasses import dataclass, field
|
|
29
|
+
from datetime import datetime, timezone
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from threading import Lock
|
|
32
|
+
from typing import Any, Callable, Literal
|
|
33
|
+
|
|
34
|
+
import structlog
|
|
35
|
+
|
|
36
|
+
from rnsr.agent.variable_store import VariableStore, generate_pointer_name
|
|
37
|
+
from rnsr.indexing.kv_store import KVStore
|
|
38
|
+
from rnsr.models import SkeletonNode, TraceEntry
|
|
39
|
+
|
|
40
|
+
logger = structlog.get_logger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# =============================================================================
|
|
44
|
+
# Learned Stop Words Registry
|
|
45
|
+
# =============================================================================
|
|
46
|
+
|
|
47
|
+
DEFAULT_STOP_WORDS_PATH = Path.home() / ".rnsr" / "learned_stop_words.json"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LearnedStopWords:
|
|
51
|
+
"""
|
|
52
|
+
Registry for learning domain-specific stop words.
|
|
53
|
+
|
|
54
|
+
Learns:
|
|
55
|
+
- Words that are generic in your domain (should be filtered)
|
|
56
|
+
- Words that seem generic but are important in your domain (should be kept)
|
|
57
|
+
|
|
58
|
+
Examples:
|
|
59
|
+
- Legal: "hereby", "whereas" are filler (add to stop)
|
|
60
|
+
- Legal: "party" is important (remove from stop)
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# Base stop words (always included unless explicitly removed)
|
|
64
|
+
BASE_STOP_WORDS = {
|
|
65
|
+
"what", "is", "the", "a", "an", "are", "was", "were", "be", "been",
|
|
66
|
+
"being", "have", "has", "had", "do", "does", "did", "will", "would",
|
|
67
|
+
"could", "should", "may", "might", "must", "shall", "can", "need",
|
|
68
|
+
"dare", "ought", "used", "to", "of", "in", "for", "on", "with", "at",
|
|
69
|
+
"by", "from", "about", "into", "through", "during", "before", "after",
|
|
70
|
+
"above", "below", "between", "under", "again", "further", "then",
|
|
71
|
+
"once", "here", "there", "when", "where", "why", "how", "all", "each",
|
|
72
|
+
"few", "more", "most", "other", "some", "such", "no", "nor", "not",
|
|
73
|
+
"only", "own", "same", "so", "than", "too", "very", "just", "and",
|
|
74
|
+
"but", "if", "or", "because", "as", "until", "while", "this", "that",
|
|
75
|
+
"these", "those", "find", "show", "list", "describe", "explain", "tell",
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
storage_path: Path | str | None = None,
|
|
81
|
+
auto_save: bool = True,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Initialize the learned stop words registry.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
storage_path: Path to JSON file for persistence.
|
|
88
|
+
auto_save: Whether to save after changes.
|
|
89
|
+
"""
|
|
90
|
+
self.storage_path = Path(storage_path) if storage_path else DEFAULT_STOP_WORDS_PATH
|
|
91
|
+
self.auto_save = auto_save
|
|
92
|
+
|
|
93
|
+
self._lock = Lock()
|
|
94
|
+
self._added_stop_words: dict[str, dict[str, Any]] = {} # Domain-specific additions
|
|
95
|
+
self._removed_stop_words: dict[str, dict[str, Any]] = {} # Words to keep despite being in base
|
|
96
|
+
self._dirty = False
|
|
97
|
+
|
|
98
|
+
self._load()
|
|
99
|
+
|
|
100
|
+
def _load(self) -> None:
|
|
101
|
+
"""Load learned stop words from storage."""
|
|
102
|
+
if not self.storage_path.exists():
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
with open(self.storage_path, "r") as f:
|
|
107
|
+
data = json.load(f)
|
|
108
|
+
|
|
109
|
+
self._added_stop_words = data.get("added", {})
|
|
110
|
+
self._removed_stop_words = data.get("removed", {})
|
|
111
|
+
|
|
112
|
+
logger.info(
|
|
113
|
+
"learned_stop_words_loaded",
|
|
114
|
+
added=len(self._added_stop_words),
|
|
115
|
+
removed=len(self._removed_stop_words),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.warning("failed_to_load_stop_words", error=str(e))
|
|
120
|
+
|
|
121
|
+
def _save(self) -> None:
|
|
122
|
+
"""Save to storage."""
|
|
123
|
+
if not self._dirty:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
|
|
129
|
+
data = {
|
|
130
|
+
"version": "1.0",
|
|
131
|
+
"updated_at": datetime.utcnow().isoformat(),
|
|
132
|
+
"added": self._added_stop_words,
|
|
133
|
+
"removed": self._removed_stop_words,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
with open(self.storage_path, "w") as f:
|
|
137
|
+
json.dump(data, f, indent=2)
|
|
138
|
+
|
|
139
|
+
self._dirty = False
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.warning("failed_to_save_stop_words", error=str(e))
|
|
143
|
+
|
|
144
|
+
def add_stop_word(
|
|
145
|
+
self,
|
|
146
|
+
word: str,
|
|
147
|
+
domain: str = "general",
|
|
148
|
+
reason: str = "",
|
|
149
|
+
) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Add a word to the stop word list.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
word: Word to add.
|
|
155
|
+
domain: Domain category.
|
|
156
|
+
reason: Why this should be a stop word.
|
|
157
|
+
"""
|
|
158
|
+
word = word.lower().strip()
|
|
159
|
+
|
|
160
|
+
if not word or word in self.BASE_STOP_WORDS:
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
with self._lock:
|
|
164
|
+
now = datetime.utcnow().isoformat()
|
|
165
|
+
|
|
166
|
+
if word not in self._added_stop_words:
|
|
167
|
+
self._added_stop_words[word] = {
|
|
168
|
+
"count": 0,
|
|
169
|
+
"domain": domain,
|
|
170
|
+
"reason": reason,
|
|
171
|
+
"first_seen": now,
|
|
172
|
+
"last_seen": now,
|
|
173
|
+
}
|
|
174
|
+
logger.info("stop_word_added", word=word)
|
|
175
|
+
|
|
176
|
+
self._added_stop_words[word]["count"] += 1
|
|
177
|
+
self._added_stop_words[word]["last_seen"] = now
|
|
178
|
+
|
|
179
|
+
self._dirty = True
|
|
180
|
+
|
|
181
|
+
if self.auto_save:
|
|
182
|
+
self._save()
|
|
183
|
+
|
|
184
|
+
def remove_stop_word(
|
|
185
|
+
self,
|
|
186
|
+
word: str,
|
|
187
|
+
domain: str = "general",
|
|
188
|
+
reason: str = "",
|
|
189
|
+
) -> None:
|
|
190
|
+
"""
|
|
191
|
+
Mark a base stop word as important (should not be filtered).
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
word: Word to keep.
|
|
195
|
+
domain: Domain where this is important.
|
|
196
|
+
reason: Why this should be kept.
|
|
197
|
+
"""
|
|
198
|
+
word = word.lower().strip()
|
|
199
|
+
|
|
200
|
+
if not word or word not in self.BASE_STOP_WORDS:
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
with self._lock:
|
|
204
|
+
now = datetime.utcnow().isoformat()
|
|
205
|
+
|
|
206
|
+
if word not in self._removed_stop_words:
|
|
207
|
+
self._removed_stop_words[word] = {
|
|
208
|
+
"count": 0,
|
|
209
|
+
"domain": domain,
|
|
210
|
+
"reason": reason,
|
|
211
|
+
"first_seen": now,
|
|
212
|
+
"last_seen": now,
|
|
213
|
+
}
|
|
214
|
+
logger.info("stop_word_marked_important", word=word)
|
|
215
|
+
|
|
216
|
+
self._removed_stop_words[word]["count"] += 1
|
|
217
|
+
self._removed_stop_words[word]["last_seen"] = now
|
|
218
|
+
|
|
219
|
+
self._dirty = True
|
|
220
|
+
|
|
221
|
+
if self.auto_save:
|
|
222
|
+
self._save()
|
|
223
|
+
|
|
224
|
+
def get_stop_words(self, min_count: int = 1) -> set[str]:
|
|
225
|
+
"""
|
|
226
|
+
Get the effective stop word set.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Set of words to filter (base + added - removed).
|
|
230
|
+
"""
|
|
231
|
+
with self._lock:
|
|
232
|
+
# Start with base
|
|
233
|
+
result = set(self.BASE_STOP_WORDS)
|
|
234
|
+
|
|
235
|
+
# Add learned additions
|
|
236
|
+
for word, data in self._added_stop_words.items():
|
|
237
|
+
if data["count"] >= min_count:
|
|
238
|
+
result.add(word)
|
|
239
|
+
|
|
240
|
+
# Remove marked-important words
|
|
241
|
+
for word, data in self._removed_stop_words.items():
|
|
242
|
+
if data["count"] >= min_count:
|
|
243
|
+
result.discard(word)
|
|
244
|
+
|
|
245
|
+
return result
|
|
246
|
+
|
|
247
|
+
def get_stats(self) -> dict[str, Any]:
|
|
248
|
+
"""Get statistics about stop words."""
|
|
249
|
+
return {
|
|
250
|
+
"base_count": len(self.BASE_STOP_WORDS),
|
|
251
|
+
"added_count": len(self._added_stop_words),
|
|
252
|
+
"removed_count": len(self._removed_stop_words),
|
|
253
|
+
"effective_count": len(self.get_stop_words()),
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# Global stop words registry
|
|
258
|
+
_global_stop_words: LearnedStopWords | None = None
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def get_learned_stop_words() -> LearnedStopWords:
|
|
262
|
+
"""Get the global learned stop words registry."""
|
|
263
|
+
global _global_stop_words
|
|
264
|
+
|
|
265
|
+
if _global_stop_words is None:
|
|
266
|
+
custom_path = os.getenv("RNSR_STOP_WORDS_PATH")
|
|
267
|
+
_global_stop_words = LearnedStopWords(
|
|
268
|
+
storage_path=custom_path if custom_path else None
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return _global_stop_words
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# =============================================================================
|
|
275
|
+
# Learned Query Patterns Registry
|
|
276
|
+
# =============================================================================
|
|
277
|
+
|
|
278
|
+
DEFAULT_QUERY_PATTERNS_PATH = Path.home() / ".rnsr" / "learned_query_patterns.json"
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class LearnedQueryPatterns:
|
|
282
|
+
"""
|
|
283
|
+
Registry for learning successful query patterns.
|
|
284
|
+
|
|
285
|
+
Tracks:
|
|
286
|
+
- Query patterns that lead to high-confidence answers
|
|
287
|
+
- Patterns that need decomposition vs. direct retrieval
|
|
288
|
+
- Entity-focused vs. section-focused queries
|
|
289
|
+
|
|
290
|
+
Used to:
|
|
291
|
+
- Inform decomposition strategy
|
|
292
|
+
- Adjust confidence thresholds
|
|
293
|
+
- Route to specialized handlers
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
def __init__(
|
|
297
|
+
self,
|
|
298
|
+
storage_path: Path | str | None = None,
|
|
299
|
+
auto_save: bool = True,
|
|
300
|
+
):
|
|
301
|
+
"""
|
|
302
|
+
Initialize the query patterns registry.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
storage_path: Path to JSON file for persistence.
|
|
306
|
+
auto_save: Whether to save after changes.
|
|
307
|
+
"""
|
|
308
|
+
self.storage_path = Path(storage_path) if storage_path else DEFAULT_QUERY_PATTERNS_PATH
|
|
309
|
+
self.auto_save = auto_save
|
|
310
|
+
|
|
311
|
+
self._lock = Lock()
|
|
312
|
+
self._patterns: dict[str, dict[str, Any]] = {}
|
|
313
|
+
self._dirty = False
|
|
314
|
+
|
|
315
|
+
self._load()
|
|
316
|
+
|
|
317
|
+
def _load(self) -> None:
|
|
318
|
+
"""Load learned patterns from storage."""
|
|
319
|
+
if not self.storage_path.exists():
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
with open(self.storage_path, "r") as f:
|
|
324
|
+
data = json.load(f)
|
|
325
|
+
|
|
326
|
+
self._patterns = data.get("patterns", {})
|
|
327
|
+
|
|
328
|
+
logger.info(
|
|
329
|
+
"query_patterns_loaded",
|
|
330
|
+
patterns=len(self._patterns),
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
except Exception as e:
|
|
334
|
+
logger.warning("failed_to_load_query_patterns", error=str(e))
|
|
335
|
+
|
|
336
|
+
def _save(self) -> None:
|
|
337
|
+
"""Save to storage."""
|
|
338
|
+
if not self._dirty:
|
|
339
|
+
return
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
|
343
|
+
|
|
344
|
+
data = {
|
|
345
|
+
"version": "1.0",
|
|
346
|
+
"updated_at": datetime.utcnow().isoformat(),
|
|
347
|
+
"patterns": self._patterns,
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
with open(self.storage_path, "w") as f:
|
|
351
|
+
json.dump(data, f, indent=2)
|
|
352
|
+
|
|
353
|
+
self._dirty = False
|
|
354
|
+
|
|
355
|
+
except Exception as e:
|
|
356
|
+
logger.warning("failed_to_save_query_patterns", error=str(e))
|
|
357
|
+
|
|
358
|
+
def record_query(
|
|
359
|
+
self,
|
|
360
|
+
query: str,
|
|
361
|
+
pattern_type: str,
|
|
362
|
+
success: bool,
|
|
363
|
+
confidence: float,
|
|
364
|
+
needed_decomposition: bool,
|
|
365
|
+
sub_questions_count: int = 0,
|
|
366
|
+
entities_involved: list[str] | None = None,
|
|
367
|
+
) -> None:
|
|
368
|
+
"""
|
|
369
|
+
Record a query and its outcome.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
query: The original query.
|
|
373
|
+
pattern_type: Detected pattern type (entity_lookup, comparison, etc.)
|
|
374
|
+
success: Whether the query was answered successfully.
|
|
375
|
+
confidence: Answer confidence score.
|
|
376
|
+
needed_decomposition: Whether decomposition was required.
|
|
377
|
+
sub_questions_count: Number of sub-questions generated.
|
|
378
|
+
entities_involved: Entity types involved in the query.
|
|
379
|
+
"""
|
|
380
|
+
pattern_type = pattern_type.lower().strip()
|
|
381
|
+
|
|
382
|
+
with self._lock:
|
|
383
|
+
now = datetime.utcnow().isoformat()
|
|
384
|
+
|
|
385
|
+
if pattern_type not in self._patterns:
|
|
386
|
+
self._patterns[pattern_type] = {
|
|
387
|
+
"total_queries": 0,
|
|
388
|
+
"successful_queries": 0,
|
|
389
|
+
"total_confidence": 0.0,
|
|
390
|
+
"decomposition_count": 0,
|
|
391
|
+
"total_sub_questions": 0,
|
|
392
|
+
"entity_types": {},
|
|
393
|
+
"first_seen": now,
|
|
394
|
+
"last_seen": now,
|
|
395
|
+
"example_queries": [],
|
|
396
|
+
}
|
|
397
|
+
logger.info("new_query_pattern_discovered", pattern_type=pattern_type)
|
|
398
|
+
|
|
399
|
+
pt = self._patterns[pattern_type]
|
|
400
|
+
pt["total_queries"] += 1
|
|
401
|
+
pt["total_confidence"] += confidence
|
|
402
|
+
pt["last_seen"] = now
|
|
403
|
+
|
|
404
|
+
if success:
|
|
405
|
+
pt["successful_queries"] += 1
|
|
406
|
+
|
|
407
|
+
if needed_decomposition:
|
|
408
|
+
pt["decomposition_count"] += 1
|
|
409
|
+
pt["total_sub_questions"] += sub_questions_count
|
|
410
|
+
|
|
411
|
+
if entities_involved:
|
|
412
|
+
for entity_type in entities_involved:
|
|
413
|
+
pt["entity_types"][entity_type] = pt["entity_types"].get(entity_type, 0) + 1
|
|
414
|
+
|
|
415
|
+
if len(pt["example_queries"]) < 5:
|
|
416
|
+
pt["example_queries"].append({
|
|
417
|
+
"query": query[:200],
|
|
418
|
+
"success": success,
|
|
419
|
+
"confidence": confidence,
|
|
420
|
+
"timestamp": now,
|
|
421
|
+
})
|
|
422
|
+
|
|
423
|
+
self._dirty = True
|
|
424
|
+
|
|
425
|
+
if self.auto_save:
|
|
426
|
+
self._save()
|
|
427
|
+
|
|
428
|
+
def detect_pattern_type(self, query: str) -> str:
|
|
429
|
+
"""
|
|
430
|
+
Detect the pattern type of a query.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
query: The query to analyze.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
Detected pattern type.
|
|
437
|
+
"""
|
|
438
|
+
query_lower = query.lower()
|
|
439
|
+
|
|
440
|
+
# Pattern detection heuristics
|
|
441
|
+
if any(word in query_lower for word in ["compare", "difference", "versus", "vs"]):
|
|
442
|
+
return "comparison"
|
|
443
|
+
|
|
444
|
+
if any(word in query_lower for word in ["list", "all", "every", "enumerate"]):
|
|
445
|
+
return "enumeration"
|
|
446
|
+
|
|
447
|
+
if any(word in query_lower for word in ["when", "date", "time", "timeline"]):
|
|
448
|
+
return "temporal"
|
|
449
|
+
|
|
450
|
+
if any(word in query_lower for word in ["who", "person", "name"]):
|
|
451
|
+
return "entity_person"
|
|
452
|
+
|
|
453
|
+
if any(word in query_lower for word in ["company", "organization", "entity"]):
|
|
454
|
+
return "entity_organization"
|
|
455
|
+
|
|
456
|
+
if any(word in query_lower for word in ["how much", "amount", "price", "cost", "$"]):
|
|
457
|
+
return "monetary"
|
|
458
|
+
|
|
459
|
+
if any(word in query_lower for word in ["section", "clause", "paragraph", "article"]):
|
|
460
|
+
return "section_lookup"
|
|
461
|
+
|
|
462
|
+
if any(word in query_lower for word in ["what is", "define", "explain", "describe"]):
|
|
463
|
+
return "definition"
|
|
464
|
+
|
|
465
|
+
if any(word in query_lower for word in ["why", "reason", "cause"]):
|
|
466
|
+
return "causal"
|
|
467
|
+
|
|
468
|
+
return "general"
|
|
469
|
+
|
|
470
|
+
def get_pattern_stats(self, pattern_type: str) -> dict[str, Any] | None:
|
|
471
|
+
"""
|
|
472
|
+
Get statistics for a pattern type.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
pattern_type: The pattern type to look up.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
Pattern statistics or None if not found.
|
|
479
|
+
"""
|
|
480
|
+
pattern_type = pattern_type.lower().strip()
|
|
481
|
+
|
|
482
|
+
with self._lock:
|
|
483
|
+
if pattern_type not in self._patterns:
|
|
484
|
+
return None
|
|
485
|
+
|
|
486
|
+
pt = self._patterns[pattern_type]
|
|
487
|
+
total = pt["total_queries"]
|
|
488
|
+
|
|
489
|
+
return {
|
|
490
|
+
"pattern_type": pattern_type,
|
|
491
|
+
"total_queries": total,
|
|
492
|
+
"success_rate": pt["successful_queries"] / total if total > 0 else 0,
|
|
493
|
+
"avg_confidence": pt["total_confidence"] / total if total > 0 else 0,
|
|
494
|
+
"decomposition_rate": pt["decomposition_count"] / total if total > 0 else 0,
|
|
495
|
+
"avg_sub_questions": pt["total_sub_questions"] / pt["decomposition_count"] if pt["decomposition_count"] > 0 else 0,
|
|
496
|
+
"top_entity_types": sorted(
|
|
497
|
+
pt["entity_types"].items(),
|
|
498
|
+
key=lambda x: -x[1]
|
|
499
|
+
)[:5],
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
def should_decompose(self, pattern_type: str) -> bool:
|
|
503
|
+
"""
|
|
504
|
+
Determine if a pattern type typically needs decomposition.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
pattern_type: The pattern type.
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
True if decomposition is recommended.
|
|
511
|
+
"""
|
|
512
|
+
stats = self.get_pattern_stats(pattern_type)
|
|
513
|
+
|
|
514
|
+
if not stats:
|
|
515
|
+
# Default recommendations for unknown patterns
|
|
516
|
+
always_decompose = {"comparison", "enumeration", "temporal"}
|
|
517
|
+
return pattern_type.lower() in always_decompose
|
|
518
|
+
|
|
519
|
+
# Recommend decomposition if historically needed > 50% of the time
|
|
520
|
+
return stats["decomposition_rate"] > 0.5
|
|
521
|
+
|
|
522
|
+
def get_confidence_threshold(self, pattern_type: str) -> float:
|
|
523
|
+
"""
|
|
524
|
+
Get recommended confidence threshold for a pattern type.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
pattern_type: The pattern type.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
Recommended confidence threshold.
|
|
531
|
+
"""
|
|
532
|
+
stats = self.get_pattern_stats(pattern_type)
|
|
533
|
+
|
|
534
|
+
if not stats or stats["total_queries"] < 5:
|
|
535
|
+
return 0.7 # Default threshold
|
|
536
|
+
|
|
537
|
+
# Use average confidence minus one standard deviation as threshold
|
|
538
|
+
avg_conf = stats["avg_confidence"]
|
|
539
|
+
return max(0.5, min(0.9, avg_conf - 0.1))
|
|
540
|
+
|
|
541
|
+
def get_all_patterns(self) -> list[dict[str, Any]]:
|
|
542
|
+
"""Get statistics for all known patterns."""
|
|
543
|
+
results = []
|
|
544
|
+
|
|
545
|
+
with self._lock:
|
|
546
|
+
for pattern_type in self._patterns:
|
|
547
|
+
stats = self.get_pattern_stats(pattern_type)
|
|
548
|
+
if stats:
|
|
549
|
+
results.append(stats)
|
|
550
|
+
|
|
551
|
+
return sorted(results, key=lambda x: -x["total_queries"])
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
# Global query patterns registry
|
|
555
|
+
_global_query_patterns: LearnedQueryPatterns | None = None
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def get_learned_query_patterns() -> LearnedQueryPatterns:
|
|
559
|
+
"""Get the global learned query patterns registry."""
|
|
560
|
+
global _global_query_patterns
|
|
561
|
+
|
|
562
|
+
if _global_query_patterns is None:
|
|
563
|
+
custom_path = os.getenv("RNSR_QUERY_PATTERNS_PATH")
|
|
564
|
+
_global_query_patterns = LearnedQueryPatterns(
|
|
565
|
+
storage_path=custom_path if custom_path else None
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
return _global_query_patterns
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
# =============================================================================
|
|
572
|
+
# RLM Configuration
|
|
573
|
+
# =============================================================================
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
@dataclass
|
|
577
|
+
class RLMConfig:
|
|
578
|
+
"""Configuration for the RLM Navigator."""
|
|
579
|
+
|
|
580
|
+
# Recursion control
|
|
581
|
+
max_recursion_depth: int = 3 # Max depth for recursive sub-LLM calls
|
|
582
|
+
max_iterations: int = 30 # Max navigation iterations
|
|
583
|
+
|
|
584
|
+
# Tree of Thoughts parameters
|
|
585
|
+
top_k: int = 3 # Base children to explore
|
|
586
|
+
selection_threshold: float = 0.4 # Min probability for selection
|
|
587
|
+
dead_end_threshold: float = 0.1 # Threshold for dead end
|
|
588
|
+
|
|
589
|
+
# Pre-filtering
|
|
590
|
+
enable_pre_filtering: bool = True # Use regex/keyword filtering before ToT
|
|
591
|
+
pre_filter_min_matches: int = 1 # Min keyword matches to include node
|
|
592
|
+
|
|
593
|
+
# REPL execution
|
|
594
|
+
enable_code_execution: bool = True # Allow LLM to write/execute code
|
|
595
|
+
max_code_execution_time: int = 30 # Seconds
|
|
596
|
+
|
|
597
|
+
# Answer verification
|
|
598
|
+
enable_verification: bool = True # Verify answers with sub-LLM
|
|
599
|
+
verification_retries: int = 2 # Max verification attempts
|
|
600
|
+
|
|
601
|
+
# Async processing
|
|
602
|
+
enable_async: bool = True # Use async for parallel sub-LLM calls
|
|
603
|
+
max_concurrent_calls: int = 5 # Max parallel LLM calls
|
|
604
|
+
|
|
605
|
+
# Vision mode
|
|
606
|
+
enable_vision: bool = False # Use vision LLM for page images
|
|
607
|
+
vision_model: str = "gemini-2.5-flash" # Vision model to use
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
# =============================================================================
|
|
611
|
+
# Pre-Filtering Engine (Before ToT Evaluation)
|
|
612
|
+
# =============================================================================
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
class PreFilterEngine:
|
|
616
|
+
"""
|
|
617
|
+
Pre-filters nodes before expensive ToT LLM evaluation.
|
|
618
|
+
|
|
619
|
+
Implements the key RLM insight: use code (regex, keywords) to filter
|
|
620
|
+
before sending to LLM. This dramatically reduces LLM calls.
|
|
621
|
+
|
|
622
|
+
Uses adaptive stop words that learn from domain-specific usage.
|
|
623
|
+
|
|
624
|
+
Example:
|
|
625
|
+
# Query: "What is the liability clause?"
|
|
626
|
+
# Instead of evaluating all 50 children with LLM:
|
|
627
|
+
# 1. Extract keywords: ["liability", "clause", "indemnification"]
|
|
628
|
+
# 2. Regex search children summaries
|
|
629
|
+
# 3. Only send matching children to ToT evaluation
|
|
630
|
+
"""
|
|
631
|
+
|
|
632
|
+
def __init__(self, config: RLMConfig, enable_stop_word_learning: bool = True):
|
|
633
|
+
self.config = config
|
|
634
|
+
self._keyword_cache: dict[str, list[str]] = {}
|
|
635
|
+
self._stop_word_registry = get_learned_stop_words() if enable_stop_word_learning else None
|
|
636
|
+
|
|
637
|
+
def extract_keywords(self, query: str) -> list[str]:
|
|
638
|
+
"""Extract searchable keywords from a query."""
|
|
639
|
+
if query in self._keyword_cache:
|
|
640
|
+
return self._keyword_cache[query]
|
|
641
|
+
|
|
642
|
+
# Get stop words (base + learned)
|
|
643
|
+
if self._stop_word_registry:
|
|
644
|
+
stop_words = self._stop_word_registry.get_stop_words()
|
|
645
|
+
else:
|
|
646
|
+
stop_words = LearnedStopWords.BASE_STOP_WORDS
|
|
647
|
+
|
|
648
|
+
# Tokenize and filter
|
|
649
|
+
words = re.findall(r'\b[a-zA-Z]{3,}\b', query.lower())
|
|
650
|
+
keywords = [w for w in words if w not in stop_words]
|
|
651
|
+
|
|
652
|
+
# Add quoted phrases as single keywords
|
|
653
|
+
quoted = re.findall(r'"([^"]+)"', query)
|
|
654
|
+
keywords.extend(quoted)
|
|
655
|
+
|
|
656
|
+
# Add capitalized words (likely proper nouns)
|
|
657
|
+
proper_nouns = re.findall(r'\b[A-Z][a-z]+\b', query)
|
|
658
|
+
keywords.extend([pn.lower() for pn in proper_nouns])
|
|
659
|
+
|
|
660
|
+
# Deduplicate while preserving order
|
|
661
|
+
seen = set()
|
|
662
|
+
unique_keywords = []
|
|
663
|
+
for kw in keywords:
|
|
664
|
+
if kw not in seen:
|
|
665
|
+
seen.add(kw)
|
|
666
|
+
unique_keywords.append(kw)
|
|
667
|
+
|
|
668
|
+
self._keyword_cache[query] = unique_keywords
|
|
669
|
+
logger.debug("keywords_extracted", query=query[:50], keywords=unique_keywords)
|
|
670
|
+
return unique_keywords
|
|
671
|
+
|
|
672
|
+
def filter_nodes_by_keywords(
|
|
673
|
+
self,
|
|
674
|
+
nodes: list[SkeletonNode],
|
|
675
|
+
keywords: list[str],
|
|
676
|
+
min_matches: int | None = None,
|
|
677
|
+
) -> tuple[list[SkeletonNode], list[SkeletonNode]]:
|
|
678
|
+
"""
|
|
679
|
+
Filter nodes by keyword matching.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
Tuple of (matching_nodes, remaining_nodes)
|
|
683
|
+
"""
|
|
684
|
+
if not self.config.enable_pre_filtering:
|
|
685
|
+
return nodes, []
|
|
686
|
+
|
|
687
|
+
if not keywords:
|
|
688
|
+
return nodes, []
|
|
689
|
+
|
|
690
|
+
min_matches = min_matches or self.config.pre_filter_min_matches
|
|
691
|
+
|
|
692
|
+
matching = []
|
|
693
|
+
remaining = []
|
|
694
|
+
|
|
695
|
+
for node in nodes:
|
|
696
|
+
# Search in header and summary
|
|
697
|
+
search_text = f"{node.header} {node.summary}".lower()
|
|
698
|
+
|
|
699
|
+
matches = sum(1 for kw in keywords if kw in search_text)
|
|
700
|
+
|
|
701
|
+
if matches >= min_matches:
|
|
702
|
+
matching.append(node)
|
|
703
|
+
else:
|
|
704
|
+
remaining.append(node)
|
|
705
|
+
|
|
706
|
+
logger.debug(
|
|
707
|
+
"pre_filter_complete",
|
|
708
|
+
total=len(nodes),
|
|
709
|
+
matching=len(matching),
|
|
710
|
+
remaining=len(remaining),
|
|
711
|
+
keywords=keywords[:5],
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
return matching, remaining
|
|
715
|
+
|
|
716
|
+
def regex_search_nodes(
|
|
717
|
+
self,
|
|
718
|
+
nodes: list[SkeletonNode],
|
|
719
|
+
pattern: str,
|
|
720
|
+
) -> list[tuple[SkeletonNode, list[str]]]:
|
|
721
|
+
"""
|
|
722
|
+
Search nodes using regex pattern.
|
|
723
|
+
|
|
724
|
+
Returns:
|
|
725
|
+
List of (node, matches) tuples.
|
|
726
|
+
"""
|
|
727
|
+
results = []
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
regex = re.compile(pattern, re.IGNORECASE)
|
|
731
|
+
except re.error as e:
|
|
732
|
+
logger.warning("invalid_regex_pattern", pattern=pattern, error=str(e))
|
|
733
|
+
return results
|
|
734
|
+
|
|
735
|
+
for node in nodes:
|
|
736
|
+
search_text = f"{node.header}\n{node.summary}"
|
|
737
|
+
matches = regex.findall(search_text)
|
|
738
|
+
if matches:
|
|
739
|
+
results.append((node, matches))
|
|
740
|
+
|
|
741
|
+
return results
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
# =============================================================================
|
|
745
|
+
# Deep Recursive Sub-LLM Engine
|
|
746
|
+
# =============================================================================
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
class RecursiveSubLLMEngine:
|
|
750
|
+
"""
|
|
751
|
+
Enables true multi-level recursive sub-LLM calls.
|
|
752
|
+
|
|
753
|
+
Unlike single-level decomposition, this allows sub-LLMs to spawn
|
|
754
|
+
their own sub-LLMs up to a configurable depth.
|
|
755
|
+
|
|
756
|
+
Example:
|
|
757
|
+
Query: "Compare the liability clauses in 2023 vs 2024 contracts"
|
|
758
|
+
|
|
759
|
+
Depth 0 (Root): Decompose into sub-tasks
|
|
760
|
+
├── Depth 1: "Find 2023 liability clause"
|
|
761
|
+
│ └── Depth 2: "Extract specific terms"
|
|
762
|
+
└── Depth 1: "Find 2024 liability clause"
|
|
763
|
+
└── Depth 2: "Extract specific terms"
|
|
764
|
+
"""
|
|
765
|
+
|
|
766
|
+
def __init__(
|
|
767
|
+
self,
|
|
768
|
+
config: RLMConfig,
|
|
769
|
+
llm_fn: Callable[[str], str] | None = None,
|
|
770
|
+
):
|
|
771
|
+
self.config = config
|
|
772
|
+
self._llm_fn = llm_fn
|
|
773
|
+
self._call_count = 0
|
|
774
|
+
self._depth_stats: dict[int, int] = {}
|
|
775
|
+
|
|
776
|
+
def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
|
|
777
|
+
"""Set the LLM function for sub-calls."""
|
|
778
|
+
self._llm_fn = llm_fn
|
|
779
|
+
|
|
780
|
+
def recursive_call(
|
|
781
|
+
self,
|
|
782
|
+
prompt: str,
|
|
783
|
+
context: str,
|
|
784
|
+
depth: int = 0,
|
|
785
|
+
allow_sub_calls: bool = True,
|
|
786
|
+
) -> str:
|
|
787
|
+
"""
|
|
788
|
+
Execute a recursive LLM call.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
prompt: The task/question for the LLM.
|
|
792
|
+
context: Context to process.
|
|
793
|
+
depth: Current recursion depth.
|
|
794
|
+
allow_sub_calls: Whether this call can spawn sub-calls.
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
LLM response.
|
|
798
|
+
"""
|
|
799
|
+
if self._llm_fn is None:
|
|
800
|
+
return "[ERROR: LLM function not configured]"
|
|
801
|
+
|
|
802
|
+
if depth >= self.config.max_recursion_depth:
|
|
803
|
+
allow_sub_calls = False
|
|
804
|
+
logger.debug("max_recursion_depth_reached", depth=depth)
|
|
805
|
+
|
|
806
|
+
# Track stats
|
|
807
|
+
self._call_count += 1
|
|
808
|
+
self._depth_stats[depth] = self._depth_stats.get(depth, 0) + 1
|
|
809
|
+
|
|
810
|
+
# Build the prompt with recursion capability
|
|
811
|
+
if allow_sub_calls:
|
|
812
|
+
system_instruction = f"""You are a sub-LLM at recursion depth {depth}.
|
|
813
|
+
You can decompose complex tasks into sub-tasks.
|
|
814
|
+
If you need to process multiple items independently, list them as:
|
|
815
|
+
SUB_TASK[1]: <task description>
|
|
816
|
+
SUB_TASK[2]: <task description>
|
|
817
|
+
...
|
|
818
|
+
These will be processed by sub-LLMs and results aggregated.
|
|
819
|
+
"""
|
|
820
|
+
else:
|
|
821
|
+
system_instruction = f"""You are a sub-LLM at max recursion depth {depth}.
|
|
822
|
+
Provide a direct answer without further decomposition."""
|
|
823
|
+
|
|
824
|
+
full_prompt = f"""{system_instruction}
|
|
825
|
+
|
|
826
|
+
Task: {prompt}
|
|
827
|
+
|
|
828
|
+
Context:
|
|
829
|
+
{context}
|
|
830
|
+
|
|
831
|
+
Response:"""
|
|
832
|
+
|
|
833
|
+
try:
|
|
834
|
+
response = self._llm_fn(full_prompt)
|
|
835
|
+
|
|
836
|
+
# Check for sub-task declarations and process them
|
|
837
|
+
if allow_sub_calls and "SUB_TASK[" in response:
|
|
838
|
+
response = self._process_sub_tasks(response, depth + 1)
|
|
839
|
+
|
|
840
|
+
return response
|
|
841
|
+
|
|
842
|
+
except Exception as e:
|
|
843
|
+
logger.error("recursive_call_failed", depth=depth, error=str(e))
|
|
844
|
+
return f"[ERROR: {str(e)}]"
|
|
845
|
+
|
|
846
|
+
def _process_sub_tasks(self, response: str, depth: int) -> str:
|
|
847
|
+
"""Process SUB_TASK declarations in the response."""
|
|
848
|
+
# Extract sub-tasks
|
|
849
|
+
sub_tasks = re.findall(r'SUB_TASK\[(\d+)\]:\s*(.+?)(?=SUB_TASK\[|$)', response, re.DOTALL)
|
|
850
|
+
|
|
851
|
+
if not sub_tasks:
|
|
852
|
+
return response
|
|
853
|
+
|
|
854
|
+
logger.debug("processing_sub_tasks", count=len(sub_tasks), depth=depth)
|
|
855
|
+
|
|
856
|
+
# Process each sub-task recursively
|
|
857
|
+
results = []
|
|
858
|
+
for idx, (task_num, task_desc) in enumerate(sub_tasks):
|
|
859
|
+
result = self.recursive_call(
|
|
860
|
+
prompt=task_desc.strip(),
|
|
861
|
+
context="(inherited from parent)",
|
|
862
|
+
depth=depth,
|
|
863
|
+
allow_sub_calls=(depth < self.config.max_recursion_depth),
|
|
864
|
+
)
|
|
865
|
+
results.append(f"Result[{task_num}]: {result}")
|
|
866
|
+
|
|
867
|
+
# Synthesize results
|
|
868
|
+
synthesis_prompt = f"""Synthesize the following sub-task results into a coherent answer:
|
|
869
|
+
|
|
870
|
+
{chr(10).join(results)}
|
|
871
|
+
|
|
872
|
+
Original task: {response.split('SUB_TASK[')[0].strip()}
|
|
873
|
+
|
|
874
|
+
Synthesized answer:"""
|
|
875
|
+
|
|
876
|
+
return self._llm_fn(synthesis_prompt) if self._llm_fn else "\n".join(results)
|
|
877
|
+
|
|
878
|
+
async def async_recursive_call(
|
|
879
|
+
self,
|
|
880
|
+
prompt: str,
|
|
881
|
+
context: str,
|
|
882
|
+
depth: int = 0,
|
|
883
|
+
) -> str:
|
|
884
|
+
"""Async version of recursive_call for parallel processing."""
|
|
885
|
+
# Run in thread pool to not block
|
|
886
|
+
loop = asyncio.get_event_loop()
|
|
887
|
+
return await loop.run_in_executor(
|
|
888
|
+
None,
|
|
889
|
+
lambda: self.recursive_call(prompt, context, depth),
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
def batch_recursive_calls(
|
|
893
|
+
self,
|
|
894
|
+
prompts: list[str],
|
|
895
|
+
contexts: list[str],
|
|
896
|
+
depth: int = 0,
|
|
897
|
+
) -> list[str]:
|
|
898
|
+
"""
|
|
899
|
+
Execute multiple recursive calls in parallel.
|
|
900
|
+
|
|
901
|
+
Uses ThreadPoolExecutor for parallel processing.
|
|
902
|
+
"""
|
|
903
|
+
if len(prompts) != len(contexts):
|
|
904
|
+
raise ValueError("prompts and contexts must have same length")
|
|
905
|
+
|
|
906
|
+
if not prompts:
|
|
907
|
+
return []
|
|
908
|
+
|
|
909
|
+
results: list[str] = [""] * len(prompts)
|
|
910
|
+
|
|
911
|
+
with ThreadPoolExecutor(max_workers=self.config.max_concurrent_calls) as executor:
|
|
912
|
+
futures = {}
|
|
913
|
+
for idx, (prompt, context) in enumerate(zip(prompts, contexts)):
|
|
914
|
+
future = executor.submit(
|
|
915
|
+
self.recursive_call,
|
|
916
|
+
prompt,
|
|
917
|
+
context,
|
|
918
|
+
depth,
|
|
919
|
+
)
|
|
920
|
+
futures[future] = idx
|
|
921
|
+
|
|
922
|
+
for future in futures:
|
|
923
|
+
idx = futures[future]
|
|
924
|
+
try:
|
|
925
|
+
results[idx] = future.result(timeout=60)
|
|
926
|
+
except Exception as e:
|
|
927
|
+
results[idx] = f"[ERROR: {str(e)}]"
|
|
928
|
+
|
|
929
|
+
return results
|
|
930
|
+
|
|
931
|
+
def get_stats(self) -> dict[str, Any]:
|
|
932
|
+
"""Get call statistics."""
|
|
933
|
+
return {
|
|
934
|
+
"total_calls": self._call_count,
|
|
935
|
+
"calls_by_depth": dict(self._depth_stats),
|
|
936
|
+
}
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
# =============================================================================
|
|
940
|
+
# Answer Verification Engine
|
|
941
|
+
# =============================================================================
|
|
942
|
+
|
|
943
|
+
|
|
944
|
+
class AnswerVerificationEngine:
|
|
945
|
+
"""
|
|
946
|
+
Verifies answers using sub-LLM calls.
|
|
947
|
+
|
|
948
|
+
Implements the RLM pattern of using sub-LLMs to verify answers
|
|
949
|
+
before returning, ensuring higher accuracy.
|
|
950
|
+
"""
|
|
951
|
+
|
|
952
|
+
def __init__(
|
|
953
|
+
self,
|
|
954
|
+
config: RLMConfig,
|
|
955
|
+
llm_fn: Callable[[str], str] | None = None,
|
|
956
|
+
):
|
|
957
|
+
self.config = config
|
|
958
|
+
self._llm_fn = llm_fn
|
|
959
|
+
|
|
960
|
+
def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
|
|
961
|
+
"""Set the LLM function."""
|
|
962
|
+
self._llm_fn = llm_fn
|
|
963
|
+
|
|
964
|
+
def verify_answer(
|
|
965
|
+
self,
|
|
966
|
+
question: str,
|
|
967
|
+
proposed_answer: str,
|
|
968
|
+
evidence: list[str],
|
|
969
|
+
attempt: int = 0,
|
|
970
|
+
) -> dict[str, Any]:
|
|
971
|
+
"""
|
|
972
|
+
Verify an answer using sub-LLM evaluation.
|
|
973
|
+
|
|
974
|
+
Returns:
|
|
975
|
+
Dict with 'is_valid', 'confidence', 'issues', 'improved_answer'.
|
|
976
|
+
"""
|
|
977
|
+
if not self.config.enable_verification:
|
|
978
|
+
return {
|
|
979
|
+
"is_valid": True,
|
|
980
|
+
"confidence": 0.7,
|
|
981
|
+
"issues": [],
|
|
982
|
+
"improved_answer": proposed_answer,
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
if self._llm_fn is None:
|
|
986
|
+
return {
|
|
987
|
+
"is_valid": True,
|
|
988
|
+
"confidence": 0.5,
|
|
989
|
+
"issues": ["LLM not configured for verification"],
|
|
990
|
+
"improved_answer": proposed_answer,
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
evidence_text = "\n---\n".join(evidence) if evidence else "(no evidence provided)"
|
|
994
|
+
|
|
995
|
+
verification_prompt = f"""You are a fact-checker verifying an answer.
|
|
996
|
+
|
|
997
|
+
Question: {question}
|
|
998
|
+
|
|
999
|
+
Proposed Answer: {proposed_answer}
|
|
1000
|
+
|
|
1001
|
+
Evidence:
|
|
1002
|
+
{evidence_text}
|
|
1003
|
+
|
|
1004
|
+
VERIFICATION TASK:
|
|
1005
|
+
1. Check if the answer is supported by the evidence
|
|
1006
|
+
2. Check if the answer directly addresses the question
|
|
1007
|
+
3. Check for any factual errors or unsupported claims
|
|
1008
|
+
|
|
1009
|
+
OUTPUT FORMAT (JSON):
|
|
1010
|
+
{{
|
|
1011
|
+
"is_valid": true/false,
|
|
1012
|
+
"confidence": 0.0-1.0,
|
|
1013
|
+
"issues": ["list of issues if any"],
|
|
1014
|
+
"improved_answer": "corrected answer if needed, or null if valid"
|
|
1015
|
+
}}
|
|
1016
|
+
|
|
1017
|
+
Respond ONLY with JSON:"""
|
|
1018
|
+
|
|
1019
|
+
try:
|
|
1020
|
+
import json
|
|
1021
|
+
|
|
1022
|
+
response = self._llm_fn(verification_prompt)
|
|
1023
|
+
|
|
1024
|
+
# Parse JSON response
|
|
1025
|
+
json_match = re.search(r'\{[\s\S]*\}', response)
|
|
1026
|
+
if json_match:
|
|
1027
|
+
result = json.loads(json_match.group())
|
|
1028
|
+
|
|
1029
|
+
# If not valid and we have retries left, try to improve
|
|
1030
|
+
if not result.get("is_valid", True) and attempt < self.config.verification_retries:
|
|
1031
|
+
logger.debug(
|
|
1032
|
+
"answer_verification_failed",
|
|
1033
|
+
attempt=attempt,
|
|
1034
|
+
issues=result.get("issues", []),
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
# Try with improved answer
|
|
1038
|
+
if result.get("improved_answer"):
|
|
1039
|
+
return self.verify_answer(
|
|
1040
|
+
question,
|
|
1041
|
+
result["improved_answer"],
|
|
1042
|
+
evidence,
|
|
1043
|
+
attempt + 1,
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
return result
|
|
1047
|
+
else:
|
|
1048
|
+
logger.warning("verification_json_parse_failed", response=response[:200])
|
|
1049
|
+
return {
|
|
1050
|
+
"is_valid": True,
|
|
1051
|
+
"confidence": 0.5,
|
|
1052
|
+
"issues": ["Could not parse verification response"],
|
|
1053
|
+
"improved_answer": proposed_answer,
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
except Exception as e:
|
|
1057
|
+
logger.error("verification_failed", error=str(e))
|
|
1058
|
+
return {
|
|
1059
|
+
"is_valid": True,
|
|
1060
|
+
"confidence": 0.3,
|
|
1061
|
+
"issues": [str(e)],
|
|
1062
|
+
"improved_answer": proposed_answer,
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
# =============================================================================
|
|
1067
|
+
# Enhanced RLM Navigator Agent State
|
|
1068
|
+
# =============================================================================
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
class RLMAgentState:
|
|
1072
|
+
"""
|
|
1073
|
+
State for the RLM Navigator Agent.
|
|
1074
|
+
|
|
1075
|
+
Extends the base AgentState with RLM-specific fields.
|
|
1076
|
+
"""
|
|
1077
|
+
|
|
1078
|
+
def __init__(
|
|
1079
|
+
self,
|
|
1080
|
+
question: str,
|
|
1081
|
+
root_node_id: str,
|
|
1082
|
+
config: RLMConfig | None = None,
|
|
1083
|
+
metadata: dict[str, Any] | None = None,
|
|
1084
|
+
):
|
|
1085
|
+
self.question = question
|
|
1086
|
+
self.config = config or RLMConfig()
|
|
1087
|
+
self.metadata = metadata or {}
|
|
1088
|
+
|
|
1089
|
+
# Navigation state
|
|
1090
|
+
self.current_node_id: str | None = root_node_id
|
|
1091
|
+
self.visited_nodes: list[str] = []
|
|
1092
|
+
self.navigation_path: list[str] = [root_node_id]
|
|
1093
|
+
self.nodes_to_visit: list[str] = []
|
|
1094
|
+
self.dead_ends: list[str] = []
|
|
1095
|
+
self.backtrack_stack: list[str] = []
|
|
1096
|
+
|
|
1097
|
+
# Variable stitching
|
|
1098
|
+
self.variables: list[str] = []
|
|
1099
|
+
self.context: str = ""
|
|
1100
|
+
|
|
1101
|
+
# Sub-questions (RLM decomposition)
|
|
1102
|
+
self.sub_questions: list[str] = []
|
|
1103
|
+
self.pending_questions: list[str] = []
|
|
1104
|
+
self.current_sub_question: str | None = None
|
|
1105
|
+
|
|
1106
|
+
# Pre-filtering state
|
|
1107
|
+
self.extracted_keywords: list[str] = []
|
|
1108
|
+
self.pre_filtered_nodes: dict[str, list[str]] = {} # node_id -> matched keywords
|
|
1109
|
+
|
|
1110
|
+
# Recursion tracking
|
|
1111
|
+
self.current_recursion_depth: int = 0
|
|
1112
|
+
self.recursion_call_count: int = 0
|
|
1113
|
+
|
|
1114
|
+
# Output
|
|
1115
|
+
self.answer: str | None = None
|
|
1116
|
+
self.confidence: float = 0.0
|
|
1117
|
+
self.verification_result: dict[str, Any] | None = None
|
|
1118
|
+
|
|
1119
|
+
# Traceability
|
|
1120
|
+
self.trace: list[dict[str, Any]] = []
|
|
1121
|
+
self.iteration: int = 0
|
|
1122
|
+
|
|
1123
|
+
def add_trace(
|
|
1124
|
+
self,
|
|
1125
|
+
node_type: str,
|
|
1126
|
+
action: str,
|
|
1127
|
+
details: dict | None = None,
|
|
1128
|
+
) -> None:
|
|
1129
|
+
"""Add a trace entry."""
|
|
1130
|
+
entry = {
|
|
1131
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
1132
|
+
"node_type": node_type,
|
|
1133
|
+
"action": action,
|
|
1134
|
+
"details": details or {},
|
|
1135
|
+
"iteration": self.iteration,
|
|
1136
|
+
}
|
|
1137
|
+
self.trace.append(entry)
|
|
1138
|
+
|
|
1139
|
+
def to_dict(self) -> dict[str, Any]:
|
|
1140
|
+
"""Convert state to dictionary."""
|
|
1141
|
+
return {
|
|
1142
|
+
"question": self.question,
|
|
1143
|
+
"answer": self.answer,
|
|
1144
|
+
"confidence": self.confidence,
|
|
1145
|
+
"variables": self.variables,
|
|
1146
|
+
"visited_nodes": self.visited_nodes,
|
|
1147
|
+
"iteration": self.iteration,
|
|
1148
|
+
"recursion_call_count": self.recursion_call_count,
|
|
1149
|
+
"verification_result": self.verification_result,
|
|
1150
|
+
"trace": self.trace,
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
# =============================================================================
|
|
1155
|
+
# RLM Navigator - Main Class
|
|
1156
|
+
# =============================================================================
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
class RLMNavigator:
|
|
1160
|
+
"""
|
|
1161
|
+
The RLM Navigator combines:
|
|
1162
|
+
1. PageIndex-style tree search with reasoning
|
|
1163
|
+
2. RLM-style REPL environment with code execution
|
|
1164
|
+
3. RNSR-style variable stitching and skeleton indexing
|
|
1165
|
+
4. Entity-aware query decomposition (when knowledge graph available)
|
|
1166
|
+
|
|
1167
|
+
This is the unified, state-of-the-art document retrieval agent.
|
|
1168
|
+
"""
|
|
1169
|
+
|
|
1170
|
+
def __init__(
|
|
1171
|
+
self,
|
|
1172
|
+
skeleton: dict[str, SkeletonNode],
|
|
1173
|
+
kv_store: KVStore,
|
|
1174
|
+
config: RLMConfig | None = None,
|
|
1175
|
+
knowledge_graph=None,
|
|
1176
|
+
):
|
|
1177
|
+
self.skeleton = skeleton
|
|
1178
|
+
self.kv_store = kv_store
|
|
1179
|
+
self.config = config or RLMConfig()
|
|
1180
|
+
self.knowledge_graph = knowledge_graph
|
|
1181
|
+
|
|
1182
|
+
# Initialize components
|
|
1183
|
+
self.variable_store = VariableStore()
|
|
1184
|
+
self.pre_filter = PreFilterEngine(self.config)
|
|
1185
|
+
self.recursive_engine = RecursiveSubLLMEngine(self.config)
|
|
1186
|
+
self.verification_engine = AnswerVerificationEngine(self.config)
|
|
1187
|
+
self.entity_decomposer = EntityAwareDecomposer(knowledge_graph)
|
|
1188
|
+
|
|
1189
|
+
# LLM function
|
|
1190
|
+
self._llm_fn: Callable[[str], str] | None = None
|
|
1191
|
+
|
|
1192
|
+
# Find root node
|
|
1193
|
+
self.root_id = self._find_root_id()
|
|
1194
|
+
|
|
1195
|
+
def _find_root_id(self) -> str:
|
|
1196
|
+
"""Find the root node ID."""
|
|
1197
|
+
for node in self.skeleton.values():
|
|
1198
|
+
if node.level == 0:
|
|
1199
|
+
return node.node_id
|
|
1200
|
+
raise ValueError("No root node found in skeleton")
|
|
1201
|
+
|
|
1202
|
+
def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
|
|
1203
|
+
"""Configure the LLM function for all components."""
|
|
1204
|
+
self._llm_fn = llm_fn
|
|
1205
|
+
self.recursive_engine.set_llm_function(llm_fn)
|
|
1206
|
+
self.verification_engine.set_llm_function(llm_fn)
|
|
1207
|
+
self.entity_decomposer.set_llm_function(llm_fn)
|
|
1208
|
+
|
|
1209
|
+
def set_knowledge_graph(self, kg) -> None:
|
|
1210
|
+
"""Set the knowledge graph for entity-aware decomposition."""
|
|
1211
|
+
self.knowledge_graph = kg
|
|
1212
|
+
self.entity_decomposer.set_knowledge_graph(kg)
|
|
1213
|
+
|
|
1214
|
+
def navigate(
|
|
1215
|
+
self,
|
|
1216
|
+
question: str,
|
|
1217
|
+
metadata: dict[str, Any] | None = None,
|
|
1218
|
+
) -> dict[str, Any]:
|
|
1219
|
+
"""
|
|
1220
|
+
Navigate the document tree to answer a question.
|
|
1221
|
+
|
|
1222
|
+
This is the main entry point for the RLM Navigator.
|
|
1223
|
+
|
|
1224
|
+
Args:
|
|
1225
|
+
question: The user's question.
|
|
1226
|
+
metadata: Optional metadata (e.g., multiple choice options).
|
|
1227
|
+
|
|
1228
|
+
Returns:
|
|
1229
|
+
Dict with answer, confidence, trace, etc.
|
|
1230
|
+
"""
|
|
1231
|
+
# Initialize state
|
|
1232
|
+
state = RLMAgentState(
|
|
1233
|
+
question=question,
|
|
1234
|
+
root_node_id=self.root_id,
|
|
1235
|
+
config=self.config,
|
|
1236
|
+
metadata=metadata,
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
# Ensure LLM is configured
|
|
1240
|
+
if self._llm_fn is None:
|
|
1241
|
+
self._configure_default_llm()
|
|
1242
|
+
|
|
1243
|
+
logger.info("rlm_navigation_started", question=question[:100])
|
|
1244
|
+
|
|
1245
|
+
try:
|
|
1246
|
+
# Phase 1: Pre-filtering with keyword extraction
|
|
1247
|
+
state = self._phase_pre_filter(state)
|
|
1248
|
+
|
|
1249
|
+
# Phase 2: Query decomposition
|
|
1250
|
+
state = self._phase_decompose(state)
|
|
1251
|
+
|
|
1252
|
+
# Phase 3: Tree navigation with ToT
|
|
1253
|
+
state = self._phase_navigate(state)
|
|
1254
|
+
|
|
1255
|
+
# Phase 4: Synthesis
|
|
1256
|
+
state = self._phase_synthesize(state)
|
|
1257
|
+
|
|
1258
|
+
# Phase 5: Verification (if enabled)
|
|
1259
|
+
if self.config.enable_verification:
|
|
1260
|
+
state = self._phase_verify(state)
|
|
1261
|
+
|
|
1262
|
+
logger.info(
|
|
1263
|
+
"rlm_navigation_complete",
|
|
1264
|
+
confidence=state.confidence,
|
|
1265
|
+
variables=len(state.variables),
|
|
1266
|
+
iterations=state.iteration,
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
return state.to_dict()
|
|
1270
|
+
|
|
1271
|
+
except Exception as e:
|
|
1272
|
+
logger.error("rlm_navigation_failed", error=str(e))
|
|
1273
|
+
state.answer = f"Error during navigation: {str(e)}"
|
|
1274
|
+
state.confidence = 0.0
|
|
1275
|
+
return state.to_dict()
|
|
1276
|
+
|
|
1277
|
+
def _configure_default_llm(self) -> None:
|
|
1278
|
+
"""Configure the default LLM if none set."""
|
|
1279
|
+
try:
|
|
1280
|
+
from rnsr.llm import get_llm
|
|
1281
|
+
llm = get_llm()
|
|
1282
|
+
self.set_llm_function(lambda p: str(llm.complete(p)))
|
|
1283
|
+
except Exception as e:
|
|
1284
|
+
logger.warning("default_llm_config_failed", error=str(e))
|
|
1285
|
+
|
|
1286
|
+
def _phase_pre_filter(self, state: RLMAgentState) -> RLMAgentState:
|
|
1287
|
+
"""Phase 1: Extract keywords and pre-filter nodes."""
|
|
1288
|
+
state.add_trace("pre_filter", "Extracting keywords from query")
|
|
1289
|
+
|
|
1290
|
+
# Extract keywords
|
|
1291
|
+
keywords = self.pre_filter.extract_keywords(state.question)
|
|
1292
|
+
state.extracted_keywords = keywords
|
|
1293
|
+
|
|
1294
|
+
if not keywords:
|
|
1295
|
+
state.add_trace("pre_filter", "No keywords extracted, skipping pre-filter")
|
|
1296
|
+
return state
|
|
1297
|
+
|
|
1298
|
+
# Pre-filter all leaf nodes
|
|
1299
|
+
all_nodes = list(self.skeleton.values())
|
|
1300
|
+
matching, remaining = self.pre_filter.filter_nodes_by_keywords(all_nodes, keywords)
|
|
1301
|
+
|
|
1302
|
+
# Store which nodes matched which keywords
|
|
1303
|
+
for node in matching:
|
|
1304
|
+
search_text = f"{node.header} {node.summary}".lower()
|
|
1305
|
+
matched_keywords = [kw for kw in keywords if kw in search_text]
|
|
1306
|
+
state.pre_filtered_nodes[node.node_id] = matched_keywords
|
|
1307
|
+
|
|
1308
|
+
state.add_trace(
|
|
1309
|
+
"pre_filter",
|
|
1310
|
+
f"Pre-filtered {len(matching)}/{len(all_nodes)} nodes",
|
|
1311
|
+
{"keywords": keywords, "matching_nodes": len(matching)},
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
return state
|
|
1315
|
+
|
|
1316
|
+
def _phase_decompose(self, state: RLMAgentState) -> RLMAgentState:
|
|
1317
|
+
"""Phase 2: Decompose query into sub-questions with entity awareness."""
|
|
1318
|
+
state.add_trace("decomposition", "Analyzing query for decomposition")
|
|
1319
|
+
|
|
1320
|
+
if self._llm_fn is None:
|
|
1321
|
+
state.sub_questions = [state.question]
|
|
1322
|
+
state.pending_questions = [state.question]
|
|
1323
|
+
return state
|
|
1324
|
+
|
|
1325
|
+
# Try entity-aware decomposition first if knowledge graph is available
|
|
1326
|
+
if self.knowledge_graph:
|
|
1327
|
+
try:
|
|
1328
|
+
entity_result = self.entity_decomposer.decompose_with_entities(
|
|
1329
|
+
state.question
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
if entity_result.get("entities_found"):
|
|
1333
|
+
# Store entity information in state
|
|
1334
|
+
state.metadata["entities_found"] = entity_result.get("entities_found", [])
|
|
1335
|
+
state.metadata["entity_nodes"] = entity_result.get("entity_nodes", {})
|
|
1336
|
+
state.metadata["retrieval_plan"] = entity_result.get("retrieval_plan", [])
|
|
1337
|
+
state.metadata["relationships"] = entity_result.get("relationships", [])
|
|
1338
|
+
|
|
1339
|
+
sub_tasks = entity_result.get("sub_queries", [state.question])
|
|
1340
|
+
state.sub_questions = sub_tasks
|
|
1341
|
+
state.pending_questions = sub_tasks.copy()
|
|
1342
|
+
state.current_sub_question = sub_tasks[0] if sub_tasks else state.question
|
|
1343
|
+
|
|
1344
|
+
# Prioritize nodes from retrieval plan in pre-filtering
|
|
1345
|
+
for item in entity_result.get("retrieval_plan", []):
|
|
1346
|
+
node_id = item.get("node_id")
|
|
1347
|
+
if node_id and node_id not in state.pre_filtered_nodes:
|
|
1348
|
+
state.pre_filtered_nodes[node_id] = ["entity_match"]
|
|
1349
|
+
|
|
1350
|
+
state.add_trace(
|
|
1351
|
+
"decomposition",
|
|
1352
|
+
f"Entity-aware decomposition: {len(sub_tasks)} sub-tasks, {len(entity_result.get('entities_found', []))} entities",
|
|
1353
|
+
{
|
|
1354
|
+
"sub_tasks": sub_tasks,
|
|
1355
|
+
"entities": [e.canonical_name for e in entity_result.get("entities_found", [])],
|
|
1356
|
+
},
|
|
1357
|
+
)
|
|
1358
|
+
|
|
1359
|
+
return state
|
|
1360
|
+
|
|
1361
|
+
except Exception as e:
|
|
1362
|
+
logger.debug("entity_aware_decomposition_failed", error=str(e))
|
|
1363
|
+
# Fall through to standard decomposition
|
|
1364
|
+
|
|
1365
|
+
# Standard LLM decomposition
|
|
1366
|
+
decomposition_prompt = f"""Analyze this query and decompose it into specific sub-tasks.
|
|
1367
|
+
|
|
1368
|
+
Query: {state.question}
|
|
1369
|
+
|
|
1370
|
+
Available document sections (pre-filtered matches):
|
|
1371
|
+
{chr(10).join(f"- {self.skeleton[nid].header}" for nid in list(state.pre_filtered_nodes.keys())[:10])}
|
|
1372
|
+
|
|
1373
|
+
RULES:
|
|
1374
|
+
1. Each sub-task should target a specific piece of information
|
|
1375
|
+
2. For comparison queries, create one sub-task per item
|
|
1376
|
+
3. Maximum 5 sub-tasks
|
|
1377
|
+
4. If the query is simple, return just one sub-task
|
|
1378
|
+
|
|
1379
|
+
OUTPUT FORMAT (JSON):
|
|
1380
|
+
{{
|
|
1381
|
+
"sub_tasks": ["task1", "task2", ...],
|
|
1382
|
+
"synthesis_plan": "how to combine results"
|
|
1383
|
+
}}
|
|
1384
|
+
|
|
1385
|
+
Respond with JSON only:"""
|
|
1386
|
+
|
|
1387
|
+
try:
|
|
1388
|
+
import json
|
|
1389
|
+
|
|
1390
|
+
response = self._llm_fn(decomposition_prompt)
|
|
1391
|
+
json_match = re.search(r'\{[\s\S]*\}', response)
|
|
1392
|
+
|
|
1393
|
+
if json_match:
|
|
1394
|
+
result = json.loads(json_match.group())
|
|
1395
|
+
sub_tasks = result.get("sub_tasks", [state.question])
|
|
1396
|
+
state.sub_questions = sub_tasks
|
|
1397
|
+
state.pending_questions = sub_tasks.copy()
|
|
1398
|
+
state.current_sub_question = sub_tasks[0] if sub_tasks else state.question
|
|
1399
|
+
|
|
1400
|
+
state.add_trace(
|
|
1401
|
+
"decomposition",
|
|
1402
|
+
f"Decomposed into {len(sub_tasks)} sub-tasks",
|
|
1403
|
+
{"sub_tasks": sub_tasks},
|
|
1404
|
+
)
|
|
1405
|
+
else:
|
|
1406
|
+
state.sub_questions = [state.question]
|
|
1407
|
+
state.pending_questions = [state.question]
|
|
1408
|
+
|
|
1409
|
+
except Exception as e:
|
|
1410
|
+
logger.warning("decomposition_failed", error=str(e))
|
|
1411
|
+
state.sub_questions = [state.question]
|
|
1412
|
+
state.pending_questions = [state.question]
|
|
1413
|
+
|
|
1414
|
+
return state
|
|
1415
|
+
|
|
1416
|
+
def _phase_navigate(self, state: RLMAgentState) -> RLMAgentState:
|
|
1417
|
+
"""Phase 3: Navigate the tree using ToT with pre-filtering."""
|
|
1418
|
+
state.add_trace("navigation", "Starting tree navigation")
|
|
1419
|
+
|
|
1420
|
+
# Initialize navigation at root
|
|
1421
|
+
state.current_node_id = self.root_id
|
|
1422
|
+
|
|
1423
|
+
while state.iteration < self.config.max_iterations:
|
|
1424
|
+
state.iteration += 1
|
|
1425
|
+
|
|
1426
|
+
# Check termination conditions
|
|
1427
|
+
if state.current_node_id is None and not state.nodes_to_visit:
|
|
1428
|
+
break
|
|
1429
|
+
|
|
1430
|
+
# Pop from queue if needed
|
|
1431
|
+
if state.current_node_id is None and state.nodes_to_visit:
|
|
1432
|
+
state.current_node_id = state.nodes_to_visit.pop(0)
|
|
1433
|
+
|
|
1434
|
+
if state.current_node_id is None:
|
|
1435
|
+
break
|
|
1436
|
+
|
|
1437
|
+
node = self.skeleton.get(state.current_node_id)
|
|
1438
|
+
if node is None:
|
|
1439
|
+
state.current_node_id = None
|
|
1440
|
+
continue
|
|
1441
|
+
|
|
1442
|
+
# Already visited?
|
|
1443
|
+
if state.current_node_id in state.visited_nodes:
|
|
1444
|
+
state.current_node_id = None
|
|
1445
|
+
continue
|
|
1446
|
+
|
|
1447
|
+
# Decide: expand or traverse
|
|
1448
|
+
action = self._decide_action(state, node)
|
|
1449
|
+
|
|
1450
|
+
if action == "expand":
|
|
1451
|
+
state = self._do_expand(state, node)
|
|
1452
|
+
elif action == "traverse":
|
|
1453
|
+
state = self._do_traverse(state, node)
|
|
1454
|
+
elif action == "backtrack":
|
|
1455
|
+
state = self._do_backtrack(state)
|
|
1456
|
+
else:
|
|
1457
|
+
break
|
|
1458
|
+
|
|
1459
|
+
state.add_trace(
|
|
1460
|
+
"navigation",
|
|
1461
|
+
f"Navigation complete after {state.iteration} iterations",
|
|
1462
|
+
{"variables_found": len(state.variables)},
|
|
1463
|
+
)
|
|
1464
|
+
|
|
1465
|
+
return state
|
|
1466
|
+
|
|
1467
|
+
def _decide_action(
|
|
1468
|
+
self,
|
|
1469
|
+
state: RLMAgentState,
|
|
1470
|
+
node: SkeletonNode,
|
|
1471
|
+
) -> Literal["expand", "traverse", "backtrack", "done"]:
|
|
1472
|
+
"""Decide what action to take at current node."""
|
|
1473
|
+
# Leaf node -> expand
|
|
1474
|
+
if not node.child_ids:
|
|
1475
|
+
if node.node_id in state.visited_nodes:
|
|
1476
|
+
return "done"
|
|
1477
|
+
return "expand"
|
|
1478
|
+
|
|
1479
|
+
# Check unvisited children
|
|
1480
|
+
unvisited = [
|
|
1481
|
+
cid for cid in node.child_ids
|
|
1482
|
+
if cid not in state.visited_nodes and cid not in state.dead_ends
|
|
1483
|
+
]
|
|
1484
|
+
|
|
1485
|
+
if not unvisited:
|
|
1486
|
+
if state.backtrack_stack:
|
|
1487
|
+
return "backtrack"
|
|
1488
|
+
return "done"
|
|
1489
|
+
|
|
1490
|
+
# Has unvisited children -> traverse
|
|
1491
|
+
return "traverse"
|
|
1492
|
+
|
|
1493
|
+
def _do_expand(self, state: RLMAgentState, node: SkeletonNode) -> RLMAgentState:
|
|
1494
|
+
"""Expand current node: fetch content and store as variable."""
|
|
1495
|
+
content = self.kv_store.get(node.node_id)
|
|
1496
|
+
|
|
1497
|
+
if content:
|
|
1498
|
+
pointer = generate_pointer_name(node.header)
|
|
1499
|
+
self.variable_store.assign(pointer, content, node.node_id)
|
|
1500
|
+
state.variables.append(pointer)
|
|
1501
|
+
state.context += f"\nFound: {pointer} (from {node.header})"
|
|
1502
|
+
|
|
1503
|
+
state.add_trace(
|
|
1504
|
+
"variable_stitching",
|
|
1505
|
+
f"Stored {pointer}",
|
|
1506
|
+
{"node": node.node_id, "chars": len(content)},
|
|
1507
|
+
)
|
|
1508
|
+
|
|
1509
|
+
state.visited_nodes.append(node.node_id)
|
|
1510
|
+
state.current_node_id = None
|
|
1511
|
+
return state
|
|
1512
|
+
|
|
1513
|
+
def _do_traverse(self, state: RLMAgentState, node: SkeletonNode) -> RLMAgentState:
|
|
1514
|
+
"""Traverse to children using ToT with pre-filtering."""
|
|
1515
|
+
# Get children
|
|
1516
|
+
children = [self.skeleton.get(cid) for cid in node.child_ids]
|
|
1517
|
+
children = [c for c in children if c is not None]
|
|
1518
|
+
|
|
1519
|
+
# Apply pre-filtering
|
|
1520
|
+
if state.extracted_keywords and self.config.enable_pre_filtering:
|
|
1521
|
+
matching, remaining = self.pre_filter.filter_nodes_by_keywords(
|
|
1522
|
+
children,
|
|
1523
|
+
state.extracted_keywords,
|
|
1524
|
+
)
|
|
1525
|
+
|
|
1526
|
+
# If we have matching nodes, prioritize them
|
|
1527
|
+
if matching:
|
|
1528
|
+
selected = matching[:self.config.top_k]
|
|
1529
|
+
state.add_trace(
|
|
1530
|
+
"navigation",
|
|
1531
|
+
f"Pre-filter selected {len(selected)}/{len(children)} children",
|
|
1532
|
+
{"selected": [n.node_id for n in selected]},
|
|
1533
|
+
)
|
|
1534
|
+
else:
|
|
1535
|
+
# Fall back to ToT evaluation
|
|
1536
|
+
selected = self._tot_evaluate_children(state, children)
|
|
1537
|
+
else:
|
|
1538
|
+
# Use ToT evaluation
|
|
1539
|
+
selected = self._tot_evaluate_children(state, children)
|
|
1540
|
+
|
|
1541
|
+
# Queue selected children
|
|
1542
|
+
if selected:
|
|
1543
|
+
for child in selected:
|
|
1544
|
+
if child.node_id not in state.nodes_to_visit:
|
|
1545
|
+
state.nodes_to_visit.append(child.node_id)
|
|
1546
|
+
|
|
1547
|
+
# Push current node to backtrack stack
|
|
1548
|
+
state.backtrack_stack.append(node.node_id)
|
|
1549
|
+
else:
|
|
1550
|
+
# Dead end
|
|
1551
|
+
state.dead_ends.append(node.node_id)
|
|
1552
|
+
|
|
1553
|
+
state.visited_nodes.append(node.node_id)
|
|
1554
|
+
state.current_node_id = None
|
|
1555
|
+
return state
|
|
1556
|
+
|
|
1557
|
+
def _tot_evaluate_children(
|
|
1558
|
+
self,
|
|
1559
|
+
state: RLMAgentState,
|
|
1560
|
+
children: list[SkeletonNode],
|
|
1561
|
+
) -> list[SkeletonNode]:
|
|
1562
|
+
"""Use Tree of Thoughts to evaluate children."""
|
|
1563
|
+
if not self._llm_fn or not children:
|
|
1564
|
+
return children[:self.config.top_k]
|
|
1565
|
+
|
|
1566
|
+
# Format children for evaluation
|
|
1567
|
+
children_text = "\n".join(
|
|
1568
|
+
f" - [{c.node_id}] {c.header}: {c.summary[:150]}"
|
|
1569
|
+
for c in children
|
|
1570
|
+
)
|
|
1571
|
+
|
|
1572
|
+
current_node = self.skeleton.get(state.current_node_id or self.root_id)
|
|
1573
|
+
current_summary = f"{current_node.header}: {current_node.summary}" if current_node else ""
|
|
1574
|
+
|
|
1575
|
+
tot_prompt = f"""You are evaluating document sections for relevance.
|
|
1576
|
+
|
|
1577
|
+
Current location: {current_summary}
|
|
1578
|
+
|
|
1579
|
+
Children sections:
|
|
1580
|
+
{children_text}
|
|
1581
|
+
|
|
1582
|
+
Query: {state.current_sub_question or state.question}
|
|
1583
|
+
|
|
1584
|
+
TASK: Evaluate each child's probability (0.0-1.0) of containing relevant information.
|
|
1585
|
+
|
|
1586
|
+
OUTPUT FORMAT (JSON):
|
|
1587
|
+
{{
|
|
1588
|
+
"evaluations": [
|
|
1589
|
+
{{"node_id": "...", "probability": 0.85, "reasoning": "..."}}
|
|
1590
|
+
],
|
|
1591
|
+
"selected_nodes": ["node_id_1", "node_id_2"],
|
|
1592
|
+
"is_dead_end": false
|
|
1593
|
+
}}
|
|
1594
|
+
|
|
1595
|
+
JSON only:"""
|
|
1596
|
+
|
|
1597
|
+
try:
|
|
1598
|
+
import json
|
|
1599
|
+
|
|
1600
|
+
response = self._llm_fn(tot_prompt)
|
|
1601
|
+
json_match = re.search(r'\{[\s\S]*\}', response)
|
|
1602
|
+
|
|
1603
|
+
if json_match:
|
|
1604
|
+
result = json.loads(json_match.group())
|
|
1605
|
+
selected_ids = result.get("selected_nodes", [])
|
|
1606
|
+
|
|
1607
|
+
# Map back to nodes
|
|
1608
|
+
selected = [c for c in children if c.node_id in selected_ids]
|
|
1609
|
+
|
|
1610
|
+
if not selected and not result.get("is_dead_end", False):
|
|
1611
|
+
# Fallback: take top-k by probability
|
|
1612
|
+
evaluations = result.get("evaluations", [])
|
|
1613
|
+
sorted_evals = sorted(
|
|
1614
|
+
evaluations,
|
|
1615
|
+
key=lambda x: x.get("probability", 0),
|
|
1616
|
+
reverse=True,
|
|
1617
|
+
)
|
|
1618
|
+
top_ids = [e["node_id"] for e in sorted_evals[:self.config.top_k]]
|
|
1619
|
+
selected = [c for c in children if c.node_id in top_ids]
|
|
1620
|
+
|
|
1621
|
+
return selected
|
|
1622
|
+
|
|
1623
|
+
except Exception as e:
|
|
1624
|
+
logger.warning("tot_evaluation_failed", error=str(e))
|
|
1625
|
+
|
|
1626
|
+
# Fallback: return first top_k children
|
|
1627
|
+
return children[:self.config.top_k]
|
|
1628
|
+
|
|
1629
|
+
def _do_backtrack(self, state: RLMAgentState) -> RLMAgentState:
|
|
1630
|
+
"""Backtrack to previous node."""
|
|
1631
|
+
if state.backtrack_stack:
|
|
1632
|
+
parent_id = state.backtrack_stack.pop()
|
|
1633
|
+
state.dead_ends.append(state.current_node_id or "")
|
|
1634
|
+
state.current_node_id = parent_id
|
|
1635
|
+
|
|
1636
|
+
state.add_trace(
|
|
1637
|
+
"navigation",
|
|
1638
|
+
f"Backtracked to {parent_id}",
|
|
1639
|
+
{"from": state.current_node_id},
|
|
1640
|
+
)
|
|
1641
|
+
else:
|
|
1642
|
+
state.current_node_id = None
|
|
1643
|
+
|
|
1644
|
+
return state
|
|
1645
|
+
|
|
1646
|
+
def _phase_synthesize(self, state: RLMAgentState) -> RLMAgentState:
|
|
1647
|
+
"""Phase 4: Synthesize answer from variables."""
|
|
1648
|
+
state.add_trace("synthesis", "Synthesizing answer from variables")
|
|
1649
|
+
|
|
1650
|
+
if not state.variables:
|
|
1651
|
+
state.answer = "No relevant content found in the document."
|
|
1652
|
+
state.confidence = 0.0
|
|
1653
|
+
return state
|
|
1654
|
+
|
|
1655
|
+
# Collect all variable content
|
|
1656
|
+
contents = []
|
|
1657
|
+
for pointer in state.variables:
|
|
1658
|
+
content = self.variable_store.resolve(pointer)
|
|
1659
|
+
if content:
|
|
1660
|
+
contents.append(f"=== {pointer} ===\n{content}")
|
|
1661
|
+
|
|
1662
|
+
context_text = "\n\n".join(contents)
|
|
1663
|
+
|
|
1664
|
+
if not self._llm_fn:
|
|
1665
|
+
state.answer = context_text
|
|
1666
|
+
state.confidence = 0.5
|
|
1667
|
+
return state
|
|
1668
|
+
|
|
1669
|
+
# Handle multiple choice
|
|
1670
|
+
options = state.metadata.get("options")
|
|
1671
|
+
if options:
|
|
1672
|
+
options_text = "\n".join(f"{chr(65+i)}. {opt}" for i, opt in enumerate(options))
|
|
1673
|
+
synthesis_prompt = f"""Based on the context, answer this multiple-choice question.
|
|
1674
|
+
|
|
1675
|
+
Question: {state.question}
|
|
1676
|
+
|
|
1677
|
+
Options:
|
|
1678
|
+
{options_text}
|
|
1679
|
+
|
|
1680
|
+
Context:
|
|
1681
|
+
{context_text}
|
|
1682
|
+
|
|
1683
|
+
Respond with ONLY the letter and full option text (e.g., "A. [option text]"):"""
|
|
1684
|
+
else:
|
|
1685
|
+
synthesis_prompt = f"""Based on the context, answer the question concisely.
|
|
1686
|
+
|
|
1687
|
+
Question: {state.question}
|
|
1688
|
+
|
|
1689
|
+
Context:
|
|
1690
|
+
{context_text}
|
|
1691
|
+
|
|
1692
|
+
Answer:"""
|
|
1693
|
+
|
|
1694
|
+
try:
|
|
1695
|
+
answer = self._llm_fn(synthesis_prompt)
|
|
1696
|
+
state.answer = answer.strip()
|
|
1697
|
+
state.confidence = min(1.0, len(state.variables) * 0.25)
|
|
1698
|
+
|
|
1699
|
+
# Normalize multiple choice answer
|
|
1700
|
+
if options:
|
|
1701
|
+
state.answer = self._normalize_mc_answer(state.answer, options)
|
|
1702
|
+
|
|
1703
|
+
except Exception as e:
|
|
1704
|
+
logger.error("synthesis_failed", error=str(e))
|
|
1705
|
+
state.answer = f"Error during synthesis: {str(e)}"
|
|
1706
|
+
state.confidence = 0.0
|
|
1707
|
+
|
|
1708
|
+
return state
|
|
1709
|
+
|
|
1710
|
+
def _normalize_mc_answer(self, answer: str, options: list) -> str:
|
|
1711
|
+
"""Normalize multiple choice answer to match option text."""
|
|
1712
|
+
answer_lower = answer.lower().strip()
|
|
1713
|
+
|
|
1714
|
+
for i, opt in enumerate(options):
|
|
1715
|
+
letter = chr(65 + i)
|
|
1716
|
+
opt_lower = opt.lower()
|
|
1717
|
+
|
|
1718
|
+
if (answer_lower.startswith(f"{letter.lower()}.") or
|
|
1719
|
+
answer_lower.startswith(f"{letter.lower()})") or
|
|
1720
|
+
opt_lower in answer_lower):
|
|
1721
|
+
return opt
|
|
1722
|
+
|
|
1723
|
+
return answer
|
|
1724
|
+
|
|
1725
|
+
def _phase_verify(self, state: RLMAgentState) -> RLMAgentState:
|
|
1726
|
+
"""Phase 5: Verify the answer."""
|
|
1727
|
+
state.add_trace("verification", "Verifying answer")
|
|
1728
|
+
|
|
1729
|
+
# Collect evidence
|
|
1730
|
+
evidence = [
|
|
1731
|
+
self.variable_store.resolve(p) or ""
|
|
1732
|
+
for p in state.variables
|
|
1733
|
+
]
|
|
1734
|
+
|
|
1735
|
+
result = self.verification_engine.verify_answer(
|
|
1736
|
+
state.question,
|
|
1737
|
+
state.answer or "",
|
|
1738
|
+
evidence,
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
state.verification_result = result
|
|
1742
|
+
|
|
1743
|
+
if result.get("improved_answer"):
|
|
1744
|
+
state.answer = result["improved_answer"]
|
|
1745
|
+
|
|
1746
|
+
state.confidence = result.get("confidence", state.confidence)
|
|
1747
|
+
|
|
1748
|
+
state.add_trace(
|
|
1749
|
+
"verification",
|
|
1750
|
+
f"Verification complete: valid={result.get('is_valid', True)}",
|
|
1751
|
+
{"issues": result.get("issues", [])},
|
|
1752
|
+
)
|
|
1753
|
+
|
|
1754
|
+
return state
|
|
1755
|
+
|
|
1756
|
+
|
|
1757
|
+
# =============================================================================
|
|
1758
|
+
# Entity-Aware Query Decomposition
|
|
1759
|
+
# =============================================================================
|
|
1760
|
+
|
|
1761
|
+
|
|
1762
|
+
class EntityAwareDecomposer:
|
|
1763
|
+
"""
|
|
1764
|
+
Enhances query decomposition by leveraging entity relationships
|
|
1765
|
+
from the knowledge graph.
|
|
1766
|
+
|
|
1767
|
+
This allows the navigator to:
|
|
1768
|
+
1. Identify entities mentioned in the query
|
|
1769
|
+
2. Look up related entities via the knowledge graph
|
|
1770
|
+
3. Plan retrieval based on entity relationships
|
|
1771
|
+
4. Generate entity-focused sub-queries
|
|
1772
|
+
"""
|
|
1773
|
+
|
|
1774
|
+
def __init__(
|
|
1775
|
+
self,
|
|
1776
|
+
knowledge_graph=None,
|
|
1777
|
+
llm_fn: Callable[[str], str] | None = None,
|
|
1778
|
+
):
|
|
1779
|
+
"""
|
|
1780
|
+
Initialize the entity-aware decomposer.
|
|
1781
|
+
|
|
1782
|
+
Args:
|
|
1783
|
+
knowledge_graph: Optional knowledge graph for entity lookup.
|
|
1784
|
+
llm_fn: LLM function for query analysis.
|
|
1785
|
+
"""
|
|
1786
|
+
self.kg = knowledge_graph
|
|
1787
|
+
self._llm_fn = llm_fn
|
|
1788
|
+
|
|
1789
|
+
def set_llm_function(self, llm_fn: Callable[[str], str]) -> None:
|
|
1790
|
+
"""Set the LLM function."""
|
|
1791
|
+
self._llm_fn = llm_fn
|
|
1792
|
+
|
|
1793
|
+
def set_knowledge_graph(self, kg) -> None:
|
|
1794
|
+
"""Set the knowledge graph."""
|
|
1795
|
+
self.kg = kg
|
|
1796
|
+
|
|
1797
|
+
def decompose_with_entities(
|
|
1798
|
+
self,
|
|
1799
|
+
query: str,
|
|
1800
|
+
doc_id: str | None = None,
|
|
1801
|
+
) -> dict[str, Any]:
|
|
1802
|
+
"""
|
|
1803
|
+
Decompose a query using entity awareness.
|
|
1804
|
+
|
|
1805
|
+
Args:
|
|
1806
|
+
query: The user's query.
|
|
1807
|
+
doc_id: Optional document ID to scope entity lookup.
|
|
1808
|
+
|
|
1809
|
+
Returns:
|
|
1810
|
+
Dict with sub_queries, entities_found, and retrieval_plan.
|
|
1811
|
+
"""
|
|
1812
|
+
result = {
|
|
1813
|
+
"original_query": query,
|
|
1814
|
+
"sub_queries": [query],
|
|
1815
|
+
"entities_found": [],
|
|
1816
|
+
"entity_nodes": {},
|
|
1817
|
+
"retrieval_plan": [],
|
|
1818
|
+
}
|
|
1819
|
+
|
|
1820
|
+
if not self.kg:
|
|
1821
|
+
return result
|
|
1822
|
+
|
|
1823
|
+
# Step 1: Extract entity names from query
|
|
1824
|
+
entity_names = self._extract_entity_names(query)
|
|
1825
|
+
|
|
1826
|
+
if not entity_names:
|
|
1827
|
+
return result
|
|
1828
|
+
|
|
1829
|
+
# Step 2: Look up entities in knowledge graph
|
|
1830
|
+
entities_found = []
|
|
1831
|
+
entity_nodes: dict[str, list[str]] = {}
|
|
1832
|
+
|
|
1833
|
+
for name in entity_names:
|
|
1834
|
+
matches = self.kg.find_entities_by_name(name, fuzzy=True)
|
|
1835
|
+
|
|
1836
|
+
# Filter by document if specified
|
|
1837
|
+
if doc_id:
|
|
1838
|
+
matches = [e for e in matches if doc_id in e.document_ids]
|
|
1839
|
+
|
|
1840
|
+
for entity in matches:
|
|
1841
|
+
if entity not in entities_found:
|
|
1842
|
+
entities_found.append(entity)
|
|
1843
|
+
# Get nodes where this entity is mentioned
|
|
1844
|
+
entity_nodes[entity.id] = list(entity.node_ids)
|
|
1845
|
+
|
|
1846
|
+
result["entities_found"] = entities_found
|
|
1847
|
+
result["entity_nodes"] = entity_nodes
|
|
1848
|
+
|
|
1849
|
+
if not entities_found:
|
|
1850
|
+
return result
|
|
1851
|
+
|
|
1852
|
+
# Step 3: Get related entities and relationships
|
|
1853
|
+
related_entities = []
|
|
1854
|
+
relationships = []
|
|
1855
|
+
|
|
1856
|
+
for entity in entities_found:
|
|
1857
|
+
# Get entities co-mentioned with this one
|
|
1858
|
+
co_mentions = self.kg.get_entities_mentioned_together(entity.id)
|
|
1859
|
+
for related, count in co_mentions[:5]: # Top 5 co-mentions
|
|
1860
|
+
if related not in related_entities:
|
|
1861
|
+
related_entities.append(related)
|
|
1862
|
+
|
|
1863
|
+
# Get relationships
|
|
1864
|
+
rels = self.kg.get_entity_relationships(entity.id)
|
|
1865
|
+
relationships.extend(rels)
|
|
1866
|
+
|
|
1867
|
+
# Step 4: Generate entity-focused sub-queries
|
|
1868
|
+
sub_queries = self._generate_entity_sub_queries(
|
|
1869
|
+
query, entities_found, related_entities, relationships
|
|
1870
|
+
)
|
|
1871
|
+
|
|
1872
|
+
result["sub_queries"] = sub_queries
|
|
1873
|
+
result["related_entities"] = related_entities
|
|
1874
|
+
result["relationships"] = relationships
|
|
1875
|
+
|
|
1876
|
+
# Step 5: Create retrieval plan
|
|
1877
|
+
result["retrieval_plan"] = self._create_retrieval_plan(
|
|
1878
|
+
query, entities_found, entity_nodes, relationships
|
|
1879
|
+
)
|
|
1880
|
+
|
|
1881
|
+
logger.debug(
|
|
1882
|
+
"entity_aware_decomposition",
|
|
1883
|
+
entities=len(entities_found),
|
|
1884
|
+
sub_queries=len(sub_queries),
|
|
1885
|
+
relationships=len(relationships),
|
|
1886
|
+
)
|
|
1887
|
+
|
|
1888
|
+
return result
|
|
1889
|
+
|
|
1890
|
+
def _extract_entity_names(self, query: str) -> list[str]:
|
|
1891
|
+
"""Extract potential entity names from a query."""
|
|
1892
|
+
entity_names = []
|
|
1893
|
+
|
|
1894
|
+
# Use LLM if available
|
|
1895
|
+
if self._llm_fn:
|
|
1896
|
+
try:
|
|
1897
|
+
prompt = f"""Extract entity names (people, organizations, places, documents) from this query.
|
|
1898
|
+
|
|
1899
|
+
Query: {query}
|
|
1900
|
+
|
|
1901
|
+
Return as JSON array of names:
|
|
1902
|
+
["Name 1", "Name 2"]
|
|
1903
|
+
|
|
1904
|
+
JSON only:"""
|
|
1905
|
+
|
|
1906
|
+
response = self._llm_fn(prompt)
|
|
1907
|
+
json_match = re.search(r'\[[\s\S]*?\]', response)
|
|
1908
|
+
if json_match:
|
|
1909
|
+
import json
|
|
1910
|
+
entity_names = json.loads(json_match.group())
|
|
1911
|
+
|
|
1912
|
+
except Exception as e:
|
|
1913
|
+
logger.debug("entity_extraction_llm_failed", error=str(e))
|
|
1914
|
+
|
|
1915
|
+
# Fallback: extract capitalized phrases
|
|
1916
|
+
if not entity_names:
|
|
1917
|
+
# Find capitalized words (likely proper nouns)
|
|
1918
|
+
proper_nouns = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', query)
|
|
1919
|
+
entity_names = proper_nouns
|
|
1920
|
+
|
|
1921
|
+
return entity_names
|
|
1922
|
+
|
|
1923
|
+
def _generate_entity_sub_queries(
|
|
1924
|
+
self,
|
|
1925
|
+
query: str,
|
|
1926
|
+
entities: list,
|
|
1927
|
+
related: list,
|
|
1928
|
+
relationships: list,
|
|
1929
|
+
) -> list[str]:
|
|
1930
|
+
"""Generate sub-queries focused on entities."""
|
|
1931
|
+
sub_queries = []
|
|
1932
|
+
|
|
1933
|
+
if not self._llm_fn:
|
|
1934
|
+
# Simple decomposition: one query per entity
|
|
1935
|
+
for entity in entities[:3]:
|
|
1936
|
+
sub_queries.append(
|
|
1937
|
+
f"Find information about {entity.canonical_name}: {query}"
|
|
1938
|
+
)
|
|
1939
|
+
return sub_queries if sub_queries else [query]
|
|
1940
|
+
|
|
1941
|
+
# Use LLM for intelligent decomposition
|
|
1942
|
+
try:
|
|
1943
|
+
entity_names = [e.canonical_name for e in entities]
|
|
1944
|
+
related_names = [e.canonical_name for e in related[:5]]
|
|
1945
|
+
|
|
1946
|
+
rel_descriptions = []
|
|
1947
|
+
for rel in relationships[:10]:
|
|
1948
|
+
rel_descriptions.append(
|
|
1949
|
+
f"- {rel.source_id} {rel.type.value} {rel.target_id}"
|
|
1950
|
+
)
|
|
1951
|
+
|
|
1952
|
+
prompt = f"""Decompose this query into focused sub-queries based on the entities.
|
|
1953
|
+
|
|
1954
|
+
Query: {query}
|
|
1955
|
+
|
|
1956
|
+
Key entities found: {', '.join(entity_names)}
|
|
1957
|
+
Related entities: {', '.join(related_names)}
|
|
1958
|
+
|
|
1959
|
+
Known relationships:
|
|
1960
|
+
{chr(10).join(rel_descriptions) if rel_descriptions else '(none)'}
|
|
1961
|
+
|
|
1962
|
+
Generate 1-5 focused sub-queries. Each should target specific entities or relationships.
|
|
1963
|
+
|
|
1964
|
+
Return as JSON:
|
|
1965
|
+
{{"sub_queries": ["query 1", "query 2"]}}
|
|
1966
|
+
|
|
1967
|
+
JSON only:"""
|
|
1968
|
+
|
|
1969
|
+
response = self._llm_fn(prompt)
|
|
1970
|
+
json_match = re.search(r'\{[\s\S]*?\}', response)
|
|
1971
|
+
if json_match:
|
|
1972
|
+
import json
|
|
1973
|
+
result = json.loads(json_match.group())
|
|
1974
|
+
sub_queries = result.get("sub_queries", [])
|
|
1975
|
+
|
|
1976
|
+
except Exception as e:
|
|
1977
|
+
logger.debug("sub_query_generation_failed", error=str(e))
|
|
1978
|
+
|
|
1979
|
+
return sub_queries if sub_queries else [query]
|
|
1980
|
+
|
|
1981
|
+
def _create_retrieval_plan(
|
|
1982
|
+
self,
|
|
1983
|
+
query: str,
|
|
1984
|
+
entities: list,
|
|
1985
|
+
entity_nodes: dict[str, list[str]],
|
|
1986
|
+
relationships: list,
|
|
1987
|
+
) -> list[dict[str, Any]]:
|
|
1988
|
+
"""Create a retrieval plan based on entities."""
|
|
1989
|
+
plan = []
|
|
1990
|
+
|
|
1991
|
+
# Priority 1: Nodes with direct entity mentions
|
|
1992
|
+
priority_nodes = set()
|
|
1993
|
+
for entity in entities:
|
|
1994
|
+
nodes = entity_nodes.get(entity.id, [])
|
|
1995
|
+
for node_id in nodes:
|
|
1996
|
+
priority_nodes.add(node_id)
|
|
1997
|
+
plan.append({
|
|
1998
|
+
"node_id": node_id,
|
|
1999
|
+
"priority": 1,
|
|
2000
|
+
"reason": f"Contains {entity.canonical_name}",
|
|
2001
|
+
"entity_id": entity.id,
|
|
2002
|
+
})
|
|
2003
|
+
|
|
2004
|
+
# Priority 2: Nodes involved in relationships
|
|
2005
|
+
for rel in relationships:
|
|
2006
|
+
if rel.source_type == "node" and rel.source_id not in priority_nodes:
|
|
2007
|
+
plan.append({
|
|
2008
|
+
"node_id": rel.source_id,
|
|
2009
|
+
"priority": 2,
|
|
2010
|
+
"reason": f"Related via {rel.type.value}",
|
|
2011
|
+
"relationship_id": rel.id,
|
|
2012
|
+
})
|
|
2013
|
+
if rel.target_type == "node" and rel.target_id not in priority_nodes:
|
|
2014
|
+
plan.append({
|
|
2015
|
+
"node_id": rel.target_id,
|
|
2016
|
+
"priority": 2,
|
|
2017
|
+
"reason": f"Related via {rel.type.value}",
|
|
2018
|
+
"relationship_id": rel.id,
|
|
2019
|
+
})
|
|
2020
|
+
|
|
2021
|
+
# Sort by priority
|
|
2022
|
+
plan.sort(key=lambda x: x["priority"])
|
|
2023
|
+
|
|
2024
|
+
return plan
|
|
2025
|
+
|
|
2026
|
+
|
|
2027
|
+
# =============================================================================
|
|
2028
|
+
# Factory Function
|
|
2029
|
+
# =============================================================================
|
|
2030
|
+
|
|
2031
|
+
|
|
2032
|
+
def create_rlm_navigator(
|
|
2033
|
+
skeleton: dict[str, SkeletonNode],
|
|
2034
|
+
kv_store: KVStore,
|
|
2035
|
+
config: RLMConfig | None = None,
|
|
2036
|
+
knowledge_graph=None,
|
|
2037
|
+
) -> RLMNavigator:
|
|
2038
|
+
"""
|
|
2039
|
+
Create an RLM Navigator instance.
|
|
2040
|
+
|
|
2041
|
+
Args:
|
|
2042
|
+
skeleton: Skeleton index.
|
|
2043
|
+
kv_store: KV store with full content.
|
|
2044
|
+
config: Optional configuration.
|
|
2045
|
+
knowledge_graph: Optional knowledge graph for entity-aware queries.
|
|
2046
|
+
|
|
2047
|
+
Returns:
|
|
2048
|
+
Configured RLMNavigator.
|
|
2049
|
+
|
|
2050
|
+
Example:
|
|
2051
|
+
from rnsr import ingest_document, build_skeleton_index
|
|
2052
|
+
from rnsr.agent.rlm_navigator import create_rlm_navigator, RLMConfig
|
|
2053
|
+
from rnsr.indexing.knowledge_graph import KnowledgeGraph
|
|
2054
|
+
|
|
2055
|
+
result = ingest_document("contract.pdf")
|
|
2056
|
+
skeleton, kv_store = build_skeleton_index(result.tree)
|
|
2057
|
+
|
|
2058
|
+
# With knowledge graph for entity-aware queries
|
|
2059
|
+
kg = KnowledgeGraph("./data/kg.db")
|
|
2060
|
+
|
|
2061
|
+
# With custom config
|
|
2062
|
+
config = RLMConfig(
|
|
2063
|
+
max_recursion_depth=3,
|
|
2064
|
+
enable_pre_filtering=True,
|
|
2065
|
+
enable_verification=True,
|
|
2066
|
+
)
|
|
2067
|
+
|
|
2068
|
+
navigator = create_rlm_navigator(skeleton, kv_store, config, kg)
|
|
2069
|
+
result = navigator.navigate("What are the liability terms?")
|
|
2070
|
+
print(result["answer"])
|
|
2071
|
+
"""
|
|
2072
|
+
nav = RLMNavigator(skeleton, kv_store, config, knowledge_graph)
|
|
2073
|
+
|
|
2074
|
+
# Configure LLM
|
|
2075
|
+
try:
|
|
2076
|
+
from rnsr.llm import get_llm
|
|
2077
|
+
llm = get_llm()
|
|
2078
|
+
nav.set_llm_function(lambda p: str(llm.complete(p)))
|
|
2079
|
+
except Exception as e:
|
|
2080
|
+
logger.warning("llm_config_failed", error=str(e))
|
|
2081
|
+
|
|
2082
|
+
return nav
|
|
2083
|
+
|
|
2084
|
+
|
|
2085
|
+
def run_rlm_navigator(
|
|
2086
|
+
question: str,
|
|
2087
|
+
skeleton: dict[str, SkeletonNode],
|
|
2088
|
+
kv_store: KVStore,
|
|
2089
|
+
config: RLMConfig | None = None,
|
|
2090
|
+
metadata: dict[str, Any] | None = None,
|
|
2091
|
+
) -> dict[str, Any]:
|
|
2092
|
+
"""
|
|
2093
|
+
Run the RLM Navigator on a question.
|
|
2094
|
+
|
|
2095
|
+
Convenience function that creates and runs the navigator.
|
|
2096
|
+
|
|
2097
|
+
Args:
|
|
2098
|
+
question: The user's question.
|
|
2099
|
+
skeleton: Skeleton index.
|
|
2100
|
+
kv_store: KV store.
|
|
2101
|
+
config: Optional configuration.
|
|
2102
|
+
metadata: Optional metadata.
|
|
2103
|
+
|
|
2104
|
+
Returns:
|
|
2105
|
+
Dict with answer, confidence, trace.
|
|
2106
|
+
"""
|
|
2107
|
+
navigator = create_rlm_navigator(skeleton, kv_store, config)
|
|
2108
|
+
return navigator.navigate(question, metadata)
|