chatterer 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ from .atom_of_thoughts import (
2
+ AoTPipeline,
3
+ AoTStrategy,
4
+ BaseAoTPrompter,
5
+ CodingAoTPrompter,
6
+ GeneralAoTPrompter,
7
+ PhilosophyAoTPrompter,
8
+ )
9
+ from .base import BaseStrategy
10
+
11
+ __all__ = [
12
+ "BaseStrategy",
13
+ "AoTPipeline",
14
+ "BaseAoTPrompter",
15
+ "AoTStrategy",
16
+ "GeneralAoTPrompter",
17
+ "CodingAoTPrompter",
18
+ "PhilosophyAoTPrompter",
19
+ ]
@@ -0,0 +1,594 @@
1
+ # original source: https://github.com/qixucen/atom
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass, field
9
+ from enum import StrEnum
10
+ from typing import LiteralString, Optional, Type, TypeVar
11
+
12
+ from pydantic import BaseModel, Field, ValidationError
13
+
14
+ from ..language_model import Chatterer, LanguageModelInput
15
+ from .base import BaseStrategy
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # ------------------------- 0) Enums and Basic Models -------------------------
20
+
21
+
22
+ class Domain(StrEnum):
23
+ """Defines the domain of a question for specialized handling."""
24
+
25
+ GENERAL = "general"
26
+ MATH = "math"
27
+ CODING = "coding"
28
+ PHILOSOPHY = "philosophy"
29
+ MULTIHOP = "multihop"
30
+
31
+
32
+ class SubQuestionNode(BaseModel):
33
+ """A single sub-question node in a decomposition tree."""
34
+
35
+ question: str = Field(description="A sub-question string that arises from decomposition.")
36
+ answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
37
+ depend: list[int] = Field(description="Indices of sub-questions that this node depends on.")
38
+
39
+
40
+ class RecursiveDecomposeResponse(BaseModel):
41
+ """The result of a recursive decomposition step."""
42
+
43
+ thought: str = Field(description="Reasoning about decomposition.")
44
+ final_answer: str = Field(description="Best answer to the main question.")
45
+ sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
46
+
47
+
48
+ class DirectResponse(BaseModel):
49
+ """A direct response to a question."""
50
+
51
+ thought: str = Field(description="Short reasoning.")
52
+ answer: str = Field(description="Direct final answer.")
53
+
54
+
55
+ class ContractQuestionResponse(BaseModel):
56
+ """The result of contracting (simplifying) a question."""
57
+
58
+ thought: str = Field(description="Reasoning on how the question was compressed.")
59
+ question: str = Field(description="New, simplified, self-contained question.")
60
+
61
+
62
+ class EnsembleResponse(BaseModel):
63
+ """The ensemble process result."""
64
+
65
+ thought: str = Field(description="Explanation for choosing the final answer.")
66
+ answer: str = Field(description="Best final answer after ensemble.")
67
+ confidence: float = Field(description="Confidence score in [0, 1].")
68
+
69
+ def model_post_init(self, __context) -> None:
70
+ # Clamp confidence to [0, 1]
71
+ self.confidence = max(0.0, min(1.0, self.confidence))
72
+
73
+
74
+ class LabelResponse(BaseModel):
75
+ """A response used to refine sub-question dependencies and structure."""
76
+
77
+ thought: str = Field(description="Explanation or reasoning about labeling.")
78
+ sub_questions: list[SubQuestionNode] = Field(
79
+ description="Refined list of sub-questions with corrected dependencies."
80
+ )
81
+ # Some tasks also keep the final answer, but we focus on sub-questions.
82
+
83
+
84
+ # --------------- 1) Prompter Classes with multi-hop context ---------------
85
+
86
+
87
+ class BaseAoTPrompter(ABC):
88
+ """
89
+ Abstract base prompter that defines the required prompt methods.
90
+ """
91
+
92
+ @abstractmethod
93
+ def recursive_decompose_prompt(
94
+ self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
95
+ ) -> str: ...
96
+
97
+ @abstractmethod
98
+ def direct_prompt(self, question: str, context: Optional[str] = None) -> str: ...
99
+
100
+ @abstractmethod
101
+ def label_prompt(
102
+ self, question: str, decompose_response: RecursiveDecomposeResponse, context: Optional[str] = None
103
+ ) -> str:
104
+ """
105
+ Prompt used to 're-check' the sub-questions for correctness and dependencies.
106
+ """
107
+
108
+ @abstractmethod
109
+ def contract_prompt(self, question: str, sub_answers: str, context: Optional[str] = None) -> str: ...
110
+
111
+ @abstractmethod
112
+ def ensemble_prompt(
113
+ self,
114
+ original_question: str,
115
+ direct_answer: str,
116
+ decompose_answer: str,
117
+ contracted_direct_answer: str,
118
+ context: Optional[str] = None,
119
+ ) -> str: ...
120
+
121
+
122
+ class GeneralAoTPrompter(BaseAoTPrompter):
123
+ """
124
+ Generic prompter for non-specialized or 'general' queries.
125
+ """
126
+
127
+ def recursive_decompose_prompt(
128
+ self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
129
+ ) -> str:
130
+ sub_ans_str = f"\nSub-question answers:\n{sub_answers}" if sub_answers else ""
131
+ context_str = f"\nCONTEXT:\n{context}" if context else ""
132
+ return (
133
+ "You are a highly analytical assistant skilled in breaking down complex problems.\n"
134
+ "Decompose the question into sub-questions recursively.\n\n"
135
+ "REQUIREMENTS:\n"
136
+ "1. Return valid JSON:\n"
137
+ " {\n"
138
+ ' "thought": "...",\n'
139
+ ' "final_answer": "...",\n'
140
+ ' "sub_questions": [{"question": "...", "answer": null, "depend": []}, ...]\n'
141
+ " }\n"
142
+ "2. 'thought': Provide detailed reasoning.\n"
143
+ "3. 'final_answer': Integrate sub-answers if any.\n"
144
+ "4. 'sub_questions': Key sub-questions with potential dependencies.\n\n"
145
+ f"QUESTION:\n{question}{sub_ans_str}{context_str}"
146
+ )
147
+
148
+ def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
149
+ context_str = f"\nCONTEXT:\n{context}" if context else ""
150
+ return (
151
+ "You are a concise and insightful assistant.\n"
152
+ "Provide a direct answer with a short reasoning.\n\n"
153
+ "REQUIREMENTS:\n"
154
+ "1. Return valid JSON:\n"
155
+ " {'thought': '...', 'answer': '...'}\n"
156
+ "2. 'thought': Offer a brief reasoning.\n"
157
+ "3. 'answer': Deliver a precise solution.\n\n"
158
+ f"QUESTION:\n{question}{context_str}"
159
+ )
160
+
161
+ def label_prompt(
162
+ self, question: str, decompose_response: RecursiveDecomposeResponse, context: Optional[str] = None
163
+ ) -> str:
164
+ context_str = f"\nCONTEXT:\n{context}" if context else ""
165
+ return (
166
+ "You have a set of sub-questions from a decomposition process.\n"
167
+ "We want to correct or refine the dependencies between sub-questions.\n\n"
168
+ "REQUIREMENTS:\n"
169
+ "1. Return valid JSON:\n"
170
+ " {\n"
171
+ ' "thought": "...",\n'
172
+ ' "sub_questions": [\n'
173
+ ' {"question":"...", "answer":"...", "depend":[...]},\n'
174
+ " ...\n"
175
+ " ]\n"
176
+ " }\n"
177
+ "2. 'thought': Provide reasoning about any changes.\n"
178
+ "3. 'sub_questions': Possibly updated sub-questions with correct 'depend' lists.\n\n"
179
+ f"ORIGINAL QUESTION:\n{question}\n"
180
+ f"CURRENT DECOMPOSITION:\n{decompose_response.model_dump_json(indent=2)}"
181
+ f"{context_str}"
182
+ )
183
+
184
+ def contract_prompt(self, question: str, sub_answers: str, context: Optional[str] = None) -> str:
185
+ context_str = f"\nCONTEXT:\n{context}" if context else ""
186
+ return (
187
+ "You are tasked with compressing or simplifying a complex question into a single self-contained one.\n\n"
188
+ "REQUIREMENTS:\n"
189
+ "1. Return valid JSON:\n"
190
+ " {'thought': '...', 'question': '...'}\n"
191
+ "2. 'thought': Explain your simplification.\n"
192
+ "3. 'question': The streamlined question.\n\n"
193
+ f"ORIGINAL QUESTION:\n{question}\n"
194
+ f"SUB-ANSWERS:\n{sub_answers}"
195
+ f"{context_str}"
196
+ )
197
+
198
+ def ensemble_prompt(
199
+ self,
200
+ original_question: str,
201
+ direct_answer: str,
202
+ decompose_answer: str,
203
+ contracted_direct_answer: str,
204
+ context: Optional[str] = None,
205
+ ) -> str:
206
+ context_str = f"\nCONTEXT:\n{context}" if context else ""
207
+ return (
208
+ "You are an expert at synthesizing multiple candidate answers.\n"
209
+ "Consider the following candidates:\n"
210
+ f"1) Direct: {direct_answer}\n"
211
+ f"2) Decomposition-based: {decompose_answer}\n"
212
+ f"3) Contracted Direct: {contracted_direct_answer}\n\n"
213
+ "REQUIREMENTS:\n"
214
+ "1. Return valid JSON:\n"
215
+ " {'thought': '...', 'answer': '...', 'confidence': <float in [0,1]>}\n"
216
+ "2. 'thought': Summarize how you decided.\n"
217
+ "3. 'answer': Final best answer.\n"
218
+ "4. 'confidence': Confidence score in [0,1].\n\n"
219
+ f"ORIGINAL QUESTION:\n{original_question}"
220
+ f"{context_str}"
221
+ )
222
+
223
+
224
+ class MathAoTPrompter(GeneralAoTPrompter):
225
+ """
226
+ Specialized prompter for math questions; includes domain-specific hints.
227
+ """
228
+
229
+ def recursive_decompose_prompt(
230
+ self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
231
+ ) -> str:
232
+ base = super().recursive_decompose_prompt(question, sub_answers, context)
233
+ return base + "\nFocus on mathematical rigor and step-by-step derivations.\n"
234
+
235
+ def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
236
+ base = super().direct_prompt(question, context)
237
+ return base + "\nEnsure mathematical correctness in your reasoning.\n"
238
+
239
+ # label_prompt, contract_prompt, ensemble_prompt can also be overridden similarly if needed.
240
+
241
+
242
+ class CodingAoTPrompter(GeneralAoTPrompter):
243
+ """
244
+ Specialized prompter for coding/algorithmic queries.
245
+ """
246
+
247
+ def recursive_decompose_prompt(
248
+ self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
249
+ ) -> str:
250
+ base = super().recursive_decompose_prompt(question, sub_answers, context)
251
+ return base + "\nBreak down into programming concepts or implementation steps.\n"
252
+
253
+
254
+ class PhilosophyAoTPrompter(GeneralAoTPrompter):
255
+ """
256
+ Specialized prompter for philosophical discussions.
257
+ """
258
+
259
+ def recursive_decompose_prompt(
260
+ self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
261
+ ) -> str:
262
+ base = super().recursive_decompose_prompt(question, sub_answers, context)
263
+ return base + "\nConsider key philosophical theories and arguments.\n"
264
+
265
+
266
+ class MultiHopAoTPrompter(GeneralAoTPrompter):
267
+ """
268
+ Specialized prompter for multi-hop Q&A with explicit context usage.
269
+ """
270
+
271
+ def recursive_decompose_prompt(
272
+ self, question: str, sub_answers: Optional[str] = None, context: Optional[str] = None
273
+ ) -> str:
274
+ base = super().recursive_decompose_prompt(question, sub_answers, context)
275
+ return (
276
+ base + "\nTreat this as a multi-hop question. Use the provided context carefully.\n"
277
+ "Extract partial evidence from the context for each sub-question.\n"
278
+ )
279
+
280
+ def direct_prompt(self, question: str, context: Optional[str] = None) -> str:
281
+ base = super().direct_prompt(question, context)
282
+ return base + "\nFor multi-hop, ensure each piece of reasoning uses only the relevant parts of the context.\n"
283
+
284
+
285
+ # ----------------- 2) The AoTPipeline class with label + score + multi-hop -----------------
286
+
287
+ T = TypeVar("T", bound=BaseModel)
288
+
289
+
290
+ @dataclass
291
+ class AoTPipeline:
292
+ """
293
+ Implements an Atom-of-Thought pipeline with:
294
+ 1) Domain detection
295
+ 2) Direct solution
296
+ 3) Recursive Decomposition
297
+ 4) Label step to refine sub-questions
298
+ 5) Contract question
299
+ 6) Contracted direct solution
300
+ 7) Ensemble
301
+ 8) (Optional) Score final answer if ground-truth is available
302
+ """
303
+
304
+ chatterer: Chatterer
305
+ max_depth: int = 2
306
+ max_retries: int = 2
307
+
308
+ prompter_map: dict[Domain, BaseAoTPrompter] = field(
309
+ default_factory=lambda: {
310
+ Domain.GENERAL: GeneralAoTPrompter(),
311
+ Domain.MATH: MathAoTPrompter(),
312
+ Domain.CODING: CodingAoTPrompter(),
313
+ Domain.PHILOSOPHY: PhilosophyAoTPrompter(),
314
+ Domain.MULTIHOP: MultiHopAoTPrompter(),
315
+ }
316
+ )
317
+
318
+ async def _ainvoke_pydantic(self, prompt: str, model_cls: Type[T], default_answer: str = "Unable to process") -> T:
319
+ """
320
+ Attempts up to max_retries to parse the model_cls from LLM output.
321
+ """
322
+ for attempt in range(1, self.max_retries + 1):
323
+ try:
324
+ result = await self.chatterer.agenerate_pydantic(
325
+ response_model=model_cls, messages=[{"role": "user", "content": prompt}]
326
+ )
327
+ logger.debug(f"[Validated attempt {attempt}] => {result.model_dump_json(indent=2)}")
328
+ return result
329
+ except ValidationError as e:
330
+ logger.warning(f"Validation error on attempt {attempt}/{self.max_retries}: {str(e)}")
331
+ if attempt == self.max_retries:
332
+ # Return an empty or fallback
333
+ if model_cls == EnsembleResponse:
334
+ return model_cls(thought="Failed label parse", answer=default_answer, confidence=0.0)
335
+ elif model_cls == ContractQuestionResponse:
336
+ return model_cls(thought="Failed contract parse", question="Unknown")
337
+ elif model_cls == LabelResponse:
338
+ return model_cls(thought="Failed label parse", sub_questions=[])
339
+ else:
340
+ # fallback for Direct/Decompose
341
+ return model_cls(thought="Failed parse", answer=default_answer)
342
+ raise RuntimeError("Unexpected error in _ainvoke_pydantic")
343
+
344
+ async def _adetect_domain(self, question: str, context: Optional[str] = None) -> Domain:
345
+ """
346
+ Queries an LLM to figure out which domain is best suited.
347
+ """
348
+
349
+ class InferredDomain(BaseModel):
350
+ domain: Domain
351
+
352
+ ctx_str = f"\nCONTEXT:\n{context}" if context else ""
353
+ domain_prompt = (
354
+ "You are an expert domain classifier. "
355
+ "Possible domains: [general, math, coding, philosophy, multihop].\n\n"
356
+ "Return valid JSON: {'domain': '...'}.\n"
357
+ f"QUESTION:\n{question}{ctx_str}"
358
+ )
359
+ try:
360
+ result: InferredDomain = await self.chatterer.agenerate_pydantic(
361
+ response_model=InferredDomain, messages=[{"role": "user", "content": domain_prompt}]
362
+ )
363
+ return result.domain
364
+ except ValidationError:
365
+ logger.warning("Failed domain detection, defaulting to general.")
366
+ return Domain.GENERAL
367
+
368
+ async def _arecursive_decompose_question(
369
+ self, question: str, depth: int, prompter: BaseAoTPrompter, context: Optional[str] = None
370
+ ) -> RecursiveDecomposeResponse:
371
+ """
372
+ Recursively decomposes a question into sub-questions, applying an optional label step.
373
+ """
374
+ if depth < 0:
375
+ return RecursiveDecomposeResponse(thought="Max depth reached", final_answer="Unknown", sub_questions=[])
376
+
377
+ indent: LiteralString = " " * (self.max_depth - depth)
378
+ logger.debug(f"{indent}Decomposing at depth {self.max_depth - depth}: {question}")
379
+
380
+ # Step 1: Base decomposition
381
+ prompt = prompter.recursive_decompose_prompt(question, context=context)
382
+ decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(prompt, RecursiveDecomposeResponse)
383
+
384
+ # Step 2: Label step to refine sub-questions if any
385
+ if decompose_resp.sub_questions:
386
+ label_prompt: str = prompter.label_prompt(question, decompose_resp, context=context)
387
+ label_resp: LabelResponse = await self._ainvoke_pydantic(label_prompt, LabelResponse)
388
+ # Overwrite the sub-questions with refined ones
389
+ decompose_resp.sub_questions = label_resp.sub_questions
390
+
391
+ # Step 3: If depth > 0, try to resolve sub-questions recursively
392
+ # so we can potentially update final_answer
393
+ if depth > 0 and decompose_resp.sub_questions:
394
+ resolved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
395
+ decompose_resp.sub_questions, depth, prompter, context
396
+ )
397
+ # Step 4: Re-invoke decomposition with known sub-answers (sub_answers)
398
+ sub_answers_str: str = "\n".join(f"{sq.question}: {sq.answer}" for sq in resolved_subs if sq.answer)
399
+ if sub_answers_str:
400
+ refine_prompt: str = prompter.recursive_decompose_prompt(question, sub_answers_str, context=context)
401
+ refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
402
+ refine_prompt, RecursiveDecomposeResponse
403
+ )
404
+ # Use the refined final answer, keep resolved sub-questions
405
+ decompose_resp.final_answer = refined_resp.final_answer
406
+ decompose_resp.sub_questions = resolved_subs
407
+
408
+ return decompose_resp
409
+
410
+ async def _aresolve_sub_questions(
411
+ self, sub_questions: list[SubQuestionNode], depth: int, prompter: BaseAoTPrompter, context: Optional[str] = None
412
+ ) -> list[SubQuestionNode]:
413
+ """
414
+ Resolves each sub-question (potentially reusing the same decomposition approach)
415
+ in a topological order of dependencies.
416
+ """
417
+ n: int = len(sub_questions)
418
+ resolved: dict[int, SubQuestionNode] = {}
419
+
420
+ # Build adjacency
421
+ in_degree: list[int] = [0] * n
422
+ graph: list[list[int]] = [[] for _ in range(n)]
423
+ for i, sq in enumerate(sub_questions):
424
+ for dep in sq.depend:
425
+ if 0 <= dep < n:
426
+ in_degree[i] += 1
427
+ graph[dep].append(i)
428
+
429
+ # Topological BFS
430
+ queue: list[int] = [i for i in range(n) if in_degree[i] == 0]
431
+ order: list[int] = []
432
+ while queue:
433
+ node = queue.pop(0)
434
+ order.append(node)
435
+ for nxt in graph[node]:
436
+ in_degree[nxt] -= 1
437
+ if in_degree[nxt] == 0:
438
+ queue.append(nxt)
439
+
440
+ # If there's a cycle, some sub-questions won't appear in order
441
+ # but we'll attempt to resolve what we can
442
+ async def resolve_single_subq(idx: int) -> None:
443
+ sq: SubQuestionNode = sub_questions[idx]
444
+ # Attempt to answer this sub-question by decomposition if needed
445
+ sub_decomp: RecursiveDecomposeResponse = await self._arecursive_decompose_question(
446
+ sq.question, depth - 1, prompter, context
447
+ )
448
+ sq.answer = sub_decomp.final_answer
449
+ resolved[idx] = sq
450
+
451
+ await asyncio.gather(*(resolve_single_subq(i) for i in order))
452
+
453
+ # Return only resolved sub-questions
454
+ return [resolved[i] for i in range(n) if i in resolved]
455
+
456
+ def _calculate_score(self, answer: str, ground_truth: Optional[str], domain: Domain) -> float:
457
+ """
458
+ Example scoring function. Real usage depends on having ground-truth.
459
+ If ground_truth is None, returns -1.0 as 'no score possible'.
460
+
461
+ This function is used for `post-hoc` scoring of the final answer.
462
+ """
463
+ if ground_truth is None:
464
+ return -1.0
465
+
466
+ # Very simplistic example:
467
+ # MATH: attempt numeric equality
468
+ if domain == Domain.MATH:
469
+ try:
470
+ ans_val = float(answer.strip())
471
+ gt_val = float(ground_truth.strip())
472
+ return 1.0 if abs(ans_val - gt_val) < 1e-9 else 0.0
473
+ except ValueError:
474
+ return 0.0
475
+
476
+ # For anything else, do a naive exact-match check ignoring case
477
+ return 1.0 if answer.strip().lower() == ground_truth.strip().lower() else 0.0
478
+
479
+ async def arun_pipeline(
480
+ self, question: str, context: Optional[str] = None, ground_truth: Optional[str] = None
481
+ ) -> str:
482
+ """
483
+ Full AoT pipeline. If ground_truth is provided, we compute a final score.
484
+ """
485
+ # 1) Domain detection
486
+ domain = await self._adetect_domain(question, context)
487
+ prompter = self.prompter_map.get(domain, GeneralAoTPrompter())
488
+ logger.debug(f"Detected domain: {domain}")
489
+
490
+ # 2) Direct approach
491
+ direct_prompt = prompter.direct_prompt(question, context)
492
+ direct_resp = await self._ainvoke_pydantic(direct_prompt, DirectResponse)
493
+ direct_answer = direct_resp.answer
494
+ logger.debug(f"Direct answer => {direct_answer}")
495
+
496
+ # 3) Recursive Decomposition + label
497
+ decomp_resp = await self._arecursive_decompose_question(question, self.max_depth, prompter, context)
498
+ decompose_answer = decomp_resp.final_answer
499
+ logger.debug(f"Decomposition answer => {decompose_answer}")
500
+
501
+ # 4) Contract question
502
+ sub_answers_str = "\n".join(f"{sq.question}: {sq.answer}" for sq in decomp_resp.sub_questions if sq.answer)
503
+ contract_prompt = prompter.contract_prompt(question, sub_answers_str, context)
504
+ contract_resp = await self._ainvoke_pydantic(contract_prompt, ContractQuestionResponse)
505
+ contracted_question = contract_resp.question
506
+ logger.debug(f"Contracted question => {contracted_question}")
507
+
508
+ # 5) Direct approach on contracted question
509
+ contracted_direct_prompt = prompter.direct_prompt(contracted_question, context)
510
+ contracted_direct_resp = await self._ainvoke_pydantic(contracted_direct_prompt, DirectResponse)
511
+ contracted_direct_answer = contracted_direct_resp.answer
512
+ logger.debug(f"Contracted direct answer => {contracted_direct_answer}")
513
+
514
+ # 6) Ensemble
515
+ ensemble_prompt = prompter.ensemble_prompt(
516
+ original_question=question,
517
+ direct_answer=direct_answer,
518
+ decompose_answer=decompose_answer,
519
+ contracted_direct_answer=contracted_direct_answer,
520
+ context=context,
521
+ )
522
+ ensemble_resp = await self._ainvoke_pydantic(ensemble_prompt, EnsembleResponse)
523
+ final_answer = ensemble_resp.answer
524
+ logger.debug(f"Ensemble final answer => {final_answer} (confidence={ensemble_resp.confidence})")
525
+
526
+ # 7) Optional scoring
527
+ final_score = self._calculate_score(final_answer, ground_truth, domain)
528
+ if final_score >= 0.0:
529
+ logger.info(f"Final Score: {final_score:.3f} (domain={domain})")
530
+
531
+ return final_answer
532
+
533
+
534
+ # ------------------ 3) AoTStrategy that uses the pipeline ------------------
535
+
536
+
537
+ @dataclass(kw_only=True)
538
+ class AoTStrategy(BaseStrategy):
539
+ pipeline: AoTPipeline
540
+
541
+ async def ainvoke(self, messages: LanguageModelInput) -> str:
542
+ logger.debug(f"Invoking with messages: {messages}")
543
+
544
+ input = self.pipeline.chatterer.client._convert_input(messages)
545
+ input_string = input.to_string()
546
+ logger.debug(f"Extracted question: {input_string}")
547
+ return await self.pipeline.arun_pipeline(input_string)
548
+
549
+ def invoke(self, messages: LanguageModelInput) -> str:
550
+ return asyncio.run(self.ainvoke(messages))
551
+
552
+
553
+ # ------------------ 4) Example usage (main) ------------------
554
+ if __name__ == "__main__":
555
+ from warnings import filterwarnings
556
+
557
+ import colorama
558
+
559
+ filterwarnings("ignore", category=UserWarning)
560
+
561
+ colorama.init(autoreset=True)
562
+
563
+ class ColoredFormatter(logging.Formatter):
564
+ COLORS = {
565
+ "DEBUG": colorama.Fore.CYAN,
566
+ "INFO": colorama.Fore.GREEN,
567
+ "WARNING": colorama.Fore.YELLOW,
568
+ "ERROR": colorama.Fore.RED,
569
+ "CRITICAL": colorama.Fore.RED + colorama.Style.BRIGHT,
570
+ }
571
+
572
+ def format(self, record):
573
+ levelname = record.levelname
574
+ message = super().format(record)
575
+ return f"{self.COLORS.get(levelname, colorama.Fore.WHITE)}{message}{colorama.Style.RESET_ALL}"
576
+
577
+ logger.setLevel(logging.DEBUG)
578
+ handler = logging.StreamHandler()
579
+ handler.setLevel(logging.DEBUG)
580
+ formatter = ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
581
+ handler.setFormatter(formatter)
582
+ logger.addHandler(handler)
583
+
584
+ file_handler = logging.FileHandler("atom_of_thoughts.log", encoding="utf-8", mode="w")
585
+ file_handler.setLevel(logging.DEBUG)
586
+ logger.addHandler(file_handler)
587
+ logger.propagate = False
588
+
589
+ pipeline = AoTPipeline(chatterer=Chatterer.openai(), max_depth=2)
590
+ strategy = AoTStrategy(pipeline=pipeline)
591
+
592
+ question = "What would Newton discover if hit by an apple falling from 100 meters?"
593
+ answer = strategy.invoke(question)
594
+ print(answer)
@@ -0,0 +1,14 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from ..language_model import LanguageModelInput
4
+
5
+
6
+ class BaseStrategy(ABC):
7
+ @abstractmethod
8
+ def invoke(self, messages: LanguageModelInput) -> str:
9
+ """
10
+ Invoke the strategy with the given messages.
11
+
12
+ messages: List of messages to be passed to the strategy.
13
+ e.g. [{"role": "user", "content": "What is the meaning of life?"}]
14
+ """