chatterer 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,32 +1,25 @@
1
- # original source: https://github.com/qixucen/atom
2
-
3
1
  from __future__ import annotations
4
2
 
5
3
  import asyncio
6
4
  import logging
7
- from abc import ABC, abstractmethod
8
5
  from dataclasses import dataclass, field
9
6
  from enum import StrEnum
10
- from typing import LiteralString, Optional, Type, TypeVar
7
+ from typing import Optional, Type, TypeVar
11
8
 
9
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
12
10
  from pydantic import BaseModel, Field, ValidationError
13
11
 
12
+ # Import your Chatterer interface (do not remove)
14
13
  from ..language_model import Chatterer, LanguageModelInput
15
14
  from .base import BaseStrategy
16
15
 
17
- logger = logging.getLogger(__name__)
18
-
19
- # ------------------------- 0) Enums and Basic Models -------------------------
20
-
16
+ # ---------------------------------------------------------------------------------
17
+ # 0) Enums and Basic Models
18
+ # ---------------------------------------------------------------------------------
21
19
 
22
- class Domain(StrEnum):
23
- """Defines the domain of a question for specialized handling."""
24
-
25
- GENERAL = "general"
26
- MATH = "math"
27
- CODING = "coding"
28
- PHILOSOPHY = "philosophy"
29
- MULTIHOP = "multihop"
20
+ QA_TEMPLATE = "Q: {question}\nA: {answer}"
21
+ MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
22
+ UNKNOWN = "Unknown"
30
23
 
31
24
 
32
25
  class SubQuestionNode(BaseModel):
@@ -45,13 +38,6 @@ class RecursiveDecomposeResponse(BaseModel):
45
38
  sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
46
39
 
47
40
 
48
- class DirectResponse(BaseModel):
49
- """A direct response to a question."""
50
-
51
- thought: str = Field(description="Short reasoning.")
52
- answer: str = Field(description="Direct final answer.")
53
-
54
-
55
41
  class ContractQuestionResponse(BaseModel):
56
42
  """The result of contracting (simplifying) a question."""
57
43
 
@@ -66,359 +52,507 @@ class EnsembleResponse(BaseModel):
66
52
  answer: str = Field(description="Best final answer after ensemble.")
67
53
  confidence: float = Field(description="Confidence score in [0, 1].")
68
54
 
69
- def model_post_init(self, __context) -> None:
70
- # Clamp confidence to [0, 1]
55
+ def model_post_init(self, __context: object) -> None:
71
56
  self.confidence = max(0.0, min(1.0, self.confidence))
72
57
 
73
58
 
74
59
  class LabelResponse(BaseModel):
75
- """A response used to refine sub-question dependencies and structure."""
60
+ """Used to refine and reorder the sub-questions with corrected dependencies."""
76
61
 
77
62
  thought: str = Field(description="Explanation or reasoning about labeling.")
78
63
  sub_questions: list[SubQuestionNode] = Field(
79
64
  description="Refined list of sub-questions with corrected dependencies."
80
65
  )
81
- # Some tasks also keep the final answer, but we focus on sub-questions.
82
66
 
83
67
 
84
- # --------------- 1) Prompter Classes with multi-hop context ---------------
68
+ class CritiqueResponse(BaseModel):
69
+ """A response used for LLM to self-critique or question its own correctness."""
85
70
 
71
+ thought: str = Field(description="Critical reflection on correctness.")
72
+ self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
86
73
 
87
- class BaseAoTPrompter(ABC):
88
- """
89
- Abstract base prompter that defines the required prompt methods.
90
- """
91
74
 
92
- @abstractmethod
93
- def recursive_decompose_prompt(
94
- self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
95
- ) -> str: ...
75
+ # ---------------------------------------------------------------------------------
76
+ # [NEW] Additional classes to incorporate a separate sub-question devil's advocate
77
+ # ---------------------------------------------------------------------------------
96
78
 
97
- @abstractmethod
98
- def direct_prompt(self, question: str, context: Optional[str] = None) -> str: ...
99
79
 
100
- @abstractmethod
101
- def label_prompt(
102
- self, question: str, decompose_response: RecursiveDecomposeResponse, context: Optional[str] = None
103
- ) -> str:
104
- """
105
- Prompt used to 're-check' the sub-questions for correctness and dependencies.
106
- """
80
+ class DevilsAdvocateResponse(BaseModel):
81
+ """
82
+ A response for a 'devil's advocate' pass.
83
+ We consider an alternative viewpoint or contradictory answer.
84
+ """
107
85
 
108
- @abstractmethod
109
- def contract_prompt(self, question: str, sub_answers: str, context: Optional[str] = None) -> str: ...
86
+ thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
87
+ final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
88
+ sub_questions: list[SubQuestionNode] = Field(
89
+ description="Any additional sub-questions from the contrarian perspective."
90
+ )
110
91
 
111
- @abstractmethod
112
- def ensemble_prompt(
113
- self,
114
- original_question: str,
115
- direct_answer: str,
116
- decompose_answer: str,
117
- contracted_direct_answer: str,
118
- context: Optional[str] = None,
119
- ) -> str: ...
120
92
 
93
+ # ---------------------------------------------------------------------------------
94
+ # 1) Prompter Classes with Multi-Hop + Devil's Advocate
95
+ # ---------------------------------------------------------------------------------
121
96
 
122
- class GeneralAoTPrompter(BaseAoTPrompter):
123
- """
124
- Generic prompter for non-specialized or 'general' queries.
125
- """
97
+
98
+ class AoTPrompter:
99
+ """Generic base prompter that defines the required prompt methods."""
126
100
 
127
101
  def recursive_decompose_prompt(
128
- self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
129
- ) -> str:
130
- sub_ans_str = f"\nSub-question answers:\n{sub_answers}" if sub_answers else ""
131
- context_str = f"\nCONTEXT:\n{context}" if context else ""
132
- return (
133
- "You are a highly analytical assistant skilled in breaking down complex problems.\n"
134
- "Decompose the question into sub-questions recursively.\n\n"
135
- "REQUIREMENTS:\n"
136
- "1. Return valid JSON:\n"
137
- " {\n"
138
- ' "thought": "...",\n'
139
- ' "final_answer": "...",\n'
140
- ' "sub_questions": [{"question": "...", "answer": null, "depend": []}, ...]\n'
141
- " }\n"
142
- "2. 'thought': Provide detailed reasoning.\n"
143
- "3. 'final_answer': Integrate sub-answers if any.\n"
144
- "4. 'sub_questions': Key sub-questions with potential dependencies.\n\n"
145
- f"QUESTION:\n{question}{sub_ans_str}{context_str}"
102
+ self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
103
+ ) -> list[BaseMessage]:
104
+ """
105
+ Prompt for main decomposition.
106
+ Encourages step-by-step reasoning and listing sub-questions as JSON.
107
+ """
108
+ decompose_instructions = (
109
+ "First, restate the main question.\n"
110
+ "Decide if sub-questions are needed. If so, list them.\n"
111
+ "In the 'thought' field, show your chain-of-thought.\n"
112
+ "Return valid JSON:\n"
113
+ "{\n"
114
+ ' "thought": "...",\n'
115
+ ' "final_answer": "...",\n'
116
+ ' "sub_questions": [\n'
117
+ ' {"question": "...", "answer": null, "depend": []},\n'
118
+ " ...\n"
119
+ " ]\n"
120
+ "}\n"
146
121
  )
147
122
 
148
- def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
149
- context_str = f"\nCONTEXT:\n{context}" if context else ""
150
- return (
151
- "You are a concise and insightful assistant.\n"
152
- "Provide a direct answer with a short reasoning.\n\n"
153
- "REQUIREMENTS:\n"
154
- "1. Return valid JSON:\n"
155
- " {'thought': '...', 'answer': '...'}\n"
156
- "2. 'thought': Offer a brief reasoning.\n"
157
- "3. 'answer': Deliver a precise solution.\n\n"
158
- f"QUESTION:\n{question}{context_str}"
159
- )
123
+ content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
124
+ return messages + [
125
+ HumanMessage(content=f"Main question:\n{question}"),
126
+ AIMessage(content=content_sub_answers),
127
+ AIMessage(content=decompose_instructions),
128
+ ]
160
129
 
161
130
  def label_prompt(
162
- self, question: str, decompose_response: RecursiveDecomposeResponse, context: Optional[str] = None
163
- ) -> str:
164
- context_str = f"\nCONTEXT:\n{context}" if context else ""
165
- return (
166
- "You have a set of sub-questions from a decomposition process.\n"
167
- "We want to correct or refine the dependencies between sub-questions.\n\n"
168
- "REQUIREMENTS:\n"
169
- "1. Return valid JSON:\n"
170
- " {\n"
171
- ' "thought": "...",\n'
172
- ' "sub_questions": [\n'
173
- ' {"question":"...", "answer":"...", "depend":[...]},\n'
174
- " ...\n"
175
- " ]\n"
176
- " }\n"
177
- "2. 'thought': Provide reasoning about any changes.\n"
178
- "3. 'sub_questions': Possibly updated sub-questions with correct 'depend' lists.\n\n"
179
- f"ORIGINAL QUESTION:\n{question}\n"
180
- f"CURRENT DECOMPOSITION:\n{decompose_response.model_dump_json(indent=2)}"
181
- f"{context_str}"
131
+ self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
132
+ ) -> list[BaseMessage]:
133
+ """
134
+ Prompt for refining the sub-questions and dependencies.
135
+ """
136
+ label_instructions = (
137
+ "Review each sub-question to ensure correctness and proper ordering.\n"
138
+ "Return valid JSON in the form:\n"
139
+ "{\n"
140
+ ' "thought": "...",\n'
141
+ ' "sub_questions": [\n'
142
+ ' {"question": "...", "answer": "...", "depend": [...]},\n'
143
+ " ...\n"
144
+ " ]\n"
145
+ "}\n"
182
146
  )
147
+ return messages + [
148
+ AIMessage(content=f"Question: {question}"),
149
+ AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
150
+ AIMessage(content=label_instructions),
151
+ ]
183
152
 
184
- def contract_prompt(self, question: str, sub_answers: str, context: Optional[str] = None) -> str:
185
- context_str = f"\nCONTEXT:\n{context}" if context else ""
186
- return (
187
- "You are tasked with compressing or simplifying a complex question into a single self-contained one.\n\n"
188
- "REQUIREMENTS:\n"
189
- "1. Return valid JSON:\n"
190
- " {'thought': '...', 'question': '...'}\n"
191
- "2. 'thought': Explain your simplification.\n"
192
- "3. 'question': The streamlined question.\n\n"
193
- f"ORIGINAL QUESTION:\n{question}\n"
194
- f"SUB-ANSWERS:\n{sub_answers}"
195
- f"{context_str}"
153
+ def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
154
+ """
155
+ Prompt for merging sub-answers into one self-contained question.
156
+ """
157
+ contract_instructions = (
158
+ "Please merge sub-answers into a single short question that is fully self-contained.\n"
159
+ "In 'thought', show how you unify the information.\n"
160
+ "Then produce JSON:\n"
161
+ "{\n"
162
+ ' "thought": "...",\n'
163
+ ' "question": "a short but self-contained question"\n'
164
+ "}\n"
196
165
  )
197
-
198
- def ensemble_prompt(
199
- self,
200
- original_question: str,
201
- direct_answer: str,
202
- decompose_answer: str,
203
- contracted_direct_answer: str,
204
- context: Optional[str] = None,
205
- ) -> str:
206
- context_str = f"\nCONTEXT:\n{context}" if context else ""
207
- return (
208
- "You are an expert at synthesizing multiple candidate answers.\n"
209
- "Consider the following candidates:\n"
210
- f"1) Direct: {direct_answer}\n"
211
- f"2) Decomposition-based: {decompose_answer}\n"
212
- f"3) Contracted Direct: {contracted_direct_answer}\n\n"
213
- "REQUIREMENTS:\n"
214
- "1. Return valid JSON:\n"
215
- " {'thought': '...', 'answer': '...', 'confidence': <float in [0,1]>}\n"
216
- "2. 'thought': Summarize how you decided.\n"
217
- "3. 'answer': Final best answer.\n"
218
- "4. 'confidence': Confidence score in [0,1].\n\n"
219
- f"ORIGINAL QUESTION:\n{original_question}"
220
- f"{context_str}"
166
+ sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
167
+ return messages + [
168
+ AIMessage(content="We have these sub-questions and answers:"),
169
+ AIMessage(content=sub_q_content),
170
+ AIMessage(content=contract_instructions),
171
+ ]
172
+
173
+ def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
174
+ """
175
+ Prompt for directly answering the contracted question thoroughly.
176
+ """
177
+ direct_instructions = (
178
+ "Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
179
+ "Return JSON:\n"
180
+ "{\n"
181
+ ' "thought": "...",\n'
182
+ ' "final_answer": "..."\n'
183
+ "}\n"
221
184
  )
185
+ return messages + [
186
+ HumanMessage(content=f"Simplified question: {contracted_question}"),
187
+ AIMessage(content=direct_instructions),
188
+ ]
222
189
 
223
-
224
- class MathAoTPrompter(GeneralAoTPrompter):
225
- """
226
- Specialized prompter for math questions; includes domain-specific hints.
227
- """
228
-
229
- def recursive_decompose_prompt(
230
- self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
231
- ) -> str:
232
- base = super().recursive_decompose_prompt(question, sub_answers, context)
233
- return base + "\nFocus on mathematical rigor and step-by-step derivations.\n"
234
-
235
- def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
236
- base = super().direct_prompt(question, context)
237
- return base + "\nEnsure mathematical correctness in your reasoning.\n"
238
-
239
- # label_prompt, contract_prompt, ensemble_prompt can also be overridden similarly if needed.
240
-
241
-
242
- class CodingAoTPrompter(GeneralAoTPrompter):
243
- """
244
- Specialized prompter for coding/algorithmic queries.
245
- """
246
-
247
- def recursive_decompose_prompt(
248
- self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
249
- ) -> str:
250
- base = super().recursive_decompose_prompt(question, sub_answers, context)
251
- return base + "\nBreak down into programming concepts or implementation steps.\n"
252
-
253
-
254
- class PhilosophyAoTPrompter(GeneralAoTPrompter):
255
- """
256
- Specialized prompter for philosophical discussions.
257
- """
258
-
259
- def recursive_decompose_prompt(
260
- self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
261
- ) -> str:
262
- base = super().recursive_decompose_prompt(question, sub_answers, context)
263
- return base + "\nConsider key philosophical theories and arguments.\n"
264
-
265
-
266
- class MultiHopAoTPrompter(GeneralAoTPrompter):
267
- """
268
- Specialized prompter for multi-hop Q&A with explicit context usage.
269
- """
270
-
271
- def recursive_decompose_prompt(
272
- self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
273
- ) -> str:
274
- base = super().recursive_decompose_prompt(question, sub_answers, context)
275
- return (
276
- base + "\nTreat this as a multi-hop question. Use the provided context carefully.\n"
277
- "Extract partial evidence from the context for each sub-question.\n"
190
+ def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
191
+ """
192
+ Prompt for self-critique.
193
+ """
194
+ critique_instructions = (
195
+ "Critique your own approach. Identify possible errors or leaps.\n"
196
+ "Return JSON:\n"
197
+ "{\n"
198
+ ' "thought": "...",\n'
199
+ ' "self_assessment": <float in [0,1]>\n'
200
+ "}\n"
278
201
  )
202
+ return messages + [
203
+ AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
204
+ AIMessage(content=f"Your previous ANSWER:\n{answer}"),
205
+ AIMessage(content=critique_instructions),
206
+ ]
279
207
 
280
- def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
281
- base = super().direct_prompt(question, context)
282
- return base + "\nFor multi-hop, ensure each piece of reasoning uses only the relevant parts of the context.\n"
283
-
284
-
285
- # ----------------- 2) The AoTPipeline class with label + score + multi-hop -----------------
286
-
287
- T = TypeVar("T", bound=BaseModel)
208
+ def ensemble_prompt(
209
+ self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
210
+ ) -> list[BaseMessage]:
211
+ """
212
+ Show multiple candidate solutions and pick the best final answer with confidence.
213
+ """
214
+ instructions = (
215
+ "You have multiple candidate solutions. Compare carefully and pick the best.\n"
216
+ "Return JSON:\n"
217
+ "{\n"
218
+ ' "thought": "why you chose this final answer",\n'
219
+ ' "answer": "the best consolidated answer",\n'
220
+ ' "confidence": 0.0 ~ 1.0\n'
221
+ "}\n"
222
+ )
223
+ reasonings: list[BaseMessage] = []
224
+ for idx, (thought, ans) in enumerate(possible_thought_and_answers):
225
+ reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
226
+ return messages + reasonings + [AIMessage(content=instructions)]
227
+
228
+ def devils_advocate_prompt(
229
+ self, messages: list[BaseMessage], question: str, existing_answer: str
230
+ ) -> list[BaseMessage]:
231
+ """
232
+ Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
233
+ """
234
+ instructions = (
235
+ "Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
236
+ "Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
237
+ "Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
238
+ "But here, let's keep it in a new dedicated structure:\n"
239
+ "{\n"
240
+ ' "thought": "...",\n'
241
+ ' "final_answer": "...",\n'
242
+ ' "sub_questions": [\n'
243
+ ' {"question": "...", "answer": null, "depend": []},\n'
244
+ " ...\n"
245
+ " ]\n"
246
+ "}\n"
247
+ )
248
+ return messages + [
249
+ AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
250
+ AIMessage(content=instructions),
251
+ ]
252
+
253
+
254
+ # ---------------------------------------------------------------------------------
255
+ # 2) Strict Typed Steps for Pipeline
256
+ # ---------------------------------------------------------------------------------
257
+
258
+
259
+ class StepName(StrEnum):
260
+ """Enum for step names in the pipeline."""
261
+
262
+ DOMAIN_DETECTION = "DomainDetection"
263
+ DECOMPOSITION = "Decomposition"
264
+ DECOMPOSITION_CRITIQUE = "DecompositionCritique"
265
+ CONTRACTED_QUESTION = "ContractedQuestion"
266
+ CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
267
+ CONTRACT_CRITIQUE = "ContractCritique"
268
+ BEST_APPROACH_DECISION = "BestApproachDecision"
269
+ ENSEMBLE = "Ensemble"
270
+ FINAL_ANSWER = "FinalAnswer"
271
+
272
+ DEVILS_ADVOCATE = "DevilsAdvocate"
273
+ DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
274
+
275
+
276
+ class StepRelation(StrEnum):
277
+ """Enum for relationship types in the reasoning graph."""
278
+
279
+ CRITIQUES = "CRITIQUES"
280
+ SELECTS = "SELECTS"
281
+ RESULT_OF = "RESULT_OF"
282
+ SPLIT_INTO = "SPLIT_INTO"
283
+ DEPEND_ON = "DEPEND_ON"
284
+ PRECEDES = "PRECEDES"
285
+ DECOMPOSED_BY = "DECOMPOSED_BY"
286
+
287
+
288
+ class StepRecord(BaseModel):
289
+ """A typed record for each pipeline step."""
290
+
291
+ step_name: StepName
292
+ domain: Optional[str] = None
293
+ score: Optional[float] = None
294
+ used: Optional[StepName] = None
295
+ sub_questions: Optional[list[SubQuestionNode]] = None
296
+ parent_decomp_step_idx: Optional[int] = None
297
+ parent_subq_idx: Optional[int] = None
298
+ question: Optional[str] = None
299
+ thought: Optional[str] = None
300
+ answer: Optional[str] = None
301
+
302
+ def as_properties(self) -> dict[str, str | float | int | None]:
303
+ """Converts the StepRecord to a dictionary of properties."""
304
+ result: dict[str, str | float | int | None] = {}
305
+ if self.score is not None:
306
+ result["score"] = self.score
307
+ if self.domain:
308
+ result["domain"] = self.domain
309
+ if self.question:
310
+ result["question"] = self.question
311
+ if self.thought:
312
+ result["thought"] = self.thought
313
+ if self.answer:
314
+ result["answer"] = self.answer
315
+ return result
316
+
317
+
318
+ # ---------------------------------------------------------------------------------
319
+ # 3) Logging Setup
320
+ # ---------------------------------------------------------------------------------
321
+
322
+
323
+ class SimpleColorFormatter(logging.Formatter):
324
+ """Simple color-coded logging formatter for console output using ANSI escape codes."""
325
+
326
+ BLUE = "\033[94m"
327
+ GREEN = "\033[92m"
328
+ YELLOW = "\033[93m"
329
+ RED = "\033[91m"
330
+ RESET = "\033[0m"
331
+ LEVEL_COLORS = {
332
+ logging.DEBUG: BLUE,
333
+ logging.INFO: GREEN,
334
+ logging.WARNING: YELLOW,
335
+ logging.ERROR: RED,
336
+ logging.CRITICAL: RED,
337
+ }
338
+
339
+ def format(self, record: logging.LogRecord) -> str:
340
+ log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
341
+ message = super().format(record)
342
+ return f"{log_color}{message}{self.RESET}"
343
+
344
+
345
+ logger = logging.getLogger("AoT")
346
+ logger.setLevel(logging.INFO)
347
+ handler = logging.StreamHandler()
348
+ handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
349
+ logger.handlers = [handler]
350
+ logger.propagate = False
351
+
352
+
353
+ # ---------------------------------------------------------------------------------
354
+ # 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
355
+ # ---------------------------------------------------------------------------------
356
+
357
+ T = TypeVar(
358
+ "T",
359
+ bound=EnsembleResponse
360
+ | ContractQuestionResponse
361
+ | LabelResponse
362
+ | CritiqueResponse
363
+ | RecursiveDecomposeResponse
364
+ | DevilsAdvocateResponse,
365
+ )
288
366
 
289
367
 
290
368
  @dataclass
291
369
  class AoTPipeline:
292
370
  """
293
- Implements an Atom-of-Thought pipeline with:
294
- 1) Domain detection
295
- 2) Direct solution
296
- 3) Recursive Decomposition
297
- 4) Label step to refine sub-questions
298
- 5) Contract question
299
- 6) Contracted direct solution
300
- 7) Ensemble
301
- 8) (Optional) Score final answer if ground-truth is available
371
+ The pipeline orchestrates:
372
+ 1) Recursive decomposition
373
+ 2) For each sub-question, it tries a main approach + a devil's advocate approach
374
+ 3) Merges sub-answers using an ensemble
375
+ 4) Contracts the question
376
+ 5) Possibly does a direct approach on the contracted question
377
+ 6) Ensembling the final answers
302
378
  """
303
379
 
304
380
  chatterer: Chatterer
305
381
  max_depth: int = 2
306
382
  max_retries: int = 2
383
+ steps_history: list[StepRecord] = field(default_factory=list)
384
+ prompter: AoTPrompter = field(default_factory=AoTPrompter)
307
385
 
308
- prompter_map: dict[Domain, BaseAoTPrompter] = field(
309
- default_factory=lambda: {
310
- Domain.GENERAL: GeneralAoTPrompter(),
311
- Domain.MATH: MathAoTPrompter(),
312
- Domain.CODING: CodingAoTPrompter(),
313
- Domain.PHILOSOPHY: PhilosophyAoTPrompter(),
314
- Domain.MULTIHOP: MultiHopAoTPrompter(),
315
- }
316
- )
317
-
318
- async def _ainvoke_pydantic(self, prompt: str, model_cls: Type[T], default_answer: str = "Unable to process") -> T:
386
+ # 4.1) Utility for calling the LLM with Pydantic parsing
387
+ async def _ainvoke_pydantic(
388
+ self,
389
+ messages: list[BaseMessage],
390
+ model_cls: Type[T],
391
+ fallback: str = "<None>",
392
+ ) -> T:
319
393
  """
320
- Attempts up to max_retries to parse the model_cls from LLM output.
394
+ Attempts up to max_retries to parse the model_cls from LLM output as JSON.
321
395
  """
322
396
  for attempt in range(1, self.max_retries + 1):
323
397
  try:
324
- result = await self.chatterer.agenerate_pydantic(
325
- response_model=model_cls, messages=[{"role": "user", "content": prompt}]
326
- )
327
- logger.debug(f"[Validated attempt {attempt}] => {result.model_dump_json(indent=2)}")
328
- return result
398
+ return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
329
399
  except ValidationError as e:
330
- logger.warning(f"Validation error on attempt {attempt}/{self.max_retries}: {str(e)}")
400
+ logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
331
401
  if attempt == self.max_retries:
332
- # Return an empty or fallback
333
- if model_cls == EnsembleResponse:
334
- return model_cls(thought="Failed label parse", answer=default_answer, confidence=0.0)
335
- elif model_cls == ContractQuestionResponse:
336
- return model_cls(thought="Failed contract parse", question="Unknown")
337
- elif model_cls == LabelResponse:
338
- return model_cls(thought="Failed label parse", sub_questions=[])
402
+ # Return a fallback version
403
+ if issubclass(model_cls, EnsembleResponse):
404
+ return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
405
+ elif issubclass(model_cls, ContractQuestionResponse):
406
+ return model_cls(thought=fallback, question=fallback) # type: ignore
407
+ elif issubclass(model_cls, LabelResponse):
408
+ return model_cls(thought=fallback, sub_questions=[]) # type: ignore
409
+ elif issubclass(model_cls, CritiqueResponse):
410
+ return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
411
+ elif issubclass(model_cls, DevilsAdvocateResponse):
412
+ return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
339
413
  else:
340
- # fallback for Direct/Decompose
341
- return model_cls(thought="Failed parse", answer=default_answer)
414
+ return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
415
+ # theoretically unreachable
342
416
  raise RuntimeError("Unexpected error in _ainvoke_pydantic")
343
417
 
344
- async def _adetect_domain(self, question: str, context: Optional[str] = None) -> Domain:
418
+ # 4.2) Helper method for self-critique
419
+ async def _ainvoke_critique(
420
+ self,
421
+ messages: list[BaseMessage],
422
+ thought: str,
423
+ answer: str,
424
+ ) -> CritiqueResponse:
345
425
  """
346
- Queries an LLM to figure out which domain is best suited.
426
+ Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
347
427
  """
428
+ return await self._ainvoke_pydantic(
429
+ messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
430
+ model_cls=CritiqueResponse,
431
+ )
348
432
 
349
- class InferredDomain(BaseModel):
350
- domain: Domain
351
-
352
- ctx_str = f"\nCONTEXT:\n{context}" if context else ""
353
- domain_prompt = (
354
- "You are an expert domain classifier. "
355
- "Possible domains: [general, math, coding, philosophy, multihop].\n\n"
356
- "Return valid JSON: {'domain': '...'}.\n"
357
- f"QUESTION:\n{question}{ctx_str}"
433
+ # 4.3) Helper method for devil's advocate approach
434
+ async def _ainvoke_devils_advocate(
435
+ self,
436
+ messages: list[BaseMessage],
437
+ question: str,
438
+ existing_answer: str,
439
+ ) -> DevilsAdvocateResponse:
440
+ """
441
+ Instructs the LLM to challenge an existing answer with a devil's advocate approach.
442
+ """
443
+ return await self._ainvoke_pydantic(
444
+ messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
445
+ model_cls=DevilsAdvocateResponse,
358
446
  )
359
- try:
360
- result: InferredDomain = await self.chatterer.agenerate_pydantic(
361
- response_model=InferredDomain, messages=[{"role": "user", "content": domain_prompt}]
362
- )
363
- return result.domain
364
- except ValidationError:
365
- logger.warning("Failed domain detection, defaulting to general.")
366
- return Domain.GENERAL
367
447
 
448
+ # 4.4) The main function that recursively decomposes a question and calls sub-steps
368
449
  async def _arecursive_decompose_question(
369
- self, question: str, depth: int, prompter: BaseAoTPrompter, context: Optional[str] = None
450
+ self,
451
+ messages: list[BaseMessage],
452
+ question: str,
453
+ depth: int,
454
+ parent_decomp_step_idx: Optional[int] = None,
455
+ parent_subq_idx: Optional[int] = None,
370
456
  ) -> RecursiveDecomposeResponse:
371
457
  """
372
- Recursively decomposes a question into sub-questions, applying an optional label step.
458
+ Recursively decompose the given question. For each sub-question:
459
+ 1) Recursively decompose that sub-question if we still have depth left
460
+ 2) After getting a main sub-answer, do a devil's advocate pass
461
+ 3) Combine main sub-answer + devil's advocate alternative via an ensemble
373
462
  """
374
463
  if depth < 0:
375
- return RecursiveDecomposeResponse(thought="Max depth reached", final_answer="Unknown", sub_questions=[])
464
+ logger.info("Max depth reached, returning unknown.")
465
+ return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
376
466
 
377
- indent: LiteralString = " " * (self.max_depth - depth)
378
- logger.debug(f"{indent}Decomposing at depth {self.max_depth - depth}: {question}")
379
-
380
- # Step 1: Base decomposition
381
- prompt = prompter.recursive_decompose_prompt(question, context=context)
382
- decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(prompt, RecursiveDecomposeResponse)
467
+ # Step 1: Perform the decomposition
468
+ decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
469
+ messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
470
+ model_cls=RecursiveDecomposeResponse,
471
+ )
383
472
 
384
- # Step 2: Label step to refine sub-questions if any
473
+ # Step 2: Label / refine sub-questions (dependencies, ordering)
385
474
  if decompose_resp.sub_questions:
386
- label_prompt: str = prompter.label_prompt(question, decompose_resp, context=context)
387
- label_resp: LabelResponse = await self._ainvoke_pydantic(label_prompt, LabelResponse)
388
- # Overwrite the sub-questions with refined ones
475
+ label_resp: LabelResponse = await self._ainvoke_pydantic(
476
+ messages=self.prompter.label_prompt(messages, question, decompose_resp),
477
+ model_cls=LabelResponse,
478
+ )
389
479
  decompose_resp.sub_questions = label_resp.sub_questions
390
480
 
391
- # Step 3: If depth > 0, try to resolve sub-questions recursively
392
- # so we can potentially update final_answer
481
+ # Save a pipeline record for this decomposition step
482
+ current_decomp_step_idx = self._record_decomposition_step(
483
+ question=question,
484
+ final_answer=decompose_resp.final_answer,
485
+ sub_questions=decompose_resp.sub_questions,
486
+ parent_decomp_step_idx=parent_decomp_step_idx,
487
+ parent_subq_idx=parent_subq_idx,
488
+ )
489
+
490
+ # Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
393
491
  if depth > 0 and decompose_resp.sub_questions:
394
- resolved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
395
- decompose_resp.sub_questions, depth, prompter, context
492
+ solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
493
+ messages=messages,
494
+ sub_questions=decompose_resp.sub_questions,
495
+ depth=depth,
496
+ parent_decomp_step_idx=current_decomp_step_idx,
396
497
  )
397
- # Step 4: Re-invoke decomposition with known sub-answers (sub_answers)
398
- sub_answers_str: str = "\n".join(f"{sq.question}: {sq.answer}" for sq in resolved_subs if sq.answer)
399
- if sub_answers_str:
400
- refine_prompt: str = prompter.recursive_decompose_prompt(question, sub_answers_str, context=context)
401
- refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
402
- refine_prompt, RecursiveDecomposeResponse
403
- )
404
- # Use the refined final answer, keep resolved sub-questions
405
- decompose_resp.final_answer = refined_resp.final_answer
406
- decompose_resp.sub_questions = resolved_subs
498
+ # Then we can refine the "final_answer" from those sub-answers
499
+ # or we do a secondary pass to refine the final answer
500
+ refined_prompt = self.prompter.recursive_decompose_prompt(
501
+ messages=messages,
502
+ question=question,
503
+ sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
504
+ )
505
+ refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
506
+ refined_prompt, RecursiveDecomposeResponse
507
+ )
508
+ decompose_resp.final_answer = refined_resp.final_answer
509
+ decompose_resp.sub_questions = solved_subs
510
+
511
+ # Update pipeline record
512
+ self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
513
+ self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
407
514
 
408
515
  return decompose_resp
409
516
 
517
+ def _record_decomposition_step(
518
+ self,
519
+ question: str,
520
+ final_answer: str,
521
+ sub_questions: list[SubQuestionNode],
522
+ parent_decomp_step_idx: Optional[int],
523
+ parent_subq_idx: Optional[int],
524
+ ) -> int:
525
+ """
526
+ Save the decomposition step in steps_history, returning the index.
527
+ """
528
+ step_record = StepRecord(
529
+ step_name=StepName.DECOMPOSITION,
530
+ question=question,
531
+ answer=final_answer,
532
+ sub_questions=sub_questions,
533
+ parent_decomp_step_idx=parent_decomp_step_idx,
534
+ parent_subq_idx=parent_subq_idx,
535
+ )
536
+ self.steps_history.append(step_record)
537
+ return len(self.steps_history) - 1
538
+
410
539
  async def _aresolve_sub_questions(
411
- self, sub_questions: list[SubQuestionNode], depth: int, prompter: BaseAoTPrompter, context: Optional[str] = None
540
+ self,
541
+ messages: list[BaseMessage],
542
+ sub_questions: list[SubQuestionNode],
543
+ depth: int,
544
+ parent_decomp_step_idx: Optional[int],
412
545
  ) -> list[SubQuestionNode]:
413
546
  """
414
- Resolves each sub-question (potentially reusing the same decomposition approach)
415
- in a topological order of dependencies.
547
+ Resolve sub-questions in topological order.
548
+ For each sub-question:
549
+ 1) Recursively decompose (main approach).
550
+ 2) Acquire a devil's advocate alternative.
551
+ 3) Critique or ensemble if needed.
552
+ 4) Finalize sub-question answer.
416
553
  """
417
- n: int = len(sub_questions)
418
- resolved: dict[int, SubQuestionNode] = {}
419
-
420
- # Build adjacency
421
- in_degree: list[int] = [0] * n
554
+ n = len(sub_questions)
555
+ in_degree = [0] * n
422
556
  graph: list[list[int]] = [[] for _ in range(n)]
423
557
  for i, sq in enumerate(sub_questions):
424
558
  for dep in sq.depend:
@@ -426,169 +560,417 @@ class AoTPipeline:
426
560
  in_degree[i] += 1
427
561
  graph[dep].append(i)
428
562
 
429
- # Topological BFS
430
- queue: list[int] = [i for i in range(n) if in_degree[i] == 0]
431
- order: list[int] = []
563
+ # Kahn's algorithm for topological order
564
+ queue = [i for i in range(n) if in_degree[i] == 0]
565
+ topo_order: list[int] = []
566
+
432
567
  while queue:
433
568
  node = queue.pop(0)
434
- order.append(node)
569
+ topo_order.append(node)
435
570
  for nxt in graph[node]:
436
571
  in_degree[nxt] -= 1
437
572
  if in_degree[nxt] == 0:
438
573
  queue.append(nxt)
439
574
 
440
- # If there's a cycle, some sub-questions won't appear in order
441
- # but we'll attempt to resolve what we can
442
- async def resolve_single_subq(idx: int) -> None:
443
- sq: SubQuestionNode = sub_questions[idx]
444
- # Attempt to answer this sub-question by decomposition if needed
445
- sub_decomp: RecursiveDecomposeResponse = await self._arecursive_decompose_question(
446
- sq.question, depth - 1, prompter, context
575
+ # We'll store the resolved sub-questions
576
+ final_subs: dict[int, SubQuestionNode] = {}
577
+
578
+ async def _resolve_one_subq(idx: int):
579
+ sq = sub_questions[idx]
580
+ # 1) Main approach
581
+ main_resp = await self._arecursive_decompose_question(
582
+ messages=messages,
583
+ question=sq.question,
584
+ depth=depth - 1,
585
+ parent_decomp_step_idx=parent_decomp_step_idx,
586
+ parent_subq_idx=idx,
447
587
  )
448
- sq.answer = sub_decomp.final_answer
449
- resolved[idx] = sq
450
588
 
451
- await asyncio.gather(*(resolve_single_subq(i) for i in order))
589
+ main_answer = main_resp.final_answer
452
590
 
453
- # Return only resolved sub-questions
454
- return [resolved[i] for i in range(n) if i in resolved]
591
+ # 2) Devil's Advocate approach
592
+ devils_resp = await self._ainvoke_devils_advocate(
593
+ messages=messages, question=sq.question, existing_answer=main_answer
594
+ )
595
+ # 3) Ensemble to combine main_answer + devils_alternative
596
+ ensemble_sub = await self._ainvoke_pydantic(
597
+ self.prompter.ensemble_prompt(
598
+ messages=messages,
599
+ possible_thought_and_answers=[
600
+ (main_resp.thought, main_answer),
601
+ (devils_resp.thought, devils_resp.final_answer),
602
+ ],
603
+ ),
604
+ EnsembleResponse,
605
+ )
606
+ sub_best_answer = ensemble_sub.answer
607
+
608
+ # Store final subq answer
609
+ sq.answer = sub_best_answer
610
+ final_subs[idx] = sq
611
+
612
+ # Record pipeline steps for devil's advocate
613
+ self.steps_history.append(
614
+ StepRecord(
615
+ step_name=StepName.DEVILS_ADVOCATE,
616
+ question=sq.question,
617
+ answer=devils_resp.final_answer,
618
+ thought=devils_resp.thought,
619
+ sub_questions=devils_resp.sub_questions,
620
+ )
621
+ )
622
+ # Possibly critique the devils advocate result
623
+ dev_adv_crit = await self._ainvoke_critique(
624
+ messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
625
+ )
626
+ self.steps_history.append(
627
+ StepRecord(
628
+ step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
629
+ thought=dev_adv_crit.thought,
630
+ score=dev_adv_crit.self_assessment,
631
+ )
632
+ )
455
633
 
456
- def _calculate_score(self, answer: str, ground_truth: Optional[str], domain: Domain) -> float:
457
- """
458
- Example scoring function. Real usage depends on having ground-truth.
459
- If ground_truth is None, returns -1.0 as 'no score possible'.
634
+ # Solve sub-questions in topological order
635
+ tasks = [_resolve_one_subq(i) for i in topo_order]
636
+ await asyncio.gather(*tasks, return_exceptions=False)
460
637
 
461
- This function is used for `post-hoc` scoring of the final answer.
462
- """
463
- if ground_truth is None:
464
- return -1.0
638
+ return [final_subs[i] for i in range(n)]
465
639
 
466
- # Very simplistic example:
467
- # MATH: attempt numeric equality
468
- if domain == Domain.MATH:
469
- try:
470
- ans_val = float(answer.strip())
471
- gt_val = float(ground_truth.strip())
472
- return 1.0 if abs(ans_val - gt_val) < 1e-9 else 0.0
473
- except ValueError:
474
- return 0.0
475
-
476
- # For anything else, do a naive exact-match check ignoring case
477
- return 1.0 if answer.strip().lower() == ground_truth.strip().lower() else 0.0
478
-
479
- async def arun_pipeline(
480
- self, question: str, context: Optional[str] = None, ground_truth: Optional[str] = None
481
- ) -> str:
640
+ # 4.5) The primary pipeline method
641
+ async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
482
642
  """
483
- Full AoT pipeline. If ground_truth is provided, we compute a final score.
643
+ Execute the pipeline:
644
+ 1) Decompose the main question (recursively).
645
+ 2) Self-critique.
646
+ 3) Provide a devil's advocate approach on the entire main result.
647
+ 4) Contract sub-answers (optional).
648
+ 5) Directly solve the contracted question.
649
+ 6) Self-critique again.
650
+ 7) Final ensemble across main vs devil's vs contracted direct answer.
651
+ 8) Return final answer.
484
652
  """
485
- # 1) Domain detection
486
- domain = await self._adetect_domain(question, context)
487
- prompter = self.prompter_map.get(domain, GeneralAoTPrompter())
488
- logger.debug(f"Detected domain: {domain}")
489
-
490
- # 2) Direct approach
491
- direct_prompt = prompter.direct_prompt(question, context)
492
- direct_resp = await self._ainvoke_pydantic(direct_prompt, DirectResponse)
493
- direct_answer = direct_resp.answer
494
- logger.debug(f"Direct answer => {direct_answer}")
495
-
496
- # 3) Recursive Decomposition + label
497
- decomp_resp = await self._arecursive_decompose_question(question, self.max_depth, prompter, context)
498
- decompose_answer = decomp_resp.final_answer
499
- logger.debug(f"Decomposition answer => {decompose_answer}")
500
-
501
- # 4) Contract question
502
- sub_answers_str = "\n".join(f"{sq.question}: {sq.answer}" for sq in decomp_resp.sub_questions if sq.answer)
503
- contract_prompt = prompter.contract_prompt(question, sub_answers_str, context)
504
- contract_resp = await self._ainvoke_pydantic(contract_prompt, ContractQuestionResponse)
653
+ self.steps_history.clear()
654
+
655
+ original_question: str = messages[-1].text()
656
+ # 1) Recursive decomposition
657
+ decomp_resp = await self._arecursive_decompose_question(
658
+ messages=messages,
659
+ question=original_question,
660
+ depth=self.max_depth,
661
+ )
662
+ logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
663
+
664
+ # 2) Self-critique of main decomposition
665
+ decomp_critique = await self._ainvoke_critique(
666
+ messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
667
+ )
668
+ self.steps_history.append(
669
+ StepRecord(
670
+ step_name=StepName.DECOMPOSITION_CRITIQUE,
671
+ thought=decomp_critique.thought,
672
+ score=decomp_critique.self_assessment,
673
+ )
674
+ )
675
+
676
+ # 3) Devil's advocate on the entire main answer
677
+ devils_on_main = await self._ainvoke_devils_advocate(
678
+ messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
679
+ )
680
+ self.steps_history.append(
681
+ StepRecord(
682
+ step_name=StepName.DEVILS_ADVOCATE,
683
+ question=original_question,
684
+ answer=devils_on_main.final_answer,
685
+ thought=devils_on_main.thought,
686
+ sub_questions=devils_on_main.sub_questions,
687
+ )
688
+ )
689
+ devils_crit_main = await self._ainvoke_critique(
690
+ messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
691
+ )
692
+ self.steps_history.append(
693
+ StepRecord(
694
+ step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
695
+ thought=devils_crit_main.thought,
696
+ score=devils_crit_main.self_assessment,
697
+ )
698
+ )
699
+
700
+ # 4) Contract sub-answers from main decomposition
701
+ top_decomp_record: Optional[StepRecord] = next(
702
+ (
703
+ s
704
+ for s in reversed(self.steps_history)
705
+ if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
706
+ ),
707
+ None,
708
+ )
709
+ if top_decomp_record and top_decomp_record.sub_questions:
710
+ sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
711
+ else:
712
+ sub_answers = []
713
+
714
+ contract_resp = await self._ainvoke_pydantic(
715
+ messages=self.prompter.contract_prompt(messages, sub_answers),
716
+ model_cls=ContractQuestionResponse,
717
+ )
505
718
  contracted_question = contract_resp.question
506
- logger.debug(f"Contracted question => {contracted_question}")
507
-
508
- # 5) Direct approach on contracted question
509
- contracted_direct_prompt = prompter.direct_prompt(contracted_question, context)
510
- contracted_direct_resp = await self._ainvoke_pydantic(contracted_direct_prompt, DirectResponse)
511
- contracted_direct_answer = contracted_direct_resp.answer
512
- logger.debug(f"Contracted direct answer => {contracted_direct_answer}")
513
-
514
- # 6) Ensemble
515
- ensemble_prompt = prompter.ensemble_prompt(
516
- original_question=question,
517
- direct_answer=direct_answer,
518
- decompose_answer=decompose_answer,
519
- contracted_direct_answer=contracted_direct_answer,
520
- context=context,
719
+ self.steps_history.append(
720
+ StepRecord(
721
+ step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
722
+ )
521
723
  )
522
- ensemble_resp = await self._ainvoke_pydantic(ensemble_prompt, EnsembleResponse)
523
- final_answer = ensemble_resp.answer
524
- logger.debug(f"Ensemble final answer => {final_answer} (confidence={ensemble_resp.confidence})")
525
724
 
526
- # 7) Optional scoring
527
- final_score = self._calculate_score(final_answer, ground_truth, domain)
528
- if final_score >= 0.0:
529
- logger.info(f"Final Score: {final_score:.3f} (domain={domain})")
725
+ # 5) Attempt direct approach on contracted question
726
+ contracted_direct = await self._ainvoke_pydantic(
727
+ messages=self.prompter.contract_direct_prompt(messages, contracted_question),
728
+ model_cls=RecursiveDecomposeResponse,
729
+ fallback="No Contracted Direct Answer",
730
+ )
731
+ self.steps_history.append(
732
+ StepRecord(
733
+ step_name=StepName.CONTRACTED_DIRECT_ANSWER,
734
+ answer=contracted_direct.final_answer,
735
+ thought=contracted_direct.thought,
736
+ )
737
+ )
738
+ logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
530
739
 
531
- return final_answer
740
+ # 5.1) Critique the contracted direct approach
741
+ contract_critique = await self._ainvoke_critique(
742
+ messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
743
+ )
744
+ self.steps_history.append(
745
+ StepRecord(
746
+ step_name=StepName.CONTRACT_CRITIQUE,
747
+ thought=contract_critique.thought,
748
+ score=contract_critique.self_assessment,
749
+ )
750
+ )
532
751
 
752
+ # 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
753
+ ensemble_resp = await self._ainvoke_pydantic(
754
+ self.prompter.ensemble_prompt(
755
+ messages=messages,
756
+ possible_thought_and_answers=[
757
+ (decomp_resp.thought, decomp_resp.final_answer),
758
+ (devils_on_main.thought, devils_on_main.final_answer),
759
+ (contracted_direct.thought, contracted_direct.final_answer),
760
+ ],
761
+ ),
762
+ EnsembleResponse,
763
+ )
764
+ best_approach_answer = ensemble_resp.answer
765
+ approach_used = StepName.ENSEMBLE
766
+ self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
767
+ logger.info(f"[Best Approach Decision] => {approach_used}")
768
+
769
+ # 7) Final answer
770
+ self.steps_history.append(
771
+ StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
772
+ )
773
+ logger.info(f"[Final Answer] => {best_approach_answer}")
533
774
 
534
- # ------------------ 3) AoTStrategy that uses the pipeline ------------------
775
+ return best_approach_answer
535
776
 
777
+ def run_pipeline(self, messages: list[BaseMessage]) -> str:
778
+ """Synchronous wrapper around arun_pipeline."""
779
+ return asyncio.run(self.arun_pipeline(messages))
536
780
 
537
- @dataclass(kw_only=True)
538
- class AoTStrategy(BaseStrategy):
539
- pipeline: AoTPipeline
781
+ # ---------------------------------------------------------------------------------
782
+ # 4.6) Build or export a reasoning graph
783
+ # ---------------------------------------------------------------------------------
540
784
 
541
- async def ainvoke(self, messages: LanguageModelInput) -> str:
542
- logger.debug(f"Invoking with messages: {messages}")
785
+ def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
786
+ """
787
+ Constructs a Graph object (from hypothetical `neo4j_extension`)
788
+ capturing the pipeline steps, including devil's advocate steps.
789
+ """
790
+ from neo4j_extension import Graph, Node, Relationship
791
+
792
+ g = Graph()
793
+ step_nodes: dict[int, Node] = {}
794
+ subq_nodes: dict[str, Node] = {}
795
+
796
+ # Step A: Create nodes for each pipeline step
797
+ for i, record in enumerate(self.steps_history):
798
+ # We'll skip nested Decomposition steps only if we want to flatten them.
799
+ # But let's keep them for clarity.
800
+ step_node = Node(
801
+ properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
802
+ )
803
+ g.add_node(step_node)
804
+ step_nodes[i] = step_node
805
+
806
+ # Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
807
+ all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
808
+ for i, record in enumerate(self.steps_history):
809
+ if record.sub_questions:
810
+ for sq_idx, sq in enumerate(record.sub_questions):
811
+ sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
812
+ all_sub_questions[sq_id] = (i, sq_idx, sq)
813
+
814
+ for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
815
+ n_subq = Node(
816
+ properties={
817
+ "question": sq.question,
818
+ "answer": sq.answer or "",
819
+ },
820
+ labels={"SubQuestion"},
821
+ globalId=sq_id,
822
+ )
823
+ g.add_node(n_subq)
824
+ subq_nodes[sq_id] = n_subq
825
+
826
+ # Step C: Add relationships. We do a simple approach:
827
+ # - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
828
+ for i, record in enumerate(self.steps_history):
829
+ if record.sub_questions:
830
+ start_node = step_nodes[i]
831
+ for sq_idx, sq in enumerate(record.sub_questions):
832
+ sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
833
+ end_node = subq_nodes[sq_id]
834
+ rel = Relationship(
835
+ properties={},
836
+ rel_type=StepRelation.SPLIT_INTO,
837
+ start_node=start_node,
838
+ end_node=end_node,
839
+ globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
840
+ )
841
+ g.add_relationship(rel)
842
+ # Also add sub-question dependencies
843
+ for dep in sq.depend:
844
+ # The same record i -> sub-question subq
845
+ if 0 <= dep < len(record.sub_questions):
846
+ dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
847
+ if dep_id in subq_nodes:
848
+ dep_node = subq_nodes[dep_id]
849
+ rel_dep = Relationship(
850
+ properties={},
851
+ rel_type=StepRelation.DEPEND_ON,
852
+ start_node=end_node,
853
+ end_node=dep_node,
854
+ globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
855
+ )
856
+ g.add_relationship(rel_dep)
857
+
858
+ # Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
859
+ for i in range(len(self.steps_history) - 1):
860
+ start_node = step_nodes[i]
861
+ end_node = step_nodes[i + 1]
862
+ rel = Relationship(
863
+ properties={},
864
+ rel_type=StepRelation.PRECEDES,
865
+ start_node=start_node,
866
+ end_node=end_node,
867
+ globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
868
+ )
869
+ g.add_relationship(rel)
870
+
871
+ # Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
872
+ # We'll do a simple pass:
873
+ # If step_name ends with CRITIQUE => it critiques the step before it
874
+ for i, record in enumerate(self.steps_history):
875
+ if "CRITIQUE" in record.step_name:
876
+ # Let it point to the preceding step
877
+ if i > 0:
878
+ start_node = step_nodes[i]
879
+ end_node = step_nodes[i - 1]
880
+ rel = Relationship(
881
+ properties={},
882
+ rel_type=StepRelation.CRITIQUES,
883
+ start_node=start_node,
884
+ end_node=end_node,
885
+ globalId=f"{global_id_prefix}_crit_{i}",
886
+ )
887
+ g.add_relationship(rel)
888
+
889
+ # If there's a BEST_APPROACH_DECISION step, link it to the step it uses
890
+ best_decision_idx = None
891
+ used_step_idx = None
892
+ for i, record in enumerate(self.steps_history):
893
+ if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
894
+ best_decision_idx = i
895
+ # find the step with that name
896
+ used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
897
+ if used_step_idx is not None:
898
+ rel = Relationship(
899
+ properties={},
900
+ rel_type=StepRelation.SELECTS,
901
+ start_node=step_nodes[i],
902
+ end_node=step_nodes[used_step_idx],
903
+ globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
904
+ )
905
+ g.add_relationship(rel)
906
+
907
+ # And link the final answer to the best approach
908
+ final_answer_idx = next(
909
+ (i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
910
+ )
911
+ if final_answer_idx is not None and best_decision_idx is not None:
912
+ rel = Relationship(
913
+ properties={},
914
+ rel_type=StepRelation.RESULT_OF,
915
+ start_node=step_nodes[final_answer_idx],
916
+ end_node=step_nodes[best_decision_idx],
917
+ globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
918
+ )
919
+ g.add_relationship(rel)
543
920
 
544
- input = self.pipeline.chatterer.client._convert_input(messages)
545
- input_string = input.to_string()
546
- logger.debug(f"Extracted question: {input_string}")
547
- return await self.pipeline.arun_pipeline(input_string)
921
+ return g
548
922
 
549
- def invoke(self, messages: LanguageModelInput) -> str:
550
- return asyncio.run(self.ainvoke(messages))
551
923
 
924
+ # ---------------------------------------------------------------------------------
925
+ # 5) AoTStrategy class that uses the pipeline
926
+ # ---------------------------------------------------------------------------------
552
927
 
553
- # ------------------ 4) Example usage (main) ------------------
554
- if __name__ == "__main__":
555
- from warnings import filterwarnings
556
928
 
557
- import colorama
929
+ @dataclass
930
+ class AoTStrategy(BaseStrategy):
931
+ """
932
+ Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
933
+ """
558
934
 
559
- filterwarnings("ignore", category=UserWarning)
935
+ pipeline: AoTPipeline
560
936
 
561
- colorama.init(autoreset=True)
937
+ async def ainvoke(self, messages: LanguageModelInput) -> str:
938
+ """Asynchronously run the pipeline with the given messages."""
939
+ # Convert your custom input to list[BaseMessage] as needed:
940
+ msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
941
+ return await self.pipeline.arun_pipeline(msgs)
562
942
 
563
- class ColoredFormatter(logging.Formatter):
564
- COLORS = {
565
- "DEBUG": colorama.Fore.CYAN,
566
- "INFO": colorama.Fore.GREEN,
567
- "WARNING": colorama.Fore.YELLOW,
568
- "ERROR": colorama.Fore.RED,
569
- "CRITICAL": colorama.Fore.RED + colorama.Style.BRIGHT,
570
- }
943
+ def invoke(self, messages: LanguageModelInput) -> str:
944
+ """Synchronously run the pipeline with the given messages."""
945
+ msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
946
+ return self.pipeline.run_pipeline(msgs)
571
947
 
572
- def format(self, record):
573
- levelname = record.levelname
574
- message = super().format(record)
575
- return f"{self.COLORS.get(levelname, colorama.Fore.WHITE)}{message}{colorama.Style.RESET_ALL}"
948
+ def get_reasoning_graph(self):
949
+ """Return the AoT reasoning graph from the pipeline’s steps history."""
950
+ return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
576
951
 
577
- logger.setLevel(logging.DEBUG)
578
- handler = logging.StreamHandler()
579
- handler.setLevel(logging.DEBUG)
580
- formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
581
- handler.setFormatter(formatter)
582
- logger.addHandler(handler)
583
952
 
584
- file_handler = logging.FileHandler("atom_of_thoughts.log", encoding="utf-8", mode="w")
585
- file_handler.setLevel(logging.DEBUG)
586
- logger.addHandler(file_handler)
587
- logger.propagate = False
953
+ # ---------------------------------------------------------------------------------
954
+ # Example usage (pseudo-code)
955
+ # ---------------------------------------------------------------------------------
956
+ if __name__ == "__main__":
957
+ from neo4j_extension import Neo4jConnection # or your actual DB connector
588
958
 
589
- pipeline = AoTPipeline(chatterer=Chatterer.openai(), max_depth=2)
959
+ # You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
960
+ chatterer = Chatterer.openai() # pseudo-code
961
+ pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
590
962
  strategy = AoTStrategy(pipeline=pipeline)
591
963
 
592
- question = "What would Newton discover if hit by an apple falling from 100 meters?"
964
+ question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
593
965
  answer = strategy.invoke(question)
594
- print(answer)
966
+ print("Final Answer:", answer)
967
+
968
+ # Build the reasoning graph
969
+ graph = strategy.get_reasoning_graph()
970
+ print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
971
+
972
+ # Optionally store in Neo4j
973
+ with Neo4jConnection() as conn:
974
+ conn.clear_all()
975
+ conn.upsert_graph(graph)
976
+ print("Graph stored in Neo4j.")