chatterer 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterer/__init__.py +21 -23
- chatterer/language_model.py +590 -0
- chatterer/strategies/__init__.py +19 -0
- chatterer/strategies/atom_of_thoughts.py +594 -0
- chatterer/strategies/base.py +14 -0
- chatterer-0.1.3.dist-info/METADATA +150 -0
- chatterer-0.1.3.dist-info/RECORD +9 -0
- chatterer/llms/__init__.py +0 -20
- chatterer/llms/base.py +0 -42
- chatterer/llms/instructor.py +0 -127
- chatterer/llms/langchain.py +0 -49
- chatterer/llms/ollama.py +0 -69
- chatterer-0.1.2.dist-info/METADATA +0 -213
- chatterer-0.1.2.dist-info/RECORD +0 -10
- {chatterer-0.1.2.dist-info → chatterer-0.1.3.dist-info}/WHEEL +0 -0
- {chatterer-0.1.2.dist-info → chatterer-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,594 @@
|
|
1
|
+
# original source: https://github.com/qixucen/atom
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from dataclasses import dataclass, field
|
9
|
+
from enum import StrEnum
|
10
|
+
from typing import LiteralString, Optional, Type, TypeVar
|
11
|
+
|
12
|
+
from pydantic import BaseModel, Field, ValidationError
|
13
|
+
|
14
|
+
from ..language_model import Chatterer, LanguageModelInput
|
15
|
+
from .base import BaseStrategy
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
# ------------------------- 0) Enums and Basic Models -------------------------
|
20
|
+
|
21
|
+
|
22
|
+
class Domain(StrEnum):
|
23
|
+
"""Defines the domain of a question for specialized handling."""
|
24
|
+
|
25
|
+
GENERAL = "general"
|
26
|
+
MATH = "math"
|
27
|
+
CODING = "coding"
|
28
|
+
PHILOSOPHY = "philosophy"
|
29
|
+
MULTIHOP = "multihop"
|
30
|
+
|
31
|
+
|
32
|
+
class SubQuestionNode(BaseModel):
|
33
|
+
"""A single sub-question node in a decomposition tree."""
|
34
|
+
|
35
|
+
question: str = Field(description="A sub-question string that arises from decomposition.")
|
36
|
+
answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
|
37
|
+
depend: list[int] = Field(default_factory=list, description="Indices of sub-questions that this node depends on.")
|
38
|
+
|
39
|
+
|
40
|
+
class RecursiveDecomposeResponse(BaseModel):
|
41
|
+
"""The result of a recursive decomposition step."""
|
42
|
+
|
43
|
+
thought: str = Field(description="Reasoning about decomposition.")
|
44
|
+
final_answer: str = Field(description="Best answer to the main question.")
|
45
|
+
sub_questions: list[SubQuestionNode] = Field(default_factory=list, description="Root-level sub-questions.")
|
46
|
+
|
47
|
+
|
48
|
+
class DirectResponse(BaseModel):
|
49
|
+
"""A direct response to a question."""
|
50
|
+
|
51
|
+
thought: str = Field(description="Short reasoning.")
|
52
|
+
answer: str = Field(description="Direct final answer.")
|
53
|
+
|
54
|
+
|
55
|
+
class ContractQuestionResponse(BaseModel):
|
56
|
+
"""The result of contracting (simplifying) a question."""
|
57
|
+
|
58
|
+
thought: str = Field(description="Reasoning on how the question was compressed.")
|
59
|
+
question: str = Field(description="New, simplified, self-contained question.")
|
60
|
+
|
61
|
+
|
62
|
+
class EnsembleResponse(BaseModel):
|
63
|
+
"""The ensemble process result."""
|
64
|
+
|
65
|
+
thought: str = Field(description="Explanation for choosing the final answer.")
|
66
|
+
answer: str = Field(description="Best final answer after ensemble.")
|
67
|
+
confidence: float = Field(description="Confidence score in [0, 1].")
|
68
|
+
|
69
|
+
def model_post_init(self, __context) -> None:
|
70
|
+
# Clamp confidence to [0, 1]
|
71
|
+
self.confidence = max(0.0, min(1.0, self.confidence))
|
72
|
+
|
73
|
+
|
74
|
+
class LabelResponse(BaseModel):
|
75
|
+
"""A response used to refine sub-question dependencies and structure."""
|
76
|
+
|
77
|
+
thought: str = Field(description="Explanation or reasoning about labeling.")
|
78
|
+
sub_questions: list[SubQuestionNode] = Field(
|
79
|
+
default_factory=list, description="Refined list of sub-questions with corrected dependencies."
|
80
|
+
)
|
81
|
+
# Some tasks also keep the final answer, but we focus on sub-questions.
|
82
|
+
|
83
|
+
|
84
|
+
# --------------- 1) Prompter Classes with multi-hop context ---------------
|
85
|
+
|
86
|
+
|
87
|
+
class BaseAoTPrompter(ABC):
|
88
|
+
"""
|
89
|
+
Abstract base prompter that defines the required prompt methods.
|
90
|
+
"""
|
91
|
+
|
92
|
+
@abstractmethod
|
93
|
+
def recursive_decompose_prompt(
|
94
|
+
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
95
|
+
) -> str: ...
|
96
|
+
|
97
|
+
@abstractmethod
|
98
|
+
def direct_prompt(self, question: str, context: Optional[str] = None) -> str: ...
|
99
|
+
|
100
|
+
@abstractmethod
|
101
|
+
def label_prompt(
|
102
|
+
self, question: str, decompose_response: RecursiveDecomposeResponse, context: Optional[str] = None
|
103
|
+
) -> str:
|
104
|
+
"""
|
105
|
+
Prompt used to 're-check' the sub-questions for correctness and dependencies.
|
106
|
+
"""
|
107
|
+
|
108
|
+
@abstractmethod
|
109
|
+
def contract_prompt(self, question: str, sub_answers: str, context: Optional[str] = None) -> str: ...
|
110
|
+
|
111
|
+
@abstractmethod
|
112
|
+
def ensemble_prompt(
|
113
|
+
self,
|
114
|
+
original_question: str,
|
115
|
+
direct_answer: str,
|
116
|
+
decompose_answer: str,
|
117
|
+
contracted_direct_answer: str,
|
118
|
+
context: Optional[str] = None,
|
119
|
+
) -> str: ...
|
120
|
+
|
121
|
+
|
122
|
+
class GeneralAoTPrompter(BaseAoTPrompter):
|
123
|
+
"""
|
124
|
+
Generic prompter for non-specialized or 'general' queries.
|
125
|
+
"""
|
126
|
+
|
127
|
+
def recursive_decompose_prompt(
|
128
|
+
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
129
|
+
) -> str:
|
130
|
+
sub_ans_str = f"\nSub-question answers:\n{sub_answers}" if sub_answers else ""
|
131
|
+
context_str = f"\nCONTEXT:\n{context}" if context else ""
|
132
|
+
return (
|
133
|
+
"You are a highly analytical assistant skilled in breaking down complex problems.\n"
|
134
|
+
"Decompose the question into sub-questions recursively.\n\n"
|
135
|
+
"REQUIREMENTS:\n"
|
136
|
+
"1. Return valid JSON:\n"
|
137
|
+
" {\n"
|
138
|
+
' "thought": "...",\n'
|
139
|
+
' "final_answer": "...",\n'
|
140
|
+
' "sub_questions": [{"question": "...", "answer": null, "depend": []}, ...]\n'
|
141
|
+
" }\n"
|
142
|
+
"2. 'thought': Provide detailed reasoning.\n"
|
143
|
+
"3. 'final_answer': Integrate sub-answers if any.\n"
|
144
|
+
"4. 'sub_questions': Key sub-questions with potential dependencies.\n\n"
|
145
|
+
f"QUESTION:\n{question}{sub_ans_str}{context_str}"
|
146
|
+
)
|
147
|
+
|
148
|
+
def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
|
149
|
+
context_str = f"\nCONTEXT:\n{context}" if context else ""
|
150
|
+
return (
|
151
|
+
"You are a concise and insightful assistant.\n"
|
152
|
+
"Provide a direct answer with a short reasoning.\n\n"
|
153
|
+
"REQUIREMENTS:\n"
|
154
|
+
"1. Return valid JSON:\n"
|
155
|
+
" {'thought': '...', 'answer': '...'}\n"
|
156
|
+
"2. 'thought': Offer a brief reasoning.\n"
|
157
|
+
"3. 'answer': Deliver a precise solution.\n\n"
|
158
|
+
f"QUESTION:\n{question}{context_str}"
|
159
|
+
)
|
160
|
+
|
161
|
+
def label_prompt(
|
162
|
+
self, question: str, decompose_response: RecursiveDecomposeResponse, context: Optional[str] = None
|
163
|
+
) -> str:
|
164
|
+
context_str = f"\nCONTEXT:\n{context}" if context else ""
|
165
|
+
return (
|
166
|
+
"You have a set of sub-questions from a decomposition process.\n"
|
167
|
+
"We want to correct or refine the dependencies between sub-questions.\n\n"
|
168
|
+
"REQUIREMENTS:\n"
|
169
|
+
"1. Return valid JSON:\n"
|
170
|
+
" {\n"
|
171
|
+
' "thought": "...",\n'
|
172
|
+
' "sub_questions": [\n'
|
173
|
+
' {"question":"...", "answer":"...", "depend":[...]},\n'
|
174
|
+
" ...\n"
|
175
|
+
" ]\n"
|
176
|
+
" }\n"
|
177
|
+
"2. 'thought': Provide reasoning about any changes.\n"
|
178
|
+
"3. 'sub_questions': Possibly updated sub-questions with correct 'depend' lists.\n\n"
|
179
|
+
f"ORIGINAL QUESTION:\n{question}\n"
|
180
|
+
f"CURRENT DECOMPOSITION:\n{decompose_response.model_dump_json(indent=2)}"
|
181
|
+
f"{context_str}"
|
182
|
+
)
|
183
|
+
|
184
|
+
def contract_prompt(self, question: str, sub_answers: str, context: Optional[str] = None) -> str:
|
185
|
+
context_str = f"\nCONTEXT:\n{context}" if context else ""
|
186
|
+
return (
|
187
|
+
"You are tasked with compressing or simplifying a complex question into a single self-contained one.\n\n"
|
188
|
+
"REQUIREMENTS:\n"
|
189
|
+
"1. Return valid JSON:\n"
|
190
|
+
" {'thought': '...', 'question': '...'}\n"
|
191
|
+
"2. 'thought': Explain your simplification.\n"
|
192
|
+
"3. 'question': The streamlined question.\n\n"
|
193
|
+
f"ORIGINAL QUESTION:\n{question}\n"
|
194
|
+
f"SUB-ANSWERS:\n{sub_answers}"
|
195
|
+
f"{context_str}"
|
196
|
+
)
|
197
|
+
|
198
|
+
def ensemble_prompt(
|
199
|
+
self,
|
200
|
+
original_question: str,
|
201
|
+
direct_answer: str,
|
202
|
+
decompose_answer: str,
|
203
|
+
contracted_direct_answer: str,
|
204
|
+
context: Optional[str] = None,
|
205
|
+
) -> str:
|
206
|
+
context_str = f"\nCONTEXT:\n{context}" if context else ""
|
207
|
+
return (
|
208
|
+
"You are an expert at synthesizing multiple candidate answers.\n"
|
209
|
+
"Consider the following candidates:\n"
|
210
|
+
f"1) Direct: {direct_answer}\n"
|
211
|
+
f"2) Decomposition-based: {decompose_answer}\n"
|
212
|
+
f"3) Contracted Direct: {contracted_direct_answer}\n\n"
|
213
|
+
"REQUIREMENTS:\n"
|
214
|
+
"1. Return valid JSON:\n"
|
215
|
+
" {'thought': '...', 'answer': '...', 'confidence': <float in [0,1]>}\n"
|
216
|
+
"2. 'thought': Summarize how you decided.\n"
|
217
|
+
"3. 'answer': Final best answer.\n"
|
218
|
+
"4. 'confidence': Confidence score in [0,1].\n\n"
|
219
|
+
f"ORIGINAL QUESTION:\n{original_question}"
|
220
|
+
f"{context_str}"
|
221
|
+
)
|
222
|
+
|
223
|
+
|
224
|
+
class MathAoTPrompter(GeneralAoTPrompter):
|
225
|
+
"""
|
226
|
+
Specialized prompter for math questions; includes domain-specific hints.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def recursive_decompose_prompt(
|
230
|
+
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
231
|
+
) -> str:
|
232
|
+
base = super().recursive_decompose_prompt(question, sub_answers, context)
|
233
|
+
return base + "\nFocus on mathematical rigor and step-by-step derivations.\n"
|
234
|
+
|
235
|
+
def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
|
236
|
+
base = super().direct_prompt(question, context)
|
237
|
+
return base + "\nEnsure mathematical correctness in your reasoning.\n"
|
238
|
+
|
239
|
+
# label_prompt, contract_prompt, ensemble_prompt can also be overridden similarly if needed.
|
240
|
+
|
241
|
+
|
242
|
+
class CodingAoTPrompter(GeneralAoTPrompter):
|
243
|
+
"""
|
244
|
+
Specialized prompter for coding/algorithmic queries.
|
245
|
+
"""
|
246
|
+
|
247
|
+
def recursive_decompose_prompt(
|
248
|
+
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
249
|
+
) -> str:
|
250
|
+
base = super().recursive_decompose_prompt(question, sub_answers, context)
|
251
|
+
return base + "\nBreak down into programming concepts or implementation steps.\n"
|
252
|
+
|
253
|
+
|
254
|
+
class PhilosophyAoTPrompter(GeneralAoTPrompter):
|
255
|
+
"""
|
256
|
+
Specialized prompter for philosophical discussions.
|
257
|
+
"""
|
258
|
+
|
259
|
+
def recursive_decompose_prompt(
|
260
|
+
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
261
|
+
) -> str:
|
262
|
+
base = super().recursive_decompose_prompt(question, sub_answers, context)
|
263
|
+
return base + "\nConsider key philosophical theories and arguments.\n"
|
264
|
+
|
265
|
+
|
266
|
+
class MultiHopAoTPrompter(GeneralAoTPrompter):
|
267
|
+
"""
|
268
|
+
Specialized prompter for multi-hop Q&A with explicit context usage.
|
269
|
+
"""
|
270
|
+
|
271
|
+
def recursive_decompose_prompt(
|
272
|
+
self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
|
273
|
+
) -> str:
|
274
|
+
base = super().recursive_decompose_prompt(question, sub_answers, context)
|
275
|
+
return (
|
276
|
+
base + "\nTreat this as a multi-hop question. Use the provided context carefully.\n"
|
277
|
+
"Extract partial evidence from the context for each sub-question.\n"
|
278
|
+
)
|
279
|
+
|
280
|
+
def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
|
281
|
+
base = super().direct_prompt(question, context)
|
282
|
+
return base + "\nFor multi-hop, ensure each piece of reasoning uses only the relevant parts of the context.\n"
|
283
|
+
|
284
|
+
|
285
|
+
# ----------------- 2) The AoTPipeline class with label + score + multi-hop -----------------
|
286
|
+
|
287
|
+
T = TypeVar("T", bound=BaseModel)
|
288
|
+
|
289
|
+
|
290
|
+
@dataclass
|
291
|
+
class AoTPipeline:
|
292
|
+
"""
|
293
|
+
Implements an Atom-of-Thought pipeline with:
|
294
|
+
1) Domain detection
|
295
|
+
2) Direct solution
|
296
|
+
3) Recursive Decomposition
|
297
|
+
4) Label step to refine sub-questions
|
298
|
+
5) Contract question
|
299
|
+
6) Contracted direct solution
|
300
|
+
7) Ensemble
|
301
|
+
8) (Optional) Score final answer if ground-truth is available
|
302
|
+
"""
|
303
|
+
|
304
|
+
chatterer: Chatterer
|
305
|
+
max_depth: int = 2
|
306
|
+
max_retries: int = 2
|
307
|
+
|
308
|
+
prompter_map: dict[Domain, BaseAoTPrompter] = field(
|
309
|
+
default_factory=lambda: {
|
310
|
+
Domain.GENERAL: GeneralAoTPrompter(),
|
311
|
+
Domain.MATH: MathAoTPrompter(),
|
312
|
+
Domain.CODING: CodingAoTPrompter(),
|
313
|
+
Domain.PHILOSOPHY: PhilosophyAoTPrompter(),
|
314
|
+
Domain.MULTIHOP: MultiHopAoTPrompter(),
|
315
|
+
}
|
316
|
+
)
|
317
|
+
|
318
|
+
async def _ainvoke_pydantic(self, prompt: str, model_cls: Type[T], default_answer: str = "Unable to process") -> T:
|
319
|
+
"""
|
320
|
+
Attempts up to max_retries to parse the model_cls from LLM output.
|
321
|
+
"""
|
322
|
+
for attempt in range(1, self.max_retries + 1):
|
323
|
+
try:
|
324
|
+
result = await self.chatterer.agenerate_pydantic(
|
325
|
+
response_model=model_cls, messages=[{"role": "user", "content": prompt}]
|
326
|
+
)
|
327
|
+
logger.debug(f"[Validated attempt {attempt}] => {result.model_dump_json(indent=2)}")
|
328
|
+
return result
|
329
|
+
except ValidationError as e:
|
330
|
+
logger.warning(f"Validation error on attempt {attempt}/{self.max_retries}: {str(e)}")
|
331
|
+
if attempt == self.max_retries:
|
332
|
+
# Return an empty or fallback
|
333
|
+
if model_cls == EnsembleResponse:
|
334
|
+
return model_cls(thought="Failed label parse", answer=default_answer, confidence=0.0)
|
335
|
+
elif model_cls == ContractQuestionResponse:
|
336
|
+
return model_cls(thought="Failed contract parse", question="Unknown")
|
337
|
+
elif model_cls == LabelResponse:
|
338
|
+
return model_cls(thought="Failed label parse", sub_questions=[])
|
339
|
+
else:
|
340
|
+
# fallback for Direct/Decompose
|
341
|
+
return model_cls(thought="Failed parse", answer=default_answer)
|
342
|
+
raise RuntimeError("Unexpected error in _ainvoke_pydantic")
|
343
|
+
|
344
|
+
async def _adetect_domain(self, question: str, context: Optional[str] = None) -> Domain:
|
345
|
+
"""
|
346
|
+
Queries an LLM to figure out which domain is best suited.
|
347
|
+
"""
|
348
|
+
|
349
|
+
class InferredDomain(BaseModel):
|
350
|
+
domain: Domain
|
351
|
+
|
352
|
+
ctx_str = f"\nCONTEXT:\n{context}" if context else ""
|
353
|
+
domain_prompt = (
|
354
|
+
"You are an expert domain classifier. "
|
355
|
+
"Possible domains: [general, math, coding, philosophy, multihop].\n\n"
|
356
|
+
"Return valid JSON: {'domain': '...'}.\n"
|
357
|
+
f"QUESTION:\n{question}{ctx_str}"
|
358
|
+
)
|
359
|
+
try:
|
360
|
+
result: InferredDomain = await self.chatterer.agenerate_pydantic(
|
361
|
+
response_model=InferredDomain, messages=[{"role": "user", "content": domain_prompt}]
|
362
|
+
)
|
363
|
+
return result.domain
|
364
|
+
except ValidationError:
|
365
|
+
logger.warning("Failed domain detection, defaulting to general.")
|
366
|
+
return Domain.GENERAL
|
367
|
+
|
368
|
+
async def _arecursive_decompose_question(
|
369
|
+
self, question: str, depth: int, prompter: BaseAoTPrompter, context: Optional[str] = None
|
370
|
+
) -> RecursiveDecomposeResponse:
|
371
|
+
"""
|
372
|
+
Recursively decomposes a question into sub-questions, applying an optional label step.
|
373
|
+
"""
|
374
|
+
if depth < 0:
|
375
|
+
return RecursiveDecomposeResponse(thought="Max depth reached", final_answer="Unknown", sub_questions=[])
|
376
|
+
|
377
|
+
indent: LiteralString = " " * (self.max_depth - depth)
|
378
|
+
logger.debug(f"{indent}Decomposing at depth {self.max_depth - depth}: {question}")
|
379
|
+
|
380
|
+
# Step 1: Base decomposition
|
381
|
+
prompt = prompter.recursive_decompose_prompt(question, context=context)
|
382
|
+
decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(prompt, RecursiveDecomposeResponse)
|
383
|
+
|
384
|
+
# Step 2: Label step to refine sub-questions if any
|
385
|
+
if decompose_resp.sub_questions:
|
386
|
+
label_prompt: str = prompter.label_prompt(question, decompose_resp, context=context)
|
387
|
+
label_resp: LabelResponse = await self._ainvoke_pydantic(label_prompt, LabelResponse)
|
388
|
+
# Overwrite the sub-questions with refined ones
|
389
|
+
decompose_resp.sub_questions = label_resp.sub_questions
|
390
|
+
|
391
|
+
# Step 3: If depth > 0, try to resolve sub-questions recursively
|
392
|
+
# so we can potentially update final_answer
|
393
|
+
if depth > 0 and decompose_resp.sub_questions:
|
394
|
+
resolved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
|
395
|
+
decompose_resp.sub_questions, depth, prompter, context
|
396
|
+
)
|
397
|
+
# Step 4: Re-invoke decomposition with known sub-answers (sub_answers)
|
398
|
+
sub_answers_str: str = "\n".join(f"{sq.question}: {sq.answer}" for sq in resolved_subs if sq.answer)
|
399
|
+
if sub_answers_str:
|
400
|
+
refine_prompt: str = prompter.recursive_decompose_prompt(question, sub_answers_str, context=context)
|
401
|
+
refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
402
|
+
refine_prompt, RecursiveDecomposeResponse
|
403
|
+
)
|
404
|
+
# Use the refined final answer, keep resolved sub-questions
|
405
|
+
decompose_resp.final_answer = refined_resp.final_answer
|
406
|
+
decompose_resp.sub_questions = resolved_subs
|
407
|
+
|
408
|
+
return decompose_resp
|
409
|
+
|
410
|
+
async def _aresolve_sub_questions(
|
411
|
+
self, sub_questions: list[SubQuestionNode], depth: int, prompter: BaseAoTPrompter, context: Optional[str] = None
|
412
|
+
) -> list[SubQuestionNode]:
|
413
|
+
"""
|
414
|
+
Resolves each sub-question (potentially reusing the same decomposition approach)
|
415
|
+
in a topological order of dependencies.
|
416
|
+
"""
|
417
|
+
n: int = len(sub_questions)
|
418
|
+
resolved: dict[int, SubQuestionNode] = {}
|
419
|
+
|
420
|
+
# Build adjacency
|
421
|
+
in_degree: list[int] = [0] * n
|
422
|
+
graph: list[list[int]] = [[] for _ in range(n)]
|
423
|
+
for i, sq in enumerate(sub_questions):
|
424
|
+
for dep in sq.depend:
|
425
|
+
if 0 <= dep < n:
|
426
|
+
in_degree[i] += 1
|
427
|
+
graph[dep].append(i)
|
428
|
+
|
429
|
+
# Topological BFS
|
430
|
+
queue: list[int] = [i for i in range(n) if in_degree[i] == 0]
|
431
|
+
order: list[int] = []
|
432
|
+
while queue:
|
433
|
+
node = queue.pop(0)
|
434
|
+
order.append(node)
|
435
|
+
for nxt in graph[node]:
|
436
|
+
in_degree[nxt] -= 1
|
437
|
+
if in_degree[nxt] == 0:
|
438
|
+
queue.append(nxt)
|
439
|
+
|
440
|
+
# If there's a cycle, some sub-questions won't appear in order
|
441
|
+
# but we'll attempt to resolve what we can
|
442
|
+
async def resolve_single_subq(idx: int) -> None:
|
443
|
+
sq: SubQuestionNode = sub_questions[idx]
|
444
|
+
# Attempt to answer this sub-question by decomposition if needed
|
445
|
+
sub_decomp: RecursiveDecomposeResponse = await self._arecursive_decompose_question(
|
446
|
+
sq.question, depth - 1, prompter, context
|
447
|
+
)
|
448
|
+
sq.answer = sub_decomp.final_answer
|
449
|
+
resolved[idx] = sq
|
450
|
+
|
451
|
+
await asyncio.gather(*(resolve_single_subq(i) for i in order))
|
452
|
+
|
453
|
+
# Return only resolved sub-questions
|
454
|
+
return [resolved[i] for i in range(n) if i in resolved]
|
455
|
+
|
456
|
+
def _calculate_score(self, answer: str, ground_truth: Optional[str], domain: Domain) -> float:
|
457
|
+
"""
|
458
|
+
Example scoring function. Real usage depends on having ground-truth.
|
459
|
+
If ground_truth is None, returns -1.0 as 'no score possible'.
|
460
|
+
|
461
|
+
This function is used for `post-hoc` scoring of the final answer.
|
462
|
+
"""
|
463
|
+
if ground_truth is None:
|
464
|
+
return -1.0
|
465
|
+
|
466
|
+
# Very simplistic example:
|
467
|
+
# MATH: attempt numeric equality
|
468
|
+
if domain == Domain.MATH:
|
469
|
+
try:
|
470
|
+
ans_val = float(answer.strip())
|
471
|
+
gt_val = float(ground_truth.strip())
|
472
|
+
return 1.0 if abs(ans_val - gt_val) < 1e-9 else 0.0
|
473
|
+
except ValueError:
|
474
|
+
return 0.0
|
475
|
+
|
476
|
+
# For anything else, do a naive exact-match check ignoring case
|
477
|
+
return 1.0 if answer.strip().lower() == ground_truth.strip().lower() else 0.0
|
478
|
+
|
479
|
+
async def arun_pipeline(
|
480
|
+
self, question: str, context: Optional[str] = None, ground_truth: Optional[str] = None
|
481
|
+
) -> str:
|
482
|
+
"""
|
483
|
+
Full AoT pipeline. If ground_truth is provided, we compute a final score.
|
484
|
+
"""
|
485
|
+
# 1) Domain detection
|
486
|
+
domain = await self._adetect_domain(question, context)
|
487
|
+
prompter = self.prompter_map.get(domain, GeneralAoTPrompter())
|
488
|
+
logger.debug(f"Detected domain: {domain}")
|
489
|
+
|
490
|
+
# 2) Direct approach
|
491
|
+
direct_prompt = prompter.direct_prompt(question, context)
|
492
|
+
direct_resp = await self._ainvoke_pydantic(direct_prompt, DirectResponse)
|
493
|
+
direct_answer = direct_resp.answer
|
494
|
+
logger.debug(f"Direct answer => {direct_answer}")
|
495
|
+
|
496
|
+
# 3) Recursive Decomposition + label
|
497
|
+
decomp_resp = await self._arecursive_decompose_question(question, self.max_depth, prompter, context)
|
498
|
+
decompose_answer = decomp_resp.final_answer
|
499
|
+
logger.debug(f"Decomposition answer => {decompose_answer}")
|
500
|
+
|
501
|
+
# 4) Contract question
|
502
|
+
sub_answers_str = "\n".join(f"{sq.question}: {sq.answer}" for sq in decomp_resp.sub_questions if sq.answer)
|
503
|
+
contract_prompt = prompter.contract_prompt(question, sub_answers_str, context)
|
504
|
+
contract_resp = await self._ainvoke_pydantic(contract_prompt, ContractQuestionResponse)
|
505
|
+
contracted_question = contract_resp.question
|
506
|
+
logger.debug(f"Contracted question => {contracted_question}")
|
507
|
+
|
508
|
+
# 5) Direct approach on contracted question
|
509
|
+
contracted_direct_prompt = prompter.direct_prompt(contracted_question, context)
|
510
|
+
contracted_direct_resp = await self._ainvoke_pydantic(contracted_direct_prompt, DirectResponse)
|
511
|
+
contracted_direct_answer = contracted_direct_resp.answer
|
512
|
+
logger.debug(f"Contracted direct answer => {contracted_direct_answer}")
|
513
|
+
|
514
|
+
# 6) Ensemble
|
515
|
+
ensemble_prompt = prompter.ensemble_prompt(
|
516
|
+
original_question=question,
|
517
|
+
direct_answer=direct_answer,
|
518
|
+
decompose_answer=decompose_answer,
|
519
|
+
contracted_direct_answer=contracted_direct_answer,
|
520
|
+
context=context,
|
521
|
+
)
|
522
|
+
ensemble_resp = await self._ainvoke_pydantic(ensemble_prompt, EnsembleResponse)
|
523
|
+
final_answer = ensemble_resp.answer
|
524
|
+
logger.debug(f"Ensemble final answer => {final_answer} (confidence={ensemble_resp.confidence})")
|
525
|
+
|
526
|
+
# 7) Optional scoring
|
527
|
+
final_score = self._calculate_score(final_answer, ground_truth, domain)
|
528
|
+
if final_score >= 0.0:
|
529
|
+
logger.info(f"Final Score: {final_score:.3f} (domain={domain})")
|
530
|
+
|
531
|
+
return final_answer
|
532
|
+
|
533
|
+
|
534
|
+
# ------------------ 3) AoTStrategy that uses the pipeline ------------------
|
535
|
+
|
536
|
+
|
537
|
+
@dataclass(kw_only=True)
|
538
|
+
class AoTStrategy(BaseStrategy):
|
539
|
+
pipeline: AoTPipeline
|
540
|
+
|
541
|
+
async def ainvoke(self, messages: LanguageModelInput) -> str:
|
542
|
+
logger.debug(f"Invoking with messages: {messages}")
|
543
|
+
|
544
|
+
input = self.pipeline.chatterer.client._convert_input(messages)
|
545
|
+
input_string = input.to_string()
|
546
|
+
logger.debug(f"Extracted question: {input_string}")
|
547
|
+
return await self.pipeline.arun_pipeline(input_string)
|
548
|
+
|
549
|
+
def invoke(self, messages: LanguageModelInput) -> str:
|
550
|
+
return asyncio.run(self.ainvoke(messages))
|
551
|
+
|
552
|
+
|
553
|
+
# ------------------ 4) Example usage (main) ------------------
|
554
|
+
if __name__ == "__main__":
|
555
|
+
from warnings import filterwarnings
|
556
|
+
|
557
|
+
import colorama
|
558
|
+
|
559
|
+
filterwarnings("ignore", category=UserWarning)
|
560
|
+
|
561
|
+
colorama.init(autoreset=True)
|
562
|
+
|
563
|
+
class ColoredFormatter(logging.Formatter):
|
564
|
+
COLORS = {
|
565
|
+
"DEBUG": colorama.Fore.CYAN,
|
566
|
+
"INFO": colorama.Fore.GREEN,
|
567
|
+
"WARNING": colorama.Fore.YELLOW,
|
568
|
+
"ERROR": colorama.Fore.RED,
|
569
|
+
"CRITICAL": colorama.Fore.RED + colorama.Style.BRIGHT,
|
570
|
+
}
|
571
|
+
|
572
|
+
def format(self, record):
|
573
|
+
levelname = record.levelname
|
574
|
+
message = super().format(record)
|
575
|
+
return f"{self.COLORS.get(levelname, colorama.Fore.WHITE)}{message}{colorama.Style.RESET_ALL}"
|
576
|
+
|
577
|
+
logger.setLevel(logging.DEBUG)
|
578
|
+
handler = logging.StreamHandler()
|
579
|
+
handler.setLevel(logging.DEBUG)
|
580
|
+
formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
581
|
+
handler.setFormatter(formatter)
|
582
|
+
logger.addHandler(handler)
|
583
|
+
|
584
|
+
file_handler = logging.FileHandler("atom_of_thoughts.log", encoding="utf-8", mode="w")
|
585
|
+
file_handler.setLevel(logging.DEBUG)
|
586
|
+
logger.addHandler(file_handler)
|
587
|
+
logger.propagate = False
|
588
|
+
|
589
|
+
pipeline = AoTPipeline(chatterer=Chatterer.openai(), max_depth=2)
|
590
|
+
strategy = AoTStrategy(pipeline=pipeline)
|
591
|
+
|
592
|
+
question = "What would Newton discover if hit by an apple falling from 100 meters?"
|
593
|
+
answer = strategy.invoke(question)
|
594
|
+
print(answer)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
from ..language_model import LanguageModelInput
|
4
|
+
|
5
|
+
|
6
|
+
class BaseStrategy(ABC):
|
7
|
+
@abstractmethod
|
8
|
+
def invoke(self, messages: LanguageModelInput) -> str:
|
9
|
+
"""
|
10
|
+
Invoke the strategy with the given messages.
|
11
|
+
|
12
|
+
messages: List of messages to be passed to the strategy.
|
13
|
+
e.g. [{"role": "user", "content": "What is the meaning of life?"}]
|
14
|
+
"""
|