chatterer 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterer/__init__.py +14 -8
- chatterer/language_model.py +73 -354
- chatterer/py.typed +0 -0
- chatterer/strategies/__init__.py +2 -8
- chatterer/strategies/atom_of_thoughts.py +808 -426
- chatterer/tools/__init__.py +15 -0
- chatterer/tools/convert_to_text.py +464 -0
- chatterer/tools/webpage_to_markdown/__init__.py +4 -0
- chatterer/tools/webpage_to_markdown/playwright_bot.py +631 -0
- chatterer/tools/webpage_to_markdown/utils.py +556 -0
- {chatterer-0.1.4.dist-info → chatterer-0.1.6.dist-info}/METADATA +21 -5
- chatterer-0.1.6.dist-info/RECORD +15 -0
- {chatterer-0.1.4.dist-info → chatterer-0.1.6.dist-info}/WHEEL +1 -1
- chatterer-0.1.4.dist-info/RECORD +0 -9
- {chatterer-0.1.4.dist-info → chatterer-0.1.6.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,25 @@
|
|
1
|
-
# original source: https://github.com/qixucen/atom
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
|
5
3
|
import asyncio
|
6
4
|
import logging
|
7
|
-
from abc import ABC, abstractmethod
|
8
5
|
from dataclasses import dataclass, field
|
9
6
|
from enum import StrEnum
|
10
|
-
from typing import
|
7
|
+
from typing import Optional, Type, TypeVar
|
11
8
|
|
9
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
12
10
|
from pydantic import BaseModel, Field, ValidationError
|
13
11
|
|
12
|
+
# Import your Chatterer interface (do not remove)
|
14
13
|
from ..language_model import Chatterer, LanguageModelInput
|
15
14
|
from .base import BaseStrategy
|
16
15
|
|
17
|
-
|
18
|
-
|
19
|
-
#
|
20
|
-
|
16
|
+
# ---------------------------------------------------------------------------------
|
17
|
+
# 0) Enums and Basic Models
|
18
|
+
# ---------------------------------------------------------------------------------
|
21
19
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
GENERAL = "general"
|
26
|
-
MATH = "math"
|
27
|
-
CODING = "coding"
|
28
|
-
PHILOSOPHY = "philosophy"
|
29
|
-
MULTIHOP = "multihop"
|
20
|
+
QA_TEMPLATE = "Q: {question}\nA: {answer}"
|
21
|
+
MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
|
22
|
+
UNKNOWN = "Unknown"
|
30
23
|
|
31
24
|
|
32
25
|
class SubQuestionNode(BaseModel):
|
@@ -45,13 +38,6 @@ class RecursiveDecomposeResponse(BaseModel):
|
|
45
38
|
sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
|
46
39
|
|
47
40
|
|
48
|
-
class DirectResponse(BaseModel):
|
49
|
-
"""A direct response to a question."""
|
50
|
-
|
51
|
-
thought: str = Field(description="Short reasoning.")
|
52
|
-
answer: str = Field(description="Direct final answer.")
|
53
|
-
|
54
|
-
|
55
41
|
class ContractQuestionResponse(BaseModel):
|
56
42
|
"""The result of contracting (simplifying) a question."""
|
57
43
|
|
@@ -66,359 +52,507 @@ class EnsembleResponse(BaseModel):
|
|
66
52
|
answer: str = Field(description="Best final answer after ensemble.")
|
67
53
|
confidence: float = Field(description="Confidence score in [0, 1].")
|
68
54
|
|
69
|
-
def model_post_init(self, __context) -> None:
|
70
|
-
# Clamp confidence to [0, 1]
|
55
|
+
def model_post_init(self, __context: object) -> None:
|
71
56
|
self.confidence = max(0.0, min(1.0, self.confidence))
|
72
57
|
|
73
58
|
|
74
59
|
class LabelResponse(BaseModel):
|
75
|
-
"""
|
60
|
+
"""Used to refine and reorder the sub-questions with corrected dependencies."""
|
76
61
|
|
77
62
|
thought: str = Field(description="Explanation or reasoning about labeling.")
|
78
63
|
sub_questions: list[SubQuestionNode] = Field(
|
79
64
|
description="Refined list of sub-questions with corrected dependencies."
|
80
65
|
)
|
81
|
-
# Some tasks also keep the final answer, but we focus on sub-questions.
|
82
66
|
|
83
67
|
|
84
|
-
|
68
|
+
class CritiqueResponse(BaseModel):
|
69
|
+
"""A response used for LLM to self-critique or question its own correctness."""
|
85
70
|
|
71
|
+
thought: str = Field(description="Critical reflection on correctness.")
|
72
|
+
self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
|
86
73
|
|
87
|
-
class BaseAoTPrompter(ABC):
|
88
|
-
"""
|
89
|
-
Abstract base prompter that defines the required prompt methods.
|
90
|
-
"""
|
91
74
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
) -> str: ...
|
75
|
+
# ---------------------------------------------------------------------------------
|
76
|
+
# [NEW] Additional classes to incorporate a separate sub-question devil's advocate
|
77
|
+
# ---------------------------------------------------------------------------------
|
96
78
|
|
97
|
-
@abstractmethod
|
98
|
-
def direct_prompt(self, question: str, context: Optional[str] = None) -> str: ...
|
99
79
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
Prompt used to 're-check' the sub-questions for correctness and dependencies.
|
106
|
-
"""
|
80
|
+
class DevilsAdvocateResponse(BaseModel):
|
81
|
+
"""
|
82
|
+
A response for a 'devil's advocate' pass.
|
83
|
+
We consider an alternative viewpoint or contradictory answer.
|
84
|
+
"""
|
107
85
|
|
108
|
-
|
109
|
-
|
86
|
+
thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
|
87
|
+
final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
|
88
|
+
sub_questions: list[SubQuestionNode] = Field(
|
89
|
+
description="Any additional sub-questions from the contrarian perspective."
|
90
|
+
)
|
110
91
|
|
111
|
-
@abstractmethod
|
112
|
-
def ensemble_prompt(
|
113
|
-
self,
|
114
|
-
original_question: str,
|
115
|
-
direct_answer: str,
|
116
|
-
decompose_answer: str,
|
117
|
-
contracted_direct_answer: str,
|
118
|
-
context: Optional[str] = None,
|
119
|
-
) -> str: ...
|
120
92
|
|
93
|
+
# ---------------------------------------------------------------------------------
|
94
|
+
# 1) Prompter Classes with Multi-Hop + Devil's Advocate
|
95
|
+
# ---------------------------------------------------------------------------------
|
121
96
|
|
122
|
-
|
123
|
-
|
124
|
-
Generic prompter
|
125
|
-
"""
|
97
|
+
|
98
|
+
class AoTPrompter:
|
99
|
+
"""Generic base prompter that defines the required prompt methods."""
|
126
100
|
|
127
101
|
def recursive_decompose_prompt(
|
128
|
-
self, question: str, sub_answers:
|
129
|
-
) ->
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
"
|
136
|
-
"
|
137
|
-
"
|
138
|
-
|
139
|
-
|
140
|
-
'
|
141
|
-
"
|
142
|
-
"
|
143
|
-
"
|
144
|
-
"
|
145
|
-
|
102
|
+
self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
|
103
|
+
) -> list[BaseMessage]:
|
104
|
+
"""
|
105
|
+
Prompt for main decomposition.
|
106
|
+
Encourages step-by-step reasoning and listing sub-questions as JSON.
|
107
|
+
"""
|
108
|
+
decompose_instructions = (
|
109
|
+
"First, restate the main question.\n"
|
110
|
+
"Decide if sub-questions are needed. If so, list them.\n"
|
111
|
+
"In the 'thought' field, show your chain-of-thought.\n"
|
112
|
+
"Return valid JSON:\n"
|
113
|
+
"{\n"
|
114
|
+
' "thought": "...",\n'
|
115
|
+
' "final_answer": "...",\n'
|
116
|
+
' "sub_questions": [\n'
|
117
|
+
' {"question": "...", "answer": null, "depend": []},\n'
|
118
|
+
" ...\n"
|
119
|
+
" ]\n"
|
120
|
+
"}\n"
|
146
121
|
)
|
147
122
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
"1. Return valid JSON:\n"
|
155
|
-
" {'thought': '...', 'answer': '...'}\n"
|
156
|
-
"2. 'thought': Offer a brief reasoning.\n"
|
157
|
-
"3. 'answer': Deliver a precise solution.\n\n"
|
158
|
-
f"QUESTION:\n{question}{context_str}"
|
159
|
-
)
|
123
|
+
content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
|
124
|
+
return messages + [
|
125
|
+
HumanMessage(content=f"Main question:\n{question}"),
|
126
|
+
AIMessage(content=content_sub_answers),
|
127
|
+
AIMessage(content=decompose_instructions),
|
128
|
+
]
|
160
129
|
|
161
130
|
def label_prompt(
|
162
|
-
self, question: str, decompose_response: RecursiveDecomposeResponse
|
163
|
-
) ->
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
"
|
169
|
-
"
|
170
|
-
"
|
171
|
-
'
|
172
|
-
'
|
173
|
-
'
|
174
|
-
"
|
175
|
-
"
|
176
|
-
"
|
177
|
-
"2. 'thought': Provide reasoning about any changes.\n"
|
178
|
-
"3. 'sub_questions': Possibly updated sub-questions with correct 'depend' lists.\n\n"
|
179
|
-
f"ORIGINAL QUESTION:\n{question}\n"
|
180
|
-
f"CURRENT DECOMPOSITION:\n{decompose_response.model_dump_json(indent=2)}"
|
181
|
-
f"{context_str}"
|
131
|
+
self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
|
132
|
+
) -> list[BaseMessage]:
|
133
|
+
"""
|
134
|
+
Prompt for refining the sub-questions and dependencies.
|
135
|
+
"""
|
136
|
+
label_instructions = (
|
137
|
+
"Review each sub-question to ensure correctness and proper ordering.\n"
|
138
|
+
"Return valid JSON in the form:\n"
|
139
|
+
"{\n"
|
140
|
+
' "thought": "...",\n'
|
141
|
+
' "sub_questions": [\n'
|
142
|
+
' {"question": "...", "answer": "...", "depend": [...]},\n'
|
143
|
+
" ...\n"
|
144
|
+
" ]\n"
|
145
|
+
"}\n"
|
182
146
|
)
|
147
|
+
return messages + [
|
148
|
+
AIMessage(content=f"Question: {question}"),
|
149
|
+
AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
|
150
|
+
AIMessage(content=label_instructions),
|
151
|
+
]
|
183
152
|
|
184
|
-
def contract_prompt(self,
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
"
|
190
|
-
"
|
191
|
-
"
|
192
|
-
"
|
193
|
-
|
194
|
-
|
195
|
-
|
153
|
+
def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
|
154
|
+
"""
|
155
|
+
Prompt for merging sub-answers into one self-contained question.
|
156
|
+
"""
|
157
|
+
contract_instructions = (
|
158
|
+
"Please merge sub-answers into a single short question that is fully self-contained.\n"
|
159
|
+
"In 'thought', show how you unify the information.\n"
|
160
|
+
"Then produce JSON:\n"
|
161
|
+
"{\n"
|
162
|
+
' "thought": "...",\n'
|
163
|
+
' "question": "a short but self-contained question"\n'
|
164
|
+
"}\n"
|
196
165
|
)
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
"
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
"
|
214
|
-
"
|
215
|
-
" {'thought': '...', 'answer': '...', 'confidence': <float in [0,1]>}\n"
|
216
|
-
"2. 'thought': Summarize how you decided.\n"
|
217
|
-
"3. 'answer': Final best answer.\n"
|
218
|
-
"4. 'confidence': Confidence score in [0,1].\n\n"
|
219
|
-
f"ORIGINAL QUESTION:\n{original_question}"
|
220
|
-
f"{context_str}"
|
166
|
+
sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
|
167
|
+
return messages + [
|
168
|
+
AIMessage(content="We have these sub-questions and answers:"),
|
169
|
+
AIMessage(content=sub_q_content),
|
170
|
+
AIMessage(content=contract_instructions),
|
171
|
+
]
|
172
|
+
|
173
|
+
def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
|
174
|
+
"""
|
175
|
+
Prompt for directly answering the contracted question thoroughly.
|
176
|
+
"""
|
177
|
+
direct_instructions = (
|
178
|
+
"Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
|
179
|
+
"Return JSON:\n"
|
180
|
+
"{\n"
|
181
|
+
' "thought": "...",\n'
|
182
|
+
' "final_answer": "..."\n'
|
183
|
+
"}\n"
|
221
184
|
)
|
185
|
+
return messages + [
|
186
|
+
HumanMessage(content=f"Simplified question: {contracted_question}"),
|
187
|
+
AIMessage(content=direct_instructions),
|
188
|
+
]
|
222
189
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
|
236
|
-
base = super().direct_prompt(question, context)
|
237
|
-
return base + "\nEnsure mathematical correctness in your reasoning.\n"
|
238
|
-
|
239
|
-
# label_prompt, contract_prompt, ensemble_prompt can also be overridden similarly if needed.
|
240
|
-
|
241
|
-
|
242
|
-
class CodingAoTPrompter(GeneralAoTPrompter):
|
243
|
-
"""
|
244
|
-
Specialized prompter for coding/algorithmic queries.
|
245
|
-
"""
|
246
|
-
|
247
|
-
def recursive_decompose_prompt(
|
248
|
-
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
249
|
-
) -> str:
|
250
|
-
base = super().recursive_decompose_prompt(question, sub_answers, context)
|
251
|
-
return base + "\nBreak down into programming concepts or implementation steps.\n"
|
252
|
-
|
253
|
-
|
254
|
-
class PhilosophyAoTPrompter(GeneralAoTPrompter):
|
255
|
-
"""
|
256
|
-
Specialized prompter for philosophical discussions.
|
257
|
-
"""
|
258
|
-
|
259
|
-
def recursive_decompose_prompt(
|
260
|
-
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
261
|
-
) -> str:
|
262
|
-
base = super().recursive_decompose_prompt(question, sub_answers, context)
|
263
|
-
return base + "\nConsider key philosophical theories and arguments.\n"
|
264
|
-
|
265
|
-
|
266
|
-
class MultiHopAoTPrompter(GeneralAoTPrompter):
|
267
|
-
"""
|
268
|
-
Specialized prompter for multi-hop Q&A with explicit context usage.
|
269
|
-
"""
|
270
|
-
|
271
|
-
def recursive_decompose_prompt(
|
272
|
-
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
273
|
-
) -> str:
|
274
|
-
base = super().recursive_decompose_prompt(question, sub_answers, context)
|
275
|
-
return (
|
276
|
-
base + "\nTreat this as a multi-hop question. Use the provided context carefully.\n"
|
277
|
-
"Extract partial evidence from the context for each sub-question.\n"
|
190
|
+
def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
|
191
|
+
"""
|
192
|
+
Prompt for self-critique.
|
193
|
+
"""
|
194
|
+
critique_instructions = (
|
195
|
+
"Critique your own approach. Identify possible errors or leaps.\n"
|
196
|
+
"Return JSON:\n"
|
197
|
+
"{\n"
|
198
|
+
' "thought": "...",\n'
|
199
|
+
' "self_assessment": <float in [0,1]>\n'
|
200
|
+
"}\n"
|
278
201
|
)
|
202
|
+
return messages + [
|
203
|
+
AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
|
204
|
+
AIMessage(content=f"Your previous ANSWER:\n{answer}"),
|
205
|
+
AIMessage(content=critique_instructions),
|
206
|
+
]
|
279
207
|
|
280
|
-
def
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
208
|
+
def ensemble_prompt(
|
209
|
+
self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
|
210
|
+
) -> list[BaseMessage]:
|
211
|
+
"""
|
212
|
+
Show multiple candidate solutions and pick the best final answer with confidence.
|
213
|
+
"""
|
214
|
+
instructions = (
|
215
|
+
"You have multiple candidate solutions. Compare carefully and pick the best.\n"
|
216
|
+
"Return JSON:\n"
|
217
|
+
"{\n"
|
218
|
+
' "thought": "why you chose this final answer",\n'
|
219
|
+
' "answer": "the best consolidated answer",\n'
|
220
|
+
' "confidence": 0.0 ~ 1.0\n'
|
221
|
+
"}\n"
|
222
|
+
)
|
223
|
+
reasonings: list[BaseMessage] = []
|
224
|
+
for idx, (thought, ans) in enumerate(possible_thought_and_answers):
|
225
|
+
reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
|
226
|
+
return messages + reasonings + [AIMessage(content=instructions)]
|
227
|
+
|
228
|
+
def devils_advocate_prompt(
|
229
|
+
self, messages: list[BaseMessage], question: str, existing_answer: str
|
230
|
+
) -> list[BaseMessage]:
|
231
|
+
"""
|
232
|
+
Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
|
233
|
+
"""
|
234
|
+
instructions = (
|
235
|
+
"Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
|
236
|
+
"Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
|
237
|
+
"Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
|
238
|
+
"But here, let's keep it in a new dedicated structure:\n"
|
239
|
+
"{\n"
|
240
|
+
' "thought": "...",\n'
|
241
|
+
' "final_answer": "...",\n'
|
242
|
+
' "sub_questions": [\n'
|
243
|
+
' {"question": "...", "answer": null, "depend": []},\n'
|
244
|
+
" ...\n"
|
245
|
+
" ]\n"
|
246
|
+
"}\n"
|
247
|
+
)
|
248
|
+
return messages + [
|
249
|
+
AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
|
250
|
+
AIMessage(content=instructions),
|
251
|
+
]
|
252
|
+
|
253
|
+
|
254
|
+
# ---------------------------------------------------------------------------------
|
255
|
+
# 2) Strict Typed Steps for Pipeline
|
256
|
+
# ---------------------------------------------------------------------------------
|
257
|
+
|
258
|
+
|
259
|
+
class StepName(StrEnum):
|
260
|
+
"""Enum for step names in the pipeline."""
|
261
|
+
|
262
|
+
DOMAIN_DETECTION = "DomainDetection"
|
263
|
+
DECOMPOSITION = "Decomposition"
|
264
|
+
DECOMPOSITION_CRITIQUE = "DecompositionCritique"
|
265
|
+
CONTRACTED_QUESTION = "ContractedQuestion"
|
266
|
+
CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
|
267
|
+
CONTRACT_CRITIQUE = "ContractCritique"
|
268
|
+
BEST_APPROACH_DECISION = "BestApproachDecision"
|
269
|
+
ENSEMBLE = "Ensemble"
|
270
|
+
FINAL_ANSWER = "FinalAnswer"
|
271
|
+
|
272
|
+
DEVILS_ADVOCATE = "DevilsAdvocate"
|
273
|
+
DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
|
274
|
+
|
275
|
+
|
276
|
+
class StepRelation(StrEnum):
|
277
|
+
"""Enum for relationship types in the reasoning graph."""
|
278
|
+
|
279
|
+
CRITIQUES = "CRITIQUES"
|
280
|
+
SELECTS = "SELECTS"
|
281
|
+
RESULT_OF = "RESULT_OF"
|
282
|
+
SPLIT_INTO = "SPLIT_INTO"
|
283
|
+
DEPEND_ON = "DEPEND_ON"
|
284
|
+
PRECEDES = "PRECEDES"
|
285
|
+
DECOMPOSED_BY = "DECOMPOSED_BY"
|
286
|
+
|
287
|
+
|
288
|
+
class StepRecord(BaseModel):
|
289
|
+
"""A typed record for each pipeline step."""
|
290
|
+
|
291
|
+
step_name: StepName
|
292
|
+
domain: Optional[str] = None
|
293
|
+
score: Optional[float] = None
|
294
|
+
used: Optional[StepName] = None
|
295
|
+
sub_questions: Optional[list[SubQuestionNode]] = None
|
296
|
+
parent_decomp_step_idx: Optional[int] = None
|
297
|
+
parent_subq_idx: Optional[int] = None
|
298
|
+
question: Optional[str] = None
|
299
|
+
thought: Optional[str] = None
|
300
|
+
answer: Optional[str] = None
|
301
|
+
|
302
|
+
def as_properties(self) -> dict[str, str | float | int | None]:
|
303
|
+
"""Converts the StepRecord to a dictionary of properties."""
|
304
|
+
result: dict[str, str | float | int | None] = {}
|
305
|
+
if self.score is not None:
|
306
|
+
result["score"] = self.score
|
307
|
+
if self.domain:
|
308
|
+
result["domain"] = self.domain
|
309
|
+
if self.question:
|
310
|
+
result["question"] = self.question
|
311
|
+
if self.thought:
|
312
|
+
result["thought"] = self.thought
|
313
|
+
if self.answer:
|
314
|
+
result["answer"] = self.answer
|
315
|
+
return result
|
316
|
+
|
317
|
+
|
318
|
+
# ---------------------------------------------------------------------------------
|
319
|
+
# 3) Logging Setup
|
320
|
+
# ---------------------------------------------------------------------------------
|
321
|
+
|
322
|
+
|
323
|
+
class SimpleColorFormatter(logging.Formatter):
|
324
|
+
"""Simple color-coded logging formatter for console output using ANSI escape codes."""
|
325
|
+
|
326
|
+
BLUE = "\033[94m"
|
327
|
+
GREEN = "\033[92m"
|
328
|
+
YELLOW = "\033[93m"
|
329
|
+
RED = "\033[91m"
|
330
|
+
RESET = "\033[0m"
|
331
|
+
LEVEL_COLORS = {
|
332
|
+
logging.DEBUG: BLUE,
|
333
|
+
logging.INFO: GREEN,
|
334
|
+
logging.WARNING: YELLOW,
|
335
|
+
logging.ERROR: RED,
|
336
|
+
logging.CRITICAL: RED,
|
337
|
+
}
|
338
|
+
|
339
|
+
def format(self, record: logging.LogRecord) -> str:
|
340
|
+
log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
|
341
|
+
message = super().format(record)
|
342
|
+
return f"{log_color}{message}{self.RESET}"
|
343
|
+
|
344
|
+
|
345
|
+
logger = logging.getLogger("AoT")
|
346
|
+
logger.setLevel(logging.INFO)
|
347
|
+
handler = logging.StreamHandler()
|
348
|
+
handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
|
349
|
+
logger.handlers = [handler]
|
350
|
+
logger.propagate = False
|
351
|
+
|
352
|
+
|
353
|
+
# ---------------------------------------------------------------------------------
|
354
|
+
# 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
|
355
|
+
# ---------------------------------------------------------------------------------
|
356
|
+
|
357
|
+
T = TypeVar(
|
358
|
+
"T",
|
359
|
+
bound=EnsembleResponse
|
360
|
+
| ContractQuestionResponse
|
361
|
+
| LabelResponse
|
362
|
+
| CritiqueResponse
|
363
|
+
| RecursiveDecomposeResponse
|
364
|
+
| DevilsAdvocateResponse,
|
365
|
+
)
|
288
366
|
|
289
367
|
|
290
368
|
@dataclass
|
291
369
|
class AoTPipeline:
|
292
370
|
"""
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
7) Ensemble
|
301
|
-
8) (Optional) Score final answer if ground-truth is available
|
371
|
+
The pipeline orchestrates:
|
372
|
+
1) Recursive decomposition
|
373
|
+
2) For each sub-question, it tries a main approach + a devil's advocate approach
|
374
|
+
3) Merges sub-answers using an ensemble
|
375
|
+
4) Contracts the question
|
376
|
+
5) Possibly does a direct approach on the contracted question
|
377
|
+
6) Ensembling the final answers
|
302
378
|
"""
|
303
379
|
|
304
380
|
chatterer: Chatterer
|
305
381
|
max_depth: int = 2
|
306
382
|
max_retries: int = 2
|
383
|
+
steps_history: list[StepRecord] = field(default_factory=list)
|
384
|
+
prompter: AoTPrompter = field(default_factory=AoTPrompter)
|
307
385
|
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
}
|
316
|
-
)
|
317
|
-
|
318
|
-
async def _ainvoke_pydantic(self, prompt: str, model_cls: Type[T], default_answer: str = "Unable to process") -> T:
|
386
|
+
# 4.1) Utility for calling the LLM with Pydantic parsing
|
387
|
+
async def _ainvoke_pydantic(
|
388
|
+
self,
|
389
|
+
messages: list[BaseMessage],
|
390
|
+
model_cls: Type[T],
|
391
|
+
fallback: str = "<None>",
|
392
|
+
) -> T:
|
319
393
|
"""
|
320
|
-
Attempts up to max_retries to parse the model_cls from LLM output.
|
394
|
+
Attempts up to max_retries to parse the model_cls from LLM output as JSON.
|
321
395
|
"""
|
322
396
|
for attempt in range(1, self.max_retries + 1):
|
323
397
|
try:
|
324
|
-
|
325
|
-
response_model=model_cls, messages=[{"role": "user", "content": prompt}]
|
326
|
-
)
|
327
|
-
logger.debug(f"[Validated attempt {attempt}] => {result.model_dump_json(indent=2)}")
|
328
|
-
return result
|
398
|
+
return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
|
329
399
|
except ValidationError as e:
|
330
|
-
logger.warning(f"
|
400
|
+
logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
|
331
401
|
if attempt == self.max_retries:
|
332
|
-
# Return
|
333
|
-
if model_cls
|
334
|
-
return model_cls(thought=
|
335
|
-
elif model_cls
|
336
|
-
return model_cls(thought=
|
337
|
-
elif model_cls
|
338
|
-
return model_cls(thought=
|
402
|
+
# Return a fallback version
|
403
|
+
if issubclass(model_cls, EnsembleResponse):
|
404
|
+
return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
|
405
|
+
elif issubclass(model_cls, ContractQuestionResponse):
|
406
|
+
return model_cls(thought=fallback, question=fallback) # type: ignore
|
407
|
+
elif issubclass(model_cls, LabelResponse):
|
408
|
+
return model_cls(thought=fallback, sub_questions=[]) # type: ignore
|
409
|
+
elif issubclass(model_cls, CritiqueResponse):
|
410
|
+
return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
|
411
|
+
elif issubclass(model_cls, DevilsAdvocateResponse):
|
412
|
+
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
339
413
|
else:
|
340
|
-
|
341
|
-
|
414
|
+
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
415
|
+
# theoretically unreachable
|
342
416
|
raise RuntimeError("Unexpected error in _ainvoke_pydantic")
|
343
417
|
|
344
|
-
|
418
|
+
# 4.2) Helper method for self-critique
|
419
|
+
async def _ainvoke_critique(
|
420
|
+
self,
|
421
|
+
messages: list[BaseMessage],
|
422
|
+
thought: str,
|
423
|
+
answer: str,
|
424
|
+
) -> CritiqueResponse:
|
345
425
|
"""
|
346
|
-
|
426
|
+
Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
|
347
427
|
"""
|
428
|
+
return await self._ainvoke_pydantic(
|
429
|
+
messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
|
430
|
+
model_cls=CritiqueResponse,
|
431
|
+
)
|
348
432
|
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
433
|
+
# 4.3) Helper method for devil's advocate approach
|
434
|
+
async def _ainvoke_devils_advocate(
|
435
|
+
self,
|
436
|
+
messages: list[BaseMessage],
|
437
|
+
question: str,
|
438
|
+
existing_answer: str,
|
439
|
+
) -> DevilsAdvocateResponse:
|
440
|
+
"""
|
441
|
+
Instructs the LLM to challenge an existing answer with a devil's advocate approach.
|
442
|
+
"""
|
443
|
+
return await self._ainvoke_pydantic(
|
444
|
+
messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
|
445
|
+
model_cls=DevilsAdvocateResponse,
|
358
446
|
)
|
359
|
-
try:
|
360
|
-
result: InferredDomain = await self.chatterer.agenerate_pydantic(
|
361
|
-
response_model=InferredDomain, messages=[{"role": "user", "content": domain_prompt}]
|
362
|
-
)
|
363
|
-
return result.domain
|
364
|
-
except ValidationError:
|
365
|
-
logger.warning("Failed domain detection, defaulting to general.")
|
366
|
-
return Domain.GENERAL
|
367
447
|
|
448
|
+
# 4.4) The main function that recursively decomposes a question and calls sub-steps
|
368
449
|
async def _arecursive_decompose_question(
|
369
|
-
self,
|
450
|
+
self,
|
451
|
+
messages: list[BaseMessage],
|
452
|
+
question: str,
|
453
|
+
depth: int,
|
454
|
+
parent_decomp_step_idx: Optional[int] = None,
|
455
|
+
parent_subq_idx: Optional[int] = None,
|
370
456
|
) -> RecursiveDecomposeResponse:
|
371
457
|
"""
|
372
|
-
Recursively
|
458
|
+
Recursively decompose the given question. For each sub-question:
|
459
|
+
1) Recursively decompose that sub-question if we still have depth left
|
460
|
+
2) After getting a main sub-answer, do a devil's advocate pass
|
461
|
+
3) Combine main sub-answer + devil's advocate alternative via an ensemble
|
373
462
|
"""
|
374
463
|
if depth < 0:
|
375
|
-
|
464
|
+
logger.info("Max depth reached, returning unknown.")
|
465
|
+
return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
|
376
466
|
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(prompt, RecursiveDecomposeResponse)
|
467
|
+
# Step 1: Perform the decomposition
|
468
|
+
decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
469
|
+
messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
|
470
|
+
model_cls=RecursiveDecomposeResponse,
|
471
|
+
)
|
383
472
|
|
384
|
-
# Step 2: Label
|
473
|
+
# Step 2: Label / refine sub-questions (dependencies, ordering)
|
385
474
|
if decompose_resp.sub_questions:
|
386
|
-
|
387
|
-
|
388
|
-
|
475
|
+
label_resp: LabelResponse = await self._ainvoke_pydantic(
|
476
|
+
messages=self.prompter.label_prompt(messages, question, decompose_resp),
|
477
|
+
model_cls=LabelResponse,
|
478
|
+
)
|
389
479
|
decompose_resp.sub_questions = label_resp.sub_questions
|
390
480
|
|
391
|
-
#
|
392
|
-
|
481
|
+
# Save a pipeline record for this decomposition step
|
482
|
+
current_decomp_step_idx = self._record_decomposition_step(
|
483
|
+
question=question,
|
484
|
+
final_answer=decompose_resp.final_answer,
|
485
|
+
sub_questions=decompose_resp.sub_questions,
|
486
|
+
parent_decomp_step_idx=parent_decomp_step_idx,
|
487
|
+
parent_subq_idx=parent_subq_idx,
|
488
|
+
)
|
489
|
+
|
490
|
+
# Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
|
393
491
|
if depth > 0 and decompose_resp.sub_questions:
|
394
|
-
|
395
|
-
|
492
|
+
solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
|
493
|
+
messages=messages,
|
494
|
+
sub_questions=decompose_resp.sub_questions,
|
495
|
+
depth=depth,
|
496
|
+
parent_decomp_step_idx=current_decomp_step_idx,
|
396
497
|
)
|
397
|
-
#
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
498
|
+
# Then we can refine the "final_answer" from those sub-answers
|
499
|
+
# or we do a secondary pass to refine the final answer
|
500
|
+
refined_prompt = self.prompter.recursive_decompose_prompt(
|
501
|
+
messages=messages,
|
502
|
+
question=question,
|
503
|
+
sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
|
504
|
+
)
|
505
|
+
refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
506
|
+
refined_prompt, RecursiveDecomposeResponse
|
507
|
+
)
|
508
|
+
decompose_resp.final_answer = refined_resp.final_answer
|
509
|
+
decompose_resp.sub_questions = solved_subs
|
510
|
+
|
511
|
+
# Update pipeline record
|
512
|
+
self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
|
513
|
+
self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
|
407
514
|
|
408
515
|
return decompose_resp
|
409
516
|
|
517
|
+
def _record_decomposition_step(
|
518
|
+
self,
|
519
|
+
question: str,
|
520
|
+
final_answer: str,
|
521
|
+
sub_questions: list[SubQuestionNode],
|
522
|
+
parent_decomp_step_idx: Optional[int],
|
523
|
+
parent_subq_idx: Optional[int],
|
524
|
+
) -> int:
|
525
|
+
"""
|
526
|
+
Save the decomposition step in steps_history, returning the index.
|
527
|
+
"""
|
528
|
+
step_record = StepRecord(
|
529
|
+
step_name=StepName.DECOMPOSITION,
|
530
|
+
question=question,
|
531
|
+
answer=final_answer,
|
532
|
+
sub_questions=sub_questions,
|
533
|
+
parent_decomp_step_idx=parent_decomp_step_idx,
|
534
|
+
parent_subq_idx=parent_subq_idx,
|
535
|
+
)
|
536
|
+
self.steps_history.append(step_record)
|
537
|
+
return len(self.steps_history) - 1
|
538
|
+
|
410
539
|
async def _aresolve_sub_questions(
|
411
|
-
self,
|
540
|
+
self,
|
541
|
+
messages: list[BaseMessage],
|
542
|
+
sub_questions: list[SubQuestionNode],
|
543
|
+
depth: int,
|
544
|
+
parent_decomp_step_idx: Optional[int],
|
412
545
|
) -> list[SubQuestionNode]:
|
413
546
|
"""
|
414
|
-
|
415
|
-
|
547
|
+
Resolve sub-questions in topological order.
|
548
|
+
For each sub-question:
|
549
|
+
1) Recursively decompose (main approach).
|
550
|
+
2) Acquire a devil's advocate alternative.
|
551
|
+
3) Critique or ensemble if needed.
|
552
|
+
4) Finalize sub-question answer.
|
416
553
|
"""
|
417
|
-
n
|
418
|
-
|
419
|
-
|
420
|
-
# Build adjacency
|
421
|
-
in_degree: list[int] = [0] * n
|
554
|
+
n = len(sub_questions)
|
555
|
+
in_degree = [0] * n
|
422
556
|
graph: list[list[int]] = [[] for _ in range(n)]
|
423
557
|
for i, sq in enumerate(sub_questions):
|
424
558
|
for dep in sq.depend:
|
@@ -426,169 +560,417 @@ class AoTPipeline:
|
|
426
560
|
in_degree[i] += 1
|
427
561
|
graph[dep].append(i)
|
428
562
|
|
429
|
-
#
|
430
|
-
queue
|
431
|
-
|
563
|
+
# Kahn's algorithm for topological order
|
564
|
+
queue = [i for i in range(n) if in_degree[i] == 0]
|
565
|
+
topo_order: list[int] = []
|
566
|
+
|
432
567
|
while queue:
|
433
568
|
node = queue.pop(0)
|
434
|
-
|
569
|
+
topo_order.append(node)
|
435
570
|
for nxt in graph[node]:
|
436
571
|
in_degree[nxt] -= 1
|
437
572
|
if in_degree[nxt] == 0:
|
438
573
|
queue.append(nxt)
|
439
574
|
|
440
|
-
#
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
575
|
+
# We'll store the resolved sub-questions
|
576
|
+
final_subs: dict[int, SubQuestionNode] = {}
|
577
|
+
|
578
|
+
async def _resolve_one_subq(idx: int):
|
579
|
+
sq = sub_questions[idx]
|
580
|
+
# 1) Main approach
|
581
|
+
main_resp = await self._arecursive_decompose_question(
|
582
|
+
messages=messages,
|
583
|
+
question=sq.question,
|
584
|
+
depth=depth - 1,
|
585
|
+
parent_decomp_step_idx=parent_decomp_step_idx,
|
586
|
+
parent_subq_idx=idx,
|
447
587
|
)
|
448
|
-
sq.answer = sub_decomp.final_answer
|
449
|
-
resolved[idx] = sq
|
450
588
|
|
451
|
-
|
589
|
+
main_answer = main_resp.final_answer
|
452
590
|
|
453
|
-
|
454
|
-
|
591
|
+
# 2) Devil's Advocate approach
|
592
|
+
devils_resp = await self._ainvoke_devils_advocate(
|
593
|
+
messages=messages, question=sq.question, existing_answer=main_answer
|
594
|
+
)
|
595
|
+
# 3) Ensemble to combine main_answer + devils_alternative
|
596
|
+
ensemble_sub = await self._ainvoke_pydantic(
|
597
|
+
self.prompter.ensemble_prompt(
|
598
|
+
messages=messages,
|
599
|
+
possible_thought_and_answers=[
|
600
|
+
(main_resp.thought, main_answer),
|
601
|
+
(devils_resp.thought, devils_resp.final_answer),
|
602
|
+
],
|
603
|
+
),
|
604
|
+
EnsembleResponse,
|
605
|
+
)
|
606
|
+
sub_best_answer = ensemble_sub.answer
|
607
|
+
|
608
|
+
# Store final subq answer
|
609
|
+
sq.answer = sub_best_answer
|
610
|
+
final_subs[idx] = sq
|
611
|
+
|
612
|
+
# Record pipeline steps for devil's advocate
|
613
|
+
self.steps_history.append(
|
614
|
+
StepRecord(
|
615
|
+
step_name=StepName.DEVILS_ADVOCATE,
|
616
|
+
question=sq.question,
|
617
|
+
answer=devils_resp.final_answer,
|
618
|
+
thought=devils_resp.thought,
|
619
|
+
sub_questions=devils_resp.sub_questions,
|
620
|
+
)
|
621
|
+
)
|
622
|
+
# Possibly critique the devils advocate result
|
623
|
+
dev_adv_crit = await self._ainvoke_critique(
|
624
|
+
messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
|
625
|
+
)
|
626
|
+
self.steps_history.append(
|
627
|
+
StepRecord(
|
628
|
+
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
629
|
+
thought=dev_adv_crit.thought,
|
630
|
+
score=dev_adv_crit.self_assessment,
|
631
|
+
)
|
632
|
+
)
|
455
633
|
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
If ground_truth is None, returns -1.0 as 'no score possible'.
|
634
|
+
# Solve sub-questions in topological order
|
635
|
+
tasks = [_resolve_one_subq(i) for i in topo_order]
|
636
|
+
await asyncio.gather(*tasks, return_exceptions=False)
|
460
637
|
|
461
|
-
|
462
|
-
"""
|
463
|
-
if ground_truth is None:
|
464
|
-
return -1.0
|
638
|
+
return [final_subs[i] for i in range(n)]
|
465
639
|
|
466
|
-
|
467
|
-
|
468
|
-
if domain == Domain.MATH:
|
469
|
-
try:
|
470
|
-
ans_val = float(answer.strip())
|
471
|
-
gt_val = float(ground_truth.strip())
|
472
|
-
return 1.0 if abs(ans_val - gt_val) < 1e-9 else 0.0
|
473
|
-
except ValueError:
|
474
|
-
return 0.0
|
475
|
-
|
476
|
-
# For anything else, do a naive exact-match check ignoring case
|
477
|
-
return 1.0 if answer.strip().lower() == ground_truth.strip().lower() else 0.0
|
478
|
-
|
479
|
-
async def arun_pipeline(
|
480
|
-
self, question: str, context: Optional[str] = None, ground_truth: Optional[str] = None
|
481
|
-
) -> str:
|
640
|
+
# 4.5) The primary pipeline method
|
641
|
+
async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
|
482
642
|
"""
|
483
|
-
|
643
|
+
Execute the pipeline:
|
644
|
+
1) Decompose the main question (recursively).
|
645
|
+
2) Self-critique.
|
646
|
+
3) Provide a devil's advocate approach on the entire main result.
|
647
|
+
4) Contract sub-answers (optional).
|
648
|
+
5) Directly solve the contracted question.
|
649
|
+
6) Self-critique again.
|
650
|
+
7) Final ensemble across main vs devil's vs contracted direct answer.
|
651
|
+
8) Return final answer.
|
484
652
|
"""
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
logger.
|
495
|
-
|
496
|
-
#
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
653
|
+
self.steps_history.clear()
|
654
|
+
|
655
|
+
original_question: str = messages[-1].text()
|
656
|
+
# 1) Recursive decomposition
|
657
|
+
decomp_resp = await self._arecursive_decompose_question(
|
658
|
+
messages=messages,
|
659
|
+
question=original_question,
|
660
|
+
depth=self.max_depth,
|
661
|
+
)
|
662
|
+
logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
|
663
|
+
|
664
|
+
# 2) Self-critique of main decomposition
|
665
|
+
decomp_critique = await self._ainvoke_critique(
|
666
|
+
messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
|
667
|
+
)
|
668
|
+
self.steps_history.append(
|
669
|
+
StepRecord(
|
670
|
+
step_name=StepName.DECOMPOSITION_CRITIQUE,
|
671
|
+
thought=decomp_critique.thought,
|
672
|
+
score=decomp_critique.self_assessment,
|
673
|
+
)
|
674
|
+
)
|
675
|
+
|
676
|
+
# 3) Devil's advocate on the entire main answer
|
677
|
+
devils_on_main = await self._ainvoke_devils_advocate(
|
678
|
+
messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
|
679
|
+
)
|
680
|
+
self.steps_history.append(
|
681
|
+
StepRecord(
|
682
|
+
step_name=StepName.DEVILS_ADVOCATE,
|
683
|
+
question=original_question,
|
684
|
+
answer=devils_on_main.final_answer,
|
685
|
+
thought=devils_on_main.thought,
|
686
|
+
sub_questions=devils_on_main.sub_questions,
|
687
|
+
)
|
688
|
+
)
|
689
|
+
devils_crit_main = await self._ainvoke_critique(
|
690
|
+
messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
|
691
|
+
)
|
692
|
+
self.steps_history.append(
|
693
|
+
StepRecord(
|
694
|
+
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
695
|
+
thought=devils_crit_main.thought,
|
696
|
+
score=devils_crit_main.self_assessment,
|
697
|
+
)
|
698
|
+
)
|
699
|
+
|
700
|
+
# 4) Contract sub-answers from main decomposition
|
701
|
+
top_decomp_record: Optional[StepRecord] = next(
|
702
|
+
(
|
703
|
+
s
|
704
|
+
for s in reversed(self.steps_history)
|
705
|
+
if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
|
706
|
+
),
|
707
|
+
None,
|
708
|
+
)
|
709
|
+
if top_decomp_record and top_decomp_record.sub_questions:
|
710
|
+
sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
|
711
|
+
else:
|
712
|
+
sub_answers = []
|
713
|
+
|
714
|
+
contract_resp = await self._ainvoke_pydantic(
|
715
|
+
messages=self.prompter.contract_prompt(messages, sub_answers),
|
716
|
+
model_cls=ContractQuestionResponse,
|
717
|
+
)
|
505
718
|
contracted_question = contract_resp.question
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
contracted_direct_resp = await self._ainvoke_pydantic(contracted_direct_prompt, DirectResponse)
|
511
|
-
contracted_direct_answer = contracted_direct_resp.answer
|
512
|
-
logger.debug(f"Contracted direct answer => {contracted_direct_answer}")
|
513
|
-
|
514
|
-
# 6) Ensemble
|
515
|
-
ensemble_prompt = prompter.ensemble_prompt(
|
516
|
-
original_question=question,
|
517
|
-
direct_answer=direct_answer,
|
518
|
-
decompose_answer=decompose_answer,
|
519
|
-
contracted_direct_answer=contracted_direct_answer,
|
520
|
-
context=context,
|
719
|
+
self.steps_history.append(
|
720
|
+
StepRecord(
|
721
|
+
step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
|
722
|
+
)
|
521
723
|
)
|
522
|
-
ensemble_resp = await self._ainvoke_pydantic(ensemble_prompt, EnsembleResponse)
|
523
|
-
final_answer = ensemble_resp.answer
|
524
|
-
logger.debug(f"Ensemble final answer => {final_answer} (confidence={ensemble_resp.confidence})")
|
525
724
|
|
526
|
-
#
|
527
|
-
|
528
|
-
|
529
|
-
|
725
|
+
# 5) Attempt direct approach on contracted question
|
726
|
+
contracted_direct = await self._ainvoke_pydantic(
|
727
|
+
messages=self.prompter.contract_direct_prompt(messages, contracted_question),
|
728
|
+
model_cls=RecursiveDecomposeResponse,
|
729
|
+
fallback="No Contracted Direct Answer",
|
730
|
+
)
|
731
|
+
self.steps_history.append(
|
732
|
+
StepRecord(
|
733
|
+
step_name=StepName.CONTRACTED_DIRECT_ANSWER,
|
734
|
+
answer=contracted_direct.final_answer,
|
735
|
+
thought=contracted_direct.thought,
|
736
|
+
)
|
737
|
+
)
|
738
|
+
logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
|
530
739
|
|
531
|
-
|
740
|
+
# 5.1) Critique the contracted direct approach
|
741
|
+
contract_critique = await self._ainvoke_critique(
|
742
|
+
messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
|
743
|
+
)
|
744
|
+
self.steps_history.append(
|
745
|
+
StepRecord(
|
746
|
+
step_name=StepName.CONTRACT_CRITIQUE,
|
747
|
+
thought=contract_critique.thought,
|
748
|
+
score=contract_critique.self_assessment,
|
749
|
+
)
|
750
|
+
)
|
532
751
|
|
752
|
+
# 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
|
753
|
+
ensemble_resp = await self._ainvoke_pydantic(
|
754
|
+
self.prompter.ensemble_prompt(
|
755
|
+
messages=messages,
|
756
|
+
possible_thought_and_answers=[
|
757
|
+
(decomp_resp.thought, decomp_resp.final_answer),
|
758
|
+
(devils_on_main.thought, devils_on_main.final_answer),
|
759
|
+
(contracted_direct.thought, contracted_direct.final_answer),
|
760
|
+
],
|
761
|
+
),
|
762
|
+
EnsembleResponse,
|
763
|
+
)
|
764
|
+
best_approach_answer = ensemble_resp.answer
|
765
|
+
approach_used = StepName.ENSEMBLE
|
766
|
+
self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
|
767
|
+
logger.info(f"[Best Approach Decision] => {approach_used}")
|
768
|
+
|
769
|
+
# 7) Final answer
|
770
|
+
self.steps_history.append(
|
771
|
+
StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
|
772
|
+
)
|
773
|
+
logger.info(f"[Final Answer] => {best_approach_answer}")
|
533
774
|
|
534
|
-
|
775
|
+
return best_approach_answer
|
535
776
|
|
777
|
+
def run_pipeline(self, messages: list[BaseMessage]) -> str:
|
778
|
+
"""Synchronous wrapper around arun_pipeline."""
|
779
|
+
return asyncio.run(self.arun_pipeline(messages))
|
536
780
|
|
537
|
-
|
538
|
-
|
539
|
-
|
781
|
+
# ---------------------------------------------------------------------------------
|
782
|
+
# 4.6) Build or export a reasoning graph
|
783
|
+
# ---------------------------------------------------------------------------------
|
540
784
|
|
541
|
-
|
542
|
-
|
785
|
+
def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
|
786
|
+
"""
|
787
|
+
Constructs a Graph object (from hypothetical `neo4j_extension`)
|
788
|
+
capturing the pipeline steps, including devil's advocate steps.
|
789
|
+
"""
|
790
|
+
from neo4j_extension import Graph, Node, Relationship
|
791
|
+
|
792
|
+
g = Graph()
|
793
|
+
step_nodes: dict[int, Node] = {}
|
794
|
+
subq_nodes: dict[str, Node] = {}
|
795
|
+
|
796
|
+
# Step A: Create nodes for each pipeline step
|
797
|
+
for i, record in enumerate(self.steps_history):
|
798
|
+
# We'll skip nested Decomposition steps only if we want to flatten them.
|
799
|
+
# But let's keep them for clarity.
|
800
|
+
step_node = Node(
|
801
|
+
properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
|
802
|
+
)
|
803
|
+
g.add_node(step_node)
|
804
|
+
step_nodes[i] = step_node
|
805
|
+
|
806
|
+
# Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
|
807
|
+
all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
|
808
|
+
for i, record in enumerate(self.steps_history):
|
809
|
+
if record.sub_questions:
|
810
|
+
for sq_idx, sq in enumerate(record.sub_questions):
|
811
|
+
sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
812
|
+
all_sub_questions[sq_id] = (i, sq_idx, sq)
|
813
|
+
|
814
|
+
for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
|
815
|
+
n_subq = Node(
|
816
|
+
properties={
|
817
|
+
"question": sq.question,
|
818
|
+
"answer": sq.answer or "",
|
819
|
+
},
|
820
|
+
labels={"SubQuestion"},
|
821
|
+
globalId=sq_id,
|
822
|
+
)
|
823
|
+
g.add_node(n_subq)
|
824
|
+
subq_nodes[sq_id] = n_subq
|
825
|
+
|
826
|
+
# Step C: Add relationships. We do a simple approach:
|
827
|
+
# - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
|
828
|
+
for i, record in enumerate(self.steps_history):
|
829
|
+
if record.sub_questions:
|
830
|
+
start_node = step_nodes[i]
|
831
|
+
for sq_idx, sq in enumerate(record.sub_questions):
|
832
|
+
sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
833
|
+
end_node = subq_nodes[sq_id]
|
834
|
+
rel = Relationship(
|
835
|
+
properties={},
|
836
|
+
rel_type=StepRelation.SPLIT_INTO,
|
837
|
+
start_node=start_node,
|
838
|
+
end_node=end_node,
|
839
|
+
globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
|
840
|
+
)
|
841
|
+
g.add_relationship(rel)
|
842
|
+
# Also add sub-question dependencies
|
843
|
+
for dep in sq.depend:
|
844
|
+
# The same record i -> sub-question subq
|
845
|
+
if 0 <= dep < len(record.sub_questions):
|
846
|
+
dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
|
847
|
+
if dep_id in subq_nodes:
|
848
|
+
dep_node = subq_nodes[dep_id]
|
849
|
+
rel_dep = Relationship(
|
850
|
+
properties={},
|
851
|
+
rel_type=StepRelation.DEPEND_ON,
|
852
|
+
start_node=end_node,
|
853
|
+
end_node=dep_node,
|
854
|
+
globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
|
855
|
+
)
|
856
|
+
g.add_relationship(rel_dep)
|
857
|
+
|
858
|
+
# Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
|
859
|
+
for i in range(len(self.steps_history) - 1):
|
860
|
+
start_node = step_nodes[i]
|
861
|
+
end_node = step_nodes[i + 1]
|
862
|
+
rel = Relationship(
|
863
|
+
properties={},
|
864
|
+
rel_type=StepRelation.PRECEDES,
|
865
|
+
start_node=start_node,
|
866
|
+
end_node=end_node,
|
867
|
+
globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
|
868
|
+
)
|
869
|
+
g.add_relationship(rel)
|
870
|
+
|
871
|
+
# Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
|
872
|
+
# We'll do a simple pass:
|
873
|
+
# If step_name ends with CRITIQUE => it critiques the step before it
|
874
|
+
for i, record in enumerate(self.steps_history):
|
875
|
+
if "CRITIQUE" in record.step_name:
|
876
|
+
# Let it point to the preceding step
|
877
|
+
if i > 0:
|
878
|
+
start_node = step_nodes[i]
|
879
|
+
end_node = step_nodes[i - 1]
|
880
|
+
rel = Relationship(
|
881
|
+
properties={},
|
882
|
+
rel_type=StepRelation.CRITIQUES,
|
883
|
+
start_node=start_node,
|
884
|
+
end_node=end_node,
|
885
|
+
globalId=f"{global_id_prefix}_crit_{i}",
|
886
|
+
)
|
887
|
+
g.add_relationship(rel)
|
888
|
+
|
889
|
+
# If there's a BEST_APPROACH_DECISION step, link it to the step it uses
|
890
|
+
best_decision_idx = None
|
891
|
+
used_step_idx = None
|
892
|
+
for i, record in enumerate(self.steps_history):
|
893
|
+
if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
|
894
|
+
best_decision_idx = i
|
895
|
+
# find the step with that name
|
896
|
+
used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
|
897
|
+
if used_step_idx is not None:
|
898
|
+
rel = Relationship(
|
899
|
+
properties={},
|
900
|
+
rel_type=StepRelation.SELECTS,
|
901
|
+
start_node=step_nodes[i],
|
902
|
+
end_node=step_nodes[used_step_idx],
|
903
|
+
globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
|
904
|
+
)
|
905
|
+
g.add_relationship(rel)
|
906
|
+
|
907
|
+
# And link the final answer to the best approach
|
908
|
+
final_answer_idx = next(
|
909
|
+
(i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
|
910
|
+
)
|
911
|
+
if final_answer_idx is not None and best_decision_idx is not None:
|
912
|
+
rel = Relationship(
|
913
|
+
properties={},
|
914
|
+
rel_type=StepRelation.RESULT_OF,
|
915
|
+
start_node=step_nodes[final_answer_idx],
|
916
|
+
end_node=step_nodes[best_decision_idx],
|
917
|
+
globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
|
918
|
+
)
|
919
|
+
g.add_relationship(rel)
|
543
920
|
|
544
|
-
|
545
|
-
input_string = input.to_string()
|
546
|
-
logger.debug(f"Extracted question: {input_string}")
|
547
|
-
return await self.pipeline.arun_pipeline(input_string)
|
921
|
+
return g
|
548
922
|
|
549
|
-
def invoke(self, messages: LanguageModelInput) -> str:
|
550
|
-
return asyncio.run(self.ainvoke(messages))
|
551
923
|
|
924
|
+
# ---------------------------------------------------------------------------------
|
925
|
+
# 5) AoTStrategy class that uses the pipeline
|
926
|
+
# ---------------------------------------------------------------------------------
|
552
927
|
|
553
|
-
# ------------------ 4) Example usage (main) ------------------
|
554
|
-
if __name__ == "__main__":
|
555
|
-
from warnings import filterwarnings
|
556
928
|
|
557
|
-
|
929
|
+
@dataclass
|
930
|
+
class AoTStrategy(BaseStrategy):
|
931
|
+
"""
|
932
|
+
Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
|
933
|
+
"""
|
558
934
|
|
559
|
-
|
935
|
+
pipeline: AoTPipeline
|
560
936
|
|
561
|
-
|
937
|
+
async def ainvoke(self, messages: LanguageModelInput) -> str:
|
938
|
+
"""Asynchronously run the pipeline with the given messages."""
|
939
|
+
# Convert your custom input to list[BaseMessage] as needed:
|
940
|
+
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
941
|
+
return await self.pipeline.arun_pipeline(msgs)
|
562
942
|
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
"WARNING": colorama.Fore.YELLOW,
|
568
|
-
"ERROR": colorama.Fore.RED,
|
569
|
-
"CRITICAL": colorama.Fore.RED + colorama.Style.BRIGHT,
|
570
|
-
}
|
943
|
+
def invoke(self, messages: LanguageModelInput) -> str:
|
944
|
+
"""Synchronously run the pipeline with the given messages."""
|
945
|
+
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
946
|
+
return self.pipeline.run_pipeline(msgs)
|
571
947
|
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
return f"{self.COLORS.get(levelname, colorama.Fore.WHITE)}{message}{colorama.Style.RESET_ALL}"
|
948
|
+
def get_reasoning_graph(self):
|
949
|
+
"""Return the AoT reasoning graph from the pipeline’s steps history."""
|
950
|
+
return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
|
576
951
|
|
577
|
-
logger.setLevel(logging.DEBUG)
|
578
|
-
handler = logging.StreamHandler()
|
579
|
-
handler.setLevel(logging.DEBUG)
|
580
|
-
formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
581
|
-
handler.setFormatter(formatter)
|
582
|
-
logger.addHandler(handler)
|
583
952
|
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
953
|
+
# ---------------------------------------------------------------------------------
|
954
|
+
# Example usage (pseudo-code)
|
955
|
+
# ---------------------------------------------------------------------------------
|
956
|
+
if __name__ == "__main__":
|
957
|
+
from neo4j_extension import Neo4jConnection # or your actual DB connector
|
588
958
|
|
589
|
-
|
959
|
+
# You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
|
960
|
+
chatterer = Chatterer.openai() # pseudo-code
|
961
|
+
pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
|
590
962
|
strategy = AoTStrategy(pipeline=pipeline)
|
591
963
|
|
592
|
-
question = "
|
964
|
+
question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
|
593
965
|
answer = strategy.invoke(question)
|
594
|
-
print(answer)
|
966
|
+
print("Final Answer:", answer)
|
967
|
+
|
968
|
+
# Build the reasoning graph
|
969
|
+
graph = strategy.get_reasoning_graph()
|
970
|
+
print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
|
971
|
+
|
972
|
+
# Optionally store in Neo4j
|
973
|
+
with Neo4jConnection() as conn:
|
974
|
+
conn.clear_all()
|
975
|
+
conn.upsert_graph(graph)
|
976
|
+
print("Graph stored in Neo4j.")
|