chatterer 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterer/__init__.py +93 -93
- chatterer/common_types/__init__.py +21 -21
- chatterer/common_types/io.py +19 -19
- chatterer/examples/__init__.py +0 -0
- chatterer/examples/anything_to_markdown.py +95 -91
- chatterer/examples/get_code_snippets.py +64 -62
- chatterer/examples/login_with_playwright.py +171 -167
- chatterer/examples/make_ppt.py +499 -497
- chatterer/examples/pdf_to_markdown.py +107 -107
- chatterer/examples/pdf_to_text.py +60 -56
- chatterer/examples/transcription_api.py +127 -123
- chatterer/examples/upstage_parser.py +95 -100
- chatterer/examples/webpage_to_markdown.py +79 -79
- chatterer/interactive.py +354 -354
- chatterer/language_model.py +533 -533
- chatterer/messages.py +21 -21
- chatterer/strategies/__init__.py +13 -13
- chatterer/strategies/atom_of_thoughts.py +975 -975
- chatterer/strategies/base.py +14 -14
- chatterer/tools/__init__.py +46 -46
- chatterer/tools/caption_markdown_images.py +384 -384
- chatterer/tools/citation_chunking/__init__.py +3 -3
- chatterer/tools/citation_chunking/chunks.py +53 -53
- chatterer/tools/citation_chunking/citation_chunker.py +118 -118
- chatterer/tools/citation_chunking/citations.py +285 -285
- chatterer/tools/citation_chunking/prompt.py +157 -157
- chatterer/tools/citation_chunking/reference.py +26 -26
- chatterer/tools/citation_chunking/utils.py +138 -138
- chatterer/tools/convert_pdf_to_markdown.py +302 -302
- chatterer/tools/convert_to_text.py +447 -447
- chatterer/tools/upstage_document_parser.py +705 -705
- chatterer/tools/webpage_to_markdown.py +739 -739
- chatterer/tools/youtube.py +146 -146
- chatterer/utils/__init__.py +15 -15
- chatterer/utils/base64_image.py +285 -285
- chatterer/utils/bytesio.py +59 -59
- chatterer/utils/code_agent.py +237 -237
- chatterer/utils/imghdr.py +148 -148
- {chatterer-0.1.18.dist-info → chatterer-0.1.19.dist-info}/METADATA +392 -392
- chatterer-0.1.19.dist-info/RECORD +44 -0
- {chatterer-0.1.18.dist-info → chatterer-0.1.19.dist-info}/WHEEL +1 -1
- chatterer-0.1.19.dist-info/entry_points.txt +10 -0
- chatterer-0.1.18.dist-info/RECORD +0 -42
- {chatterer-0.1.18.dist-info → chatterer-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,975 +1,975 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import asyncio
|
4
|
-
import logging
|
5
|
-
from dataclasses import dataclass, field
|
6
|
-
from enum import StrEnum
|
7
|
-
from typing import Optional, Type, TypeVar
|
8
|
-
|
9
|
-
from pydantic import BaseModel, Field, ValidationError
|
10
|
-
|
11
|
-
from ..language_model import Chatterer, LanguageModelInput
|
12
|
-
from ..messages import AIMessage, BaseMessage, HumanMessage
|
13
|
-
from .base import BaseStrategy
|
14
|
-
|
15
|
-
# ---------------------------------------------------------------------------------
|
16
|
-
# 0) Enums and Basic Models
|
17
|
-
# ---------------------------------------------------------------------------------
|
18
|
-
|
19
|
-
QA_TEMPLATE = "Q: {question}\nA: {answer}"
|
20
|
-
MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
|
21
|
-
UNKNOWN = "Unknown"
|
22
|
-
|
23
|
-
|
24
|
-
class SubQuestionNode(BaseModel):
|
25
|
-
"""A single sub-question node in a decomposition tree."""
|
26
|
-
|
27
|
-
question: str = Field(description="A sub-question string that arises from decomposition.")
|
28
|
-
answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
|
29
|
-
depend: list[int] = Field(description="Indices of sub-questions that this node depends on.")
|
30
|
-
|
31
|
-
|
32
|
-
class RecursiveDecomposeResponse(BaseModel):
|
33
|
-
"""The result of a recursive decomposition step."""
|
34
|
-
|
35
|
-
thought: str = Field(description="Reasoning about decomposition.")
|
36
|
-
final_answer: str = Field(description="Best answer to the main question.")
|
37
|
-
sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
|
38
|
-
|
39
|
-
|
40
|
-
class ContractQuestionResponse(BaseModel):
|
41
|
-
"""The result of contracting (simplifying) a question."""
|
42
|
-
|
43
|
-
thought: str = Field(description="Reasoning on how the question was compressed.")
|
44
|
-
question: str = Field(description="New, simplified, self-contained question.")
|
45
|
-
|
46
|
-
|
47
|
-
class EnsembleResponse(BaseModel):
|
48
|
-
"""The ensemble process result."""
|
49
|
-
|
50
|
-
thought: str = Field(description="Explanation for choosing the final answer.")
|
51
|
-
answer: str = Field(description="Best final answer after ensemble.")
|
52
|
-
confidence: float = Field(description="Confidence score in [0, 1].")
|
53
|
-
|
54
|
-
def model_post_init(self, __context: object) -> None:
|
55
|
-
self.confidence = max(0.0, min(1.0, self.confidence))
|
56
|
-
|
57
|
-
|
58
|
-
class LabelResponse(BaseModel):
|
59
|
-
"""Used to refine and reorder the sub-questions with corrected dependencies."""
|
60
|
-
|
61
|
-
thought: str = Field(description="Explanation or reasoning about labeling.")
|
62
|
-
sub_questions: list[SubQuestionNode] = Field(
|
63
|
-
description="Refined list of sub-questions with corrected dependencies."
|
64
|
-
)
|
65
|
-
|
66
|
-
|
67
|
-
class CritiqueResponse(BaseModel):
|
68
|
-
"""A response used for LLM to self-critique or question its own correctness."""
|
69
|
-
|
70
|
-
thought: str = Field(description="Critical reflection on correctness.")
|
71
|
-
self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
|
72
|
-
|
73
|
-
|
74
|
-
# ---------------------------------------------------------------------------------
|
75
|
-
# [NEW] Additional classes to incorporate a separate sub-question devil's advocate
|
76
|
-
# ---------------------------------------------------------------------------------
|
77
|
-
|
78
|
-
|
79
|
-
class DevilsAdvocateResponse(BaseModel):
|
80
|
-
"""
|
81
|
-
A response for a 'devil's advocate' pass.
|
82
|
-
We consider an alternative viewpoint or contradictory answer.
|
83
|
-
"""
|
84
|
-
|
85
|
-
thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
|
86
|
-
final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
|
87
|
-
sub_questions: list[SubQuestionNode] = Field(
|
88
|
-
description="Any additional sub-questions from the contrarian perspective."
|
89
|
-
)
|
90
|
-
|
91
|
-
|
92
|
-
# ---------------------------------------------------------------------------------
|
93
|
-
# 1) Prompter Classes with Multi-Hop + Devil's Advocate
|
94
|
-
# ---------------------------------------------------------------------------------
|
95
|
-
|
96
|
-
|
97
|
-
class AoTPrompter:
|
98
|
-
"""Generic base prompter that defines the required prompt methods."""
|
99
|
-
|
100
|
-
def recursive_decompose_prompt(
|
101
|
-
self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
|
102
|
-
) -> list[BaseMessage]:
|
103
|
-
"""
|
104
|
-
Prompt for main decomposition.
|
105
|
-
Encourages step-by-step reasoning and listing sub-questions as JSON.
|
106
|
-
"""
|
107
|
-
decompose_instructions = (
|
108
|
-
"First, restate the main question.\n"
|
109
|
-
"Decide if sub-questions are needed. If so, list them.\n"
|
110
|
-
"In the 'thought' field, show your chain-of-thought.\n"
|
111
|
-
"Return valid JSON:\n"
|
112
|
-
"{\n"
|
113
|
-
' "thought": "...",\n'
|
114
|
-
' "final_answer": "...",\n'
|
115
|
-
' "sub_questions": [\n'
|
116
|
-
' {"question": "...", "answer": null, "depend": []},\n'
|
117
|
-
" ...\n"
|
118
|
-
" ]\n"
|
119
|
-
"}\n"
|
120
|
-
)
|
121
|
-
|
122
|
-
content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
|
123
|
-
return messages + [
|
124
|
-
HumanMessage(content=f"Main question:\n{question}"),
|
125
|
-
AIMessage(content=content_sub_answers),
|
126
|
-
AIMessage(content=decompose_instructions),
|
127
|
-
]
|
128
|
-
|
129
|
-
def label_prompt(
|
130
|
-
self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
|
131
|
-
) -> list[BaseMessage]:
|
132
|
-
"""
|
133
|
-
Prompt for refining the sub-questions and dependencies.
|
134
|
-
"""
|
135
|
-
label_instructions = (
|
136
|
-
"Review each sub-question to ensure correctness and proper ordering.\n"
|
137
|
-
"Return valid JSON in the form:\n"
|
138
|
-
"{\n"
|
139
|
-
' "thought": "...",\n'
|
140
|
-
' "sub_questions": [\n'
|
141
|
-
' {"question": "...", "answer": "...", "depend": [...]},\n'
|
142
|
-
" ...\n"
|
143
|
-
" ]\n"
|
144
|
-
"}\n"
|
145
|
-
)
|
146
|
-
return messages + [
|
147
|
-
AIMessage(content=f"Question: {question}"),
|
148
|
-
AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
|
149
|
-
AIMessage(content=label_instructions),
|
150
|
-
]
|
151
|
-
|
152
|
-
def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
|
153
|
-
"""
|
154
|
-
Prompt for merging sub-answers into one self-contained question.
|
155
|
-
"""
|
156
|
-
contract_instructions = (
|
157
|
-
"Please merge sub-answers into a single short question that is fully self-contained.\n"
|
158
|
-
"In 'thought', show how you unify the information.\n"
|
159
|
-
"Then produce JSON:\n"
|
160
|
-
"{\n"
|
161
|
-
' "thought": "...",\n'
|
162
|
-
' "question": "a short but self-contained question"\n'
|
163
|
-
"}\n"
|
164
|
-
)
|
165
|
-
sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
|
166
|
-
return messages + [
|
167
|
-
AIMessage(content="We have these sub-questions and answers:"),
|
168
|
-
AIMessage(content=sub_q_content),
|
169
|
-
AIMessage(content=contract_instructions),
|
170
|
-
]
|
171
|
-
|
172
|
-
def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
|
173
|
-
"""
|
174
|
-
Prompt for directly answering the contracted question thoroughly.
|
175
|
-
"""
|
176
|
-
direct_instructions = (
|
177
|
-
"Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
|
178
|
-
"Return JSON:\n"
|
179
|
-
"{\n"
|
180
|
-
' "thought": "...",\n'
|
181
|
-
' "final_answer": "..."\n'
|
182
|
-
"}\n"
|
183
|
-
)
|
184
|
-
return messages + [
|
185
|
-
HumanMessage(content=f"Simplified question: {contracted_question}"),
|
186
|
-
AIMessage(content=direct_instructions),
|
187
|
-
]
|
188
|
-
|
189
|
-
def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
|
190
|
-
"""
|
191
|
-
Prompt for self-critique.
|
192
|
-
"""
|
193
|
-
critique_instructions = (
|
194
|
-
"Critique your own approach. Identify possible errors or leaps.\n"
|
195
|
-
"Return JSON:\n"
|
196
|
-
"{\n"
|
197
|
-
' "thought": "...",\n'
|
198
|
-
' "self_assessment": <float in [0,1]>\n'
|
199
|
-
"}\n"
|
200
|
-
)
|
201
|
-
return messages + [
|
202
|
-
AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
|
203
|
-
AIMessage(content=f"Your previous ANSWER:\n{answer}"),
|
204
|
-
AIMessage(content=critique_instructions),
|
205
|
-
]
|
206
|
-
|
207
|
-
def ensemble_prompt(
|
208
|
-
self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
|
209
|
-
) -> list[BaseMessage]:
|
210
|
-
"""
|
211
|
-
Show multiple candidate solutions and pick the best final answer with confidence.
|
212
|
-
"""
|
213
|
-
instructions = (
|
214
|
-
"You have multiple candidate solutions. Compare carefully and pick the best.\n"
|
215
|
-
"Return JSON:\n"
|
216
|
-
"{\n"
|
217
|
-
' "thought": "why you chose this final answer",\n'
|
218
|
-
' "answer": "the best consolidated answer",\n'
|
219
|
-
' "confidence": 0.0 ~ 1.0\n'
|
220
|
-
"}\n"
|
221
|
-
)
|
222
|
-
reasonings: list[BaseMessage] = []
|
223
|
-
for idx, (thought, ans) in enumerate(possible_thought_and_answers):
|
224
|
-
reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
|
225
|
-
return messages + reasonings + [AIMessage(content=instructions)]
|
226
|
-
|
227
|
-
def devils_advocate_prompt(
|
228
|
-
self, messages: list[BaseMessage], question: str, existing_answer: str
|
229
|
-
) -> list[BaseMessage]:
|
230
|
-
"""
|
231
|
-
Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
|
232
|
-
"""
|
233
|
-
instructions = (
|
234
|
-
"Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
|
235
|
-
"Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
|
236
|
-
"Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
|
237
|
-
"But here, let's keep it in a new dedicated structure:\n"
|
238
|
-
"{\n"
|
239
|
-
' "thought": "...",\n'
|
240
|
-
' "final_answer": "...",\n'
|
241
|
-
' "sub_questions": [\n'
|
242
|
-
' {"question": "...", "answer": null, "depend": []},\n'
|
243
|
-
" ...\n"
|
244
|
-
" ]\n"
|
245
|
-
"}\n"
|
246
|
-
)
|
247
|
-
return messages + [
|
248
|
-
AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
|
249
|
-
AIMessage(content=instructions),
|
250
|
-
]
|
251
|
-
|
252
|
-
|
253
|
-
# ---------------------------------------------------------------------------------
|
254
|
-
# 2) Strict Typed Steps for Pipeline
|
255
|
-
# ---------------------------------------------------------------------------------
|
256
|
-
|
257
|
-
|
258
|
-
class StepName(StrEnum):
|
259
|
-
"""Enum for step names in the pipeline."""
|
260
|
-
|
261
|
-
DOMAIN_DETECTION = "DomainDetection"
|
262
|
-
DECOMPOSITION = "Decomposition"
|
263
|
-
DECOMPOSITION_CRITIQUE = "DecompositionCritique"
|
264
|
-
CONTRACTED_QUESTION = "ContractedQuestion"
|
265
|
-
CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
|
266
|
-
CONTRACT_CRITIQUE = "ContractCritique"
|
267
|
-
BEST_APPROACH_DECISION = "BestApproachDecision"
|
268
|
-
ENSEMBLE = "Ensemble"
|
269
|
-
FINAL_ANSWER = "FinalAnswer"
|
270
|
-
|
271
|
-
DEVILS_ADVOCATE = "DevilsAdvocate"
|
272
|
-
DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
|
273
|
-
|
274
|
-
|
275
|
-
class StepRelation(StrEnum):
|
276
|
-
"""Enum for relationship types in the reasoning graph."""
|
277
|
-
|
278
|
-
CRITIQUES = "CRITIQUES"
|
279
|
-
SELECTS = "SELECTS"
|
280
|
-
RESULT_OF = "RESULT_OF"
|
281
|
-
SPLIT_INTO = "SPLIT_INTO"
|
282
|
-
DEPEND_ON = "DEPEND_ON"
|
283
|
-
PRECEDES = "PRECEDES"
|
284
|
-
DECOMPOSED_BY = "DECOMPOSED_BY"
|
285
|
-
|
286
|
-
|
287
|
-
class StepRecord(BaseModel):
|
288
|
-
"""A typed record for each pipeline step."""
|
289
|
-
|
290
|
-
step_name: StepName
|
291
|
-
domain: Optional[str] = None
|
292
|
-
score: Optional[float] = None
|
293
|
-
used: Optional[StepName] = None
|
294
|
-
sub_questions: Optional[list[SubQuestionNode]] = None
|
295
|
-
parent_decomp_step_idx: Optional[int] = None
|
296
|
-
parent_subq_idx: Optional[int] = None
|
297
|
-
question: Optional[str] = None
|
298
|
-
thought: Optional[str] = None
|
299
|
-
answer: Optional[str] = None
|
300
|
-
|
301
|
-
def as_properties(self) -> dict[str, str | float | int | None]:
|
302
|
-
"""Converts the StepRecord to a dictionary of properties."""
|
303
|
-
result: dict[str, str | float | int | None] = {}
|
304
|
-
if self.score is not None:
|
305
|
-
result["score"] = self.score
|
306
|
-
if self.domain:
|
307
|
-
result["domain"] = self.domain
|
308
|
-
if self.question:
|
309
|
-
result["question"] = self.question
|
310
|
-
if self.thought:
|
311
|
-
result["thought"] = self.thought
|
312
|
-
if self.answer:
|
313
|
-
result["answer"] = self.answer
|
314
|
-
return result
|
315
|
-
|
316
|
-
|
317
|
-
# ---------------------------------------------------------------------------------
|
318
|
-
# 3) Logging Setup
|
319
|
-
# ---------------------------------------------------------------------------------
|
320
|
-
|
321
|
-
|
322
|
-
class SimpleColorFormatter(logging.Formatter):
|
323
|
-
"""Simple color-coded logging formatter for console output using ANSI escape codes."""
|
324
|
-
|
325
|
-
BLUE = "\033[94m"
|
326
|
-
GREEN = "\033[92m"
|
327
|
-
YELLOW = "\033[93m"
|
328
|
-
RED = "\033[91m"
|
329
|
-
RESET = "\033[0m"
|
330
|
-
LEVEL_COLORS = {
|
331
|
-
logging.DEBUG: BLUE,
|
332
|
-
logging.INFO: GREEN,
|
333
|
-
logging.WARNING: YELLOW,
|
334
|
-
logging.ERROR: RED,
|
335
|
-
logging.CRITICAL: RED,
|
336
|
-
}
|
337
|
-
|
338
|
-
def format(self, record: logging.LogRecord) -> str:
|
339
|
-
log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
|
340
|
-
message = super().format(record)
|
341
|
-
return f"{log_color}{message}{self.RESET}"
|
342
|
-
|
343
|
-
|
344
|
-
logger = logging.getLogger("AoT")
|
345
|
-
logger.setLevel(logging.INFO)
|
346
|
-
handler = logging.StreamHandler()
|
347
|
-
handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
|
348
|
-
logger.handlers = [handler]
|
349
|
-
logger.propagate = False
|
350
|
-
|
351
|
-
|
352
|
-
# ---------------------------------------------------------------------------------
|
353
|
-
# 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
|
354
|
-
# ---------------------------------------------------------------------------------
|
355
|
-
|
356
|
-
T = TypeVar(
|
357
|
-
"T",
|
358
|
-
bound=EnsembleResponse
|
359
|
-
| ContractQuestionResponse
|
360
|
-
| LabelResponse
|
361
|
-
| CritiqueResponse
|
362
|
-
| RecursiveDecomposeResponse
|
363
|
-
| DevilsAdvocateResponse,
|
364
|
-
)
|
365
|
-
|
366
|
-
|
367
|
-
@dataclass
|
368
|
-
class AoTPipeline:
|
369
|
-
"""
|
370
|
-
The pipeline orchestrates:
|
371
|
-
1) Recursive decomposition
|
372
|
-
2) For each sub-question, it tries a main approach + a devil's advocate approach
|
373
|
-
3) Merges sub-answers using an ensemble
|
374
|
-
4) Contracts the question
|
375
|
-
5) Possibly does a direct approach on the contracted question
|
376
|
-
6) Ensembling the final answers
|
377
|
-
"""
|
378
|
-
|
379
|
-
chatterer: Chatterer
|
380
|
-
max_depth: int = 2
|
381
|
-
max_retries: int = 2
|
382
|
-
steps_history: list[StepRecord] = field(default_factory=list)
|
383
|
-
prompter: AoTPrompter = field(default_factory=AoTPrompter)
|
384
|
-
|
385
|
-
# 4.1) Utility for calling the LLM with Pydantic parsing
|
386
|
-
async def _ainvoke_pydantic(
|
387
|
-
self,
|
388
|
-
messages: list[BaseMessage],
|
389
|
-
model_cls: Type[T],
|
390
|
-
fallback: str = "<None>",
|
391
|
-
) -> T:
|
392
|
-
"""
|
393
|
-
Attempts up to max_retries to parse the model_cls from LLM output as JSON.
|
394
|
-
"""
|
395
|
-
for attempt in range(1, self.max_retries + 1):
|
396
|
-
try:
|
397
|
-
return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
|
398
|
-
except ValidationError as e:
|
399
|
-
logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
|
400
|
-
if attempt == self.max_retries:
|
401
|
-
# Return a fallback version
|
402
|
-
if issubclass(model_cls, EnsembleResponse):
|
403
|
-
return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
|
404
|
-
elif issubclass(model_cls, ContractQuestionResponse):
|
405
|
-
return model_cls(thought=fallback, question=fallback) # type: ignore
|
406
|
-
elif issubclass(model_cls, LabelResponse):
|
407
|
-
return model_cls(thought=fallback, sub_questions=[]) # type: ignore
|
408
|
-
elif issubclass(model_cls, CritiqueResponse):
|
409
|
-
return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
|
410
|
-
elif issubclass(model_cls, DevilsAdvocateResponse):
|
411
|
-
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
412
|
-
else:
|
413
|
-
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
414
|
-
# theoretically unreachable
|
415
|
-
raise RuntimeError("Unexpected error in _ainvoke_pydantic")
|
416
|
-
|
417
|
-
# 4.2) Helper method for self-critique
|
418
|
-
async def _ainvoke_critique(
|
419
|
-
self,
|
420
|
-
messages: list[BaseMessage],
|
421
|
-
thought: str,
|
422
|
-
answer: str,
|
423
|
-
) -> CritiqueResponse:
|
424
|
-
"""
|
425
|
-
Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
|
426
|
-
"""
|
427
|
-
return await self._ainvoke_pydantic(
|
428
|
-
messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
|
429
|
-
model_cls=CritiqueResponse,
|
430
|
-
)
|
431
|
-
|
432
|
-
# 4.3) Helper method for devil's advocate approach
|
433
|
-
async def _ainvoke_devils_advocate(
|
434
|
-
self,
|
435
|
-
messages: list[BaseMessage],
|
436
|
-
question: str,
|
437
|
-
existing_answer: str,
|
438
|
-
) -> DevilsAdvocateResponse:
|
439
|
-
"""
|
440
|
-
Instructs the LLM to challenge an existing answer with a devil's advocate approach.
|
441
|
-
"""
|
442
|
-
return await self._ainvoke_pydantic(
|
443
|
-
messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
|
444
|
-
model_cls=DevilsAdvocateResponse,
|
445
|
-
)
|
446
|
-
|
447
|
-
# 4.4) The main function that recursively decomposes a question and calls sub-steps
|
448
|
-
async def _arecursive_decompose_question(
|
449
|
-
self,
|
450
|
-
messages: list[BaseMessage],
|
451
|
-
question: str,
|
452
|
-
depth: int,
|
453
|
-
parent_decomp_step_idx: Optional[int] = None,
|
454
|
-
parent_subq_idx: Optional[int] = None,
|
455
|
-
) -> RecursiveDecomposeResponse:
|
456
|
-
"""
|
457
|
-
Recursively decompose the given question. For each sub-question:
|
458
|
-
1) Recursively decompose that sub-question if we still have depth left
|
459
|
-
2) After getting a main sub-answer, do a devil's advocate pass
|
460
|
-
3) Combine main sub-answer + devil's advocate alternative via an ensemble
|
461
|
-
"""
|
462
|
-
if depth < 0:
|
463
|
-
logger.info("Max depth reached, returning unknown.")
|
464
|
-
return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
|
465
|
-
|
466
|
-
# Step 1: Perform the decomposition
|
467
|
-
decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
468
|
-
messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
|
469
|
-
model_cls=RecursiveDecomposeResponse,
|
470
|
-
)
|
471
|
-
|
472
|
-
# Step 2: Label / refine sub-questions (dependencies, ordering)
|
473
|
-
if decompose_resp.sub_questions:
|
474
|
-
label_resp: LabelResponse = await self._ainvoke_pydantic(
|
475
|
-
messages=self.prompter.label_prompt(messages, question, decompose_resp),
|
476
|
-
model_cls=LabelResponse,
|
477
|
-
)
|
478
|
-
decompose_resp.sub_questions = label_resp.sub_questions
|
479
|
-
|
480
|
-
# Save a pipeline record for this decomposition step
|
481
|
-
current_decomp_step_idx = self._record_decomposition_step(
|
482
|
-
question=question,
|
483
|
-
final_answer=decompose_resp.final_answer,
|
484
|
-
sub_questions=decompose_resp.sub_questions,
|
485
|
-
parent_decomp_step_idx=parent_decomp_step_idx,
|
486
|
-
parent_subq_idx=parent_subq_idx,
|
487
|
-
)
|
488
|
-
|
489
|
-
# Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
|
490
|
-
if depth > 0 and decompose_resp.sub_questions:
|
491
|
-
solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
|
492
|
-
messages=messages,
|
493
|
-
sub_questions=decompose_resp.sub_questions,
|
494
|
-
depth=depth,
|
495
|
-
parent_decomp_step_idx=current_decomp_step_idx,
|
496
|
-
)
|
497
|
-
# Then we can refine the "final_answer" from those sub-answers
|
498
|
-
# or we do a secondary pass to refine the final answer
|
499
|
-
refined_prompt = self.prompter.recursive_decompose_prompt(
|
500
|
-
messages=messages,
|
501
|
-
question=question,
|
502
|
-
sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
|
503
|
-
)
|
504
|
-
refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
505
|
-
refined_prompt, RecursiveDecomposeResponse
|
506
|
-
)
|
507
|
-
decompose_resp.final_answer = refined_resp.final_answer
|
508
|
-
decompose_resp.sub_questions = solved_subs
|
509
|
-
|
510
|
-
# Update pipeline record
|
511
|
-
self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
|
512
|
-
self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
|
513
|
-
|
514
|
-
return decompose_resp
|
515
|
-
|
516
|
-
def _record_decomposition_step(
|
517
|
-
self,
|
518
|
-
question: str,
|
519
|
-
final_answer: str,
|
520
|
-
sub_questions: list[SubQuestionNode],
|
521
|
-
parent_decomp_step_idx: Optional[int],
|
522
|
-
parent_subq_idx: Optional[int],
|
523
|
-
) -> int:
|
524
|
-
"""
|
525
|
-
Save the decomposition step in steps_history, returning the index.
|
526
|
-
"""
|
527
|
-
step_record = StepRecord(
|
528
|
-
step_name=StepName.DECOMPOSITION,
|
529
|
-
question=question,
|
530
|
-
answer=final_answer,
|
531
|
-
sub_questions=sub_questions,
|
532
|
-
parent_decomp_step_idx=parent_decomp_step_idx,
|
533
|
-
parent_subq_idx=parent_subq_idx,
|
534
|
-
)
|
535
|
-
self.steps_history.append(step_record)
|
536
|
-
return len(self.steps_history) - 1
|
537
|
-
|
538
|
-
async def _aresolve_sub_questions(
|
539
|
-
self,
|
540
|
-
messages: list[BaseMessage],
|
541
|
-
sub_questions: list[SubQuestionNode],
|
542
|
-
depth: int,
|
543
|
-
parent_decomp_step_idx: Optional[int],
|
544
|
-
) -> list[SubQuestionNode]:
|
545
|
-
"""
|
546
|
-
Resolve sub-questions in topological order.
|
547
|
-
For each sub-question:
|
548
|
-
1) Recursively decompose (main approach).
|
549
|
-
2) Acquire a devil's advocate alternative.
|
550
|
-
3) Critique or ensemble if needed.
|
551
|
-
4) Finalize sub-question answer.
|
552
|
-
"""
|
553
|
-
n = len(sub_questions)
|
554
|
-
in_degree = [0] * n
|
555
|
-
graph: list[list[int]] = [[] for _ in range(n)]
|
556
|
-
for i, sq in enumerate(sub_questions):
|
557
|
-
for dep in sq.depend:
|
558
|
-
if 0 <= dep < n:
|
559
|
-
in_degree[i] += 1
|
560
|
-
graph[dep].append(i)
|
561
|
-
|
562
|
-
# Kahn's algorithm for topological order
|
563
|
-
queue = [i for i in range(n) if in_degree[i] == 0]
|
564
|
-
topo_order: list[int] = []
|
565
|
-
|
566
|
-
while queue:
|
567
|
-
node = queue.pop(0)
|
568
|
-
topo_order.append(node)
|
569
|
-
for nxt in graph[node]:
|
570
|
-
in_degree[nxt] -= 1
|
571
|
-
if in_degree[nxt] == 0:
|
572
|
-
queue.append(nxt)
|
573
|
-
|
574
|
-
# We'll store the resolved sub-questions
|
575
|
-
final_subs: dict[int, SubQuestionNode] = {}
|
576
|
-
|
577
|
-
async def _resolve_one_subq(idx: int):
|
578
|
-
sq = sub_questions[idx]
|
579
|
-
# 1) Main approach
|
580
|
-
main_resp = await self._arecursive_decompose_question(
|
581
|
-
messages=messages,
|
582
|
-
question=sq.question,
|
583
|
-
depth=depth - 1,
|
584
|
-
parent_decomp_step_idx=parent_decomp_step_idx,
|
585
|
-
parent_subq_idx=idx,
|
586
|
-
)
|
587
|
-
|
588
|
-
main_answer = main_resp.final_answer
|
589
|
-
|
590
|
-
# 2) Devil's Advocate approach
|
591
|
-
devils_resp = await self._ainvoke_devils_advocate(
|
592
|
-
messages=messages, question=sq.question, existing_answer=main_answer
|
593
|
-
)
|
594
|
-
# 3) Ensemble to combine main_answer + devils_alternative
|
595
|
-
ensemble_sub = await self._ainvoke_pydantic(
|
596
|
-
self.prompter.ensemble_prompt(
|
597
|
-
messages=messages,
|
598
|
-
possible_thought_and_answers=[
|
599
|
-
(main_resp.thought, main_answer),
|
600
|
-
(devils_resp.thought, devils_resp.final_answer),
|
601
|
-
],
|
602
|
-
),
|
603
|
-
EnsembleResponse,
|
604
|
-
)
|
605
|
-
sub_best_answer = ensemble_sub.answer
|
606
|
-
|
607
|
-
# Store final subq answer
|
608
|
-
sq.answer = sub_best_answer
|
609
|
-
final_subs[idx] = sq
|
610
|
-
|
611
|
-
# Record pipeline steps for devil's advocate
|
612
|
-
self.steps_history.append(
|
613
|
-
StepRecord(
|
614
|
-
step_name=StepName.DEVILS_ADVOCATE,
|
615
|
-
question=sq.question,
|
616
|
-
answer=devils_resp.final_answer,
|
617
|
-
thought=devils_resp.thought,
|
618
|
-
sub_questions=devils_resp.sub_questions,
|
619
|
-
)
|
620
|
-
)
|
621
|
-
# Possibly critique the devils advocate result
|
622
|
-
dev_adv_crit = await self._ainvoke_critique(
|
623
|
-
messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
|
624
|
-
)
|
625
|
-
self.steps_history.append(
|
626
|
-
StepRecord(
|
627
|
-
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
628
|
-
thought=dev_adv_crit.thought,
|
629
|
-
score=dev_adv_crit.self_assessment,
|
630
|
-
)
|
631
|
-
)
|
632
|
-
|
633
|
-
# Solve sub-questions in topological order
|
634
|
-
tasks = [_resolve_one_subq(i) for i in topo_order]
|
635
|
-
await asyncio.gather(*tasks, return_exceptions=False)
|
636
|
-
|
637
|
-
return [final_subs[i] for i in range(n)]
|
638
|
-
|
639
|
-
# 4.5) The primary pipeline method
|
640
|
-
async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
|
641
|
-
"""
|
642
|
-
Execute the pipeline:
|
643
|
-
1) Decompose the main question (recursively).
|
644
|
-
2) Self-critique.
|
645
|
-
3) Provide a devil's advocate approach on the entire main result.
|
646
|
-
4) Contract sub-answers (optional).
|
647
|
-
5) Directly solve the contracted question.
|
648
|
-
6) Self-critique again.
|
649
|
-
7) Final ensemble across main vs devil's vs contracted direct answer.
|
650
|
-
8) Return final answer.
|
651
|
-
"""
|
652
|
-
self.steps_history.clear()
|
653
|
-
|
654
|
-
original_question: str = messages[-1].text()
|
655
|
-
# 1) Recursive decomposition
|
656
|
-
decomp_resp = await self._arecursive_decompose_question(
|
657
|
-
messages=messages,
|
658
|
-
question=original_question,
|
659
|
-
depth=self.max_depth,
|
660
|
-
)
|
661
|
-
logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
|
662
|
-
|
663
|
-
# 2) Self-critique of main decomposition
|
664
|
-
decomp_critique = await self._ainvoke_critique(
|
665
|
-
messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
|
666
|
-
)
|
667
|
-
self.steps_history.append(
|
668
|
-
StepRecord(
|
669
|
-
step_name=StepName.DECOMPOSITION_CRITIQUE,
|
670
|
-
thought=decomp_critique.thought,
|
671
|
-
score=decomp_critique.self_assessment,
|
672
|
-
)
|
673
|
-
)
|
674
|
-
|
675
|
-
# 3) Devil's advocate on the entire main answer
|
676
|
-
devils_on_main = await self._ainvoke_devils_advocate(
|
677
|
-
messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
|
678
|
-
)
|
679
|
-
self.steps_history.append(
|
680
|
-
StepRecord(
|
681
|
-
step_name=StepName.DEVILS_ADVOCATE,
|
682
|
-
question=original_question,
|
683
|
-
answer=devils_on_main.final_answer,
|
684
|
-
thought=devils_on_main.thought,
|
685
|
-
sub_questions=devils_on_main.sub_questions,
|
686
|
-
)
|
687
|
-
)
|
688
|
-
devils_crit_main = await self._ainvoke_critique(
|
689
|
-
messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
|
690
|
-
)
|
691
|
-
self.steps_history.append(
|
692
|
-
StepRecord(
|
693
|
-
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
694
|
-
thought=devils_crit_main.thought,
|
695
|
-
score=devils_crit_main.self_assessment,
|
696
|
-
)
|
697
|
-
)
|
698
|
-
|
699
|
-
# 4) Contract sub-answers from main decomposition
|
700
|
-
top_decomp_record: Optional[StepRecord] = next(
|
701
|
-
(
|
702
|
-
s
|
703
|
-
for s in reversed(self.steps_history)
|
704
|
-
if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
|
705
|
-
),
|
706
|
-
None,
|
707
|
-
)
|
708
|
-
if top_decomp_record and top_decomp_record.sub_questions:
|
709
|
-
sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
|
710
|
-
else:
|
711
|
-
sub_answers = []
|
712
|
-
|
713
|
-
contract_resp = await self._ainvoke_pydantic(
|
714
|
-
messages=self.prompter.contract_prompt(messages, sub_answers),
|
715
|
-
model_cls=ContractQuestionResponse,
|
716
|
-
)
|
717
|
-
contracted_question = contract_resp.question
|
718
|
-
self.steps_history.append(
|
719
|
-
StepRecord(
|
720
|
-
step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
|
721
|
-
)
|
722
|
-
)
|
723
|
-
|
724
|
-
# 5) Attempt direct approach on contracted question
|
725
|
-
contracted_direct = await self._ainvoke_pydantic(
|
726
|
-
messages=self.prompter.contract_direct_prompt(messages, contracted_question),
|
727
|
-
model_cls=RecursiveDecomposeResponse,
|
728
|
-
fallback="No Contracted Direct Answer",
|
729
|
-
)
|
730
|
-
self.steps_history.append(
|
731
|
-
StepRecord(
|
732
|
-
step_name=StepName.CONTRACTED_DIRECT_ANSWER,
|
733
|
-
answer=contracted_direct.final_answer,
|
734
|
-
thought=contracted_direct.thought,
|
735
|
-
)
|
736
|
-
)
|
737
|
-
logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
|
738
|
-
|
739
|
-
# 5.1) Critique the contracted direct approach
|
740
|
-
contract_critique = await self._ainvoke_critique(
|
741
|
-
messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
|
742
|
-
)
|
743
|
-
self.steps_history.append(
|
744
|
-
StepRecord(
|
745
|
-
step_name=StepName.CONTRACT_CRITIQUE,
|
746
|
-
thought=contract_critique.thought,
|
747
|
-
score=contract_critique.self_assessment,
|
748
|
-
)
|
749
|
-
)
|
750
|
-
|
751
|
-
# 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
|
752
|
-
ensemble_resp = await self._ainvoke_pydantic(
|
753
|
-
self.prompter.ensemble_prompt(
|
754
|
-
messages=messages,
|
755
|
-
possible_thought_and_answers=[
|
756
|
-
(decomp_resp.thought, decomp_resp.final_answer),
|
757
|
-
(devils_on_main.thought, devils_on_main.final_answer),
|
758
|
-
(contracted_direct.thought, contracted_direct.final_answer),
|
759
|
-
],
|
760
|
-
),
|
761
|
-
EnsembleResponse,
|
762
|
-
)
|
763
|
-
best_approach_answer = ensemble_resp.answer
|
764
|
-
approach_used = StepName.ENSEMBLE
|
765
|
-
self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
|
766
|
-
logger.info(f"[Best Approach Decision] => {approach_used}")
|
767
|
-
|
768
|
-
# 7) Final answer
|
769
|
-
self.steps_history.append(
|
770
|
-
StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
|
771
|
-
)
|
772
|
-
logger.info(f"[Final Answer] => {best_approach_answer}")
|
773
|
-
|
774
|
-
return best_approach_answer
|
775
|
-
|
776
|
-
def run_pipeline(self, messages: list[BaseMessage]) -> str:
|
777
|
-
"""Synchronous wrapper around arun_pipeline."""
|
778
|
-
return asyncio.run(self.arun_pipeline(messages))
|
779
|
-
|
780
|
-
# ---------------------------------------------------------------------------------
|
781
|
-
# 4.6) Build or export a reasoning graph
|
782
|
-
# ---------------------------------------------------------------------------------
|
783
|
-
|
784
|
-
def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
|
785
|
-
"""
|
786
|
-
Constructs a Graph object (from hypothetical `neo4j_extension`)
|
787
|
-
capturing the pipeline steps, including devil's advocate steps.
|
788
|
-
"""
|
789
|
-
from neo4j_extension import Graph, Node, Relationship
|
790
|
-
|
791
|
-
g = Graph()
|
792
|
-
step_nodes: dict[int, Node] = {}
|
793
|
-
subq_nodes: dict[str, Node] = {}
|
794
|
-
|
795
|
-
# Step A: Create nodes for each pipeline step
|
796
|
-
for i, record in enumerate(self.steps_history):
|
797
|
-
# We'll skip nested Decomposition steps only if we want to flatten them.
|
798
|
-
# But let's keep them for clarity.
|
799
|
-
step_node = Node(
|
800
|
-
properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
|
801
|
-
)
|
802
|
-
g.add_node(step_node)
|
803
|
-
step_nodes[i] = step_node
|
804
|
-
|
805
|
-
# Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
|
806
|
-
all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
|
807
|
-
for i, record in enumerate(self.steps_history):
|
808
|
-
if record.sub_questions:
|
809
|
-
for sq_idx, sq in enumerate(record.sub_questions):
|
810
|
-
sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
811
|
-
all_sub_questions[sq_id] = (i, sq_idx, sq)
|
812
|
-
|
813
|
-
for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
|
814
|
-
n_subq = Node(
|
815
|
-
properties={
|
816
|
-
"question": sq.question,
|
817
|
-
"answer": sq.answer or "",
|
818
|
-
},
|
819
|
-
labels={"SubQuestion"},
|
820
|
-
globalId=sq_id,
|
821
|
-
)
|
822
|
-
g.add_node(n_subq)
|
823
|
-
subq_nodes[sq_id] = n_subq
|
824
|
-
|
825
|
-
# Step C: Add relationships. We do a simple approach:
|
826
|
-
# - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
|
827
|
-
for i, record in enumerate(self.steps_history):
|
828
|
-
if record.sub_questions:
|
829
|
-
start_node = step_nodes[i]
|
830
|
-
for sq_idx, sq in enumerate(record.sub_questions):
|
831
|
-
sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
832
|
-
end_node = subq_nodes[sq_id]
|
833
|
-
rel = Relationship(
|
834
|
-
properties={},
|
835
|
-
rel_type=StepRelation.SPLIT_INTO,
|
836
|
-
start_node=start_node,
|
837
|
-
end_node=end_node,
|
838
|
-
globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
|
839
|
-
)
|
840
|
-
g.add_relationship(rel)
|
841
|
-
# Also add sub-question dependencies
|
842
|
-
for dep in sq.depend:
|
843
|
-
# The same record i -> sub-question subq
|
844
|
-
if 0 <= dep < len(record.sub_questions):
|
845
|
-
dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
|
846
|
-
if dep_id in subq_nodes:
|
847
|
-
dep_node = subq_nodes[dep_id]
|
848
|
-
rel_dep = Relationship(
|
849
|
-
properties={},
|
850
|
-
rel_type=StepRelation.DEPEND_ON,
|
851
|
-
start_node=end_node,
|
852
|
-
end_node=dep_node,
|
853
|
-
globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
|
854
|
-
)
|
855
|
-
g.add_relationship(rel_dep)
|
856
|
-
|
857
|
-
# Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
|
858
|
-
for i in range(len(self.steps_history) - 1):
|
859
|
-
start_node = step_nodes[i]
|
860
|
-
end_node = step_nodes[i + 1]
|
861
|
-
rel = Relationship(
|
862
|
-
properties={},
|
863
|
-
rel_type=StepRelation.PRECEDES,
|
864
|
-
start_node=start_node,
|
865
|
-
end_node=end_node,
|
866
|
-
globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
|
867
|
-
)
|
868
|
-
g.add_relationship(rel)
|
869
|
-
|
870
|
-
# Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
|
871
|
-
# We'll do a simple pass:
|
872
|
-
# If step_name ends with CRITIQUE => it critiques the step before it
|
873
|
-
for i, record in enumerate(self.steps_history):
|
874
|
-
if "CRITIQUE" in record.step_name:
|
875
|
-
# Let it point to the preceding step
|
876
|
-
if i > 0:
|
877
|
-
start_node = step_nodes[i]
|
878
|
-
end_node = step_nodes[i - 1]
|
879
|
-
rel = Relationship(
|
880
|
-
properties={},
|
881
|
-
rel_type=StepRelation.CRITIQUES,
|
882
|
-
start_node=start_node,
|
883
|
-
end_node=end_node,
|
884
|
-
globalId=f"{global_id_prefix}_crit_{i}",
|
885
|
-
)
|
886
|
-
g.add_relationship(rel)
|
887
|
-
|
888
|
-
# If there's a BEST_APPROACH_DECISION step, link it to the step it uses
|
889
|
-
best_decision_idx = None
|
890
|
-
used_step_idx = None
|
891
|
-
for i, record in enumerate(self.steps_history):
|
892
|
-
if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
|
893
|
-
best_decision_idx = i
|
894
|
-
# find the step with that name
|
895
|
-
used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
|
896
|
-
if used_step_idx is not None:
|
897
|
-
rel = Relationship(
|
898
|
-
properties={},
|
899
|
-
rel_type=StepRelation.SELECTS,
|
900
|
-
start_node=step_nodes[i],
|
901
|
-
end_node=step_nodes[used_step_idx],
|
902
|
-
globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
|
903
|
-
)
|
904
|
-
g.add_relationship(rel)
|
905
|
-
|
906
|
-
# And link the final answer to the best approach
|
907
|
-
final_answer_idx = next(
|
908
|
-
(i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
|
909
|
-
)
|
910
|
-
if final_answer_idx is not None and best_decision_idx is not None:
|
911
|
-
rel = Relationship(
|
912
|
-
properties={},
|
913
|
-
rel_type=StepRelation.RESULT_OF,
|
914
|
-
start_node=step_nodes[final_answer_idx],
|
915
|
-
end_node=step_nodes[best_decision_idx],
|
916
|
-
globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
|
917
|
-
)
|
918
|
-
g.add_relationship(rel)
|
919
|
-
|
920
|
-
return g
|
921
|
-
|
922
|
-
|
923
|
-
# ---------------------------------------------------------------------------------
|
924
|
-
# 5) AoTStrategy class that uses the pipeline
|
925
|
-
# ---------------------------------------------------------------------------------
|
926
|
-
|
927
|
-
|
928
|
-
@dataclass
|
929
|
-
class AoTStrategy(BaseStrategy):
|
930
|
-
"""
|
931
|
-
Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
|
932
|
-
"""
|
933
|
-
|
934
|
-
pipeline: AoTPipeline
|
935
|
-
|
936
|
-
async def ainvoke(self, messages: LanguageModelInput) -> str:
|
937
|
-
"""Asynchronously run the pipeline with the given messages."""
|
938
|
-
# Convert your custom input to list[BaseMessage] as needed:
|
939
|
-
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
940
|
-
return await self.pipeline.arun_pipeline(msgs)
|
941
|
-
|
942
|
-
def invoke(self, messages: LanguageModelInput) -> str:
|
943
|
-
"""Synchronously run the pipeline with the given messages."""
|
944
|
-
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
945
|
-
return self.pipeline.run_pipeline(msgs)
|
946
|
-
|
947
|
-
def get_reasoning_graph(self):
|
948
|
-
"""Return the AoT reasoning graph from the pipeline’s steps history."""
|
949
|
-
return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
|
950
|
-
|
951
|
-
|
952
|
-
# ---------------------------------------------------------------------------------
|
953
|
-
# Example usage (pseudo-code)
|
954
|
-
# ---------------------------------------------------------------------------------
|
955
|
-
if __name__ == "__main__":
|
956
|
-
from neo4j_extension import Neo4jConnection # or your actual DB connector
|
957
|
-
|
958
|
-
# You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
|
959
|
-
chatterer = Chatterer.openai() # pseudo-code
|
960
|
-
pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
|
961
|
-
strategy = AoTStrategy(pipeline=pipeline)
|
962
|
-
|
963
|
-
question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
|
964
|
-
answer = strategy.invoke(question)
|
965
|
-
print("Final Answer:", answer)
|
966
|
-
|
967
|
-
# Build the reasoning graph
|
968
|
-
graph = strategy.get_reasoning_graph()
|
969
|
-
print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
|
970
|
-
|
971
|
-
# Optionally store in Neo4j
|
972
|
-
with Neo4jConnection() as conn:
|
973
|
-
conn.clear_all()
|
974
|
-
conn.upsert_graph(graph)
|
975
|
-
print("Graph stored in Neo4j.")
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from enum import StrEnum
|
7
|
+
from typing import Optional, Type, TypeVar
|
8
|
+
|
9
|
+
from pydantic import BaseModel, Field, ValidationError
|
10
|
+
|
11
|
+
from ..language_model import Chatterer, LanguageModelInput
|
12
|
+
from ..messages import AIMessage, BaseMessage, HumanMessage
|
13
|
+
from .base import BaseStrategy
|
14
|
+
|
15
|
+
# ---------------------------------------------------------------------------------
|
16
|
+
# 0) Enums and Basic Models
|
17
|
+
# ---------------------------------------------------------------------------------
|
18
|
+
|
19
|
+
QA_TEMPLATE = "Q: {question}\nA: {answer}"
|
20
|
+
MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
|
21
|
+
UNKNOWN = "Unknown"
|
22
|
+
|
23
|
+
|
24
|
+
class SubQuestionNode(BaseModel):
|
25
|
+
"""A single sub-question node in a decomposition tree."""
|
26
|
+
|
27
|
+
question: str = Field(description="A sub-question string that arises from decomposition.")
|
28
|
+
answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
|
29
|
+
depend: list[int] = Field(description="Indices of sub-questions that this node depends on.")
|
30
|
+
|
31
|
+
|
32
|
+
class RecursiveDecomposeResponse(BaseModel):
|
33
|
+
"""The result of a recursive decomposition step."""
|
34
|
+
|
35
|
+
thought: str = Field(description="Reasoning about decomposition.")
|
36
|
+
final_answer: str = Field(description="Best answer to the main question.")
|
37
|
+
sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
|
38
|
+
|
39
|
+
|
40
|
+
class ContractQuestionResponse(BaseModel):
|
41
|
+
"""The result of contracting (simplifying) a question."""
|
42
|
+
|
43
|
+
thought: str = Field(description="Reasoning on how the question was compressed.")
|
44
|
+
question: str = Field(description="New, simplified, self-contained question.")
|
45
|
+
|
46
|
+
|
47
|
+
class EnsembleResponse(BaseModel):
|
48
|
+
"""The ensemble process result."""
|
49
|
+
|
50
|
+
thought: str = Field(description="Explanation for choosing the final answer.")
|
51
|
+
answer: str = Field(description="Best final answer after ensemble.")
|
52
|
+
confidence: float = Field(description="Confidence score in [0, 1].")
|
53
|
+
|
54
|
+
def model_post_init(self, __context: object) -> None:
|
55
|
+
self.confidence = max(0.0, min(1.0, self.confidence))
|
56
|
+
|
57
|
+
|
58
|
+
class LabelResponse(BaseModel):
|
59
|
+
"""Used to refine and reorder the sub-questions with corrected dependencies."""
|
60
|
+
|
61
|
+
thought: str = Field(description="Explanation or reasoning about labeling.")
|
62
|
+
sub_questions: list[SubQuestionNode] = Field(
|
63
|
+
description="Refined list of sub-questions with corrected dependencies."
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
class CritiqueResponse(BaseModel):
|
68
|
+
"""A response used for LLM to self-critique or question its own correctness."""
|
69
|
+
|
70
|
+
thought: str = Field(description="Critical reflection on correctness.")
|
71
|
+
self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
|
72
|
+
|
73
|
+
|
74
|
+
# ---------------------------------------------------------------------------------
|
75
|
+
# [NEW] Additional classes to incorporate a separate sub-question devil's advocate
|
76
|
+
# ---------------------------------------------------------------------------------
|
77
|
+
|
78
|
+
|
79
|
+
class DevilsAdvocateResponse(BaseModel):
|
80
|
+
"""
|
81
|
+
A response for a 'devil's advocate' pass.
|
82
|
+
We consider an alternative viewpoint or contradictory answer.
|
83
|
+
"""
|
84
|
+
|
85
|
+
thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
|
86
|
+
final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
|
87
|
+
sub_questions: list[SubQuestionNode] = Field(
|
88
|
+
description="Any additional sub-questions from the contrarian perspective."
|
89
|
+
)
|
90
|
+
|
91
|
+
|
92
|
+
# ---------------------------------------------------------------------------------
|
93
|
+
# 1) Prompter Classes with Multi-Hop + Devil's Advocate
|
94
|
+
# ---------------------------------------------------------------------------------
|
95
|
+
|
96
|
+
|
97
|
+
class AoTPrompter:
|
98
|
+
"""Generic base prompter that defines the required prompt methods."""
|
99
|
+
|
100
|
+
def recursive_decompose_prompt(
|
101
|
+
self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
|
102
|
+
) -> list[BaseMessage]:
|
103
|
+
"""
|
104
|
+
Prompt for main decomposition.
|
105
|
+
Encourages step-by-step reasoning and listing sub-questions as JSON.
|
106
|
+
"""
|
107
|
+
decompose_instructions = (
|
108
|
+
"First, restate the main question.\n"
|
109
|
+
"Decide if sub-questions are needed. If so, list them.\n"
|
110
|
+
"In the 'thought' field, show your chain-of-thought.\n"
|
111
|
+
"Return valid JSON:\n"
|
112
|
+
"{\n"
|
113
|
+
' "thought": "...",\n'
|
114
|
+
' "final_answer": "...",\n'
|
115
|
+
' "sub_questions": [\n'
|
116
|
+
' {"question": "...", "answer": null, "depend": []},\n'
|
117
|
+
" ...\n"
|
118
|
+
" ]\n"
|
119
|
+
"}\n"
|
120
|
+
)
|
121
|
+
|
122
|
+
content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
|
123
|
+
return messages + [
|
124
|
+
HumanMessage(content=f"Main question:\n{question}"),
|
125
|
+
AIMessage(content=content_sub_answers),
|
126
|
+
AIMessage(content=decompose_instructions),
|
127
|
+
]
|
128
|
+
|
129
|
+
def label_prompt(
|
130
|
+
self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
|
131
|
+
) -> list[BaseMessage]:
|
132
|
+
"""
|
133
|
+
Prompt for refining the sub-questions and dependencies.
|
134
|
+
"""
|
135
|
+
label_instructions = (
|
136
|
+
"Review each sub-question to ensure correctness and proper ordering.\n"
|
137
|
+
"Return valid JSON in the form:\n"
|
138
|
+
"{\n"
|
139
|
+
' "thought": "...",\n'
|
140
|
+
' "sub_questions": [\n'
|
141
|
+
' {"question": "...", "answer": "...", "depend": [...]},\n'
|
142
|
+
" ...\n"
|
143
|
+
" ]\n"
|
144
|
+
"}\n"
|
145
|
+
)
|
146
|
+
return messages + [
|
147
|
+
AIMessage(content=f"Question: {question}"),
|
148
|
+
AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
|
149
|
+
AIMessage(content=label_instructions),
|
150
|
+
]
|
151
|
+
|
152
|
+
def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
|
153
|
+
"""
|
154
|
+
Prompt for merging sub-answers into one self-contained question.
|
155
|
+
"""
|
156
|
+
contract_instructions = (
|
157
|
+
"Please merge sub-answers into a single short question that is fully self-contained.\n"
|
158
|
+
"In 'thought', show how you unify the information.\n"
|
159
|
+
"Then produce JSON:\n"
|
160
|
+
"{\n"
|
161
|
+
' "thought": "...",\n'
|
162
|
+
' "question": "a short but self-contained question"\n'
|
163
|
+
"}\n"
|
164
|
+
)
|
165
|
+
sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
|
166
|
+
return messages + [
|
167
|
+
AIMessage(content="We have these sub-questions and answers:"),
|
168
|
+
AIMessage(content=sub_q_content),
|
169
|
+
AIMessage(content=contract_instructions),
|
170
|
+
]
|
171
|
+
|
172
|
+
def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
|
173
|
+
"""
|
174
|
+
Prompt for directly answering the contracted question thoroughly.
|
175
|
+
"""
|
176
|
+
direct_instructions = (
|
177
|
+
"Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
|
178
|
+
"Return JSON:\n"
|
179
|
+
"{\n"
|
180
|
+
' "thought": "...",\n'
|
181
|
+
' "final_answer": "..."\n'
|
182
|
+
"}\n"
|
183
|
+
)
|
184
|
+
return messages + [
|
185
|
+
HumanMessage(content=f"Simplified question: {contracted_question}"),
|
186
|
+
AIMessage(content=direct_instructions),
|
187
|
+
]
|
188
|
+
|
189
|
+
def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
|
190
|
+
"""
|
191
|
+
Prompt for self-critique.
|
192
|
+
"""
|
193
|
+
critique_instructions = (
|
194
|
+
"Critique your own approach. Identify possible errors or leaps.\n"
|
195
|
+
"Return JSON:\n"
|
196
|
+
"{\n"
|
197
|
+
' "thought": "...",\n'
|
198
|
+
' "self_assessment": <float in [0,1]>\n'
|
199
|
+
"}\n"
|
200
|
+
)
|
201
|
+
return messages + [
|
202
|
+
AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
|
203
|
+
AIMessage(content=f"Your previous ANSWER:\n{answer}"),
|
204
|
+
AIMessage(content=critique_instructions),
|
205
|
+
]
|
206
|
+
|
207
|
+
def ensemble_prompt(
|
208
|
+
self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
|
209
|
+
) -> list[BaseMessage]:
|
210
|
+
"""
|
211
|
+
Show multiple candidate solutions and pick the best final answer with confidence.
|
212
|
+
"""
|
213
|
+
instructions = (
|
214
|
+
"You have multiple candidate solutions. Compare carefully and pick the best.\n"
|
215
|
+
"Return JSON:\n"
|
216
|
+
"{\n"
|
217
|
+
' "thought": "why you chose this final answer",\n'
|
218
|
+
' "answer": "the best consolidated answer",\n'
|
219
|
+
' "confidence": 0.0 ~ 1.0\n'
|
220
|
+
"}\n"
|
221
|
+
)
|
222
|
+
reasonings: list[BaseMessage] = []
|
223
|
+
for idx, (thought, ans) in enumerate(possible_thought_and_answers):
|
224
|
+
reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
|
225
|
+
return messages + reasonings + [AIMessage(content=instructions)]
|
226
|
+
|
227
|
+
def devils_advocate_prompt(
|
228
|
+
self, messages: list[BaseMessage], question: str, existing_answer: str
|
229
|
+
) -> list[BaseMessage]:
|
230
|
+
"""
|
231
|
+
Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
|
232
|
+
"""
|
233
|
+
instructions = (
|
234
|
+
"Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
|
235
|
+
"Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
|
236
|
+
"Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
|
237
|
+
"But here, let's keep it in a new dedicated structure:\n"
|
238
|
+
"{\n"
|
239
|
+
' "thought": "...",\n'
|
240
|
+
' "final_answer": "...",\n'
|
241
|
+
' "sub_questions": [\n'
|
242
|
+
' {"question": "...", "answer": null, "depend": []},\n'
|
243
|
+
" ...\n"
|
244
|
+
" ]\n"
|
245
|
+
"}\n"
|
246
|
+
)
|
247
|
+
return messages + [
|
248
|
+
AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
|
249
|
+
AIMessage(content=instructions),
|
250
|
+
]
|
251
|
+
|
252
|
+
|
253
|
+
# ---------------------------------------------------------------------------------
|
254
|
+
# 2) Strict Typed Steps for Pipeline
|
255
|
+
# ---------------------------------------------------------------------------------
|
256
|
+
|
257
|
+
|
258
|
+
class StepName(StrEnum):
|
259
|
+
"""Enum for step names in the pipeline."""
|
260
|
+
|
261
|
+
DOMAIN_DETECTION = "DomainDetection"
|
262
|
+
DECOMPOSITION = "Decomposition"
|
263
|
+
DECOMPOSITION_CRITIQUE = "DecompositionCritique"
|
264
|
+
CONTRACTED_QUESTION = "ContractedQuestion"
|
265
|
+
CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
|
266
|
+
CONTRACT_CRITIQUE = "ContractCritique"
|
267
|
+
BEST_APPROACH_DECISION = "BestApproachDecision"
|
268
|
+
ENSEMBLE = "Ensemble"
|
269
|
+
FINAL_ANSWER = "FinalAnswer"
|
270
|
+
|
271
|
+
DEVILS_ADVOCATE = "DevilsAdvocate"
|
272
|
+
DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
|
273
|
+
|
274
|
+
|
275
|
+
class StepRelation(StrEnum):
|
276
|
+
"""Enum for relationship types in the reasoning graph."""
|
277
|
+
|
278
|
+
CRITIQUES = "CRITIQUES"
|
279
|
+
SELECTS = "SELECTS"
|
280
|
+
RESULT_OF = "RESULT_OF"
|
281
|
+
SPLIT_INTO = "SPLIT_INTO"
|
282
|
+
DEPEND_ON = "DEPEND_ON"
|
283
|
+
PRECEDES = "PRECEDES"
|
284
|
+
DECOMPOSED_BY = "DECOMPOSED_BY"
|
285
|
+
|
286
|
+
|
287
|
+
class StepRecord(BaseModel):
|
288
|
+
"""A typed record for each pipeline step."""
|
289
|
+
|
290
|
+
step_name: StepName
|
291
|
+
domain: Optional[str] = None
|
292
|
+
score: Optional[float] = None
|
293
|
+
used: Optional[StepName] = None
|
294
|
+
sub_questions: Optional[list[SubQuestionNode]] = None
|
295
|
+
parent_decomp_step_idx: Optional[int] = None
|
296
|
+
parent_subq_idx: Optional[int] = None
|
297
|
+
question: Optional[str] = None
|
298
|
+
thought: Optional[str] = None
|
299
|
+
answer: Optional[str] = None
|
300
|
+
|
301
|
+
def as_properties(self) -> dict[str, str | float | int | None]:
|
302
|
+
"""Converts the StepRecord to a dictionary of properties."""
|
303
|
+
result: dict[str, str | float | int | None] = {}
|
304
|
+
if self.score is not None:
|
305
|
+
result["score"] = self.score
|
306
|
+
if self.domain:
|
307
|
+
result["domain"] = self.domain
|
308
|
+
if self.question:
|
309
|
+
result["question"] = self.question
|
310
|
+
if self.thought:
|
311
|
+
result["thought"] = self.thought
|
312
|
+
if self.answer:
|
313
|
+
result["answer"] = self.answer
|
314
|
+
return result
|
315
|
+
|
316
|
+
|
317
|
+
# ---------------------------------------------------------------------------------
|
318
|
+
# 3) Logging Setup
|
319
|
+
# ---------------------------------------------------------------------------------
|
320
|
+
|
321
|
+
|
322
|
+
class SimpleColorFormatter(logging.Formatter):
|
323
|
+
"""Simple color-coded logging formatter for console output using ANSI escape codes."""
|
324
|
+
|
325
|
+
BLUE = "\033[94m"
|
326
|
+
GREEN = "\033[92m"
|
327
|
+
YELLOW = "\033[93m"
|
328
|
+
RED = "\033[91m"
|
329
|
+
RESET = "\033[0m"
|
330
|
+
LEVEL_COLORS = {
|
331
|
+
logging.DEBUG: BLUE,
|
332
|
+
logging.INFO: GREEN,
|
333
|
+
logging.WARNING: YELLOW,
|
334
|
+
logging.ERROR: RED,
|
335
|
+
logging.CRITICAL: RED,
|
336
|
+
}
|
337
|
+
|
338
|
+
def format(self, record: logging.LogRecord) -> str:
|
339
|
+
log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
|
340
|
+
message = super().format(record)
|
341
|
+
return f"{log_color}{message}{self.RESET}"
|
342
|
+
|
343
|
+
|
344
|
+
logger = logging.getLogger("AoT")
|
345
|
+
logger.setLevel(logging.INFO)
|
346
|
+
handler = logging.StreamHandler()
|
347
|
+
handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
|
348
|
+
logger.handlers = [handler]
|
349
|
+
logger.propagate = False
|
350
|
+
|
351
|
+
|
352
|
+
# ---------------------------------------------------------------------------------
|
353
|
+
# 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
|
354
|
+
# ---------------------------------------------------------------------------------
|
355
|
+
|
356
|
+
T = TypeVar(
|
357
|
+
"T",
|
358
|
+
bound=EnsembleResponse
|
359
|
+
| ContractQuestionResponse
|
360
|
+
| LabelResponse
|
361
|
+
| CritiqueResponse
|
362
|
+
| RecursiveDecomposeResponse
|
363
|
+
| DevilsAdvocateResponse,
|
364
|
+
)
|
365
|
+
|
366
|
+
|
367
|
+
@dataclass
|
368
|
+
class AoTPipeline:
|
369
|
+
"""
|
370
|
+
The pipeline orchestrates:
|
371
|
+
1) Recursive decomposition
|
372
|
+
2) For each sub-question, it tries a main approach + a devil's advocate approach
|
373
|
+
3) Merges sub-answers using an ensemble
|
374
|
+
4) Contracts the question
|
375
|
+
5) Possibly does a direct approach on the contracted question
|
376
|
+
6) Ensembling the final answers
|
377
|
+
"""
|
378
|
+
|
379
|
+
chatterer: Chatterer
|
380
|
+
max_depth: int = 2
|
381
|
+
max_retries: int = 2
|
382
|
+
steps_history: list[StepRecord] = field(default_factory=list)
|
383
|
+
prompter: AoTPrompter = field(default_factory=AoTPrompter)
|
384
|
+
|
385
|
+
# 4.1) Utility for calling the LLM with Pydantic parsing
|
386
|
+
async def _ainvoke_pydantic(
|
387
|
+
self,
|
388
|
+
messages: list[BaseMessage],
|
389
|
+
model_cls: Type[T],
|
390
|
+
fallback: str = "<None>",
|
391
|
+
) -> T:
|
392
|
+
"""
|
393
|
+
Attempts up to max_retries to parse the model_cls from LLM output as JSON.
|
394
|
+
"""
|
395
|
+
for attempt in range(1, self.max_retries + 1):
|
396
|
+
try:
|
397
|
+
return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
|
398
|
+
except ValidationError as e:
|
399
|
+
logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
|
400
|
+
if attempt == self.max_retries:
|
401
|
+
# Return a fallback version
|
402
|
+
if issubclass(model_cls, EnsembleResponse):
|
403
|
+
return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
|
404
|
+
elif issubclass(model_cls, ContractQuestionResponse):
|
405
|
+
return model_cls(thought=fallback, question=fallback) # type: ignore
|
406
|
+
elif issubclass(model_cls, LabelResponse):
|
407
|
+
return model_cls(thought=fallback, sub_questions=[]) # type: ignore
|
408
|
+
elif issubclass(model_cls, CritiqueResponse):
|
409
|
+
return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
|
410
|
+
elif issubclass(model_cls, DevilsAdvocateResponse):
|
411
|
+
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
412
|
+
else:
|
413
|
+
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
414
|
+
# theoretically unreachable
|
415
|
+
raise RuntimeError("Unexpected error in _ainvoke_pydantic")
|
416
|
+
|
417
|
+
# 4.2) Helper method for self-critique
|
418
|
+
async def _ainvoke_critique(
|
419
|
+
self,
|
420
|
+
messages: list[BaseMessage],
|
421
|
+
thought: str,
|
422
|
+
answer: str,
|
423
|
+
) -> CritiqueResponse:
|
424
|
+
"""
|
425
|
+
Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
|
426
|
+
"""
|
427
|
+
return await self._ainvoke_pydantic(
|
428
|
+
messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
|
429
|
+
model_cls=CritiqueResponse,
|
430
|
+
)
|
431
|
+
|
432
|
+
# 4.3) Helper method for devil's advocate approach
|
433
|
+
async def _ainvoke_devils_advocate(
|
434
|
+
self,
|
435
|
+
messages: list[BaseMessage],
|
436
|
+
question: str,
|
437
|
+
existing_answer: str,
|
438
|
+
) -> DevilsAdvocateResponse:
|
439
|
+
"""
|
440
|
+
Instructs the LLM to challenge an existing answer with a devil's advocate approach.
|
441
|
+
"""
|
442
|
+
return await self._ainvoke_pydantic(
|
443
|
+
messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
|
444
|
+
model_cls=DevilsAdvocateResponse,
|
445
|
+
)
|
446
|
+
|
447
|
+
# 4.4) The main function that recursively decomposes a question and calls sub-steps
|
448
|
+
async def _arecursive_decompose_question(
|
449
|
+
self,
|
450
|
+
messages: list[BaseMessage],
|
451
|
+
question: str,
|
452
|
+
depth: int,
|
453
|
+
parent_decomp_step_idx: Optional[int] = None,
|
454
|
+
parent_subq_idx: Optional[int] = None,
|
455
|
+
) -> RecursiveDecomposeResponse:
|
456
|
+
"""
|
457
|
+
Recursively decompose the given question. For each sub-question:
|
458
|
+
1) Recursively decompose that sub-question if we still have depth left
|
459
|
+
2) After getting a main sub-answer, do a devil's advocate pass
|
460
|
+
3) Combine main sub-answer + devil's advocate alternative via an ensemble
|
461
|
+
"""
|
462
|
+
if depth < 0:
|
463
|
+
logger.info("Max depth reached, returning unknown.")
|
464
|
+
return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
|
465
|
+
|
466
|
+
# Step 1: Perform the decomposition
|
467
|
+
decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
468
|
+
messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
|
469
|
+
model_cls=RecursiveDecomposeResponse,
|
470
|
+
)
|
471
|
+
|
472
|
+
# Step 2: Label / refine sub-questions (dependencies, ordering)
|
473
|
+
if decompose_resp.sub_questions:
|
474
|
+
label_resp: LabelResponse = await self._ainvoke_pydantic(
|
475
|
+
messages=self.prompter.label_prompt(messages, question, decompose_resp),
|
476
|
+
model_cls=LabelResponse,
|
477
|
+
)
|
478
|
+
decompose_resp.sub_questions = label_resp.sub_questions
|
479
|
+
|
480
|
+
# Save a pipeline record for this decomposition step
|
481
|
+
current_decomp_step_idx = self._record_decomposition_step(
|
482
|
+
question=question,
|
483
|
+
final_answer=decompose_resp.final_answer,
|
484
|
+
sub_questions=decompose_resp.sub_questions,
|
485
|
+
parent_decomp_step_idx=parent_decomp_step_idx,
|
486
|
+
parent_subq_idx=parent_subq_idx,
|
487
|
+
)
|
488
|
+
|
489
|
+
# Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
|
490
|
+
if depth > 0 and decompose_resp.sub_questions:
|
491
|
+
solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
|
492
|
+
messages=messages,
|
493
|
+
sub_questions=decompose_resp.sub_questions,
|
494
|
+
depth=depth,
|
495
|
+
parent_decomp_step_idx=current_decomp_step_idx,
|
496
|
+
)
|
497
|
+
# Then we can refine the "final_answer" from those sub-answers
|
498
|
+
# or we do a secondary pass to refine the final answer
|
499
|
+
refined_prompt = self.prompter.recursive_decompose_prompt(
|
500
|
+
messages=messages,
|
501
|
+
question=question,
|
502
|
+
sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
|
503
|
+
)
|
504
|
+
refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
505
|
+
refined_prompt, RecursiveDecomposeResponse
|
506
|
+
)
|
507
|
+
decompose_resp.final_answer = refined_resp.final_answer
|
508
|
+
decompose_resp.sub_questions = solved_subs
|
509
|
+
|
510
|
+
# Update pipeline record
|
511
|
+
self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
|
512
|
+
self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
|
513
|
+
|
514
|
+
return decompose_resp
|
515
|
+
|
516
|
+
def _record_decomposition_step(
|
517
|
+
self,
|
518
|
+
question: str,
|
519
|
+
final_answer: str,
|
520
|
+
sub_questions: list[SubQuestionNode],
|
521
|
+
parent_decomp_step_idx: Optional[int],
|
522
|
+
parent_subq_idx: Optional[int],
|
523
|
+
) -> int:
|
524
|
+
"""
|
525
|
+
Save the decomposition step in steps_history, returning the index.
|
526
|
+
"""
|
527
|
+
step_record = StepRecord(
|
528
|
+
step_name=StepName.DECOMPOSITION,
|
529
|
+
question=question,
|
530
|
+
answer=final_answer,
|
531
|
+
sub_questions=sub_questions,
|
532
|
+
parent_decomp_step_idx=parent_decomp_step_idx,
|
533
|
+
parent_subq_idx=parent_subq_idx,
|
534
|
+
)
|
535
|
+
self.steps_history.append(step_record)
|
536
|
+
return len(self.steps_history) - 1
|
537
|
+
|
538
|
+
async def _aresolve_sub_questions(
|
539
|
+
self,
|
540
|
+
messages: list[BaseMessage],
|
541
|
+
sub_questions: list[SubQuestionNode],
|
542
|
+
depth: int,
|
543
|
+
parent_decomp_step_idx: Optional[int],
|
544
|
+
) -> list[SubQuestionNode]:
|
545
|
+
"""
|
546
|
+
Resolve sub-questions in topological order.
|
547
|
+
For each sub-question:
|
548
|
+
1) Recursively decompose (main approach).
|
549
|
+
2) Acquire a devil's advocate alternative.
|
550
|
+
3) Critique or ensemble if needed.
|
551
|
+
4) Finalize sub-question answer.
|
552
|
+
"""
|
553
|
+
n = len(sub_questions)
|
554
|
+
in_degree = [0] * n
|
555
|
+
graph: list[list[int]] = [[] for _ in range(n)]
|
556
|
+
for i, sq in enumerate(sub_questions):
|
557
|
+
for dep in sq.depend:
|
558
|
+
if 0 <= dep < n:
|
559
|
+
in_degree[i] += 1
|
560
|
+
graph[dep].append(i)
|
561
|
+
|
562
|
+
# Kahn's algorithm for topological order
|
563
|
+
queue = [i for i in range(n) if in_degree[i] == 0]
|
564
|
+
topo_order: list[int] = []
|
565
|
+
|
566
|
+
while queue:
|
567
|
+
node = queue.pop(0)
|
568
|
+
topo_order.append(node)
|
569
|
+
for nxt in graph[node]:
|
570
|
+
in_degree[nxt] -= 1
|
571
|
+
if in_degree[nxt] == 0:
|
572
|
+
queue.append(nxt)
|
573
|
+
|
574
|
+
# We'll store the resolved sub-questions
|
575
|
+
final_subs: dict[int, SubQuestionNode] = {}
|
576
|
+
|
577
|
+
async def _resolve_one_subq(idx: int):
|
578
|
+
sq = sub_questions[idx]
|
579
|
+
# 1) Main approach
|
580
|
+
main_resp = await self._arecursive_decompose_question(
|
581
|
+
messages=messages,
|
582
|
+
question=sq.question,
|
583
|
+
depth=depth - 1,
|
584
|
+
parent_decomp_step_idx=parent_decomp_step_idx,
|
585
|
+
parent_subq_idx=idx,
|
586
|
+
)
|
587
|
+
|
588
|
+
main_answer = main_resp.final_answer
|
589
|
+
|
590
|
+
# 2) Devil's Advocate approach
|
591
|
+
devils_resp = await self._ainvoke_devils_advocate(
|
592
|
+
messages=messages, question=sq.question, existing_answer=main_answer
|
593
|
+
)
|
594
|
+
# 3) Ensemble to combine main_answer + devils_alternative
|
595
|
+
ensemble_sub = await self._ainvoke_pydantic(
|
596
|
+
self.prompter.ensemble_prompt(
|
597
|
+
messages=messages,
|
598
|
+
possible_thought_and_answers=[
|
599
|
+
(main_resp.thought, main_answer),
|
600
|
+
(devils_resp.thought, devils_resp.final_answer),
|
601
|
+
],
|
602
|
+
),
|
603
|
+
EnsembleResponse,
|
604
|
+
)
|
605
|
+
sub_best_answer = ensemble_sub.answer
|
606
|
+
|
607
|
+
# Store final subq answer
|
608
|
+
sq.answer = sub_best_answer
|
609
|
+
final_subs[idx] = sq
|
610
|
+
|
611
|
+
# Record pipeline steps for devil's advocate
|
612
|
+
self.steps_history.append(
|
613
|
+
StepRecord(
|
614
|
+
step_name=StepName.DEVILS_ADVOCATE,
|
615
|
+
question=sq.question,
|
616
|
+
answer=devils_resp.final_answer,
|
617
|
+
thought=devils_resp.thought,
|
618
|
+
sub_questions=devils_resp.sub_questions,
|
619
|
+
)
|
620
|
+
)
|
621
|
+
# Possibly critique the devils advocate result
|
622
|
+
dev_adv_crit = await self._ainvoke_critique(
|
623
|
+
messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
|
624
|
+
)
|
625
|
+
self.steps_history.append(
|
626
|
+
StepRecord(
|
627
|
+
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
628
|
+
thought=dev_adv_crit.thought,
|
629
|
+
score=dev_adv_crit.self_assessment,
|
630
|
+
)
|
631
|
+
)
|
632
|
+
|
633
|
+
# Solve sub-questions in topological order
|
634
|
+
tasks = [_resolve_one_subq(i) for i in topo_order]
|
635
|
+
await asyncio.gather(*tasks, return_exceptions=False)
|
636
|
+
|
637
|
+
return [final_subs[i] for i in range(n)]
|
638
|
+
|
639
|
+
# 4.5) The primary pipeline method
|
640
|
+
async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
|
641
|
+
"""
|
642
|
+
Execute the pipeline:
|
643
|
+
1) Decompose the main question (recursively).
|
644
|
+
2) Self-critique.
|
645
|
+
3) Provide a devil's advocate approach on the entire main result.
|
646
|
+
4) Contract sub-answers (optional).
|
647
|
+
5) Directly solve the contracted question.
|
648
|
+
6) Self-critique again.
|
649
|
+
7) Final ensemble across main vs devil's vs contracted direct answer.
|
650
|
+
8) Return final answer.
|
651
|
+
"""
|
652
|
+
self.steps_history.clear()
|
653
|
+
|
654
|
+
original_question: str = messages[-1].text()
|
655
|
+
# 1) Recursive decomposition
|
656
|
+
decomp_resp = await self._arecursive_decompose_question(
|
657
|
+
messages=messages,
|
658
|
+
question=original_question,
|
659
|
+
depth=self.max_depth,
|
660
|
+
)
|
661
|
+
logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
|
662
|
+
|
663
|
+
# 2) Self-critique of main decomposition
|
664
|
+
decomp_critique = await self._ainvoke_critique(
|
665
|
+
messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
|
666
|
+
)
|
667
|
+
self.steps_history.append(
|
668
|
+
StepRecord(
|
669
|
+
step_name=StepName.DECOMPOSITION_CRITIQUE,
|
670
|
+
thought=decomp_critique.thought,
|
671
|
+
score=decomp_critique.self_assessment,
|
672
|
+
)
|
673
|
+
)
|
674
|
+
|
675
|
+
# 3) Devil's advocate on the entire main answer
|
676
|
+
devils_on_main = await self._ainvoke_devils_advocate(
|
677
|
+
messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
|
678
|
+
)
|
679
|
+
self.steps_history.append(
|
680
|
+
StepRecord(
|
681
|
+
step_name=StepName.DEVILS_ADVOCATE,
|
682
|
+
question=original_question,
|
683
|
+
answer=devils_on_main.final_answer,
|
684
|
+
thought=devils_on_main.thought,
|
685
|
+
sub_questions=devils_on_main.sub_questions,
|
686
|
+
)
|
687
|
+
)
|
688
|
+
devils_crit_main = await self._ainvoke_critique(
|
689
|
+
messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
|
690
|
+
)
|
691
|
+
self.steps_history.append(
|
692
|
+
StepRecord(
|
693
|
+
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
694
|
+
thought=devils_crit_main.thought,
|
695
|
+
score=devils_crit_main.self_assessment,
|
696
|
+
)
|
697
|
+
)
|
698
|
+
|
699
|
+
# 4) Contract sub-answers from main decomposition
|
700
|
+
top_decomp_record: Optional[StepRecord] = next(
|
701
|
+
(
|
702
|
+
s
|
703
|
+
for s in reversed(self.steps_history)
|
704
|
+
if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
|
705
|
+
),
|
706
|
+
None,
|
707
|
+
)
|
708
|
+
if top_decomp_record and top_decomp_record.sub_questions:
|
709
|
+
sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
|
710
|
+
else:
|
711
|
+
sub_answers = []
|
712
|
+
|
713
|
+
contract_resp = await self._ainvoke_pydantic(
|
714
|
+
messages=self.prompter.contract_prompt(messages, sub_answers),
|
715
|
+
model_cls=ContractQuestionResponse,
|
716
|
+
)
|
717
|
+
contracted_question = contract_resp.question
|
718
|
+
self.steps_history.append(
|
719
|
+
StepRecord(
|
720
|
+
step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
|
721
|
+
)
|
722
|
+
)
|
723
|
+
|
724
|
+
# 5) Attempt direct approach on contracted question
|
725
|
+
contracted_direct = await self._ainvoke_pydantic(
|
726
|
+
messages=self.prompter.contract_direct_prompt(messages, contracted_question),
|
727
|
+
model_cls=RecursiveDecomposeResponse,
|
728
|
+
fallback="No Contracted Direct Answer",
|
729
|
+
)
|
730
|
+
self.steps_history.append(
|
731
|
+
StepRecord(
|
732
|
+
step_name=StepName.CONTRACTED_DIRECT_ANSWER,
|
733
|
+
answer=contracted_direct.final_answer,
|
734
|
+
thought=contracted_direct.thought,
|
735
|
+
)
|
736
|
+
)
|
737
|
+
logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
|
738
|
+
|
739
|
+
# 5.1) Critique the contracted direct approach
|
740
|
+
contract_critique = await self._ainvoke_critique(
|
741
|
+
messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
|
742
|
+
)
|
743
|
+
self.steps_history.append(
|
744
|
+
StepRecord(
|
745
|
+
step_name=StepName.CONTRACT_CRITIQUE,
|
746
|
+
thought=contract_critique.thought,
|
747
|
+
score=contract_critique.self_assessment,
|
748
|
+
)
|
749
|
+
)
|
750
|
+
|
751
|
+
# 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
|
752
|
+
ensemble_resp = await self._ainvoke_pydantic(
|
753
|
+
self.prompter.ensemble_prompt(
|
754
|
+
messages=messages,
|
755
|
+
possible_thought_and_answers=[
|
756
|
+
(decomp_resp.thought, decomp_resp.final_answer),
|
757
|
+
(devils_on_main.thought, devils_on_main.final_answer),
|
758
|
+
(contracted_direct.thought, contracted_direct.final_answer),
|
759
|
+
],
|
760
|
+
),
|
761
|
+
EnsembleResponse,
|
762
|
+
)
|
763
|
+
best_approach_answer = ensemble_resp.answer
|
764
|
+
approach_used = StepName.ENSEMBLE
|
765
|
+
self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
|
766
|
+
logger.info(f"[Best Approach Decision] => {approach_used}")
|
767
|
+
|
768
|
+
# 7) Final answer
|
769
|
+
self.steps_history.append(
|
770
|
+
StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
|
771
|
+
)
|
772
|
+
logger.info(f"[Final Answer] => {best_approach_answer}")
|
773
|
+
|
774
|
+
return best_approach_answer
|
775
|
+
|
776
|
+
def run_pipeline(self, messages: list[BaseMessage]) -> str:
|
777
|
+
"""Synchronous wrapper around arun_pipeline."""
|
778
|
+
return asyncio.run(self.arun_pipeline(messages))
|
779
|
+
|
780
|
+
# ---------------------------------------------------------------------------------
|
781
|
+
# 4.6) Build or export a reasoning graph
|
782
|
+
# ---------------------------------------------------------------------------------
|
783
|
+
|
784
|
+
def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
|
785
|
+
"""
|
786
|
+
Constructs a Graph object (from hypothetical `neo4j_extension`)
|
787
|
+
capturing the pipeline steps, including devil's advocate steps.
|
788
|
+
"""
|
789
|
+
from neo4j_extension import Graph, Node, Relationship
|
790
|
+
|
791
|
+
g = Graph()
|
792
|
+
step_nodes: dict[int, Node] = {}
|
793
|
+
subq_nodes: dict[str, Node] = {}
|
794
|
+
|
795
|
+
# Step A: Create nodes for each pipeline step
|
796
|
+
for i, record in enumerate(self.steps_history):
|
797
|
+
# We'll skip nested Decomposition steps only if we want to flatten them.
|
798
|
+
# But let's keep them for clarity.
|
799
|
+
step_node = Node(
|
800
|
+
properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
|
801
|
+
)
|
802
|
+
g.add_node(step_node)
|
803
|
+
step_nodes[i] = step_node
|
804
|
+
|
805
|
+
# Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
|
806
|
+
all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
|
807
|
+
for i, record in enumerate(self.steps_history):
|
808
|
+
if record.sub_questions:
|
809
|
+
for sq_idx, sq in enumerate(record.sub_questions):
|
810
|
+
sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
811
|
+
all_sub_questions[sq_id] = (i, sq_idx, sq)
|
812
|
+
|
813
|
+
for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
|
814
|
+
n_subq = Node(
|
815
|
+
properties={
|
816
|
+
"question": sq.question,
|
817
|
+
"answer": sq.answer or "",
|
818
|
+
},
|
819
|
+
labels={"SubQuestion"},
|
820
|
+
globalId=sq_id,
|
821
|
+
)
|
822
|
+
g.add_node(n_subq)
|
823
|
+
subq_nodes[sq_id] = n_subq
|
824
|
+
|
825
|
+
# Step C: Add relationships. We do a simple approach:
|
826
|
+
# - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
|
827
|
+
for i, record in enumerate(self.steps_history):
|
828
|
+
if record.sub_questions:
|
829
|
+
start_node = step_nodes[i]
|
830
|
+
for sq_idx, sq in enumerate(record.sub_questions):
|
831
|
+
sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
832
|
+
end_node = subq_nodes[sq_id]
|
833
|
+
rel = Relationship(
|
834
|
+
properties={},
|
835
|
+
rel_type=StepRelation.SPLIT_INTO,
|
836
|
+
start_node=start_node,
|
837
|
+
end_node=end_node,
|
838
|
+
globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
|
839
|
+
)
|
840
|
+
g.add_relationship(rel)
|
841
|
+
# Also add sub-question dependencies
|
842
|
+
for dep in sq.depend:
|
843
|
+
# The same record i -> sub-question subq
|
844
|
+
if 0 <= dep < len(record.sub_questions):
|
845
|
+
dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
|
846
|
+
if dep_id in subq_nodes:
|
847
|
+
dep_node = subq_nodes[dep_id]
|
848
|
+
rel_dep = Relationship(
|
849
|
+
properties={},
|
850
|
+
rel_type=StepRelation.DEPEND_ON,
|
851
|
+
start_node=end_node,
|
852
|
+
end_node=dep_node,
|
853
|
+
globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
|
854
|
+
)
|
855
|
+
g.add_relationship(rel_dep)
|
856
|
+
|
857
|
+
# Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
|
858
|
+
for i in range(len(self.steps_history) - 1):
|
859
|
+
start_node = step_nodes[i]
|
860
|
+
end_node = step_nodes[i + 1]
|
861
|
+
rel = Relationship(
|
862
|
+
properties={},
|
863
|
+
rel_type=StepRelation.PRECEDES,
|
864
|
+
start_node=start_node,
|
865
|
+
end_node=end_node,
|
866
|
+
globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
|
867
|
+
)
|
868
|
+
g.add_relationship(rel)
|
869
|
+
|
870
|
+
# Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
|
871
|
+
# We'll do a simple pass:
|
872
|
+
# If step_name ends with CRITIQUE => it critiques the step before it
|
873
|
+
for i, record in enumerate(self.steps_history):
|
874
|
+
if "CRITIQUE" in record.step_name:
|
875
|
+
# Let it point to the preceding step
|
876
|
+
if i > 0:
|
877
|
+
start_node = step_nodes[i]
|
878
|
+
end_node = step_nodes[i - 1]
|
879
|
+
rel = Relationship(
|
880
|
+
properties={},
|
881
|
+
rel_type=StepRelation.CRITIQUES,
|
882
|
+
start_node=start_node,
|
883
|
+
end_node=end_node,
|
884
|
+
globalId=f"{global_id_prefix}_crit_{i}",
|
885
|
+
)
|
886
|
+
g.add_relationship(rel)
|
887
|
+
|
888
|
+
# If there's a BEST_APPROACH_DECISION step, link it to the step it uses
|
889
|
+
best_decision_idx = None
|
890
|
+
used_step_idx = None
|
891
|
+
for i, record in enumerate(self.steps_history):
|
892
|
+
if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
|
893
|
+
best_decision_idx = i
|
894
|
+
# find the step with that name
|
895
|
+
used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
|
896
|
+
if used_step_idx is not None:
|
897
|
+
rel = Relationship(
|
898
|
+
properties={},
|
899
|
+
rel_type=StepRelation.SELECTS,
|
900
|
+
start_node=step_nodes[i],
|
901
|
+
end_node=step_nodes[used_step_idx],
|
902
|
+
globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
|
903
|
+
)
|
904
|
+
g.add_relationship(rel)
|
905
|
+
|
906
|
+
# And link the final answer to the best approach
|
907
|
+
final_answer_idx = next(
|
908
|
+
(i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
|
909
|
+
)
|
910
|
+
if final_answer_idx is not None and best_decision_idx is not None:
|
911
|
+
rel = Relationship(
|
912
|
+
properties={},
|
913
|
+
rel_type=StepRelation.RESULT_OF,
|
914
|
+
start_node=step_nodes[final_answer_idx],
|
915
|
+
end_node=step_nodes[best_decision_idx],
|
916
|
+
globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
|
917
|
+
)
|
918
|
+
g.add_relationship(rel)
|
919
|
+
|
920
|
+
return g
|
921
|
+
|
922
|
+
|
923
|
+
# ---------------------------------------------------------------------------------
|
924
|
+
# 5) AoTStrategy class that uses the pipeline
|
925
|
+
# ---------------------------------------------------------------------------------
|
926
|
+
|
927
|
+
|
928
|
+
@dataclass
|
929
|
+
class AoTStrategy(BaseStrategy):
|
930
|
+
"""
|
931
|
+
Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
|
932
|
+
"""
|
933
|
+
|
934
|
+
pipeline: AoTPipeline
|
935
|
+
|
936
|
+
async def ainvoke(self, messages: LanguageModelInput) -> str:
|
937
|
+
"""Asynchronously run the pipeline with the given messages."""
|
938
|
+
# Convert your custom input to list[BaseMessage] as needed:
|
939
|
+
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
940
|
+
return await self.pipeline.arun_pipeline(msgs)
|
941
|
+
|
942
|
+
def invoke(self, messages: LanguageModelInput) -> str:
|
943
|
+
"""Synchronously run the pipeline with the given messages."""
|
944
|
+
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
945
|
+
return self.pipeline.run_pipeline(msgs)
|
946
|
+
|
947
|
+
def get_reasoning_graph(self):
|
948
|
+
"""Return the AoT reasoning graph from the pipeline’s steps history."""
|
949
|
+
return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
|
950
|
+
|
951
|
+
|
952
|
+
# ---------------------------------------------------------------------------------
|
953
|
+
# Example usage (pseudo-code)
|
954
|
+
# ---------------------------------------------------------------------------------
|
955
|
+
if __name__ == "__main__":
|
956
|
+
from neo4j_extension import Neo4jConnection # or your actual DB connector
|
957
|
+
|
958
|
+
# You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
|
959
|
+
chatterer = Chatterer.openai() # pseudo-code
|
960
|
+
pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
|
961
|
+
strategy = AoTStrategy(pipeline=pipeline)
|
962
|
+
|
963
|
+
question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
|
964
|
+
answer = strategy.invoke(question)
|
965
|
+
print("Final Answer:", answer)
|
966
|
+
|
967
|
+
# Build the reasoning graph
|
968
|
+
graph = strategy.get_reasoning_graph()
|
969
|
+
print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
|
970
|
+
|
971
|
+
# Optionally store in Neo4j
|
972
|
+
with Neo4jConnection() as conn:
|
973
|
+
conn.clear_all()
|
974
|
+
conn.upsert_graph(graph)
|
975
|
+
print("Graph stored in Neo4j.")
|