chatterer 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. chatterer/__init__.py +93 -93
  2. chatterer/common_types/__init__.py +21 -21
  3. chatterer/common_types/io.py +19 -19
  4. chatterer/examples/__init__.py +0 -0
  5. chatterer/examples/anything_to_markdown.py +85 -91
  6. chatterer/examples/get_code_snippets.py +55 -62
  7. chatterer/examples/login_with_playwright.py +156 -167
  8. chatterer/examples/make_ppt.py +488 -497
  9. chatterer/examples/pdf_to_markdown.py +100 -107
  10. chatterer/examples/pdf_to_text.py +54 -56
  11. chatterer/examples/transcription_api.py +112 -123
  12. chatterer/examples/upstage_parser.py +89 -100
  13. chatterer/examples/webpage_to_markdown.py +70 -79
  14. chatterer/interactive.py +354 -354
  15. chatterer/language_model.py +533 -533
  16. chatterer/messages.py +21 -21
  17. chatterer/strategies/__init__.py +13 -13
  18. chatterer/strategies/atom_of_thoughts.py +975 -975
  19. chatterer/strategies/base.py +14 -14
  20. chatterer/tools/__init__.py +46 -46
  21. chatterer/tools/caption_markdown_images.py +384 -384
  22. chatterer/tools/citation_chunking/__init__.py +3 -3
  23. chatterer/tools/citation_chunking/chunks.py +53 -53
  24. chatterer/tools/citation_chunking/citation_chunker.py +118 -118
  25. chatterer/tools/citation_chunking/citations.py +285 -285
  26. chatterer/tools/citation_chunking/prompt.py +157 -157
  27. chatterer/tools/citation_chunking/reference.py +26 -26
  28. chatterer/tools/citation_chunking/utils.py +138 -138
  29. chatterer/tools/convert_pdf_to_markdown.py +393 -302
  30. chatterer/tools/convert_to_text.py +446 -447
  31. chatterer/tools/upstage_document_parser.py +705 -705
  32. chatterer/tools/webpage_to_markdown.py +739 -739
  33. chatterer/tools/youtube.py +146 -146
  34. chatterer/utils/__init__.py +15 -15
  35. chatterer/utils/base64_image.py +285 -285
  36. chatterer/utils/bytesio.py +59 -59
  37. chatterer/utils/code_agent.py +237 -237
  38. chatterer/utils/imghdr.py +148 -148
  39. {chatterer-0.1.18.dist-info → chatterer-0.1.20.dist-info}/METADATA +392 -392
  40. chatterer-0.1.20.dist-info/RECORD +44 -0
  41. {chatterer-0.1.18.dist-info → chatterer-0.1.20.dist-info}/WHEEL +1 -1
  42. chatterer-0.1.20.dist-info/entry_points.txt +10 -0
  43. chatterer-0.1.18.dist-info/RECORD +0 -42
  44. {chatterer-0.1.18.dist-info → chatterer-0.1.20.dist-info}/top_level.txt +0 -0
@@ -1,975 +1,975 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import logging
5
- from dataclasses import dataclass, field
6
- from enum import StrEnum
7
- from typing import Optional, Type, TypeVar
8
-
9
- from pydantic import BaseModel, Field, ValidationError
10
-
11
- from ..language_model import Chatterer, LanguageModelInput
12
- from ..messages import AIMessage, BaseMessage, HumanMessage
13
- from .base import BaseStrategy
14
-
15
- # ---------------------------------------------------------------------------------
16
- # 0) Enums and Basic Models
17
- # ---------------------------------------------------------------------------------
18
-
19
- QA_TEMPLATE = "Q: {question}\nA: {answer}"
20
- MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
21
- UNKNOWN = "Unknown"
22
-
23
-
24
- class SubQuestionNode(BaseModel):
25
- """A single sub-question node in a decomposition tree."""
26
-
27
- question: str = Field(description="A sub-question string that arises from decomposition.")
28
- answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
29
- depend: list[int] = Field(description="Indices of sub-questions that this node depends on.")
30
-
31
-
32
- class RecursiveDecomposeResponse(BaseModel):
33
- """The result of a recursive decomposition step."""
34
-
35
- thought: str = Field(description="Reasoning about decomposition.")
36
- final_answer: str = Field(description="Best answer to the main question.")
37
- sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
38
-
39
-
40
- class ContractQuestionResponse(BaseModel):
41
- """The result of contracting (simplifying) a question."""
42
-
43
- thought: str = Field(description="Reasoning on how the question was compressed.")
44
- question: str = Field(description="New, simplified, self-contained question.")
45
-
46
-
47
- class EnsembleResponse(BaseModel):
48
- """The ensemble process result."""
49
-
50
- thought: str = Field(description="Explanation for choosing the final answer.")
51
- answer: str = Field(description="Best final answer after ensemble.")
52
- confidence: float = Field(description="Confidence score in [0, 1].")
53
-
54
- def model_post_init(self, __context: object) -> None:
55
- self.confidence = max(0.0, min(1.0, self.confidence))
56
-
57
-
58
- class LabelResponse(BaseModel):
59
- """Used to refine and reorder the sub-questions with corrected dependencies."""
60
-
61
- thought: str = Field(description="Explanation or reasoning about labeling.")
62
- sub_questions: list[SubQuestionNode] = Field(
63
- description="Refined list of sub-questions with corrected dependencies."
64
- )
65
-
66
-
67
- class CritiqueResponse(BaseModel):
68
- """A response used for LLM to self-critique or question its own correctness."""
69
-
70
- thought: str = Field(description="Critical reflection on correctness.")
71
- self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
72
-
73
-
74
- # ---------------------------------------------------------------------------------
75
- # [NEW] Additional classes to incorporate a separate sub-question devil's advocate
76
- # ---------------------------------------------------------------------------------
77
-
78
-
79
- class DevilsAdvocateResponse(BaseModel):
80
- """
81
- A response for a 'devil's advocate' pass.
82
- We consider an alternative viewpoint or contradictory answer.
83
- """
84
-
85
- thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
86
- final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
87
- sub_questions: list[SubQuestionNode] = Field(
88
- description="Any additional sub-questions from the contrarian perspective."
89
- )
90
-
91
-
92
- # ---------------------------------------------------------------------------------
93
- # 1) Prompter Classes with Multi-Hop + Devil's Advocate
94
- # ---------------------------------------------------------------------------------
95
-
96
-
97
- class AoTPrompter:
98
- """Generic base prompter that defines the required prompt methods."""
99
-
100
- def recursive_decompose_prompt(
101
- self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
102
- ) -> list[BaseMessage]:
103
- """
104
- Prompt for main decomposition.
105
- Encourages step-by-step reasoning and listing sub-questions as JSON.
106
- """
107
- decompose_instructions = (
108
- "First, restate the main question.\n"
109
- "Decide if sub-questions are needed. If so, list them.\n"
110
- "In the 'thought' field, show your chain-of-thought.\n"
111
- "Return valid JSON:\n"
112
- "{\n"
113
- ' "thought": "...",\n'
114
- ' "final_answer": "...",\n'
115
- ' "sub_questions": [\n'
116
- ' {"question": "...", "answer": null, "depend": []},\n'
117
- " ...\n"
118
- " ]\n"
119
- "}\n"
120
- )
121
-
122
- content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
123
- return messages + [
124
- HumanMessage(content=f"Main question:\n{question}"),
125
- AIMessage(content=content_sub_answers),
126
- AIMessage(content=decompose_instructions),
127
- ]
128
-
129
- def label_prompt(
130
- self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
131
- ) -> list[BaseMessage]:
132
- """
133
- Prompt for refining the sub-questions and dependencies.
134
- """
135
- label_instructions = (
136
- "Review each sub-question to ensure correctness and proper ordering.\n"
137
- "Return valid JSON in the form:\n"
138
- "{\n"
139
- ' "thought": "...",\n'
140
- ' "sub_questions": [\n'
141
- ' {"question": "...", "answer": "...", "depend": [...]},\n'
142
- " ...\n"
143
- " ]\n"
144
- "}\n"
145
- )
146
- return messages + [
147
- AIMessage(content=f"Question: {question}"),
148
- AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
149
- AIMessage(content=label_instructions),
150
- ]
151
-
152
- def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
153
- """
154
- Prompt for merging sub-answers into one self-contained question.
155
- """
156
- contract_instructions = (
157
- "Please merge sub-answers into a single short question that is fully self-contained.\n"
158
- "In 'thought', show how you unify the information.\n"
159
- "Then produce JSON:\n"
160
- "{\n"
161
- ' "thought": "...",\n'
162
- ' "question": "a short but self-contained question"\n'
163
- "}\n"
164
- )
165
- sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
166
- return messages + [
167
- AIMessage(content="We have these sub-questions and answers:"),
168
- AIMessage(content=sub_q_content),
169
- AIMessage(content=contract_instructions),
170
- ]
171
-
172
- def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
173
- """
174
- Prompt for directly answering the contracted question thoroughly.
175
- """
176
- direct_instructions = (
177
- "Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
178
- "Return JSON:\n"
179
- "{\n"
180
- ' "thought": "...",\n'
181
- ' "final_answer": "..."\n'
182
- "}\n"
183
- )
184
- return messages + [
185
- HumanMessage(content=f"Simplified question: {contracted_question}"),
186
- AIMessage(content=direct_instructions),
187
- ]
188
-
189
- def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
190
- """
191
- Prompt for self-critique.
192
- """
193
- critique_instructions = (
194
- "Critique your own approach. Identify possible errors or leaps.\n"
195
- "Return JSON:\n"
196
- "{\n"
197
- ' "thought": "...",\n'
198
- ' "self_assessment": <float in [0,1]>\n'
199
- "}\n"
200
- )
201
- return messages + [
202
- AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
203
- AIMessage(content=f"Your previous ANSWER:\n{answer}"),
204
- AIMessage(content=critique_instructions),
205
- ]
206
-
207
- def ensemble_prompt(
208
- self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
209
- ) -> list[BaseMessage]:
210
- """
211
- Show multiple candidate solutions and pick the best final answer with confidence.
212
- """
213
- instructions = (
214
- "You have multiple candidate solutions. Compare carefully and pick the best.\n"
215
- "Return JSON:\n"
216
- "{\n"
217
- ' "thought": "why you chose this final answer",\n'
218
- ' "answer": "the best consolidated answer",\n'
219
- ' "confidence": 0.0 ~ 1.0\n'
220
- "}\n"
221
- )
222
- reasonings: list[BaseMessage] = []
223
- for idx, (thought, ans) in enumerate(possible_thought_and_answers):
224
- reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
225
- return messages + reasonings + [AIMessage(content=instructions)]
226
-
227
- def devils_advocate_prompt(
228
- self, messages: list[BaseMessage], question: str, existing_answer: str
229
- ) -> list[BaseMessage]:
230
- """
231
- Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
232
- """
233
- instructions = (
234
- "Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
235
- "Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
236
- "Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
237
- "But here, let's keep it in a new dedicated structure:\n"
238
- "{\n"
239
- ' "thought": "...",\n'
240
- ' "final_answer": "...",\n'
241
- ' "sub_questions": [\n'
242
- ' {"question": "...", "answer": null, "depend": []},\n'
243
- " ...\n"
244
- " ]\n"
245
- "}\n"
246
- )
247
- return messages + [
248
- AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
249
- AIMessage(content=instructions),
250
- ]
251
-
252
-
253
- # ---------------------------------------------------------------------------------
254
- # 2) Strict Typed Steps for Pipeline
255
- # ---------------------------------------------------------------------------------
256
-
257
-
258
- class StepName(StrEnum):
259
- """Enum for step names in the pipeline."""
260
-
261
- DOMAIN_DETECTION = "DomainDetection"
262
- DECOMPOSITION = "Decomposition"
263
- DECOMPOSITION_CRITIQUE = "DecompositionCritique"
264
- CONTRACTED_QUESTION = "ContractedQuestion"
265
- CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
266
- CONTRACT_CRITIQUE = "ContractCritique"
267
- BEST_APPROACH_DECISION = "BestApproachDecision"
268
- ENSEMBLE = "Ensemble"
269
- FINAL_ANSWER = "FinalAnswer"
270
-
271
- DEVILS_ADVOCATE = "DevilsAdvocate"
272
- DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
273
-
274
-
275
- class StepRelation(StrEnum):
276
- """Enum for relationship types in the reasoning graph."""
277
-
278
- CRITIQUES = "CRITIQUES"
279
- SELECTS = "SELECTS"
280
- RESULT_OF = "RESULT_OF"
281
- SPLIT_INTO = "SPLIT_INTO"
282
- DEPEND_ON = "DEPEND_ON"
283
- PRECEDES = "PRECEDES"
284
- DECOMPOSED_BY = "DECOMPOSED_BY"
285
-
286
-
287
- class StepRecord(BaseModel):
288
- """A typed record for each pipeline step."""
289
-
290
- step_name: StepName
291
- domain: Optional[str] = None
292
- score: Optional[float] = None
293
- used: Optional[StepName] = None
294
- sub_questions: Optional[list[SubQuestionNode]] = None
295
- parent_decomp_step_idx: Optional[int] = None
296
- parent_subq_idx: Optional[int] = None
297
- question: Optional[str] = None
298
- thought: Optional[str] = None
299
- answer: Optional[str] = None
300
-
301
- def as_properties(self) -> dict[str, str | float | int | None]:
302
- """Converts the StepRecord to a dictionary of properties."""
303
- result: dict[str, str | float | int | None] = {}
304
- if self.score is not None:
305
- result["score"] = self.score
306
- if self.domain:
307
- result["domain"] = self.domain
308
- if self.question:
309
- result["question"] = self.question
310
- if self.thought:
311
- result["thought"] = self.thought
312
- if self.answer:
313
- result["answer"] = self.answer
314
- return result
315
-
316
-
317
- # ---------------------------------------------------------------------------------
318
- # 3) Logging Setup
319
- # ---------------------------------------------------------------------------------
320
-
321
-
322
- class SimpleColorFormatter(logging.Formatter):
323
- """Simple color-coded logging formatter for console output using ANSI escape codes."""
324
-
325
- BLUE = "\033[94m"
326
- GREEN = "\033[92m"
327
- YELLOW = "\033[93m"
328
- RED = "\033[91m"
329
- RESET = "\033[0m"
330
- LEVEL_COLORS = {
331
- logging.DEBUG: BLUE,
332
- logging.INFO: GREEN,
333
- logging.WARNING: YELLOW,
334
- logging.ERROR: RED,
335
- logging.CRITICAL: RED,
336
- }
337
-
338
- def format(self, record: logging.LogRecord) -> str:
339
- log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
340
- message = super().format(record)
341
- return f"{log_color}{message}{self.RESET}"
342
-
343
-
344
- logger = logging.getLogger("AoT")
345
- logger.setLevel(logging.INFO)
346
- handler = logging.StreamHandler()
347
- handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
348
- logger.handlers = [handler]
349
- logger.propagate = False
350
-
351
-
352
- # ---------------------------------------------------------------------------------
353
- # 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
354
- # ---------------------------------------------------------------------------------
355
-
356
- T = TypeVar(
357
- "T",
358
- bound=EnsembleResponse
359
- | ContractQuestionResponse
360
- | LabelResponse
361
- | CritiqueResponse
362
- | RecursiveDecomposeResponse
363
- | DevilsAdvocateResponse,
364
- )
365
-
366
-
367
- @dataclass
368
- class AoTPipeline:
369
- """
370
- The pipeline orchestrates:
371
- 1) Recursive decomposition
372
- 2) For each sub-question, it tries a main approach + a devil's advocate approach
373
- 3) Merges sub-answers using an ensemble
374
- 4) Contracts the question
375
- 5) Possibly does a direct approach on the contracted question
376
- 6) Ensembling the final answers
377
- """
378
-
379
- chatterer: Chatterer
380
- max_depth: int = 2
381
- max_retries: int = 2
382
- steps_history: list[StepRecord] = field(default_factory=list)
383
- prompter: AoTPrompter = field(default_factory=AoTPrompter)
384
-
385
- # 4.1) Utility for calling the LLM with Pydantic parsing
386
- async def _ainvoke_pydantic(
387
- self,
388
- messages: list[BaseMessage],
389
- model_cls: Type[T],
390
- fallback: str = "<None>",
391
- ) -> T:
392
- """
393
- Attempts up to max_retries to parse the model_cls from LLM output as JSON.
394
- """
395
- for attempt in range(1, self.max_retries + 1):
396
- try:
397
- return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
398
- except ValidationError as e:
399
- logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
400
- if attempt == self.max_retries:
401
- # Return a fallback version
402
- if issubclass(model_cls, EnsembleResponse):
403
- return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
404
- elif issubclass(model_cls, ContractQuestionResponse):
405
- return model_cls(thought=fallback, question=fallback) # type: ignore
406
- elif issubclass(model_cls, LabelResponse):
407
- return model_cls(thought=fallback, sub_questions=[]) # type: ignore
408
- elif issubclass(model_cls, CritiqueResponse):
409
- return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
410
- elif issubclass(model_cls, DevilsAdvocateResponse):
411
- return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
412
- else:
413
- return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
414
- # theoretically unreachable
415
- raise RuntimeError("Unexpected error in _ainvoke_pydantic")
416
-
417
- # 4.2) Helper method for self-critique
418
- async def _ainvoke_critique(
419
- self,
420
- messages: list[BaseMessage],
421
- thought: str,
422
- answer: str,
423
- ) -> CritiqueResponse:
424
- """
425
- Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
426
- """
427
- return await self._ainvoke_pydantic(
428
- messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
429
- model_cls=CritiqueResponse,
430
- )
431
-
432
- # 4.3) Helper method for devil's advocate approach
433
- async def _ainvoke_devils_advocate(
434
- self,
435
- messages: list[BaseMessage],
436
- question: str,
437
- existing_answer: str,
438
- ) -> DevilsAdvocateResponse:
439
- """
440
- Instructs the LLM to challenge an existing answer with a devil's advocate approach.
441
- """
442
- return await self._ainvoke_pydantic(
443
- messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
444
- model_cls=DevilsAdvocateResponse,
445
- )
446
-
447
- # 4.4) The main function that recursively decomposes a question and calls sub-steps
448
- async def _arecursive_decompose_question(
449
- self,
450
- messages: list[BaseMessage],
451
- question: str,
452
- depth: int,
453
- parent_decomp_step_idx: Optional[int] = None,
454
- parent_subq_idx: Optional[int] = None,
455
- ) -> RecursiveDecomposeResponse:
456
- """
457
- Recursively decompose the given question. For each sub-question:
458
- 1) Recursively decompose that sub-question if we still have depth left
459
- 2) After getting a main sub-answer, do a devil's advocate pass
460
- 3) Combine main sub-answer + devil's advocate alternative via an ensemble
461
- """
462
- if depth < 0:
463
- logger.info("Max depth reached, returning unknown.")
464
- return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
465
-
466
- # Step 1: Perform the decomposition
467
- decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
468
- messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
469
- model_cls=RecursiveDecomposeResponse,
470
- )
471
-
472
- # Step 2: Label / refine sub-questions (dependencies, ordering)
473
- if decompose_resp.sub_questions:
474
- label_resp: LabelResponse = await self._ainvoke_pydantic(
475
- messages=self.prompter.label_prompt(messages, question, decompose_resp),
476
- model_cls=LabelResponse,
477
- )
478
- decompose_resp.sub_questions = label_resp.sub_questions
479
-
480
- # Save a pipeline record for this decomposition step
481
- current_decomp_step_idx = self._record_decomposition_step(
482
- question=question,
483
- final_answer=decompose_resp.final_answer,
484
- sub_questions=decompose_resp.sub_questions,
485
- parent_decomp_step_idx=parent_decomp_step_idx,
486
- parent_subq_idx=parent_subq_idx,
487
- )
488
-
489
- # Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
490
- if depth > 0 and decompose_resp.sub_questions:
491
- solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
492
- messages=messages,
493
- sub_questions=decompose_resp.sub_questions,
494
- depth=depth,
495
- parent_decomp_step_idx=current_decomp_step_idx,
496
- )
497
- # Then we can refine the "final_answer" from those sub-answers
498
- # or we do a secondary pass to refine the final answer
499
- refined_prompt = self.prompter.recursive_decompose_prompt(
500
- messages=messages,
501
- question=question,
502
- sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
503
- )
504
- refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
505
- refined_prompt, RecursiveDecomposeResponse
506
- )
507
- decompose_resp.final_answer = refined_resp.final_answer
508
- decompose_resp.sub_questions = solved_subs
509
-
510
- # Update pipeline record
511
- self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
512
- self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
513
-
514
- return decompose_resp
515
-
516
- def _record_decomposition_step(
517
- self,
518
- question: str,
519
- final_answer: str,
520
- sub_questions: list[SubQuestionNode],
521
- parent_decomp_step_idx: Optional[int],
522
- parent_subq_idx: Optional[int],
523
- ) -> int:
524
- """
525
- Save the decomposition step in steps_history, returning the index.
526
- """
527
- step_record = StepRecord(
528
- step_name=StepName.DECOMPOSITION,
529
- question=question,
530
- answer=final_answer,
531
- sub_questions=sub_questions,
532
- parent_decomp_step_idx=parent_decomp_step_idx,
533
- parent_subq_idx=parent_subq_idx,
534
- )
535
- self.steps_history.append(step_record)
536
- return len(self.steps_history) - 1
537
-
538
- async def _aresolve_sub_questions(
539
- self,
540
- messages: list[BaseMessage],
541
- sub_questions: list[SubQuestionNode],
542
- depth: int,
543
- parent_decomp_step_idx: Optional[int],
544
- ) -> list[SubQuestionNode]:
545
- """
546
- Resolve sub-questions in topological order.
547
- For each sub-question:
548
- 1) Recursively decompose (main approach).
549
- 2) Acquire a devil's advocate alternative.
550
- 3) Critique or ensemble if needed.
551
- 4) Finalize sub-question answer.
552
- """
553
- n = len(sub_questions)
554
- in_degree = [0] * n
555
- graph: list[list[int]] = [[] for _ in range(n)]
556
- for i, sq in enumerate(sub_questions):
557
- for dep in sq.depend:
558
- if 0 <= dep < n:
559
- in_degree[i] += 1
560
- graph[dep].append(i)
561
-
562
- # Kahn's algorithm for topological order
563
- queue = [i for i in range(n) if in_degree[i] == 0]
564
- topo_order: list[int] = []
565
-
566
- while queue:
567
- node = queue.pop(0)
568
- topo_order.append(node)
569
- for nxt in graph[node]:
570
- in_degree[nxt] -= 1
571
- if in_degree[nxt] == 0:
572
- queue.append(nxt)
573
-
574
- # We'll store the resolved sub-questions
575
- final_subs: dict[int, SubQuestionNode] = {}
576
-
577
- async def _resolve_one_subq(idx: int):
578
- sq = sub_questions[idx]
579
- # 1) Main approach
580
- main_resp = await self._arecursive_decompose_question(
581
- messages=messages,
582
- question=sq.question,
583
- depth=depth - 1,
584
- parent_decomp_step_idx=parent_decomp_step_idx,
585
- parent_subq_idx=idx,
586
- )
587
-
588
- main_answer = main_resp.final_answer
589
-
590
- # 2) Devil's Advocate approach
591
- devils_resp = await self._ainvoke_devils_advocate(
592
- messages=messages, question=sq.question, existing_answer=main_answer
593
- )
594
- # 3) Ensemble to combine main_answer + devils_alternative
595
- ensemble_sub = await self._ainvoke_pydantic(
596
- self.prompter.ensemble_prompt(
597
- messages=messages,
598
- possible_thought_and_answers=[
599
- (main_resp.thought, main_answer),
600
- (devils_resp.thought, devils_resp.final_answer),
601
- ],
602
- ),
603
- EnsembleResponse,
604
- )
605
- sub_best_answer = ensemble_sub.answer
606
-
607
- # Store final subq answer
608
- sq.answer = sub_best_answer
609
- final_subs[idx] = sq
610
-
611
- # Record pipeline steps for devil's advocate
612
- self.steps_history.append(
613
- StepRecord(
614
- step_name=StepName.DEVILS_ADVOCATE,
615
- question=sq.question,
616
- answer=devils_resp.final_answer,
617
- thought=devils_resp.thought,
618
- sub_questions=devils_resp.sub_questions,
619
- )
620
- )
621
- # Possibly critique the devils advocate result
622
- dev_adv_crit = await self._ainvoke_critique(
623
- messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
624
- )
625
- self.steps_history.append(
626
- StepRecord(
627
- step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
628
- thought=dev_adv_crit.thought,
629
- score=dev_adv_crit.self_assessment,
630
- )
631
- )
632
-
633
- # Solve sub-questions in topological order
634
- tasks = [_resolve_one_subq(i) for i in topo_order]
635
- await asyncio.gather(*tasks, return_exceptions=False)
636
-
637
- return [final_subs[i] for i in range(n)]
638
-
639
- # 4.5) The primary pipeline method
640
- async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
641
- """
642
- Execute the pipeline:
643
- 1) Decompose the main question (recursively).
644
- 2) Self-critique.
645
- 3) Provide a devil's advocate approach on the entire main result.
646
- 4) Contract sub-answers (optional).
647
- 5) Directly solve the contracted question.
648
- 6) Self-critique again.
649
- 7) Final ensemble across main vs devil's vs contracted direct answer.
650
- 8) Return final answer.
651
- """
652
- self.steps_history.clear()
653
-
654
- original_question: str = messages[-1].text()
655
- # 1) Recursive decomposition
656
- decomp_resp = await self._arecursive_decompose_question(
657
- messages=messages,
658
- question=original_question,
659
- depth=self.max_depth,
660
- )
661
- logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
662
-
663
- # 2) Self-critique of main decomposition
664
- decomp_critique = await self._ainvoke_critique(
665
- messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
666
- )
667
- self.steps_history.append(
668
- StepRecord(
669
- step_name=StepName.DECOMPOSITION_CRITIQUE,
670
- thought=decomp_critique.thought,
671
- score=decomp_critique.self_assessment,
672
- )
673
- )
674
-
675
- # 3) Devil's advocate on the entire main answer
676
- devils_on_main = await self._ainvoke_devils_advocate(
677
- messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
678
- )
679
- self.steps_history.append(
680
- StepRecord(
681
- step_name=StepName.DEVILS_ADVOCATE,
682
- question=original_question,
683
- answer=devils_on_main.final_answer,
684
- thought=devils_on_main.thought,
685
- sub_questions=devils_on_main.sub_questions,
686
- )
687
- )
688
- devils_crit_main = await self._ainvoke_critique(
689
- messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
690
- )
691
- self.steps_history.append(
692
- StepRecord(
693
- step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
694
- thought=devils_crit_main.thought,
695
- score=devils_crit_main.self_assessment,
696
- )
697
- )
698
-
699
- # 4) Contract sub-answers from main decomposition
700
- top_decomp_record: Optional[StepRecord] = next(
701
- (
702
- s
703
- for s in reversed(self.steps_history)
704
- if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
705
- ),
706
- None,
707
- )
708
- if top_decomp_record and top_decomp_record.sub_questions:
709
- sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
710
- else:
711
- sub_answers = []
712
-
713
- contract_resp = await self._ainvoke_pydantic(
714
- messages=self.prompter.contract_prompt(messages, sub_answers),
715
- model_cls=ContractQuestionResponse,
716
- )
717
- contracted_question = contract_resp.question
718
- self.steps_history.append(
719
- StepRecord(
720
- step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
721
- )
722
- )
723
-
724
- # 5) Attempt direct approach on contracted question
725
- contracted_direct = await self._ainvoke_pydantic(
726
- messages=self.prompter.contract_direct_prompt(messages, contracted_question),
727
- model_cls=RecursiveDecomposeResponse,
728
- fallback="No Contracted Direct Answer",
729
- )
730
- self.steps_history.append(
731
- StepRecord(
732
- step_name=StepName.CONTRACTED_DIRECT_ANSWER,
733
- answer=contracted_direct.final_answer,
734
- thought=contracted_direct.thought,
735
- )
736
- )
737
- logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
738
-
739
- # 5.1) Critique the contracted direct approach
740
- contract_critique = await self._ainvoke_critique(
741
- messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
742
- )
743
- self.steps_history.append(
744
- StepRecord(
745
- step_name=StepName.CONTRACT_CRITIQUE,
746
- thought=contract_critique.thought,
747
- score=contract_critique.self_assessment,
748
- )
749
- )
750
-
751
- # 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
752
- ensemble_resp = await self._ainvoke_pydantic(
753
- self.prompter.ensemble_prompt(
754
- messages=messages,
755
- possible_thought_and_answers=[
756
- (decomp_resp.thought, decomp_resp.final_answer),
757
- (devils_on_main.thought, devils_on_main.final_answer),
758
- (contracted_direct.thought, contracted_direct.final_answer),
759
- ],
760
- ),
761
- EnsembleResponse,
762
- )
763
- best_approach_answer = ensemble_resp.answer
764
- approach_used = StepName.ENSEMBLE
765
- self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
766
- logger.info(f"[Best Approach Decision] => {approach_used}")
767
-
768
- # 7) Final answer
769
- self.steps_history.append(
770
- StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
771
- )
772
- logger.info(f"[Final Answer] => {best_approach_answer}")
773
-
774
- return best_approach_answer
775
-
776
- def run_pipeline(self, messages: list[BaseMessage]) -> str:
777
- """Synchronous wrapper around arun_pipeline."""
778
- return asyncio.run(self.arun_pipeline(messages))
779
-
780
- # ---------------------------------------------------------------------------------
781
- # 4.6) Build or export a reasoning graph
782
- # ---------------------------------------------------------------------------------
783
-
784
- def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
785
- """
786
- Constructs a Graph object (from hypothetical `neo4j_extension`)
787
- capturing the pipeline steps, including devil's advocate steps.
788
- """
789
- from neo4j_extension import Graph, Node, Relationship
790
-
791
- g = Graph()
792
- step_nodes: dict[int, Node] = {}
793
- subq_nodes: dict[str, Node] = {}
794
-
795
- # Step A: Create nodes for each pipeline step
796
- for i, record in enumerate(self.steps_history):
797
- # We'll skip nested Decomposition steps only if we want to flatten them.
798
- # But let's keep them for clarity.
799
- step_node = Node(
800
- properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
801
- )
802
- g.add_node(step_node)
803
- step_nodes[i] = step_node
804
-
805
- # Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
806
- all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
807
- for i, record in enumerate(self.steps_history):
808
- if record.sub_questions:
809
- for sq_idx, sq in enumerate(record.sub_questions):
810
- sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
811
- all_sub_questions[sq_id] = (i, sq_idx, sq)
812
-
813
- for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
814
- n_subq = Node(
815
- properties={
816
- "question": sq.question,
817
- "answer": sq.answer or "",
818
- },
819
- labels={"SubQuestion"},
820
- globalId=sq_id,
821
- )
822
- g.add_node(n_subq)
823
- subq_nodes[sq_id] = n_subq
824
-
825
- # Step C: Add relationships. We do a simple approach:
826
- # - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
827
- for i, record in enumerate(self.steps_history):
828
- if record.sub_questions:
829
- start_node = step_nodes[i]
830
- for sq_idx, sq in enumerate(record.sub_questions):
831
- sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
832
- end_node = subq_nodes[sq_id]
833
- rel = Relationship(
834
- properties={},
835
- rel_type=StepRelation.SPLIT_INTO,
836
- start_node=start_node,
837
- end_node=end_node,
838
- globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
839
- )
840
- g.add_relationship(rel)
841
- # Also add sub-question dependencies
842
- for dep in sq.depend:
843
- # The same record i -> sub-question subq
844
- if 0 <= dep < len(record.sub_questions):
845
- dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
846
- if dep_id in subq_nodes:
847
- dep_node = subq_nodes[dep_id]
848
- rel_dep = Relationship(
849
- properties={},
850
- rel_type=StepRelation.DEPEND_ON,
851
- start_node=end_node,
852
- end_node=dep_node,
853
- globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
854
- )
855
- g.add_relationship(rel_dep)
856
-
857
- # Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
858
- for i in range(len(self.steps_history) - 1):
859
- start_node = step_nodes[i]
860
- end_node = step_nodes[i + 1]
861
- rel = Relationship(
862
- properties={},
863
- rel_type=StepRelation.PRECEDES,
864
- start_node=start_node,
865
- end_node=end_node,
866
- globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
867
- )
868
- g.add_relationship(rel)
869
-
870
- # Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
871
- # We'll do a simple pass:
872
- # If step_name ends with CRITIQUE => it critiques the step before it
873
- for i, record in enumerate(self.steps_history):
874
- if "CRITIQUE" in record.step_name:
875
- # Let it point to the preceding step
876
- if i > 0:
877
- start_node = step_nodes[i]
878
- end_node = step_nodes[i - 1]
879
- rel = Relationship(
880
- properties={},
881
- rel_type=StepRelation.CRITIQUES,
882
- start_node=start_node,
883
- end_node=end_node,
884
- globalId=f"{global_id_prefix}_crit_{i}",
885
- )
886
- g.add_relationship(rel)
887
-
888
- # If there's a BEST_APPROACH_DECISION step, link it to the step it uses
889
- best_decision_idx = None
890
- used_step_idx = None
891
- for i, record in enumerate(self.steps_history):
892
- if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
893
- best_decision_idx = i
894
- # find the step with that name
895
- used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
896
- if used_step_idx is not None:
897
- rel = Relationship(
898
- properties={},
899
- rel_type=StepRelation.SELECTS,
900
- start_node=step_nodes[i],
901
- end_node=step_nodes[used_step_idx],
902
- globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
903
- )
904
- g.add_relationship(rel)
905
-
906
- # And link the final answer to the best approach
907
- final_answer_idx = next(
908
- (i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
909
- )
910
- if final_answer_idx is not None and best_decision_idx is not None:
911
- rel = Relationship(
912
- properties={},
913
- rel_type=StepRelation.RESULT_OF,
914
- start_node=step_nodes[final_answer_idx],
915
- end_node=step_nodes[best_decision_idx],
916
- globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
917
- )
918
- g.add_relationship(rel)
919
-
920
- return g
921
-
922
-
923
- # ---------------------------------------------------------------------------------
924
- # 5) AoTStrategy class that uses the pipeline
925
- # ---------------------------------------------------------------------------------
926
-
927
-
928
- @dataclass
929
- class AoTStrategy(BaseStrategy):
930
- """
931
- Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
932
- """
933
-
934
- pipeline: AoTPipeline
935
-
936
- async def ainvoke(self, messages: LanguageModelInput) -> str:
937
- """Asynchronously run the pipeline with the given messages."""
938
- # Convert your custom input to list[BaseMessage] as needed:
939
- msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
940
- return await self.pipeline.arun_pipeline(msgs)
941
-
942
- def invoke(self, messages: LanguageModelInput) -> str:
943
- """Synchronously run the pipeline with the given messages."""
944
- msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
945
- return self.pipeline.run_pipeline(msgs)
946
-
947
- def get_reasoning_graph(self):
948
- """Return the AoT reasoning graph from the pipeline’s steps history."""
949
- return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
950
-
951
-
952
- # ---------------------------------------------------------------------------------
953
- # Example usage (pseudo-code)
954
- # ---------------------------------------------------------------------------------
955
- if __name__ == "__main__":
956
- from neo4j_extension import Neo4jConnection # or your actual DB connector
957
-
958
- # You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
959
- chatterer = Chatterer.openai() # pseudo-code
960
- pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
961
- strategy = AoTStrategy(pipeline=pipeline)
962
-
963
- question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
964
- answer = strategy.invoke(question)
965
- print("Final Answer:", answer)
966
-
967
- # Build the reasoning graph
968
- graph = strategy.get_reasoning_graph()
969
- print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
970
-
971
- # Optionally store in Neo4j
972
- with Neo4jConnection() as conn:
973
- conn.clear_all()
974
- conn.upsert_graph(graph)
975
- print("Graph stored in Neo4j.")
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from enum import StrEnum
7
+ from typing import Optional, Type, TypeVar
8
+
9
+ from pydantic import BaseModel, Field, ValidationError
10
+
11
+ from ..language_model import Chatterer, LanguageModelInput
12
+ from ..messages import AIMessage, BaseMessage, HumanMessage
13
+ from .base import BaseStrategy
14
+
15
+ # ---------------------------------------------------------------------------------
16
+ # 0) Enums and Basic Models
17
+ # ---------------------------------------------------------------------------------
18
+
19
+ QA_TEMPLATE = "Q: {question}\nA: {answer}"
20
+ MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
21
+ UNKNOWN = "Unknown"
22
+
23
+
24
+ class SubQuestionNode(BaseModel):
25
+ """A single sub-question node in a decomposition tree."""
26
+
27
+ question: str = Field(description="A sub-question string that arises from decomposition.")
28
+ answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
29
+ depend: list[int] = Field(description="Indices of sub-questions that this node depends on.")
30
+
31
+
32
+ class RecursiveDecomposeResponse(BaseModel):
33
+ """The result of a recursive decomposition step."""
34
+
35
+ thought: str = Field(description="Reasoning about decomposition.")
36
+ final_answer: str = Field(description="Best answer to the main question.")
37
+ sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
38
+
39
+
40
+ class ContractQuestionResponse(BaseModel):
41
+ """The result of contracting (simplifying) a question."""
42
+
43
+ thought: str = Field(description="Reasoning on how the question was compressed.")
44
+ question: str = Field(description="New, simplified, self-contained question.")
45
+
46
+
47
+ class EnsembleResponse(BaseModel):
48
+ """The ensemble process result."""
49
+
50
+ thought: str = Field(description="Explanation for choosing the final answer.")
51
+ answer: str = Field(description="Best final answer after ensemble.")
52
+ confidence: float = Field(description="Confidence score in [0, 1].")
53
+
54
+ def model_post_init(self, __context: object) -> None:
55
+ self.confidence = max(0.0, min(1.0, self.confidence))
56
+
57
+
58
+ class LabelResponse(BaseModel):
59
+ """Used to refine and reorder the sub-questions with corrected dependencies."""
60
+
61
+ thought: str = Field(description="Explanation or reasoning about labeling.")
62
+ sub_questions: list[SubQuestionNode] = Field(
63
+ description="Refined list of sub-questions with corrected dependencies."
64
+ )
65
+
66
+
67
+ class CritiqueResponse(BaseModel):
68
+ """A response used for LLM to self-critique or question its own correctness."""
69
+
70
+ thought: str = Field(description="Critical reflection on correctness.")
71
+ self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
72
+
73
+
74
+ # ---------------------------------------------------------------------------------
75
+ # [NEW] Additional classes to incorporate a separate sub-question devil's advocate
76
+ # ---------------------------------------------------------------------------------
77
+
78
+
79
+ class DevilsAdvocateResponse(BaseModel):
80
+ """
81
+ A response for a 'devil's advocate' pass.
82
+ We consider an alternative viewpoint or contradictory answer.
83
+ """
84
+
85
+ thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
86
+ final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
87
+ sub_questions: list[SubQuestionNode] = Field(
88
+ description="Any additional sub-questions from the contrarian perspective."
89
+ )
90
+
91
+
92
+ # ---------------------------------------------------------------------------------
93
+ # 1) Prompter Classes with Multi-Hop + Devil's Advocate
94
+ # ---------------------------------------------------------------------------------
95
+
96
+
97
+ class AoTPrompter:
98
+ """Generic base prompter that defines the required prompt methods."""
99
+
100
+ def recursive_decompose_prompt(
101
+ self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
102
+ ) -> list[BaseMessage]:
103
+ """
104
+ Prompt for main decomposition.
105
+ Encourages step-by-step reasoning and listing sub-questions as JSON.
106
+ """
107
+ decompose_instructions = (
108
+ "First, restate the main question.\n"
109
+ "Decide if sub-questions are needed. If so, list them.\n"
110
+ "In the 'thought' field, show your chain-of-thought.\n"
111
+ "Return valid JSON:\n"
112
+ "{\n"
113
+ ' "thought": "...",\n'
114
+ ' "final_answer": "...",\n'
115
+ ' "sub_questions": [\n'
116
+ ' {"question": "...", "answer": null, "depend": []},\n'
117
+ " ...\n"
118
+ " ]\n"
119
+ "}\n"
120
+ )
121
+
122
+ content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
123
+ return messages + [
124
+ HumanMessage(content=f"Main question:\n{question}"),
125
+ AIMessage(content=content_sub_answers),
126
+ AIMessage(content=decompose_instructions),
127
+ ]
128
+
129
+ def label_prompt(
130
+ self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
131
+ ) -> list[BaseMessage]:
132
+ """
133
+ Prompt for refining the sub-questions and dependencies.
134
+ """
135
+ label_instructions = (
136
+ "Review each sub-question to ensure correctness and proper ordering.\n"
137
+ "Return valid JSON in the form:\n"
138
+ "{\n"
139
+ ' "thought": "...",\n'
140
+ ' "sub_questions": [\n'
141
+ ' {"question": "...", "answer": "...", "depend": [...]},\n'
142
+ " ...\n"
143
+ " ]\n"
144
+ "}\n"
145
+ )
146
+ return messages + [
147
+ AIMessage(content=f"Question: {question}"),
148
+ AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
149
+ AIMessage(content=label_instructions),
150
+ ]
151
+
152
+ def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
153
+ """
154
+ Prompt for merging sub-answers into one self-contained question.
155
+ """
156
+ contract_instructions = (
157
+ "Please merge sub-answers into a single short question that is fully self-contained.\n"
158
+ "In 'thought', show how you unify the information.\n"
159
+ "Then produce JSON:\n"
160
+ "{\n"
161
+ ' "thought": "...",\n'
162
+ ' "question": "a short but self-contained question"\n'
163
+ "}\n"
164
+ )
165
+ sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
166
+ return messages + [
167
+ AIMessage(content="We have these sub-questions and answers:"),
168
+ AIMessage(content=sub_q_content),
169
+ AIMessage(content=contract_instructions),
170
+ ]
171
+
172
+ def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
173
+ """
174
+ Prompt for directly answering the contracted question thoroughly.
175
+ """
176
+ direct_instructions = (
177
+ "Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
178
+ "Return JSON:\n"
179
+ "{\n"
180
+ ' "thought": "...",\n'
181
+ ' "final_answer": "..."\n'
182
+ "}\n"
183
+ )
184
+ return messages + [
185
+ HumanMessage(content=f"Simplified question: {contracted_question}"),
186
+ AIMessage(content=direct_instructions),
187
+ ]
188
+
189
+ def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
190
+ """
191
+ Prompt for self-critique.
192
+ """
193
+ critique_instructions = (
194
+ "Critique your own approach. Identify possible errors or leaps.\n"
195
+ "Return JSON:\n"
196
+ "{\n"
197
+ ' "thought": "...",\n'
198
+ ' "self_assessment": <float in [0,1]>\n'
199
+ "}\n"
200
+ )
201
+ return messages + [
202
+ AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
203
+ AIMessage(content=f"Your previous ANSWER:\n{answer}"),
204
+ AIMessage(content=critique_instructions),
205
+ ]
206
+
207
+ def ensemble_prompt(
208
+ self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
209
+ ) -> list[BaseMessage]:
210
+ """
211
+ Show multiple candidate solutions and pick the best final answer with confidence.
212
+ """
213
+ instructions = (
214
+ "You have multiple candidate solutions. Compare carefully and pick the best.\n"
215
+ "Return JSON:\n"
216
+ "{\n"
217
+ ' "thought": "why you chose this final answer",\n'
218
+ ' "answer": "the best consolidated answer",\n'
219
+ ' "confidence": 0.0 ~ 1.0\n'
220
+ "}\n"
221
+ )
222
+ reasonings: list[BaseMessage] = []
223
+ for idx, (thought, ans) in enumerate(possible_thought_and_answers):
224
+ reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
225
+ return messages + reasonings + [AIMessage(content=instructions)]
226
+
227
+ def devils_advocate_prompt(
228
+ self, messages: list[BaseMessage], question: str, existing_answer: str
229
+ ) -> list[BaseMessage]:
230
+ """
231
+ Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
232
+ """
233
+ instructions = (
234
+ "Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
235
+ "Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
236
+ "Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
237
+ "But here, let's keep it in a new dedicated structure:\n"
238
+ "{\n"
239
+ ' "thought": "...",\n'
240
+ ' "final_answer": "...",\n'
241
+ ' "sub_questions": [\n'
242
+ ' {"question": "...", "answer": null, "depend": []},\n'
243
+ " ...\n"
244
+ " ]\n"
245
+ "}\n"
246
+ )
247
+ return messages + [
248
+ AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
249
+ AIMessage(content=instructions),
250
+ ]
251
+
252
+
253
+ # ---------------------------------------------------------------------------------
254
+ # 2) Strict Typed Steps for Pipeline
255
+ # ---------------------------------------------------------------------------------
256
+
257
+
258
+ class StepName(StrEnum):
259
+ """Enum for step names in the pipeline."""
260
+
261
+ DOMAIN_DETECTION = "DomainDetection"
262
+ DECOMPOSITION = "Decomposition"
263
+ DECOMPOSITION_CRITIQUE = "DecompositionCritique"
264
+ CONTRACTED_QUESTION = "ContractedQuestion"
265
+ CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
266
+ CONTRACT_CRITIQUE = "ContractCritique"
267
+ BEST_APPROACH_DECISION = "BestApproachDecision"
268
+ ENSEMBLE = "Ensemble"
269
+ FINAL_ANSWER = "FinalAnswer"
270
+
271
+ DEVILS_ADVOCATE = "DevilsAdvocate"
272
+ DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
273
+
274
+
275
+ class StepRelation(StrEnum):
276
+ """Enum for relationship types in the reasoning graph."""
277
+
278
+ CRITIQUES = "CRITIQUES"
279
+ SELECTS = "SELECTS"
280
+ RESULT_OF = "RESULT_OF"
281
+ SPLIT_INTO = "SPLIT_INTO"
282
+ DEPEND_ON = "DEPEND_ON"
283
+ PRECEDES = "PRECEDES"
284
+ DECOMPOSED_BY = "DECOMPOSED_BY"
285
+
286
+
287
+ class StepRecord(BaseModel):
288
+ """A typed record for each pipeline step."""
289
+
290
+ step_name: StepName
291
+ domain: Optional[str] = None
292
+ score: Optional[float] = None
293
+ used: Optional[StepName] = None
294
+ sub_questions: Optional[list[SubQuestionNode]] = None
295
+ parent_decomp_step_idx: Optional[int] = None
296
+ parent_subq_idx: Optional[int] = None
297
+ question: Optional[str] = None
298
+ thought: Optional[str] = None
299
+ answer: Optional[str] = None
300
+
301
+ def as_properties(self) -> dict[str, str | float | int | None]:
302
+ """Converts the StepRecord to a dictionary of properties."""
303
+ result: dict[str, str | float | int | None] = {}
304
+ if self.score is not None:
305
+ result["score"] = self.score
306
+ if self.domain:
307
+ result["domain"] = self.domain
308
+ if self.question:
309
+ result["question"] = self.question
310
+ if self.thought:
311
+ result["thought"] = self.thought
312
+ if self.answer:
313
+ result["answer"] = self.answer
314
+ return result
315
+
316
+
317
+ # ---------------------------------------------------------------------------------
318
+ # 3) Logging Setup
319
+ # ---------------------------------------------------------------------------------
320
+
321
+
322
+ class SimpleColorFormatter(logging.Formatter):
323
+ """Simple color-coded logging formatter for console output using ANSI escape codes."""
324
+
325
+ BLUE = "\033[94m"
326
+ GREEN = "\033[92m"
327
+ YELLOW = "\033[93m"
328
+ RED = "\033[91m"
329
+ RESET = "\033[0m"
330
+ LEVEL_COLORS = {
331
+ logging.DEBUG: BLUE,
332
+ logging.INFO: GREEN,
333
+ logging.WARNING: YELLOW,
334
+ logging.ERROR: RED,
335
+ logging.CRITICAL: RED,
336
+ }
337
+
338
+ def format(self, record: logging.LogRecord) -> str:
339
+ log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
340
+ message = super().format(record)
341
+ return f"{log_color}{message}{self.RESET}"
342
+
343
+
344
+ logger = logging.getLogger("AoT")
345
+ logger.setLevel(logging.INFO)
346
+ handler = logging.StreamHandler()
347
+ handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
348
+ logger.handlers = [handler]
349
+ logger.propagate = False
350
+
351
+
352
+ # ---------------------------------------------------------------------------------
353
+ # 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
354
+ # ---------------------------------------------------------------------------------
355
+
356
+ T = TypeVar(
357
+ "T",
358
+ bound=EnsembleResponse
359
+ | ContractQuestionResponse
360
+ | LabelResponse
361
+ | CritiqueResponse
362
+ | RecursiveDecomposeResponse
363
+ | DevilsAdvocateResponse,
364
+ )
365
+
366
+
367
+ @dataclass
368
+ class AoTPipeline:
369
+ """
370
+ The pipeline orchestrates:
371
+ 1) Recursive decomposition
372
+ 2) For each sub-question, it tries a main approach + a devil's advocate approach
373
+ 3) Merges sub-answers using an ensemble
374
+ 4) Contracts the question
375
+ 5) Possibly does a direct approach on the contracted question
376
+ 6) Ensembling the final answers
377
+ """
378
+
379
+ chatterer: Chatterer
380
+ max_depth: int = 2
381
+ max_retries: int = 2
382
+ steps_history: list[StepRecord] = field(default_factory=list[StepRecord])
383
+ prompter: AoTPrompter = field(default_factory=AoTPrompter)
384
+
385
+ # 4.1) Utility for calling the LLM with Pydantic parsing
386
+ async def _ainvoke_pydantic(
387
+ self,
388
+ messages: list[BaseMessage],
389
+ model_cls: Type[T],
390
+ fallback: str = "<None>",
391
+ ) -> T:
392
+ """
393
+ Attempts up to max_retries to parse the model_cls from LLM output as JSON.
394
+ """
395
+ for attempt in range(1, self.max_retries + 1):
396
+ try:
397
+ return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
398
+ except ValidationError as e:
399
+ logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
400
+ if attempt == self.max_retries:
401
+ # Return a fallback version
402
+ if issubclass(model_cls, EnsembleResponse):
403
+ return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
404
+ elif issubclass(model_cls, ContractQuestionResponse):
405
+ return model_cls(thought=fallback, question=fallback) # type: ignore
406
+ elif issubclass(model_cls, LabelResponse):
407
+ return model_cls(thought=fallback, sub_questions=[]) # type: ignore
408
+ elif issubclass(model_cls, CritiqueResponse):
409
+ return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
410
+ elif issubclass(model_cls, DevilsAdvocateResponse):
411
+ return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
412
+ else:
413
+ return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
414
+ # theoretically unreachable
415
+ raise RuntimeError("Unexpected error in _ainvoke_pydantic")
416
+
417
+ # 4.2) Helper method for self-critique
418
+ async def _ainvoke_critique(
419
+ self,
420
+ messages: list[BaseMessage],
421
+ thought: str,
422
+ answer: str,
423
+ ) -> CritiqueResponse:
424
+ """
425
+ Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
426
+ """
427
+ return await self._ainvoke_pydantic(
428
+ messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
429
+ model_cls=CritiqueResponse,
430
+ )
431
+
432
+ # 4.3) Helper method for devil's advocate approach
433
+ async def _ainvoke_devils_advocate(
434
+ self,
435
+ messages: list[BaseMessage],
436
+ question: str,
437
+ existing_answer: str,
438
+ ) -> DevilsAdvocateResponse:
439
+ """
440
+ Instructs the LLM to challenge an existing answer with a devil's advocate approach.
441
+ """
442
+ return await self._ainvoke_pydantic(
443
+ messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
444
+ model_cls=DevilsAdvocateResponse,
445
+ )
446
+
447
+ # 4.4) The main function that recursively decomposes a question and calls sub-steps
448
+ async def _arecursive_decompose_question(
449
+ self,
450
+ messages: list[BaseMessage],
451
+ question: str,
452
+ depth: int,
453
+ parent_decomp_step_idx: Optional[int] = None,
454
+ parent_subq_idx: Optional[int] = None,
455
+ ) -> RecursiveDecomposeResponse:
456
+ """
457
+ Recursively decompose the given question. For each sub-question:
458
+ 1) Recursively decompose that sub-question if we still have depth left
459
+ 2) After getting a main sub-answer, do a devil's advocate pass
460
+ 3) Combine main sub-answer + devil's advocate alternative via an ensemble
461
+ """
462
+ if depth < 0:
463
+ logger.info("Max depth reached, returning unknown.")
464
+ return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
465
+
466
+ # Step 1: Perform the decomposition
467
+ decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
468
+ messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
469
+ model_cls=RecursiveDecomposeResponse,
470
+ )
471
+
472
+ # Step 2: Label / refine sub-questions (dependencies, ordering)
473
+ if decompose_resp.sub_questions:
474
+ label_resp: LabelResponse = await self._ainvoke_pydantic(
475
+ messages=self.prompter.label_prompt(messages, question, decompose_resp),
476
+ model_cls=LabelResponse,
477
+ )
478
+ decompose_resp.sub_questions = label_resp.sub_questions
479
+
480
+ # Save a pipeline record for this decomposition step
481
+ current_decomp_step_idx = self._record_decomposition_step(
482
+ question=question,
483
+ final_answer=decompose_resp.final_answer,
484
+ sub_questions=decompose_resp.sub_questions,
485
+ parent_decomp_step_idx=parent_decomp_step_idx,
486
+ parent_subq_idx=parent_subq_idx,
487
+ )
488
+
489
+ # Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
490
+ if depth > 0 and decompose_resp.sub_questions:
491
+ solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
492
+ messages=messages,
493
+ sub_questions=decompose_resp.sub_questions,
494
+ depth=depth,
495
+ parent_decomp_step_idx=current_decomp_step_idx,
496
+ )
497
+ # Then we can refine the "final_answer" from those sub-answers
498
+ # or we do a secondary pass to refine the final answer
499
+ refined_prompt = self.prompter.recursive_decompose_prompt(
500
+ messages=messages,
501
+ question=question,
502
+ sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
503
+ )
504
+ refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
505
+ refined_prompt, RecursiveDecomposeResponse
506
+ )
507
+ decompose_resp.final_answer = refined_resp.final_answer
508
+ decompose_resp.sub_questions = solved_subs
509
+
510
+ # Update pipeline record
511
+ self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
512
+ self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
513
+
514
+ return decompose_resp
515
+
516
+ def _record_decomposition_step(
517
+ self,
518
+ question: str,
519
+ final_answer: str,
520
+ sub_questions: list[SubQuestionNode],
521
+ parent_decomp_step_idx: Optional[int],
522
+ parent_subq_idx: Optional[int],
523
+ ) -> int:
524
+ """
525
+ Save the decomposition step in steps_history, returning the index.
526
+ """
527
+ step_record = StepRecord(
528
+ step_name=StepName.DECOMPOSITION,
529
+ question=question,
530
+ answer=final_answer,
531
+ sub_questions=sub_questions,
532
+ parent_decomp_step_idx=parent_decomp_step_idx,
533
+ parent_subq_idx=parent_subq_idx,
534
+ )
535
+ self.steps_history.append(step_record)
536
+ return len(self.steps_history) - 1
537
+
538
+ async def _aresolve_sub_questions(
539
+ self,
540
+ messages: list[BaseMessage],
541
+ sub_questions: list[SubQuestionNode],
542
+ depth: int,
543
+ parent_decomp_step_idx: Optional[int],
544
+ ) -> list[SubQuestionNode]:
545
+ """
546
+ Resolve sub-questions in topological order.
547
+ For each sub-question:
548
+ 1) Recursively decompose (main approach).
549
+ 2) Acquire a devil's advocate alternative.
550
+ 3) Critique or ensemble if needed.
551
+ 4) Finalize sub-question answer.
552
+ """
553
+ n = len(sub_questions)
554
+ in_degree = [0] * n
555
+ graph: list[list[int]] = [[] for _ in range(n)]
556
+ for i, sq in enumerate(sub_questions):
557
+ for dep in sq.depend:
558
+ if 0 <= dep < n:
559
+ in_degree[i] += 1
560
+ graph[dep].append(i)
561
+
562
+ # Kahn's algorithm for topological order
563
+ queue = [i for i in range(n) if in_degree[i] == 0]
564
+ topo_order: list[int] = []
565
+
566
+ while queue:
567
+ node = queue.pop(0)
568
+ topo_order.append(node)
569
+ for nxt in graph[node]:
570
+ in_degree[nxt] -= 1
571
+ if in_degree[nxt] == 0:
572
+ queue.append(nxt)
573
+
574
+ # We'll store the resolved sub-questions
575
+ final_subs: dict[int, SubQuestionNode] = {}
576
+
577
+ async def _resolve_one_subq(idx: int):
578
+ sq = sub_questions[idx]
579
+ # 1) Main approach
580
+ main_resp = await self._arecursive_decompose_question(
581
+ messages=messages,
582
+ question=sq.question,
583
+ depth=depth - 1,
584
+ parent_decomp_step_idx=parent_decomp_step_idx,
585
+ parent_subq_idx=idx,
586
+ )
587
+
588
+ main_answer = main_resp.final_answer
589
+
590
+ # 2) Devil's Advocate approach
591
+ devils_resp = await self._ainvoke_devils_advocate(
592
+ messages=messages, question=sq.question, existing_answer=main_answer
593
+ )
594
+ # 3) Ensemble to combine main_answer + devils_alternative
595
+ ensemble_sub = await self._ainvoke_pydantic(
596
+ self.prompter.ensemble_prompt(
597
+ messages=messages,
598
+ possible_thought_and_answers=[
599
+ (main_resp.thought, main_answer),
600
+ (devils_resp.thought, devils_resp.final_answer),
601
+ ],
602
+ ),
603
+ EnsembleResponse,
604
+ )
605
+ sub_best_answer = ensemble_sub.answer
606
+
607
+ # Store final subq answer
608
+ sq.answer = sub_best_answer
609
+ final_subs[idx] = sq
610
+
611
+ # Record pipeline steps for devil's advocate
612
+ self.steps_history.append(
613
+ StepRecord(
614
+ step_name=StepName.DEVILS_ADVOCATE,
615
+ question=sq.question,
616
+ answer=devils_resp.final_answer,
617
+ thought=devils_resp.thought,
618
+ sub_questions=devils_resp.sub_questions,
619
+ )
620
+ )
621
+ # Possibly critique the devils advocate result
622
+ dev_adv_crit = await self._ainvoke_critique(
623
+ messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
624
+ )
625
+ self.steps_history.append(
626
+ StepRecord(
627
+ step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
628
+ thought=dev_adv_crit.thought,
629
+ score=dev_adv_crit.self_assessment,
630
+ )
631
+ )
632
+
633
+ # Solve sub-questions in topological order
634
+ tasks = [_resolve_one_subq(i) for i in topo_order]
635
+ await asyncio.gather(*tasks, return_exceptions=False)
636
+
637
+ return [final_subs[i] for i in range(n)]
638
+
639
+ # 4.5) The primary pipeline method
640
+ async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
641
+ """
642
+ Execute the pipeline:
643
+ 1) Decompose the main question (recursively).
644
+ 2) Self-critique.
645
+ 3) Provide a devil's advocate approach on the entire main result.
646
+ 4) Contract sub-answers (optional).
647
+ 5) Directly solve the contracted question.
648
+ 6) Self-critique again.
649
+ 7) Final ensemble across main vs devil's vs contracted direct answer.
650
+ 8) Return final answer.
651
+ """
652
+ self.steps_history.clear()
653
+
654
+ original_question: str = messages[-1].text()
655
+ # 1) Recursive decomposition
656
+ decomp_resp = await self._arecursive_decompose_question(
657
+ messages=messages,
658
+ question=original_question,
659
+ depth=self.max_depth,
660
+ )
661
+ logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
662
+
663
+ # 2) Self-critique of main decomposition
664
+ decomp_critique = await self._ainvoke_critique(
665
+ messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
666
+ )
667
+ self.steps_history.append(
668
+ StepRecord(
669
+ step_name=StepName.DECOMPOSITION_CRITIQUE,
670
+ thought=decomp_critique.thought,
671
+ score=decomp_critique.self_assessment,
672
+ )
673
+ )
674
+
675
+ # 3) Devil's advocate on the entire main answer
676
+ devils_on_main = await self._ainvoke_devils_advocate(
677
+ messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
678
+ )
679
+ self.steps_history.append(
680
+ StepRecord(
681
+ step_name=StepName.DEVILS_ADVOCATE,
682
+ question=original_question,
683
+ answer=devils_on_main.final_answer,
684
+ thought=devils_on_main.thought,
685
+ sub_questions=devils_on_main.sub_questions,
686
+ )
687
+ )
688
+ devils_crit_main = await self._ainvoke_critique(
689
+ messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
690
+ )
691
+ self.steps_history.append(
692
+ StepRecord(
693
+ step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
694
+ thought=devils_crit_main.thought,
695
+ score=devils_crit_main.self_assessment,
696
+ )
697
+ )
698
+
699
+ # 4) Contract sub-answers from main decomposition
700
+ top_decomp_record: Optional[StepRecord] = next(
701
+ (
702
+ s
703
+ for s in reversed(self.steps_history)
704
+ if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
705
+ ),
706
+ None,
707
+ )
708
+ if top_decomp_record and top_decomp_record.sub_questions:
709
+ sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
710
+ else:
711
+ sub_answers = []
712
+
713
+ contract_resp = await self._ainvoke_pydantic(
714
+ messages=self.prompter.contract_prompt(messages, sub_answers),
715
+ model_cls=ContractQuestionResponse,
716
+ )
717
+ contracted_question = contract_resp.question
718
+ self.steps_history.append(
719
+ StepRecord(
720
+ step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
721
+ )
722
+ )
723
+
724
+ # 5) Attempt direct approach on contracted question
725
+ contracted_direct = await self._ainvoke_pydantic(
726
+ messages=self.prompter.contract_direct_prompt(messages, contracted_question),
727
+ model_cls=RecursiveDecomposeResponse,
728
+ fallback="No Contracted Direct Answer",
729
+ )
730
+ self.steps_history.append(
731
+ StepRecord(
732
+ step_name=StepName.CONTRACTED_DIRECT_ANSWER,
733
+ answer=contracted_direct.final_answer,
734
+ thought=contracted_direct.thought,
735
+ )
736
+ )
737
+ logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
738
+
739
+ # 5.1) Critique the contracted direct approach
740
+ contract_critique = await self._ainvoke_critique(
741
+ messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
742
+ )
743
+ self.steps_history.append(
744
+ StepRecord(
745
+ step_name=StepName.CONTRACT_CRITIQUE,
746
+ thought=contract_critique.thought,
747
+ score=contract_critique.self_assessment,
748
+ )
749
+ )
750
+
751
+ # 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
752
+ ensemble_resp = await self._ainvoke_pydantic(
753
+ self.prompter.ensemble_prompt(
754
+ messages=messages,
755
+ possible_thought_and_answers=[
756
+ (decomp_resp.thought, decomp_resp.final_answer),
757
+ (devils_on_main.thought, devils_on_main.final_answer),
758
+ (contracted_direct.thought, contracted_direct.final_answer),
759
+ ],
760
+ ),
761
+ EnsembleResponse,
762
+ )
763
+ best_approach_answer = ensemble_resp.answer
764
+ approach_used = StepName.ENSEMBLE
765
+ self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
766
+ logger.info(f"[Best Approach Decision] => {approach_used}")
767
+
768
+ # 7) Final answer
769
+ self.steps_history.append(
770
+ StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
771
+ )
772
+ logger.info(f"[Final Answer] => {best_approach_answer}")
773
+
774
+ return best_approach_answer
775
+
776
+ def run_pipeline(self, messages: list[BaseMessage]) -> str:
777
+ """Synchronous wrapper around arun_pipeline."""
778
+ return asyncio.run(self.arun_pipeline(messages))
779
+
780
+ # ---------------------------------------------------------------------------------
781
+ # 4.6) Build or export a reasoning graph
782
+ # ---------------------------------------------------------------------------------
783
+
784
+ def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
785
+ """
786
+ Constructs a Graph object (from hypothetical `neo4j_extension`)
787
+ capturing the pipeline steps, including devil's advocate steps.
788
+ """
789
+ from neo4j_extension import Graph, Node, Relationship
790
+
791
+ g = Graph()
792
+ step_nodes: dict[int, Node] = {}
793
+ subq_nodes: dict[str, Node] = {}
794
+
795
+ # Step A: Create nodes for each pipeline step
796
+ for i, record in enumerate(self.steps_history):
797
+ # We'll skip nested Decomposition steps only if we want to flatten them.
798
+ # But let's keep them for clarity.
799
+ step_node = Node(
800
+ properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
801
+ )
802
+ g.add_node(step_node)
803
+ step_nodes[i] = step_node
804
+
805
+ # Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
806
+ all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
807
+ for i, record in enumerate(self.steps_history):
808
+ if record.sub_questions:
809
+ for sq_idx, sq in enumerate(record.sub_questions):
810
+ sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
811
+ all_sub_questions[sq_id] = (i, sq_idx, sq)
812
+
813
+ for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
814
+ n_subq = Node(
815
+ properties={
816
+ "question": sq.question,
817
+ "answer": sq.answer or "",
818
+ },
819
+ labels={"SubQuestion"},
820
+ globalId=sq_id,
821
+ )
822
+ g.add_node(n_subq)
823
+ subq_nodes[sq_id] = n_subq
824
+
825
+ # Step C: Add relationships. We do a simple approach:
826
+ # - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
827
+ for i, record in enumerate(self.steps_history):
828
+ if record.sub_questions:
829
+ start_node = step_nodes[i]
830
+ for sq_idx, sq in enumerate(record.sub_questions):
831
+ sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
832
+ end_node = subq_nodes[sq_id]
833
+ rel = Relationship(
834
+ properties={},
835
+ rel_type=StepRelation.SPLIT_INTO,
836
+ start_node=start_node,
837
+ end_node=end_node,
838
+ globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
839
+ )
840
+ g.add_relationship(rel)
841
+ # Also add sub-question dependencies
842
+ for dep in sq.depend:
843
+ # The same record i -> sub-question subq
844
+ if 0 <= dep < len(record.sub_questions):
845
+ dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
846
+ if dep_id in subq_nodes:
847
+ dep_node = subq_nodes[dep_id]
848
+ rel_dep = Relationship(
849
+ properties={},
850
+ rel_type=StepRelation.DEPEND_ON,
851
+ start_node=end_node,
852
+ end_node=dep_node,
853
+ globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
854
+ )
855
+ g.add_relationship(rel_dep)
856
+
857
+ # Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
858
+ for i in range(len(self.steps_history) - 1):
859
+ start_node = step_nodes[i]
860
+ end_node = step_nodes[i + 1]
861
+ rel = Relationship(
862
+ properties={},
863
+ rel_type=StepRelation.PRECEDES,
864
+ start_node=start_node,
865
+ end_node=end_node,
866
+ globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
867
+ )
868
+ g.add_relationship(rel)
869
+
870
+ # Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
871
+ # We'll do a simple pass:
872
+ # If step_name ends with CRITIQUE => it critiques the step before it
873
+ for i, record in enumerate(self.steps_history):
874
+ if "CRITIQUE" in record.step_name:
875
+ # Let it point to the preceding step
876
+ if i > 0:
877
+ start_node = step_nodes[i]
878
+ end_node = step_nodes[i - 1]
879
+ rel = Relationship(
880
+ properties={},
881
+ rel_type=StepRelation.CRITIQUES,
882
+ start_node=start_node,
883
+ end_node=end_node,
884
+ globalId=f"{global_id_prefix}_crit_{i}",
885
+ )
886
+ g.add_relationship(rel)
887
+
888
+ # If there's a BEST_APPROACH_DECISION step, link it to the step it uses
889
+ best_decision_idx = None
890
+ used_step_idx = None
891
+ for i, record in enumerate(self.steps_history):
892
+ if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
893
+ best_decision_idx = i
894
+ # find the step with that name
895
+ used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
896
+ if used_step_idx is not None:
897
+ rel = Relationship(
898
+ properties={},
899
+ rel_type=StepRelation.SELECTS,
900
+ start_node=step_nodes[i],
901
+ end_node=step_nodes[used_step_idx],
902
+ globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
903
+ )
904
+ g.add_relationship(rel)
905
+
906
+ # And link the final answer to the best approach
907
+ final_answer_idx = next(
908
+ (i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
909
+ )
910
+ if final_answer_idx is not None and best_decision_idx is not None:
911
+ rel = Relationship(
912
+ properties={},
913
+ rel_type=StepRelation.RESULT_OF,
914
+ start_node=step_nodes[final_answer_idx],
915
+ end_node=step_nodes[best_decision_idx],
916
+ globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
917
+ )
918
+ g.add_relationship(rel)
919
+
920
+ return g
921
+
922
+
923
+ # ---------------------------------------------------------------------------------
924
+ # 5) AoTStrategy class that uses the pipeline
925
+ # ---------------------------------------------------------------------------------
926
+
927
+
928
+ @dataclass
929
+ class AoTStrategy(BaseStrategy):
930
+ """
931
+ Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
932
+ """
933
+
934
+ pipeline: AoTPipeline
935
+
936
+ async def ainvoke(self, messages: LanguageModelInput) -> str:
937
+ """Asynchronously run the pipeline with the given messages."""
938
+ # Convert your custom input to list[BaseMessage] as needed:
939
+ msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
940
+ return await self.pipeline.arun_pipeline(msgs)
941
+
942
+ def invoke(self, messages: LanguageModelInput) -> str:
943
+ """Synchronously run the pipeline with the given messages."""
944
+ msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
945
+ return self.pipeline.run_pipeline(msgs)
946
+
947
+ def get_reasoning_graph(self):
948
+ """Return the AoT reasoning graph from the pipeline’s steps history."""
949
+ return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
950
+
951
+
952
+ # ---------------------------------------------------------------------------------
953
+ # Example usage (pseudo-code)
954
+ # ---------------------------------------------------------------------------------
955
+ if __name__ == "__main__":
956
+ from neo4j_extension import Neo4jConnection # or your actual DB connector
957
+
958
+ # You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
959
+ chatterer = Chatterer.openai() # pseudo-code
960
+ pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
961
+ strategy = AoTStrategy(pipeline=pipeline)
962
+
963
+ question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
964
+ answer = strategy.invoke(question)
965
+ print("Final Answer:", answer)
966
+
967
+ # Build the reasoning graph
968
+ graph = strategy.get_reasoning_graph()
969
+ print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
970
+
971
+ # Optionally store in Neo4j
972
+ with Neo4jConnection() as conn:
973
+ conn.clear_all()
974
+ conn.upsert_graph(graph)
975
+ print("Graph stored in Neo4j.")