rnsr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnsr/__init__.py +118 -0
- rnsr/__main__.py +242 -0
- rnsr/agent/__init__.py +218 -0
- rnsr/agent/cross_doc_navigator.py +767 -0
- rnsr/agent/graph.py +1557 -0
- rnsr/agent/llm_cache.py +575 -0
- rnsr/agent/navigator_api.py +497 -0
- rnsr/agent/provenance.py +772 -0
- rnsr/agent/query_clarifier.py +617 -0
- rnsr/agent/reasoning_memory.py +736 -0
- rnsr/agent/repl_env.py +709 -0
- rnsr/agent/rlm_navigator.py +2108 -0
- rnsr/agent/self_reflection.py +602 -0
- rnsr/agent/variable_store.py +308 -0
- rnsr/benchmarks/__init__.py +118 -0
- rnsr/benchmarks/comprehensive_benchmark.py +733 -0
- rnsr/benchmarks/evaluation_suite.py +1210 -0
- rnsr/benchmarks/finance_bench.py +147 -0
- rnsr/benchmarks/pdf_merger.py +178 -0
- rnsr/benchmarks/performance.py +321 -0
- rnsr/benchmarks/quality.py +321 -0
- rnsr/benchmarks/runner.py +298 -0
- rnsr/benchmarks/standard_benchmarks.py +995 -0
- rnsr/client.py +560 -0
- rnsr/document_store.py +394 -0
- rnsr/exceptions.py +74 -0
- rnsr/extraction/__init__.py +172 -0
- rnsr/extraction/candidate_extractor.py +357 -0
- rnsr/extraction/entity_extractor.py +581 -0
- rnsr/extraction/entity_linker.py +825 -0
- rnsr/extraction/grounded_extractor.py +722 -0
- rnsr/extraction/learned_types.py +599 -0
- rnsr/extraction/models.py +232 -0
- rnsr/extraction/relationship_extractor.py +600 -0
- rnsr/extraction/relationship_patterns.py +511 -0
- rnsr/extraction/relationship_validator.py +392 -0
- rnsr/extraction/rlm_extractor.py +589 -0
- rnsr/extraction/rlm_unified_extractor.py +990 -0
- rnsr/extraction/tot_validator.py +610 -0
- rnsr/extraction/unified_extractor.py +342 -0
- rnsr/indexing/__init__.py +60 -0
- rnsr/indexing/knowledge_graph.py +1128 -0
- rnsr/indexing/kv_store.py +313 -0
- rnsr/indexing/persistence.py +323 -0
- rnsr/indexing/semantic_retriever.py +237 -0
- rnsr/indexing/semantic_search.py +320 -0
- rnsr/indexing/skeleton_index.py +395 -0
- rnsr/ingestion/__init__.py +161 -0
- rnsr/ingestion/chart_parser.py +569 -0
- rnsr/ingestion/document_boundary.py +662 -0
- rnsr/ingestion/font_histogram.py +334 -0
- rnsr/ingestion/header_classifier.py +595 -0
- rnsr/ingestion/hierarchical_cluster.py +515 -0
- rnsr/ingestion/layout_detector.py +356 -0
- rnsr/ingestion/layout_model.py +379 -0
- rnsr/ingestion/ocr_fallback.py +177 -0
- rnsr/ingestion/pipeline.py +936 -0
- rnsr/ingestion/semantic_fallback.py +417 -0
- rnsr/ingestion/table_parser.py +799 -0
- rnsr/ingestion/text_builder.py +460 -0
- rnsr/ingestion/tree_builder.py +402 -0
- rnsr/ingestion/vision_retrieval.py +965 -0
- rnsr/ingestion/xy_cut.py +555 -0
- rnsr/llm.py +733 -0
- rnsr/models.py +167 -0
- rnsr/py.typed +2 -0
- rnsr-0.1.0.dist-info/METADATA +592 -0
- rnsr-0.1.0.dist-info/RECORD +72 -0
- rnsr-0.1.0.dist-info/WHEEL +5 -0
- rnsr-0.1.0.dist-info/entry_points.txt +2 -0
- rnsr-0.1.0.dist-info/licenses/LICENSE +21 -0
- rnsr-0.1.0.dist-info/top_level.txt +1 -0
rnsr/agent/graph.py
ADDED
|
@@ -0,0 +1,1557 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Agent Graph - LangGraph State Machine for Document Navigation
|
|
3
|
+
|
|
4
|
+
Implements the Navigator Agent with full RLM (Recursive Language Model) support:
|
|
5
|
+
|
|
6
|
+
1. Decomposes queries into sub-questions (Section 2.2 - Recursive Loop)
|
|
7
|
+
2. Navigates the document tree (expand/traverse decisions)
|
|
8
|
+
3. Stores findings as pointers (Variable Stitching - Section 2.2)
|
|
9
|
+
4. Synthesizes final answer from stored pointers
|
|
10
|
+
5. Tree of Thoughts prompting (Section 7.2)
|
|
11
|
+
6. Recursive sub-LLM invocation for complex queries
|
|
12
|
+
|
|
13
|
+
Agent State follows Appendix C specification:
|
|
14
|
+
- question: Current question being answered
|
|
15
|
+
- sub_questions: Decomposed sub-questions (via LLM)
|
|
16
|
+
- current_node_id: Where we are in the document tree
|
|
17
|
+
- visited_nodes: Navigation history
|
|
18
|
+
- variables: Stored findings as $POINTER -> content
|
|
19
|
+
- pending_questions: Sub-questions not yet answered
|
|
20
|
+
- context: Accumulated context (pointers only!)
|
|
21
|
+
- answer: Final synthesized answer
|
|
22
|
+
- trace: Full retrieval trace for transparency
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import re
|
|
28
|
+
from datetime import datetime, timezone
|
|
29
|
+
from typing import Any, Literal, TypedDict, cast
|
|
30
|
+
|
|
31
|
+
import structlog
|
|
32
|
+
|
|
33
|
+
from rnsr.agent.variable_store import VariableStore, generate_pointer_name
|
|
34
|
+
from rnsr.indexing.kv_store import KVStore
|
|
35
|
+
from rnsr.indexing.semantic_search import SemanticSearcher
|
|
36
|
+
from rnsr.models import SkeletonNode, TraceEntry
|
|
37
|
+
|
|
38
|
+
logger = structlog.get_logger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# =============================================================================
|
|
42
|
+
# Agent State Definition (Appendix C)
|
|
43
|
+
# =============================================================================
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class AgentState(TypedDict):
|
|
47
|
+
"""
|
|
48
|
+
State for the RNSR Navigator Agent.
|
|
49
|
+
|
|
50
|
+
All fields use pointer-based Variable Stitching:
|
|
51
|
+
- Full content stored in VariableStore
|
|
52
|
+
- Only $POINTER names in state fields
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# Query processing
|
|
56
|
+
question: str
|
|
57
|
+
sub_questions: list[str]
|
|
58
|
+
pending_questions: list[str]
|
|
59
|
+
current_sub_question: str | None
|
|
60
|
+
|
|
61
|
+
# Navigation state
|
|
62
|
+
current_node_id: str | None
|
|
63
|
+
visited_nodes: list[str]
|
|
64
|
+
navigation_path: list[str]
|
|
65
|
+
|
|
66
|
+
# Tree of Thoughts (ToT) state - Section 7.2
|
|
67
|
+
nodes_to_visit: list[str] # Queue for parallel exploration
|
|
68
|
+
scored_candidates: list[dict[str, Any]] # [{node_id, score, reasoning}]
|
|
69
|
+
backtrack_stack: list[str] # Stack of parent node IDs for backtracking
|
|
70
|
+
dead_ends: list[str] # Nodes marked as dead ends
|
|
71
|
+
top_k: int # Number of top candidates to explore
|
|
72
|
+
tot_selection_threshold: float # Minimum probability for selection
|
|
73
|
+
tot_dead_end_threshold: float # Probability threshold for dead end
|
|
74
|
+
|
|
75
|
+
# Variable stitching (pointers only!)
|
|
76
|
+
variables: list[str] # List of $POINTER names
|
|
77
|
+
context: str # Contains pointers, not full content
|
|
78
|
+
|
|
79
|
+
# Output
|
|
80
|
+
answer: str | None
|
|
81
|
+
confidence: float
|
|
82
|
+
|
|
83
|
+
# Question metadata (e.g., multiple-choice options)
|
|
84
|
+
metadata: dict[str, Any]
|
|
85
|
+
|
|
86
|
+
# Traceability
|
|
87
|
+
trace: list[dict[str, Any]]
|
|
88
|
+
iteration: int
|
|
89
|
+
max_iterations: int
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# =============================================================================
|
|
93
|
+
# Navigator Tools
|
|
94
|
+
# =============================================================================
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def create_navigator_tools(
|
|
98
|
+
skeleton: dict[str, SkeletonNode],
|
|
99
|
+
kv_store: KVStore,
|
|
100
|
+
variable_store: VariableStore,
|
|
101
|
+
) -> dict[str, Any]:
|
|
102
|
+
"""
|
|
103
|
+
Create tools for the Navigator Agent.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
skeleton: Skeleton index (node_id -> SkeletonNode).
|
|
107
|
+
kv_store: KV store with full content.
|
|
108
|
+
variable_store: Variable store for findings.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Dictionary of tool functions.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def get_node_summary(node_id: str) -> str:
|
|
115
|
+
"""Get the summary of a skeleton node."""
|
|
116
|
+
node = skeleton.get(node_id)
|
|
117
|
+
if node is None:
|
|
118
|
+
return f"Node {node_id} not found"
|
|
119
|
+
|
|
120
|
+
children_info = ""
|
|
121
|
+
if node.child_ids:
|
|
122
|
+
children = [skeleton.get(cid) for cid in node.child_ids]
|
|
123
|
+
children_info = "\nChildren:\n" + "\n".join(
|
|
124
|
+
f" - {c.node_id}: {c.header}" for c in children if c
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return f"""
|
|
128
|
+
Node: {node.node_id}
|
|
129
|
+
Header: {node.header}
|
|
130
|
+
Level: {node.level}
|
|
131
|
+
Summary: {node.summary}
|
|
132
|
+
{children_info}
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def navigate_to_child(node_id: str, child_id: str) -> str:
|
|
136
|
+
"""Navigate to a child node."""
|
|
137
|
+
node = skeleton.get(node_id)
|
|
138
|
+
if node is None:
|
|
139
|
+
return f"Node {node_id} not found"
|
|
140
|
+
|
|
141
|
+
if child_id not in node.child_ids:
|
|
142
|
+
return f"{child_id} is not a child of {node_id}"
|
|
143
|
+
|
|
144
|
+
child = skeleton.get(child_id)
|
|
145
|
+
if child is None:
|
|
146
|
+
return f"Child {child_id} not found"
|
|
147
|
+
|
|
148
|
+
return get_node_summary(child_id)
|
|
149
|
+
|
|
150
|
+
def expand_node(node_id: str) -> str:
|
|
151
|
+
"""
|
|
152
|
+
EXPAND: Fetch full content and store as variable.
|
|
153
|
+
|
|
154
|
+
Use when the summary answers the question.
|
|
155
|
+
"""
|
|
156
|
+
node = skeleton.get(node_id)
|
|
157
|
+
if node is None:
|
|
158
|
+
return f"Node {node_id} not found"
|
|
159
|
+
|
|
160
|
+
# Fetch full content from KV store
|
|
161
|
+
content = kv_store.get(node_id)
|
|
162
|
+
if content is None:
|
|
163
|
+
return f"No content found for {node_id}"
|
|
164
|
+
|
|
165
|
+
# Generate pointer name
|
|
166
|
+
pointer = generate_pointer_name(node.header)
|
|
167
|
+
|
|
168
|
+
# Store as variable
|
|
169
|
+
variable_store.assign(pointer, content, node_id)
|
|
170
|
+
|
|
171
|
+
logger.info(
|
|
172
|
+
"node_expanded",
|
|
173
|
+
node_id=node_id,
|
|
174
|
+
pointer=pointer,
|
|
175
|
+
chars=len(content),
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return f"Stored content as {pointer} ({len(content)} chars)"
|
|
179
|
+
|
|
180
|
+
def store_finding(
|
|
181
|
+
pointer_name: str,
|
|
182
|
+
content: str,
|
|
183
|
+
source_node_id: str = "",
|
|
184
|
+
) -> str:
|
|
185
|
+
"""
|
|
186
|
+
Store a finding as a pointer variable.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
pointer_name: Name like $LIABILITY_CLAUSE (must start with $)
|
|
190
|
+
content: Full text content to store
|
|
191
|
+
source_node_id: Source node for traceability
|
|
192
|
+
"""
|
|
193
|
+
if not pointer_name.startswith("$"):
|
|
194
|
+
pointer_name = "$" + pointer_name.upper()
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
meta = variable_store.assign(pointer_name, content, source_node_id)
|
|
198
|
+
return f"Stored as {pointer_name} ({meta.char_count} chars)"
|
|
199
|
+
except Exception as e:
|
|
200
|
+
return f"Error storing: {e}"
|
|
201
|
+
|
|
202
|
+
def compare_variables(*pointers: str) -> str:
|
|
203
|
+
"""
|
|
204
|
+
Compare multiple stored variables.
|
|
205
|
+
|
|
206
|
+
Resolves pointers and returns content for comparison.
|
|
207
|
+
Use during synthesis phase only.
|
|
208
|
+
"""
|
|
209
|
+
results = []
|
|
210
|
+
for pointer in pointers:
|
|
211
|
+
content = variable_store.resolve(pointer)
|
|
212
|
+
if content:
|
|
213
|
+
results.append(f"=== {pointer} ===\n{content}")
|
|
214
|
+
else:
|
|
215
|
+
results.append(f"=== {pointer} ===\n[Not found]")
|
|
216
|
+
|
|
217
|
+
return "\n\n".join(results)
|
|
218
|
+
|
|
219
|
+
def list_stored_variables() -> str:
|
|
220
|
+
"""List all stored variables with metadata."""
|
|
221
|
+
variables = variable_store.list_variables()
|
|
222
|
+
if not variables:
|
|
223
|
+
return "No variables stored yet."
|
|
224
|
+
|
|
225
|
+
lines = ["Stored Variables:"]
|
|
226
|
+
for v in variables:
|
|
227
|
+
lines.append(f" {v.pointer}: {v.char_count} chars from {v.source_node_id}")
|
|
228
|
+
|
|
229
|
+
return "\n".join(lines)
|
|
230
|
+
|
|
231
|
+
def synthesize_from_variables() -> str:
|
|
232
|
+
"""
|
|
233
|
+
Get all stored content for final synthesis.
|
|
234
|
+
|
|
235
|
+
Call this at the end to get full content for answer generation.
|
|
236
|
+
"""
|
|
237
|
+
variables = variable_store.list_variables()
|
|
238
|
+
if not variables:
|
|
239
|
+
return "No variables to synthesize from."
|
|
240
|
+
|
|
241
|
+
parts = []
|
|
242
|
+
for v in variables:
|
|
243
|
+
content = variable_store.resolve(v.pointer)
|
|
244
|
+
if content:
|
|
245
|
+
parts.append(f"=== {v.pointer} ===\n{content}")
|
|
246
|
+
|
|
247
|
+
return "\n\n".join(parts)
|
|
248
|
+
|
|
249
|
+
return {
|
|
250
|
+
"get_node_summary": get_node_summary,
|
|
251
|
+
"navigate_to_child": navigate_to_child,
|
|
252
|
+
"expand_node": expand_node,
|
|
253
|
+
"store_finding": store_finding,
|
|
254
|
+
"compare_variables": compare_variables,
|
|
255
|
+
"list_stored_variables": list_stored_variables,
|
|
256
|
+
"synthesize_from_variables": synthesize_from_variables,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# =============================================================================
|
|
261
|
+
# State Management Functions
|
|
262
|
+
# =============================================================================
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def create_initial_state(
|
|
266
|
+
question: str,
|
|
267
|
+
root_node_id: str,
|
|
268
|
+
max_iterations: int = 20,
|
|
269
|
+
top_k: int = 3,
|
|
270
|
+
metadata: dict[str, Any] | None = None,
|
|
271
|
+
tot_selection_threshold: float = 0.4,
|
|
272
|
+
tot_dead_end_threshold: float = 0.1,
|
|
273
|
+
) -> AgentState:
|
|
274
|
+
"""Create the initial agent state with ToT support."""
|
|
275
|
+
return AgentState(
|
|
276
|
+
question=question,
|
|
277
|
+
sub_questions=[],
|
|
278
|
+
pending_questions=[],
|
|
279
|
+
current_sub_question=None,
|
|
280
|
+
current_node_id=root_node_id,
|
|
281
|
+
visited_nodes=[],
|
|
282
|
+
navigation_path=[root_node_id],
|
|
283
|
+
# Tree of Thoughts state
|
|
284
|
+
nodes_to_visit=[],
|
|
285
|
+
scored_candidates=[],
|
|
286
|
+
backtrack_stack=[],
|
|
287
|
+
dead_ends=[],
|
|
288
|
+
top_k=top_k,
|
|
289
|
+
tot_selection_threshold=tot_selection_threshold,
|
|
290
|
+
tot_dead_end_threshold=tot_dead_end_threshold,
|
|
291
|
+
# Variable stitching
|
|
292
|
+
variables=[],
|
|
293
|
+
context="",
|
|
294
|
+
answer=None,
|
|
295
|
+
confidence=0.0,
|
|
296
|
+
metadata=metadata or {},
|
|
297
|
+
trace=[],
|
|
298
|
+
iteration=0,
|
|
299
|
+
max_iterations=max_iterations,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def add_trace_entry(
|
|
304
|
+
state: AgentState,
|
|
305
|
+
node_type: Literal["decomposition", "navigation", "variable_stitching", "synthesis"],
|
|
306
|
+
action: str,
|
|
307
|
+
details: dict | None = None,
|
|
308
|
+
) -> None:
|
|
309
|
+
"""Add a trace entry to the state."""
|
|
310
|
+
entry = TraceEntry(
|
|
311
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
312
|
+
node_type=node_type,
|
|
313
|
+
action=action,
|
|
314
|
+
details=details or {},
|
|
315
|
+
)
|
|
316
|
+
state["trace"].append(entry.model_dump())
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# =============================================================================
|
|
320
|
+
# Tree of Thoughts (ToT) Prompting - Section 7.2
|
|
321
|
+
# =============================================================================
|
|
322
|
+
|
|
323
|
+
# ToT System Prompt as specified in the research paper
|
|
324
|
+
TOT_SYSTEM_PROMPT = """You are a Deep Research Agent navigating a document tree.
|
|
325
|
+
|
|
326
|
+
You are currently at Node: {current_node_summary}
|
|
327
|
+
Children Nodes: {children_summaries}
|
|
328
|
+
Your Goal: {query}
|
|
329
|
+
|
|
330
|
+
EVALUATION TASK:
|
|
331
|
+
For each child node, estimate the probability (0.0 to 1.0) that it contains relevant information for the goal.
|
|
332
|
+
|
|
333
|
+
INSTRUCTIONS:
|
|
334
|
+
1. Evaluate: For each child node, analyze its summary and estimate relevance probability.
|
|
335
|
+
2. Be OPEN-MINDED: Select nodes with probability > {selection_threshold} (moderate evidence of relevance).
|
|
336
|
+
3. Look for matches: Prefer nodes with facts/entities mentioned in the query, but also consider broad thematic matches.
|
|
337
|
+
4. Balance PRECISION and RECALL: Do not prune branches too early. If unsure, include the node.
|
|
338
|
+
5. Plan: Select the top-{top_k} most promising nodes.
|
|
339
|
+
6. Reasoning: Explain briefly what SPECIFIC content in the summary makes it relevant.
|
|
340
|
+
7. Backtrack Signal: If NO child seems relevant (all probabilities < {dead_end_threshold}), report "DEAD_END".
|
|
341
|
+
|
|
342
|
+
OUTPUT FORMAT (JSON):
|
|
343
|
+
{{
|
|
344
|
+
"evaluations": [
|
|
345
|
+
{{"node_id": "...", "probability": 0.85, "reasoning": "Summary mentions X which directly relates to query about Y"}},
|
|
346
|
+
{{"node_id": "...", "probability": 0.60, "reasoning": "Contains information about Z"}},
|
|
347
|
+
...
|
|
348
|
+
],
|
|
349
|
+
"selected_nodes": ["node_id_1", "node_id_2", ...],
|
|
350
|
+
"is_dead_end": false,
|
|
351
|
+
"backtrack_reason": null
|
|
352
|
+
}}
|
|
353
|
+
|
|
354
|
+
If this is a dead end:
|
|
355
|
+
{{
|
|
356
|
+
"evaluations": [...],
|
|
357
|
+
"selected_nodes": [],
|
|
358
|
+
"is_dead_end": true,
|
|
359
|
+
"backtrack_reason": "None of the children appear to contain information about X."
|
|
360
|
+
}}
|
|
361
|
+
|
|
362
|
+
Respond ONLY with the JSON, no other text."""
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _format_node_summary(node: SkeletonNode) -> str:
|
|
366
|
+
"""Format a node's summary for the ToT prompt."""
|
|
367
|
+
return f"[{node.node_id}] {node.header}: {node.summary or '(no summary)'}"
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _format_children_summaries(
|
|
371
|
+
skeleton: dict[str, SkeletonNode],
|
|
372
|
+
child_ids: list[str],
|
|
373
|
+
) -> str:
|
|
374
|
+
"""Format all children summaries for the ToT prompt."""
|
|
375
|
+
if not child_ids:
|
|
376
|
+
return "(no children - this is a leaf node)"
|
|
377
|
+
|
|
378
|
+
lines = []
|
|
379
|
+
for child_id in child_ids:
|
|
380
|
+
child = skeleton.get(child_id)
|
|
381
|
+
if child:
|
|
382
|
+
lines.append(f" - {_format_node_summary(child)}")
|
|
383
|
+
|
|
384
|
+
return "\n".join(lines) if lines else "(no children found)"
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def evaluate_children_with_tot(
|
|
388
|
+
state: AgentState,
|
|
389
|
+
skeleton: dict[str, SkeletonNode],
|
|
390
|
+
top_k_override: int | None = None,
|
|
391
|
+
) -> dict[str, Any]:
|
|
392
|
+
"""
|
|
393
|
+
Use Tree of Thoughts prompting to evaluate child nodes.
|
|
394
|
+
|
|
395
|
+
This implements Section 7.2 of the research paper:
|
|
396
|
+
"We use Tree of Thoughts (ToT) which explicitly encourages the model
|
|
397
|
+
to explore multiple branches of the index before committing to a path."
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
state: Current agent state.
|
|
401
|
+
skeleton: Skeleton index.
|
|
402
|
+
top_k_override: Optional override for adaptive exploration.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Dictionary with evaluations, selected_nodes, is_dead_end, and backtrack_reason.
|
|
406
|
+
"""
|
|
407
|
+
node_id = state.get("current_node_id")
|
|
408
|
+
if node_id is None:
|
|
409
|
+
return {"evaluations": [], "selected_nodes": [], "is_dead_end": True, "backtrack_reason": "No current node"}
|
|
410
|
+
|
|
411
|
+
node = skeleton.get(node_id)
|
|
412
|
+
if node is None:
|
|
413
|
+
return {"evaluations": [], "selected_nodes": [], "is_dead_end": True, "backtrack_reason": "Node not found"}
|
|
414
|
+
|
|
415
|
+
# If no children, this is a leaf - expand instead of traverse
|
|
416
|
+
if not node.child_ids:
|
|
417
|
+
return {"evaluations": [], "selected_nodes": [], "is_dead_end": False, "backtrack_reason": None, "is_leaf": True}
|
|
418
|
+
|
|
419
|
+
# Format the ToT prompt
|
|
420
|
+
current_summary = _format_node_summary(node)
|
|
421
|
+
children_summaries = _format_children_summaries(skeleton, node.child_ids)
|
|
422
|
+
query = state.get("current_sub_question") or state.get("question", "")
|
|
423
|
+
top_k = top_k_override if top_k_override is not None else state.get("top_k", 3)
|
|
424
|
+
selection_threshold = state.get("tot_selection_threshold", 0.4)
|
|
425
|
+
dead_end_threshold = state.get("tot_dead_end_threshold", 0.1)
|
|
426
|
+
|
|
427
|
+
prompt = TOT_SYSTEM_PROMPT.format(
|
|
428
|
+
current_node_summary=current_summary,
|
|
429
|
+
children_summaries=children_summaries,
|
|
430
|
+
query=query,
|
|
431
|
+
top_k=top_k,
|
|
432
|
+
selection_threshold=selection_threshold,
|
|
433
|
+
dead_end_threshold=dead_end_threshold,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Call the LLM
|
|
437
|
+
try:
|
|
438
|
+
from rnsr.llm import get_llm
|
|
439
|
+
import json
|
|
440
|
+
|
|
441
|
+
llm = get_llm()
|
|
442
|
+
response = llm.complete(prompt)
|
|
443
|
+
response_text = str(response).strip()
|
|
444
|
+
|
|
445
|
+
# Parse JSON response with robust error handling
|
|
446
|
+
try:
|
|
447
|
+
# Handle potential markdown code blocks
|
|
448
|
+
if "```" in response_text:
|
|
449
|
+
match = re.search(r'\{[\s\S]*\}', response_text)
|
|
450
|
+
if match:
|
|
451
|
+
response_text = match.group(0)
|
|
452
|
+
|
|
453
|
+
result = json.loads(response_text)
|
|
454
|
+
except json.JSONDecodeError:
|
|
455
|
+
logger.warning("tot_json_repair_attempt", original_response=response_text)
|
|
456
|
+
# Fallback to asking the LLM to fix the JSON
|
|
457
|
+
repair_prompt = f"""The following text is a malformed JSON object. Please fix it and return ONLY the corrected JSON. Do not add any commentary.
|
|
458
|
+
|
|
459
|
+
Malformed JSON:
|
|
460
|
+
{response_text}"""
|
|
461
|
+
from rnsr.llm import get_llm
|
|
462
|
+
|
|
463
|
+
llm = get_llm()
|
|
464
|
+
repaired_response_text = str(llm.complete(repair_prompt)).strip()
|
|
465
|
+
|
|
466
|
+
# Final attempt to parse the repaired JSON
|
|
467
|
+
if "```" in repaired_response_text:
|
|
468
|
+
match = re.search(r'\{[\s\S]*\}', repaired_response_text)
|
|
469
|
+
if match:
|
|
470
|
+
repaired_response_text = match.group(0)
|
|
471
|
+
|
|
472
|
+
result = json.loads(repaired_response_text)
|
|
473
|
+
|
|
474
|
+
logger.debug(
|
|
475
|
+
"tot_evaluation_complete",
|
|
476
|
+
node_id=node_id,
|
|
477
|
+
evaluations=len(result.get("evaluations", [])),
|
|
478
|
+
selected=result.get("selected_nodes", []),
|
|
479
|
+
is_dead_end=result.get("is_dead_end", False),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
return result
|
|
483
|
+
|
|
484
|
+
except Exception as e:
|
|
485
|
+
logger.warning("tot_evaluation_failed", error=str(e), node_id=node_id)
|
|
486
|
+
|
|
487
|
+
# Fallback: use simple heuristic (select first unvisited children)
|
|
488
|
+
visited = state.get("visited_nodes", [])
|
|
489
|
+
dead_ends = state.get("dead_ends", [])
|
|
490
|
+
|
|
491
|
+
unvisited = [
|
|
492
|
+
cid for cid in node.child_ids
|
|
493
|
+
if cid not in visited and cid not in dead_ends
|
|
494
|
+
]
|
|
495
|
+
|
|
496
|
+
return {
|
|
497
|
+
"evaluations": [{"node_id": cid, "probability": 0.5, "reasoning": "Fallback selection"} for cid in unvisited[:top_k]],
|
|
498
|
+
"selected_nodes": unvisited[:top_k],
|
|
499
|
+
"is_dead_end": len(unvisited) == 0,
|
|
500
|
+
"backtrack_reason": "No unvisited children" if len(unvisited) == 0 else None,
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def backtrack_to_parent(
|
|
505
|
+
state: AgentState,
|
|
506
|
+
skeleton: dict[str, SkeletonNode],
|
|
507
|
+
) -> AgentState:
|
|
508
|
+
"""
|
|
509
|
+
Backtrack to parent node when current path is a dead end.
|
|
510
|
+
|
|
511
|
+
Implements the backtracking logic from Section 7.2:
|
|
512
|
+
"Backtrack: If a node yields no useful info, report 'Dead End' and return to parent."
|
|
513
|
+
"""
|
|
514
|
+
new_state = cast(AgentState, dict(state))
|
|
515
|
+
|
|
516
|
+
current_id = new_state["current_node_id"]
|
|
517
|
+
|
|
518
|
+
# Mark current node as dead end
|
|
519
|
+
if current_id and current_id not in new_state["dead_ends"]:
|
|
520
|
+
new_state["dead_ends"].append(current_id)
|
|
521
|
+
|
|
522
|
+
# Try to backtrack using the backtrack stack
|
|
523
|
+
if new_state["backtrack_stack"]:
|
|
524
|
+
parent_id = new_state["backtrack_stack"].pop()
|
|
525
|
+
new_state["current_node_id"] = parent_id
|
|
526
|
+
new_state["navigation_path"].append(parent_id)
|
|
527
|
+
|
|
528
|
+
add_trace_entry(
|
|
529
|
+
new_state,
|
|
530
|
+
"navigation",
|
|
531
|
+
f"Backtracked to {parent_id} (dead end at {current_id})",
|
|
532
|
+
{"from": current_id, "to": parent_id, "reason": "dead_end"},
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
logger.debug("backtrack_success", from_node=current_id, to_node=parent_id)
|
|
536
|
+
else:
|
|
537
|
+
# No parent to backtrack to - we're done exploring
|
|
538
|
+
add_trace_entry(
|
|
539
|
+
new_state,
|
|
540
|
+
"navigation",
|
|
541
|
+
"Cannot backtrack - at root or stack empty",
|
|
542
|
+
{"current": current_id},
|
|
543
|
+
)
|
|
544
|
+
logger.debug("backtrack_failed", reason="empty_stack")
|
|
545
|
+
|
|
546
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
547
|
+
return new_state
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
# =============================================================================
|
|
551
|
+
# LangGraph Node Functions
|
|
552
|
+
# =============================================================================
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
# Decomposition prompt for RLM recursive sub-task generation
|
|
556
|
+
DECOMPOSITION_PROMPT = """You are analyzing a complex query to decompose it into sub-tasks.
|
|
557
|
+
|
|
558
|
+
Query: {query}
|
|
559
|
+
|
|
560
|
+
Document Structure (top-level sections):
|
|
561
|
+
{structure}
|
|
562
|
+
|
|
563
|
+
TASK: Decompose this query into specific sub-tasks that can be executed independently.
|
|
564
|
+
|
|
565
|
+
RULES:
|
|
566
|
+
1. Each sub-task should target a specific section or piece of information
|
|
567
|
+
2. Sub-tasks should be answerable by reading specific document sections
|
|
568
|
+
3. For comparison queries, create one sub-task per item being compared
|
|
569
|
+
4. For multi-hop queries, create sequential sub-tasks (find A, then use A to find B)
|
|
570
|
+
5. Maximum 5 sub-tasks to maintain efficiency
|
|
571
|
+
|
|
572
|
+
OUTPUT FORMAT (JSON):
|
|
573
|
+
{{
|
|
574
|
+
"sub_tasks": [
|
|
575
|
+
{{"id": 1, "task": "Find X in section Y", "target_section": "section_hint"}},
|
|
576
|
+
{{"id": 2, "task": "Extract Z from W", "target_section": "section_hint"}}
|
|
577
|
+
],
|
|
578
|
+
"synthesis_plan": "How to combine sub-task results into final answer"
|
|
579
|
+
}}
|
|
580
|
+
|
|
581
|
+
Return ONLY valid JSON."""
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def decompose_query(state: AgentState) -> AgentState:
|
|
585
|
+
"""
|
|
586
|
+
Decompose the main question into sub-questions using LLM.
|
|
587
|
+
|
|
588
|
+
Implements Section 2.2 "The Recursive Loop":
|
|
589
|
+
"The model's capability to divide a complex reasoning task into
|
|
590
|
+
smaller, manageable sub-tasks and invoke instances of itself to solve them."
|
|
591
|
+
"""
|
|
592
|
+
new_state = cast(AgentState, dict(state)) # Copy
|
|
593
|
+
question = new_state["question"]
|
|
594
|
+
|
|
595
|
+
# Try LLM-based decomposition
|
|
596
|
+
try:
|
|
597
|
+
from rnsr.llm import get_llm
|
|
598
|
+
import json
|
|
599
|
+
|
|
600
|
+
llm = get_llm()
|
|
601
|
+
|
|
602
|
+
# Get document structure for context
|
|
603
|
+
# Note: skeleton is not available here, so we'll do basic decomposition
|
|
604
|
+
structure_hint = "Document sections available for navigation"
|
|
605
|
+
|
|
606
|
+
prompt = DECOMPOSITION_PROMPT.format(
|
|
607
|
+
query=question,
|
|
608
|
+
structure=structure_hint,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
response = llm.complete(prompt)
|
|
612
|
+
response_text = str(response).strip()
|
|
613
|
+
|
|
614
|
+
# Extract JSON from response
|
|
615
|
+
json_match = re.search(r'\{[\s\S]*\}', response_text)
|
|
616
|
+
if json_match:
|
|
617
|
+
decomposition = json.loads(json_match.group())
|
|
618
|
+
sub_tasks = decomposition.get("sub_tasks", [])
|
|
619
|
+
|
|
620
|
+
if sub_tasks:
|
|
621
|
+
sub_questions = [t["task"] for t in sub_tasks]
|
|
622
|
+
new_state["sub_questions"] = sub_questions
|
|
623
|
+
new_state["pending_questions"] = sub_questions.copy()
|
|
624
|
+
new_state["current_sub_question"] = sub_questions[0]
|
|
625
|
+
|
|
626
|
+
add_trace_entry(
|
|
627
|
+
new_state,
|
|
628
|
+
"decomposition",
|
|
629
|
+
f"LLM decomposed into {len(sub_questions)} sub-tasks",
|
|
630
|
+
{
|
|
631
|
+
"sub_questions": sub_questions,
|
|
632
|
+
"synthesis_plan": decomposition.get("synthesis_plan", ""),
|
|
633
|
+
},
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
637
|
+
return new_state
|
|
638
|
+
|
|
639
|
+
except Exception as e:
|
|
640
|
+
logger.warning("llm_decomposition_failed", error=str(e))
|
|
641
|
+
|
|
642
|
+
# Fallback: simple decomposition patterns
|
|
643
|
+
sub_questions = _simple_decompose(question)
|
|
644
|
+
new_state["sub_questions"] = sub_questions
|
|
645
|
+
new_state["pending_questions"] = sub_questions.copy()
|
|
646
|
+
new_state["current_sub_question"] = sub_questions[0] if sub_questions else question
|
|
647
|
+
|
|
648
|
+
add_trace_entry(
|
|
649
|
+
new_state,
|
|
650
|
+
"decomposition",
|
|
651
|
+
f"Decomposed into {len(sub_questions)} sub-questions (fallback)",
|
|
652
|
+
{"sub_questions": sub_questions},
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
656
|
+
return new_state
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def _simple_decompose(question: str) -> list[str]:
|
|
660
|
+
"""
|
|
661
|
+
Simple pattern-based decomposition fallback.
|
|
662
|
+
|
|
663
|
+
Handles common query patterns without LLM.
|
|
664
|
+
"""
|
|
665
|
+
question_lower = question.lower()
|
|
666
|
+
|
|
667
|
+
# Pattern: "Compare X and Y" or "X vs Y"
|
|
668
|
+
compare_patterns = [
|
|
669
|
+
r"compare\s+(.+?)\s+(?:and|with|to|vs\.?)\s+(.+)",
|
|
670
|
+
r"difference(?:s)?\s+between\s+(.+?)\s+and\s+(.+)",
|
|
671
|
+
r"(.+?)\s+vs\.?\s+(.+)",
|
|
672
|
+
]
|
|
673
|
+
|
|
674
|
+
for pattern in compare_patterns:
|
|
675
|
+
match = re.search(pattern, question_lower)
|
|
676
|
+
if match:
|
|
677
|
+
item1, item2 = match.groups()
|
|
678
|
+
return [
|
|
679
|
+
f"Find information about {item1.strip()}",
|
|
680
|
+
f"Find information about {item2.strip()}",
|
|
681
|
+
]
|
|
682
|
+
|
|
683
|
+
# Pattern: "What are the X in sections A, B, and C"
|
|
684
|
+
multi_section = re.search(
|
|
685
|
+
r"in\s+(?:sections?\s+)?(.+?,\s*.+?(?:,\s*.+)*)",
|
|
686
|
+
question_lower,
|
|
687
|
+
)
|
|
688
|
+
if multi_section:
|
|
689
|
+
sections = [s.strip() for s in multi_section.group(1).split(",")]
|
|
690
|
+
return [f"Find relevant information in {s}" for s in sections[:5]]
|
|
691
|
+
|
|
692
|
+
# Pattern: "List all X" or "Find all X" - may need iteration
|
|
693
|
+
if re.search(r"(list|find|show)\s+all", question_lower):
|
|
694
|
+
return [
|
|
695
|
+
f"Search for all relevant sections",
|
|
696
|
+
question, # Original as synthesis query
|
|
697
|
+
]
|
|
698
|
+
|
|
699
|
+
# Default: single question
|
|
700
|
+
return [question]
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def navigate_tree(
|
|
704
|
+
state: AgentState,
|
|
705
|
+
skeleton: dict[str, SkeletonNode],
|
|
706
|
+
) -> AgentState:
|
|
707
|
+
"""
|
|
708
|
+
Navigate the document tree based on current sub-question.
|
|
709
|
+
|
|
710
|
+
This function now handles the node visitation queue for multi-path exploration.
|
|
711
|
+
"""
|
|
712
|
+
new_state = cast(AgentState, dict(state))
|
|
713
|
+
|
|
714
|
+
# Add detailed logging to debug navigation loops
|
|
715
|
+
logger.debug(
|
|
716
|
+
"navigate_step_start",
|
|
717
|
+
iteration=new_state["iteration"],
|
|
718
|
+
current_node=new_state["current_node_id"],
|
|
719
|
+
queue=new_state["nodes_to_visit"],
|
|
720
|
+
visited=new_state["visited_nodes"],
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# If current node is None, try to pop from the visit queue
|
|
724
|
+
if new_state["current_node_id"] is None and new_state["nodes_to_visit"]:
|
|
725
|
+
next_node_id = new_state["nodes_to_visit"].pop(0)
|
|
726
|
+
new_state["current_node_id"] = next_node_id
|
|
727
|
+
|
|
728
|
+
# Also add to navigation path
|
|
729
|
+
if next_node_id not in new_state["navigation_path"]:
|
|
730
|
+
new_state["navigation_path"].append(next_node_id)
|
|
731
|
+
|
|
732
|
+
add_trace_entry(
|
|
733
|
+
new_state,
|
|
734
|
+
"navigation",
|
|
735
|
+
f"Visiting next node from queue: {next_node_id}",
|
|
736
|
+
{"queue_size": len(new_state["nodes_to_visit"])},
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
node_id = new_state["current_node_id"]
|
|
740
|
+
if node_id is None:
|
|
741
|
+
add_trace_entry(new_state, "navigation", "No current node to navigate to and queue is empty")
|
|
742
|
+
return new_state
|
|
743
|
+
|
|
744
|
+
node = skeleton.get(node_id)
|
|
745
|
+
|
|
746
|
+
if node is None:
|
|
747
|
+
add_trace_entry(new_state, "navigation", f"Node {node_id} not found")
|
|
748
|
+
return new_state
|
|
749
|
+
|
|
750
|
+
# Don't add to visited here - let expand/traverse handle that
|
|
751
|
+
# to avoid preventing expansion of newly queued nodes
|
|
752
|
+
|
|
753
|
+
add_trace_entry(
|
|
754
|
+
new_state,
|
|
755
|
+
"navigation",
|
|
756
|
+
f"At node {node_id}: {node.header}",
|
|
757
|
+
{"summary": node.summary, "children": node.child_ids},
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
761
|
+
return new_state
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def should_expand(
|
|
765
|
+
state: AgentState,
|
|
766
|
+
skeleton: dict[str, SkeletonNode],
|
|
767
|
+
) -> Literal["expand", "traverse", "backtrack", "done"]:
|
|
768
|
+
"""
|
|
769
|
+
Decide whether to expand (fetch content), traverse (go to children),
|
|
770
|
+
backtrack (dead end), or finish.
|
|
771
|
+
|
|
772
|
+
Uses Tree of Thoughts evaluation to make intelligent decisions.
|
|
773
|
+
"""
|
|
774
|
+
node_id = state.get("current_node_id")
|
|
775
|
+
if node_id is None:
|
|
776
|
+
return "done"
|
|
777
|
+
|
|
778
|
+
node = skeleton.get(node_id)
|
|
779
|
+
if node is None:
|
|
780
|
+
return "done"
|
|
781
|
+
|
|
782
|
+
# Check iteration limit
|
|
783
|
+
if state.get("iteration", 0) >= state.get("max_iterations", 20):
|
|
784
|
+
return "done"
|
|
785
|
+
|
|
786
|
+
# Removed variable count limit - iteration limit is sufficient
|
|
787
|
+
|
|
788
|
+
# If no children, must expand (leaf node)
|
|
789
|
+
# But only if not already visited
|
|
790
|
+
visited = state.get("visited_nodes", [])
|
|
791
|
+
if not node.child_ids:
|
|
792
|
+
if node_id in visited:
|
|
793
|
+
# Already expanded this leaf, done with it
|
|
794
|
+
return "done"
|
|
795
|
+
return "expand"
|
|
796
|
+
|
|
797
|
+
# Check if all children are visited or dead ends
|
|
798
|
+
visited = state.get("visited_nodes", [])
|
|
799
|
+
dead_ends = state.get("dead_ends", [])
|
|
800
|
+
unvisited_children = [
|
|
801
|
+
cid for cid in node.child_ids
|
|
802
|
+
if cid not in visited and cid not in dead_ends
|
|
803
|
+
]
|
|
804
|
+
|
|
805
|
+
if not unvisited_children:
|
|
806
|
+
# All children explored - backtrack or done
|
|
807
|
+
if state.get("backtrack_stack"):
|
|
808
|
+
return "backtrack"
|
|
809
|
+
return "done"
|
|
810
|
+
|
|
811
|
+
# Adaptive exploration: increase top_k if we haven't found much information yet
|
|
812
|
+
base_top_k = state.get("top_k", 3)
|
|
813
|
+
variables_found = len(state.get("variables", []))
|
|
814
|
+
iteration = state.get("iteration", 0)
|
|
815
|
+
|
|
816
|
+
if iteration > 5 and variables_found == 0:
|
|
817
|
+
adaptive_top_k = min(len(node.child_ids), base_top_k * 3)
|
|
818
|
+
elif iteration > 3 and variables_found < 2:
|
|
819
|
+
adaptive_top_k = min(len(node.child_ids), base_top_k * 2)
|
|
820
|
+
else:
|
|
821
|
+
adaptive_top_k = base_top_k
|
|
822
|
+
|
|
823
|
+
# Use ToT evaluation to decide
|
|
824
|
+
tot_result = evaluate_children_with_tot(state, skeleton, top_k_override=adaptive_top_k)
|
|
825
|
+
|
|
826
|
+
# Check for dead end signal from ToT
|
|
827
|
+
if tot_result.get("is_dead_end", False):
|
|
828
|
+
if state.get("backtrack_stack"):
|
|
829
|
+
return "backtrack"
|
|
830
|
+
return "done"
|
|
831
|
+
|
|
832
|
+
# Check if ToT says this is a leaf (should expand)
|
|
833
|
+
if tot_result.get("is_leaf", False):
|
|
834
|
+
return "expand"
|
|
835
|
+
|
|
836
|
+
# If ToT selected nodes, traverse
|
|
837
|
+
if tot_result.get("selected_nodes"):
|
|
838
|
+
return "traverse"
|
|
839
|
+
|
|
840
|
+
# Fallback: expand if at deep level (H3+)
|
|
841
|
+
if node.level >= 3:
|
|
842
|
+
return "expand"
|
|
843
|
+
|
|
844
|
+
# Otherwise traverse to children
|
|
845
|
+
return "traverse"
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def expand_current_node(
|
|
849
|
+
state: AgentState,
|
|
850
|
+
skeleton: dict[str, SkeletonNode],
|
|
851
|
+
kv_store: KVStore,
|
|
852
|
+
variable_store: VariableStore,
|
|
853
|
+
) -> AgentState:
|
|
854
|
+
"""
|
|
855
|
+
EXPAND: Fetch full content and store as variable.
|
|
856
|
+
"""
|
|
857
|
+
new_state = cast(AgentState, dict(state))
|
|
858
|
+
|
|
859
|
+
node_id = new_state["current_node_id"]
|
|
860
|
+
if node_id is None:
|
|
861
|
+
return new_state
|
|
862
|
+
|
|
863
|
+
# Check if already visited to prevent infinite loops
|
|
864
|
+
if node_id in new_state.get("visited_nodes", []):
|
|
865
|
+
add_trace_entry(
|
|
866
|
+
new_state,
|
|
867
|
+
"navigation",
|
|
868
|
+
f"Skipping already-visited node {node_id}",
|
|
869
|
+
{"node_id": node_id},
|
|
870
|
+
)
|
|
871
|
+
new_state["current_node_id"] = None
|
|
872
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
873
|
+
return new_state
|
|
874
|
+
|
|
875
|
+
node = skeleton.get(node_id)
|
|
876
|
+
|
|
877
|
+
if node is None:
|
|
878
|
+
new_state["current_node_id"] = None
|
|
879
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
880
|
+
return new_state
|
|
881
|
+
|
|
882
|
+
# Fetch full content
|
|
883
|
+
content = kv_store.get(node_id)
|
|
884
|
+
if content is None:
|
|
885
|
+
add_trace_entry(new_state, "variable_stitching", f"No content for {node_id}")
|
|
886
|
+
new_state["current_node_id"] = None
|
|
887
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
888
|
+
return new_state
|
|
889
|
+
|
|
890
|
+
# Generate and store as variable
|
|
891
|
+
pointer = generate_pointer_name(node.header)
|
|
892
|
+
variable_store.assign(pointer, content, node_id)
|
|
893
|
+
|
|
894
|
+
new_state["variables"].append(pointer)
|
|
895
|
+
new_state["context"] += f"\nFound: {pointer} (from {node.header})"
|
|
896
|
+
|
|
897
|
+
# Mark node as visited and clear current node
|
|
898
|
+
if node_id not in new_state["visited_nodes"]:
|
|
899
|
+
new_state["visited_nodes"].append(node_id)
|
|
900
|
+
new_state["current_node_id"] = None # Process queue next
|
|
901
|
+
|
|
902
|
+
add_trace_entry(
|
|
903
|
+
new_state,
|
|
904
|
+
"variable_stitching",
|
|
905
|
+
f"Stored {pointer}",
|
|
906
|
+
{"node_id": node_id, "header": node.header, "chars": len(content)},
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
910
|
+
return new_state
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
def traverse_to_children(
|
|
914
|
+
state: AgentState,
|
|
915
|
+
skeleton: dict[str, SkeletonNode],
|
|
916
|
+
semantic_searcher: SemanticSearcher | None = None,
|
|
917
|
+
) -> AgentState:
|
|
918
|
+
"""
|
|
919
|
+
TRAVERSE: Navigate to child nodes using Tree of Thoughts reasoning.
|
|
920
|
+
|
|
921
|
+
Primary Strategy: Tree of Thoughts (ToT) with adaptive exploration.
|
|
922
|
+
- Uses LLM reasoning to evaluate child nodes based on summaries
|
|
923
|
+
- Dynamically increases exploration when insufficient variables found
|
|
924
|
+
- Preserves document structure and context (per research paper Section 7.2)
|
|
925
|
+
|
|
926
|
+
Optional: Semantic search as shortcut for finding entry points (Section 9.1)
|
|
927
|
+
"""
|
|
928
|
+
new_state = cast(AgentState, dict(state))
|
|
929
|
+
|
|
930
|
+
node_id = new_state["current_node_id"]
|
|
931
|
+
if node_id is None:
|
|
932
|
+
return new_state
|
|
933
|
+
|
|
934
|
+
node = skeleton.get(node_id)
|
|
935
|
+
|
|
936
|
+
if node is None or not node.child_ids:
|
|
937
|
+
return new_state
|
|
938
|
+
|
|
939
|
+
question = new_state["question"]
|
|
940
|
+
base_top_k = new_state.get("top_k", 3)
|
|
941
|
+
|
|
942
|
+
# Adaptive exploration: increase top_k if we haven't found much information yet
|
|
943
|
+
# This addresses the original issue: "agent should be able to explore all nodes if it needs to"
|
|
944
|
+
variables_found = len(new_state.get("variables", []))
|
|
945
|
+
iteration = new_state.get("iteration", 0)
|
|
946
|
+
|
|
947
|
+
# Dynamic top_k based on progress
|
|
948
|
+
if iteration > 5 and variables_found == 0:
|
|
949
|
+
# Not finding anything - expand search significantly
|
|
950
|
+
adaptive_top_k = min(len(node.child_ids), base_top_k * 3)
|
|
951
|
+
add_trace_entry(
|
|
952
|
+
new_state,
|
|
953
|
+
"navigation",
|
|
954
|
+
f"Expanding search: {variables_found} variables after {iteration} iterations",
|
|
955
|
+
{"base_top_k": base_top_k, "adaptive_top_k": adaptive_top_k},
|
|
956
|
+
)
|
|
957
|
+
elif iteration > 3 and variables_found < 2:
|
|
958
|
+
# Finding very little - expand moderately
|
|
959
|
+
adaptive_top_k = min(len(node.child_ids), base_top_k * 2)
|
|
960
|
+
else:
|
|
961
|
+
# Normal exploration
|
|
962
|
+
adaptive_top_k = base_top_k
|
|
963
|
+
|
|
964
|
+
# Use Tree of Thoughts as primary navigation method (per research paper)
|
|
965
|
+
tot_result = evaluate_children_with_tot(new_state, skeleton, top_k_override=adaptive_top_k)
|
|
966
|
+
|
|
967
|
+
# Check for dead end
|
|
968
|
+
if tot_result.get("is_dead_end", False):
|
|
969
|
+
add_trace_entry(
|
|
970
|
+
new_state,
|
|
971
|
+
"navigation",
|
|
972
|
+
f"Dead end detected at {node_id}: {tot_result.get('backtrack_reason', 'Unknown')}",
|
|
973
|
+
{"node_id": node_id, "backtrack_reason": tot_result.get("backtrack_reason")},
|
|
974
|
+
)
|
|
975
|
+
# Mark as dead end and don't change current node
|
|
976
|
+
if node_id not in new_state["dead_ends"]:
|
|
977
|
+
new_state["dead_ends"].append(node_id)
|
|
978
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
979
|
+
return new_state
|
|
980
|
+
|
|
981
|
+
selected_nodes = tot_result.get("selected_nodes", [])
|
|
982
|
+
evaluations = tot_result.get("evaluations", [])
|
|
983
|
+
new_state["scored_candidates"] = evaluations
|
|
984
|
+
|
|
985
|
+
# Queue selected children for visitation
|
|
986
|
+
if selected_nodes:
|
|
987
|
+
new_state["nodes_to_visit"].extend(selected_nodes)
|
|
988
|
+
# Mark parent as visited since we've traversed its children
|
|
989
|
+
if node_id not in new_state["visited_nodes"]:
|
|
990
|
+
new_state["visited_nodes"].append(node_id)
|
|
991
|
+
# Unset current node to force the next navigate step to pull from the queue
|
|
992
|
+
new_state["current_node_id"] = None
|
|
993
|
+
add_trace_entry(
|
|
994
|
+
new_state,
|
|
995
|
+
"navigation",
|
|
996
|
+
f"Queued {len(selected_nodes)} children from {node_id}",
|
|
997
|
+
{"nodes": selected_nodes, "parent": node_id},
|
|
998
|
+
)
|
|
999
|
+
else:
|
|
1000
|
+
# No candidates available - this should trigger backtracking
|
|
1001
|
+
add_trace_entry(
|
|
1002
|
+
new_state,
|
|
1003
|
+
"navigation",
|
|
1004
|
+
f"No unvisited candidates at {node_id}",
|
|
1005
|
+
{"node_id": node_id},
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
1009
|
+
return new_state
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
def synthesize_answer(
|
|
1013
|
+
state: AgentState,
|
|
1014
|
+
variable_store: VariableStore,
|
|
1015
|
+
) -> AgentState:
|
|
1016
|
+
"""
|
|
1017
|
+
Synthesize final answer from stored variables using an LLM.
|
|
1018
|
+
|
|
1019
|
+
Uses the configured LLM to generate a concise answer
|
|
1020
|
+
from the resolved variable content.
|
|
1021
|
+
"""
|
|
1022
|
+
new_state = cast(AgentState, dict(state))
|
|
1023
|
+
|
|
1024
|
+
# Resolve all variables
|
|
1025
|
+
pointers = new_state["variables"]
|
|
1026
|
+
|
|
1027
|
+
if not pointers:
|
|
1028
|
+
new_state["answer"] = "No relevant content found."
|
|
1029
|
+
new_state["confidence"] = 0.0
|
|
1030
|
+
else:
|
|
1031
|
+
# Collect all content
|
|
1032
|
+
contents = []
|
|
1033
|
+
for pointer in pointers:
|
|
1034
|
+
content = variable_store.resolve(pointer)
|
|
1035
|
+
if content:
|
|
1036
|
+
contents.append(content)
|
|
1037
|
+
|
|
1038
|
+
context_text = "\n\n---\n\n".join(contents)
|
|
1039
|
+
question = new_state["question"]
|
|
1040
|
+
|
|
1041
|
+
# Use LLM to synthesize answer
|
|
1042
|
+
try:
|
|
1043
|
+
from rnsr.llm import get_llm
|
|
1044
|
+
|
|
1045
|
+
llm = get_llm()
|
|
1046
|
+
|
|
1047
|
+
metadata = new_state.get("metadata", {})
|
|
1048
|
+
options = metadata.get("options")
|
|
1049
|
+
|
|
1050
|
+
if options and isinstance(options, list) and len(options) > 0:
|
|
1051
|
+
# QuALITY multiple-choice question
|
|
1052
|
+
# Ensure each option is a string and format with letters
|
|
1053
|
+
options_text = "\n".join([f"{chr(65+i)}. {str(opt)}" for i, opt in enumerate(options)])
|
|
1054
|
+
prompt = f"""Based on the provided context, answer this multiple-choice question.
|
|
1055
|
+
|
|
1056
|
+
Question: {question}
|
|
1057
|
+
|
|
1058
|
+
Options:
|
|
1059
|
+
{options_text}
|
|
1060
|
+
|
|
1061
|
+
Context:
|
|
1062
|
+
{context_text}
|
|
1063
|
+
|
|
1064
|
+
Instructions:
|
|
1065
|
+
1. Read the context carefully
|
|
1066
|
+
2. Determine which option is best supported by the evidence in the context
|
|
1067
|
+
3. Respond with ONLY the letter and full text of the correct option
|
|
1068
|
+
4. Format: "X. [complete option text]" where X is A, B, C, or D
|
|
1069
|
+
|
|
1070
|
+
Your answer:"""
|
|
1071
|
+
else:
|
|
1072
|
+
# Standard open-ended question
|
|
1073
|
+
prompt = f"""Based on the following context, answer the question concisely.
|
|
1074
|
+
|
|
1075
|
+
IMPORTANT INSTRUCTIONS:
|
|
1076
|
+
- If the context contains information that DIRECTLY answers the question, provide that answer
|
|
1077
|
+
- If the context contains information that allows you to INFER the answer, provide your best inference
|
|
1078
|
+
- Use evidence from the context to support your answer
|
|
1079
|
+
- Only say "Cannot determine from available context" if the context is completely unrelated or missing critical information
|
|
1080
|
+
- It's better to provide a reasonable answer based on available evidence than to say "cannot determine"
|
|
1081
|
+
|
|
1082
|
+
Question: {question}
|
|
1083
|
+
|
|
1084
|
+
Context:
|
|
1085
|
+
{context_text}
|
|
1086
|
+
|
|
1087
|
+
Answer (be concise, direct, and confident):"""
|
|
1088
|
+
|
|
1089
|
+
response = llm.complete(prompt)
|
|
1090
|
+
answer = str(response).strip()
|
|
1091
|
+
|
|
1092
|
+
# Normalize multiple-choice answers
|
|
1093
|
+
if options:
|
|
1094
|
+
answer_lower = answer.lower().strip()
|
|
1095
|
+
matched = False
|
|
1096
|
+
|
|
1097
|
+
for i, opt in enumerate(options):
|
|
1098
|
+
letter = chr(65 + i) # A, B, C, D
|
|
1099
|
+
opt_lower = opt.lower()
|
|
1100
|
+
|
|
1101
|
+
# Match patterns: "A.", "A)", "(A)", "A. answer text"
|
|
1102
|
+
if (answer_lower.startswith(f"{letter.lower()}.") or
|
|
1103
|
+
answer_lower.startswith(f"{letter.lower()})") or
|
|
1104
|
+
answer_lower.startswith(f"({letter.lower()})")):
|
|
1105
|
+
answer = opt
|
|
1106
|
+
matched = True
|
|
1107
|
+
break
|
|
1108
|
+
|
|
1109
|
+
# Exact match (case insensitive)
|
|
1110
|
+
if answer_lower == opt_lower:
|
|
1111
|
+
answer = opt
|
|
1112
|
+
matched = True
|
|
1113
|
+
break
|
|
1114
|
+
|
|
1115
|
+
# Check if option text is contained in answer
|
|
1116
|
+
if opt_lower in answer_lower or answer_lower in opt_lower:
|
|
1117
|
+
# Use similarity to pick best match if multiple partial matches
|
|
1118
|
+
answer = opt
|
|
1119
|
+
matched = True
|
|
1120
|
+
break
|
|
1121
|
+
|
|
1122
|
+
new_state["confidence"] = 0.8 if matched else 0.1
|
|
1123
|
+
else:
|
|
1124
|
+
new_state["confidence"] = min(1.0, len(pointers) * 0.3)
|
|
1125
|
+
|
|
1126
|
+
new_state["answer"] = answer
|
|
1127
|
+
|
|
1128
|
+
except Exception as e:
|
|
1129
|
+
logger.warning("llm_synthesis_failed", error=str(e))
|
|
1130
|
+
# Fallback to context concatenation
|
|
1131
|
+
new_state["answer"] = f"Context found:\n{context_text}"
|
|
1132
|
+
new_state["confidence"] = min(1.0, len(pointers) * 0.2)
|
|
1133
|
+
|
|
1134
|
+
add_trace_entry(
|
|
1135
|
+
new_state,
|
|
1136
|
+
"synthesis",
|
|
1137
|
+
"Generated answer from variables",
|
|
1138
|
+
{"variables_used": pointers, "confidence": new_state["confidence"]},
|
|
1139
|
+
)
|
|
1140
|
+
|
|
1141
|
+
return new_state
|
|
1142
|
+
|
|
1143
|
+
|
|
1144
|
+
# =============================================================================
|
|
1145
|
+
# RLM Recursive Execution Functions (Section 2.2)
|
|
1146
|
+
# =============================================================================
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
def execute_sub_task_with_llm(
|
|
1150
|
+
sub_task: str,
|
|
1151
|
+
context: str,
|
|
1152
|
+
llm_fn: Any = None,
|
|
1153
|
+
) -> str:
|
|
1154
|
+
"""
|
|
1155
|
+
Execute a single sub-task using an LLM.
|
|
1156
|
+
|
|
1157
|
+
Implements the recursive LLM pattern from Section 2.2:
|
|
1158
|
+
"For each contract, call the LLM API (invoking itself) with a
|
|
1159
|
+
specific sub-prompt: 'Extract liability clause from this text.'"
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
sub_task: The sub-task description/prompt.
|
|
1163
|
+
context: The context to process.
|
|
1164
|
+
llm_fn: Optional LLM function. If None, uses default.
|
|
1165
|
+
|
|
1166
|
+
Returns:
|
|
1167
|
+
LLM response as string.
|
|
1168
|
+
"""
|
|
1169
|
+
if llm_fn is None:
|
|
1170
|
+
try:
|
|
1171
|
+
from rnsr.llm import get_llm
|
|
1172
|
+
llm = get_llm()
|
|
1173
|
+
llm_fn = lambda p: str(llm.complete(p))
|
|
1174
|
+
except Exception as e:
|
|
1175
|
+
logger.warning("llm_not_available", error=str(e))
|
|
1176
|
+
return f"[Error: LLM not available - {str(e)}]"
|
|
1177
|
+
|
|
1178
|
+
prompt = f"""{sub_task}
|
|
1179
|
+
|
|
1180
|
+
Context:
|
|
1181
|
+
{context}
|
|
1182
|
+
|
|
1183
|
+
Response:"""
|
|
1184
|
+
|
|
1185
|
+
try:
|
|
1186
|
+
return llm_fn(prompt)
|
|
1187
|
+
except Exception as e:
|
|
1188
|
+
logger.error("sub_task_execution_failed", error=str(e))
|
|
1189
|
+
return f"[Error: {str(e)}]"
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
def batch_execute_sub_tasks(
|
|
1193
|
+
sub_tasks: list[str],
|
|
1194
|
+
contexts: list[str],
|
|
1195
|
+
batch_size: int = 5,
|
|
1196
|
+
max_parallel: int = 4,
|
|
1197
|
+
) -> list[str]:
|
|
1198
|
+
"""
|
|
1199
|
+
Execute multiple sub-tasks in parallel batches.
|
|
1200
|
+
|
|
1201
|
+
Implements Section 2.3 "Optimization via Batching":
|
|
1202
|
+
"Instead of making 1,000 individual calls to summarize 1,000 paragraphs,
|
|
1203
|
+
the RLM writes code to group paragraphs into chunks of 5 and processes
|
|
1204
|
+
them in parallel threads."
|
|
1205
|
+
|
|
1206
|
+
Args:
|
|
1207
|
+
sub_tasks: List of sub-task prompts.
|
|
1208
|
+
contexts: List of contexts for each sub-task.
|
|
1209
|
+
batch_size: Items per batch (default 5).
|
|
1210
|
+
max_parallel: Max parallel threads (default 4).
|
|
1211
|
+
|
|
1212
|
+
Returns:
|
|
1213
|
+
List of results for each sub-task.
|
|
1214
|
+
"""
|
|
1215
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
1216
|
+
|
|
1217
|
+
if len(sub_tasks) != len(contexts):
|
|
1218
|
+
raise ValueError("sub_tasks and contexts must have same length")
|
|
1219
|
+
|
|
1220
|
+
if not sub_tasks:
|
|
1221
|
+
return []
|
|
1222
|
+
|
|
1223
|
+
logger.info(
|
|
1224
|
+
"batch_execution_start",
|
|
1225
|
+
num_tasks=len(sub_tasks),
|
|
1226
|
+
batch_size=batch_size,
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
# Get LLM function
|
|
1230
|
+
try:
|
|
1231
|
+
from rnsr.llm import get_llm
|
|
1232
|
+
llm = get_llm()
|
|
1233
|
+
llm_fn = lambda p: str(llm.complete(p))
|
|
1234
|
+
except Exception as e:
|
|
1235
|
+
logger.error("batch_llm_failed", error=str(e))
|
|
1236
|
+
return [f"[Error: {str(e)}]"] * len(sub_tasks)
|
|
1237
|
+
|
|
1238
|
+
results: list[str] = [""] * len(sub_tasks)
|
|
1239
|
+
|
|
1240
|
+
with ThreadPoolExecutor(max_workers=max_parallel) as executor:
|
|
1241
|
+
future_to_idx: dict[Any, int] = {}
|
|
1242
|
+
|
|
1243
|
+
for idx, (task, ctx) in enumerate(zip(sub_tasks, contexts)):
|
|
1244
|
+
future = executor.submit(execute_sub_task_with_llm, task, ctx, llm_fn)
|
|
1245
|
+
future_to_idx[future] = idx
|
|
1246
|
+
|
|
1247
|
+
for future in as_completed(future_to_idx):
|
|
1248
|
+
idx = future_to_idx[future]
|
|
1249
|
+
try:
|
|
1250
|
+
results[idx] = future.result(timeout=120)
|
|
1251
|
+
except Exception as e:
|
|
1252
|
+
results[idx] = f"[Error: {str(e)}]"
|
|
1253
|
+
|
|
1254
|
+
logger.info("batch_execution_complete", num_results=len(results))
|
|
1255
|
+
return results
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
def process_pending_questions(
|
|
1259
|
+
state: AgentState,
|
|
1260
|
+
skeleton: dict[str, SkeletonNode],
|
|
1261
|
+
kv_store: KVStore,
|
|
1262
|
+
variable_store: VariableStore,
|
|
1263
|
+
) -> AgentState:
|
|
1264
|
+
"""
|
|
1265
|
+
Process all pending sub-questions using recursive LLM calls.
|
|
1266
|
+
|
|
1267
|
+
This implements the full RLM recursive loop:
|
|
1268
|
+
1. For each pending question, find relevant context
|
|
1269
|
+
2. Invoke sub-LLM to extract answer
|
|
1270
|
+
3. Store result as variable
|
|
1271
|
+
4. Move to next question
|
|
1272
|
+
"""
|
|
1273
|
+
new_state = cast(AgentState, dict(state))
|
|
1274
|
+
pending = new_state.get("pending_questions", [])
|
|
1275
|
+
|
|
1276
|
+
if not pending:
|
|
1277
|
+
return new_state
|
|
1278
|
+
|
|
1279
|
+
current_question = pending[0]
|
|
1280
|
+
|
|
1281
|
+
# Find relevant content for this question
|
|
1282
|
+
# Use current node and its children
|
|
1283
|
+
node_id = new_state.get("current_node_id") or "root"
|
|
1284
|
+
node = skeleton.get(node_id)
|
|
1285
|
+
|
|
1286
|
+
if node is None:
|
|
1287
|
+
# Pop and continue
|
|
1288
|
+
new_state["pending_questions"] = pending[1:]
|
|
1289
|
+
return new_state
|
|
1290
|
+
|
|
1291
|
+
# Get context from current node and children
|
|
1292
|
+
context_parts = []
|
|
1293
|
+
|
|
1294
|
+
# Add current node content
|
|
1295
|
+
current_content = kv_store.get(node_id) if node_id else None
|
|
1296
|
+
if current_content:
|
|
1297
|
+
context_parts.append(f"[{node.header}]\n{current_content}")
|
|
1298
|
+
|
|
1299
|
+
# Add children summaries
|
|
1300
|
+
for child_id in node.child_ids[:5]: # Limit to 5 children
|
|
1301
|
+
child = skeleton.get(child_id)
|
|
1302
|
+
if child:
|
|
1303
|
+
child_content = kv_store.get(child_id)
|
|
1304
|
+
if child_content:
|
|
1305
|
+
context_parts.append(f"[{child.header}]\n{child_content[:2000]}")
|
|
1306
|
+
|
|
1307
|
+
context = "\n\n---\n\n".join(context_parts)
|
|
1308
|
+
|
|
1309
|
+
# Execute sub-task with LLM
|
|
1310
|
+
result = execute_sub_task_with_llm(
|
|
1311
|
+
sub_task=f"Answer this question: {current_question}",
|
|
1312
|
+
context=context,
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1315
|
+
# Store as variable
|
|
1316
|
+
pointer = generate_pointer_name(current_question[:30])
|
|
1317
|
+
variable_store.assign(pointer, result, node_id or "root")
|
|
1318
|
+
new_state["variables"].append(pointer)
|
|
1319
|
+
|
|
1320
|
+
add_trace_entry(
|
|
1321
|
+
new_state,
|
|
1322
|
+
"decomposition",
|
|
1323
|
+
f"Processed sub-question via recursive LLM",
|
|
1324
|
+
{
|
|
1325
|
+
"question": current_question,
|
|
1326
|
+
"pointer": pointer,
|
|
1327
|
+
"context_length": len(context),
|
|
1328
|
+
},
|
|
1329
|
+
)
|
|
1330
|
+
|
|
1331
|
+
# Pop processed question
|
|
1332
|
+
new_state["pending_questions"] = pending[1:]
|
|
1333
|
+
if pending[1:]:
|
|
1334
|
+
new_state["current_sub_question"] = pending[1]
|
|
1335
|
+
|
|
1336
|
+
new_state["iteration"] = new_state["iteration"] + 1
|
|
1337
|
+
return new_state
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
# =============================================================================
|
|
1341
|
+
# Graph Builder (LangGraph)
|
|
1342
|
+
# =============================================================================
|
|
1343
|
+
|
|
1344
|
+
|
|
1345
|
+
def build_navigator_graph(
|
|
1346
|
+
skeleton: dict[str, SkeletonNode],
|
|
1347
|
+
kv_store: KVStore,
|
|
1348
|
+
semantic_searcher: SemanticSearcher | None = None,
|
|
1349
|
+
) -> Any:
|
|
1350
|
+
"""
|
|
1351
|
+
Build the LangGraph state machine for document navigation.
|
|
1352
|
+
|
|
1353
|
+
Implements Tree of Thoughts (ToT) navigation with backtracking support.
|
|
1354
|
+
|
|
1355
|
+
Returns a compiled graph that can be invoked with a question.
|
|
1356
|
+
|
|
1357
|
+
Usage:
|
|
1358
|
+
graph = build_navigator_graph(skeleton, kv_store)
|
|
1359
|
+
result = graph.invoke({"question": "What are the payment terms?"})
|
|
1360
|
+
"""
|
|
1361
|
+
try:
|
|
1362
|
+
from langgraph.graph import END, StateGraph
|
|
1363
|
+
except ImportError:
|
|
1364
|
+
raise ImportError(
|
|
1365
|
+
"LangGraph not installed. Install with: pip install langgraph"
|
|
1366
|
+
)
|
|
1367
|
+
|
|
1368
|
+
variable_store = VariableStore()
|
|
1369
|
+
|
|
1370
|
+
# Create graph
|
|
1371
|
+
graph = StateGraph(AgentState)
|
|
1372
|
+
|
|
1373
|
+
# Add nodes
|
|
1374
|
+
graph.add_node("decompose", decompose_query)
|
|
1375
|
+
|
|
1376
|
+
graph.add_node(
|
|
1377
|
+
"navigate",
|
|
1378
|
+
lambda state: navigate_tree(cast(AgentState, state), skeleton),
|
|
1379
|
+
)
|
|
1380
|
+
|
|
1381
|
+
graph.add_node(
|
|
1382
|
+
"expand",
|
|
1383
|
+
lambda state: expand_current_node(cast(AgentState, state), skeleton, kv_store, variable_store),
|
|
1384
|
+
)
|
|
1385
|
+
|
|
1386
|
+
graph.add_node(
|
|
1387
|
+
"traverse",
|
|
1388
|
+
lambda state: traverse_to_children(
|
|
1389
|
+
cast(AgentState, state),
|
|
1390
|
+
skeleton,
|
|
1391
|
+
semantic_searcher,
|
|
1392
|
+
),
|
|
1393
|
+
)
|
|
1394
|
+
|
|
1395
|
+
# Add backtrack node for ToT dead-end handling
|
|
1396
|
+
graph.add_node(
|
|
1397
|
+
"backtrack",
|
|
1398
|
+
lambda state: backtrack_to_parent(cast(AgentState, state), skeleton),
|
|
1399
|
+
)
|
|
1400
|
+
|
|
1401
|
+
graph.add_node(
|
|
1402
|
+
"synthesize",
|
|
1403
|
+
lambda state: synthesize_answer(cast(AgentState, state), variable_store),
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
# Add edges
|
|
1407
|
+
graph.add_edge("decompose", "navigate")
|
|
1408
|
+
|
|
1409
|
+
# Conditional edge based on expand/traverse/backtrack decision
|
|
1410
|
+
graph.add_conditional_edges(
|
|
1411
|
+
"navigate",
|
|
1412
|
+
lambda s: should_expand(s, skeleton),
|
|
1413
|
+
{
|
|
1414
|
+
"expand": "expand",
|
|
1415
|
+
"traverse": "traverse",
|
|
1416
|
+
"backtrack": "backtrack",
|
|
1417
|
+
"done": "synthesize",
|
|
1418
|
+
},
|
|
1419
|
+
)
|
|
1420
|
+
|
|
1421
|
+
# After expand, traverse, or backtrack, always return to the main navigation
|
|
1422
|
+
# handler, which will decide what to do next (e.g., pull from queue)
|
|
1423
|
+
graph.add_edge("expand", "navigate")
|
|
1424
|
+
graph.add_edge("traverse", "navigate")
|
|
1425
|
+
graph.add_edge("backtrack", "navigate")
|
|
1426
|
+
|
|
1427
|
+
graph.add_edge("synthesize", END)
|
|
1428
|
+
|
|
1429
|
+
# Set entry point
|
|
1430
|
+
graph.set_entry_point("decompose")
|
|
1431
|
+
|
|
1432
|
+
logger.info("navigator_graph_built", features=["ToT", "backtracking"])
|
|
1433
|
+
|
|
1434
|
+
return graph.compile()
|
|
1435
|
+
|
|
1436
|
+
|
|
1437
|
+
# =============================================================================
|
|
1438
|
+
# High-Level API
|
|
1439
|
+
# =============================================================================
|
|
1440
|
+
|
|
1441
|
+
|
|
1442
|
+
def run_navigator(
|
|
1443
|
+
question: str,
|
|
1444
|
+
skeleton: dict[str, SkeletonNode],
|
|
1445
|
+
kv_store: KVStore,
|
|
1446
|
+
max_iterations: int = 20,
|
|
1447
|
+
top_k: int | None = None,
|
|
1448
|
+
use_semantic_search: bool = True,
|
|
1449
|
+
semantic_searcher: SemanticSearcher | None = None,
|
|
1450
|
+
metadata: dict[str, Any] | None = None,
|
|
1451
|
+
tot_selection_threshold: float = 0.4,
|
|
1452
|
+
tot_dead_end_threshold: float = 0.1,
|
|
1453
|
+
) -> dict[str, Any]:
|
|
1454
|
+
"""
|
|
1455
|
+
Run the navigator agent on a question.
|
|
1456
|
+
|
|
1457
|
+
Args:
|
|
1458
|
+
question: User's question.
|
|
1459
|
+
skeleton: Skeleton index.
|
|
1460
|
+
kv_store: KV store with full content.
|
|
1461
|
+
max_iterations: Maximum navigation iterations.
|
|
1462
|
+
top_k: Number of top children to explore (default: auto-detect based on tree depth).
|
|
1463
|
+
use_semantic_search: Use semantic search (O(log N)) instead of ToT evaluation (O(N)).
|
|
1464
|
+
Allows exploring ALL leaf nodes ranked by relevance, preventing missed data.
|
|
1465
|
+
semantic_searcher: Optional pre-built semantic searcher. If None and use_semantic_search=True, creates one.
|
|
1466
|
+
tot_selection_threshold: Minimum probability for ToT node selection (0.0-1.0).
|
|
1467
|
+
tot_dead_end_threshold: Probability threshold for declaring a dead end (0.0-1.0).
|
|
1468
|
+
|
|
1469
|
+
Returns:
|
|
1470
|
+
Dictionary with answer, confidence, trace.
|
|
1471
|
+
|
|
1472
|
+
Example:
|
|
1473
|
+
result = run_navigator(
|
|
1474
|
+
"What are the liability terms?",
|
|
1475
|
+
skeleton,
|
|
1476
|
+
kv_store,
|
|
1477
|
+
use_semantic_search=True, # Enable semantic search
|
|
1478
|
+
)
|
|
1479
|
+
print(result["answer"])
|
|
1480
|
+
"""
|
|
1481
|
+
# Get root node
|
|
1482
|
+
root_id = None
|
|
1483
|
+
root_node = None
|
|
1484
|
+
for node in skeleton.values():
|
|
1485
|
+
if node.level == 0:
|
|
1486
|
+
root_id = node.node_id
|
|
1487
|
+
root_node = node
|
|
1488
|
+
break
|
|
1489
|
+
|
|
1490
|
+
if root_id is None:
|
|
1491
|
+
return {
|
|
1492
|
+
"answer": "Error: No root node found in skeleton index.",
|
|
1493
|
+
"confidence": 0.0,
|
|
1494
|
+
"trace": [],
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1497
|
+
# Auto-detect top_k based on tree structure
|
|
1498
|
+
if top_k is None:
|
|
1499
|
+
num_root_children = len(root_node.child_ids) if root_node else 0
|
|
1500
|
+
if num_root_children > 10:
|
|
1501
|
+
# Flat hierarchy (e.g., QuALITY): explore more children
|
|
1502
|
+
top_k = min(10, num_root_children)
|
|
1503
|
+
else:
|
|
1504
|
+
# Deep hierarchy (e.g., PDFs): explore fewer
|
|
1505
|
+
top_k = 3
|
|
1506
|
+
|
|
1507
|
+
# Semantic search disabled by default per research paper Section 9.1:
|
|
1508
|
+
# "Hybridize Search: Give the agent a vector_search tool as a SHORTCUT.
|
|
1509
|
+
# The agent can use vector search to find a starting node and then
|
|
1510
|
+
# SWITCH TO TREE TRAVERSAL for local exploration."
|
|
1511
|
+
#
|
|
1512
|
+
# Primary navigation uses ToT reasoning-based retrieval.
|
|
1513
|
+
if use_semantic_search and semantic_searcher is None:
|
|
1514
|
+
try:
|
|
1515
|
+
from rnsr.indexing.semantic_search import create_semantic_searcher
|
|
1516
|
+
semantic_searcher = create_semantic_searcher(skeleton, kv_store)
|
|
1517
|
+
logger.info(
|
|
1518
|
+
"semantic_search_optional_tool",
|
|
1519
|
+
nodes=len(skeleton),
|
|
1520
|
+
note="Available as shortcut for entry points only",
|
|
1521
|
+
)
|
|
1522
|
+
except Exception as e:
|
|
1523
|
+
logger.warning(
|
|
1524
|
+
"semantic_search_unavailable",
|
|
1525
|
+
error=str(e),
|
|
1526
|
+
)
|
|
1527
|
+
semantic_searcher = None
|
|
1528
|
+
|
|
1529
|
+
logger.info(
|
|
1530
|
+
"using_tot_reasoning_navigation",
|
|
1531
|
+
method="Tree of Thoughts",
|
|
1532
|
+
adaptive_exploration=True,
|
|
1533
|
+
note="LLM reasons about document structure to navigate",
|
|
1534
|
+
)
|
|
1535
|
+
|
|
1536
|
+
# Build and run graph
|
|
1537
|
+
graph = build_navigator_graph(skeleton, kv_store, semantic_searcher)
|
|
1538
|
+
|
|
1539
|
+
initial_state = create_initial_state(
|
|
1540
|
+
question=question,
|
|
1541
|
+
root_node_id=root_id,
|
|
1542
|
+
max_iterations=max_iterations,
|
|
1543
|
+
top_k=top_k,
|
|
1544
|
+
metadata=metadata,
|
|
1545
|
+
tot_selection_threshold=tot_selection_threshold,
|
|
1546
|
+
tot_dead_end_threshold=tot_dead_end_threshold,
|
|
1547
|
+
)
|
|
1548
|
+
|
|
1549
|
+
final_state = graph.invoke(initial_state)
|
|
1550
|
+
|
|
1551
|
+
return {
|
|
1552
|
+
"answer": final_state.get("answer", ""),
|
|
1553
|
+
"confidence": final_state.get("confidence", 0.0),
|
|
1554
|
+
"trace": final_state.get("trace", []),
|
|
1555
|
+
"variables_used": final_state.get("variables", []),
|
|
1556
|
+
"nodes_visited": final_state.get("visited_nodes", []),
|
|
1557
|
+
}
|