chatterer 0.1.25__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chatterer/__init__.py CHANGED
@@ -12,12 +12,6 @@ from .messages import (
12
12
  SystemMessage,
13
13
  UsageMetadata,
14
14
  )
15
- from .strategies import (
16
- AoTPipeline,
17
- AoTPrompter,
18
- AoTStrategy,
19
- BaseStrategy,
20
- )
21
15
  from .tools import (
22
16
  CodeSnippets,
23
17
  MarkdownLink,
@@ -53,11 +47,7 @@ from .utils import (
53
47
  load_dotenv()
54
48
 
55
49
  __all__ = [
56
- "BaseStrategy",
57
50
  "Chatterer",
58
- "AoTStrategy",
59
- "AoTPipeline",
60
- "AoTPrompter",
61
51
  "html_to_markdown",
62
52
  "anything_to_markdown",
63
53
  "pdf_to_text",
@@ -27,7 +27,7 @@ from .messages import AIMessage, BaseMessage, HumanMessage, UsageMetadata
27
27
  from .utils.code_agent import CodeExecutionResult, FunctionSignature, augment_prompt_for_toolcall
28
28
 
29
29
  if TYPE_CHECKING:
30
- from instructor import Partial
30
+ from instructor import Partial # pyright: ignore[reportMissingTypeStubs]
31
31
  from langchain_experimental.tools.python.tool import PythonAstREPLTool
32
32
 
33
33
  PydanticModelT = TypeVar("PydanticModelT", bound=BaseModel)
@@ -339,7 +339,7 @@ class Chatterer(BaseModel):
339
339
  **kwargs: Any,
340
340
  ) -> Iterator[PydanticModelT]:
341
341
  try:
342
- import instructor
342
+ import instructor # pyright: ignore[reportMissingTypeStubs]
343
343
  except ImportError:
344
344
  raise ImportError("Please install `instructor` with `pip install instructor` to use this feature.")
345
345
 
@@ -360,7 +360,7 @@ class Chatterer(BaseModel):
360
360
  **kwargs: Any,
361
361
  ) -> AsyncIterator[PydanticModelT]:
362
362
  try:
363
- import instructor
363
+ import instructor # pyright: ignore[reportMissingTypeStubs]
364
364
  except ImportError:
365
365
  raise ImportError("Please install `instructor` with `pip install instructor` to use this feature.")
366
366
 
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  import re
4
2
  from base64 import b64encode
5
3
  from io import BytesIO
@@ -18,7 +16,6 @@ from typing import (
18
16
  TypeAlias,
19
17
  TypedDict,
20
18
  TypeGuard,
21
- cast,
22
19
  get_args,
23
20
  )
24
21
  from urllib.parse import urlparse
@@ -29,11 +26,16 @@ from PIL.Image import Resampling
29
26
  from PIL.Image import open as image_open
30
27
  from pydantic import BaseModel
31
28
 
29
+ from .imghdr import what
30
+
32
31
  if TYPE_CHECKING:
33
32
  from openai.types.chat.chat_completion_content_part_image_param import ChatCompletionContentPartImageParam
34
33
 
35
34
  logger = getLogger(__name__)
36
- ImageType: TypeAlias = Literal["jpeg", "jpg", "png", "gif", "webp", "bmp"]
35
+ ImageFormat: TypeAlias = Literal["jpeg", "png", "gif", "webp", "bmp"]
36
+ ExtendedImageFormat: TypeAlias = ImageFormat | Literal["jpg", "JPG"] | Literal["JPEG", "PNG", "GIF", "WEBP", "BMP"]
37
+
38
+ ALLOWED_IMAGE_FORMATS: tuple[ImageFormat, ...] = get_args(ImageFormat)
37
39
 
38
40
 
39
41
  class ImageProcessingConfig(TypedDict):
@@ -46,7 +48,7 @@ class ImageProcessingConfig(TypedDict):
46
48
  - resize_target_for_min_side: (int) 리스케일시, '가장 작은 변'을 이 값으로 줄임(비율 유지는 Lanczos).
47
49
  """
48
50
 
49
- formats: Sequence[ImageType]
51
+ formats: Sequence[ImageFormat]
50
52
  max_size_mb: NotRequired[float]
51
53
  min_largest_side: NotRequired[int]
52
54
  resize_if_min_side_exceeds: NotRequired[int]
@@ -59,16 +61,15 @@ def get_default_image_processing_config() -> ImageProcessingConfig:
59
61
  "min_largest_side": 200,
60
62
  "resize_if_min_side_exceeds": 2000,
61
63
  "resize_target_for_min_side": 1000,
62
- "formats": ["png", "jpeg", "jpg", "gif", "bmp", "webp"],
64
+ "formats": ["png", "jpeg", "gif", "bmp", "webp"],
63
65
  }
64
66
 
65
67
 
66
- # image_url: str, headers: dict[str, str]) -> Optional[bytes]:
67
68
  class Base64Image(BaseModel):
68
- ext: ImageType
69
+ ext: ImageFormat
69
70
  data: str
70
71
 
71
- IMAGE_TYPES: ClassVar[tuple[str, ...]] = tuple(map(str, get_args(ImageType)))
72
+ IMAGE_TYPES: ClassVar[tuple[str, ...]] = ALLOWED_IMAGE_FORMATS
72
73
  IMAGE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(
73
74
  r"data:image/(" + "|".join(IMAGE_TYPES) + r");base64,([A-Za-z0-9+/]+={0,2})"
74
75
  )
@@ -76,20 +77,66 @@ class Base64Image(BaseModel):
76
77
  def __hash__(self) -> int:
77
78
  return hash((self.ext, self.data))
78
79
 
79
- def model_post_init(self, __context: object) -> None:
80
- if self.ext == "jpg":
81
- self.ext = "jpeg"
80
+ @classmethod
81
+ def new(
82
+ cls,
83
+ url_or_path_or_bytes: str | bytes,
84
+ *,
85
+ headers: dict[str, str] = {},
86
+ config: ImageProcessingConfig = get_default_image_processing_config(),
87
+ img_bytes_fetcher: Optional[Callable[[str, dict[str, str]], bytes]] = None,
88
+ ) -> Self:
89
+ if isinstance(url_or_path_or_bytes, bytes):
90
+ ext = what(url_or_path_or_bytes)
91
+ if ext is None:
92
+ raise ValueError(f"Invalid image format: {url_or_path_or_bytes[:8]} ...")
93
+ if not cls._verify_ext(ext, config["formats"]):
94
+ raise ValueError(f"Invalid image format: {ext} not in {config['formats']}")
95
+ return cls.from_bytes(url_or_path_or_bytes, ext=ext)
96
+ elif maybe_base64 := cls.from_string(url_or_path_or_bytes):
97
+ return maybe_base64
98
+ elif maybe_url_or_path := cls.from_url_or_path(
99
+ url_or_path_or_bytes, headers=headers, config=config, img_bytes_fetcher=img_bytes_fetcher
100
+ ):
101
+ return maybe_url_or_path
102
+ else:
103
+ raise ValueError(f"Invalid image format: {url_or_path_or_bytes}")
104
+
105
+ @classmethod
106
+ async def anew(
107
+ cls,
108
+ url_or_path_or_bytes: str | bytes,
109
+ *,
110
+ headers: dict[str, str] = {},
111
+ config: ImageProcessingConfig = get_default_image_processing_config(),
112
+ img_bytes_fetcher: Optional[Callable[[str, dict[str, str]], Awaitable[bytes]]] = None,
113
+ ) -> Self:
114
+ if isinstance(url_or_path_or_bytes, bytes):
115
+ ext = what(url_or_path_or_bytes)
116
+ if ext is None:
117
+ raise ValueError(f"Invalid image format: {url_or_path_or_bytes[:8]} ...")
118
+ if not cls._verify_ext(ext, config["formats"]):
119
+ raise ValueError(f"Invalid image format: {ext} not in {config['formats']}")
120
+ return cls.from_bytes(url_or_path_or_bytes, ext=ext)
121
+ elif maybe_base64 := cls.from_string(url_or_path_or_bytes):
122
+ return maybe_base64
123
+ elif maybe_url_or_path := await cls.afrom_url_or_path(
124
+ url_or_path_or_bytes, headers=headers, config=config, img_bytes_fetcher=img_bytes_fetcher
125
+ ):
126
+ return maybe_url_or_path
127
+ else:
128
+ raise ValueError(f"Invalid image format: {url_or_path_or_bytes}")
82
129
 
83
130
  @classmethod
84
131
  def from_string(cls, data: str) -> Optional[Self]:
85
132
  match = cls.IMAGE_PATTERN.fullmatch(data)
86
133
  if not match:
87
134
  return None
88
- return cls(ext=cast(ImageType, match.group(1)), data=match.group(2))
135
+ return cls(ext=_to_image_format(match.group(1)), data=match.group(2))
89
136
 
90
137
  @classmethod
91
- def from_bytes(cls, data: bytes, ext: ImageType) -> Self:
92
- return cls(ext=ext, data=b64encode(data).decode("utf-8"))
138
+ def from_bytes(cls, data: bytes, ext: ExtendedImageFormat) -> Self:
139
+ return cls(ext=_to_image_format(ext), data=b64encode(data).decode("utf-8"))
93
140
 
94
141
  @classmethod
95
142
  def from_url_or_path(
@@ -154,7 +201,7 @@ class Base64Image(BaseModel):
154
201
  return {"type": "image_url", "image_url": {"url": self.data_uri}}
155
202
 
156
203
  @staticmethod
157
- def _verify_ext(ext: str, allowed_types: Sequence[ImageType]) -> TypeGuard[ImageType]:
204
+ def _verify_ext(ext: str, allowed_types: Sequence[ImageFormat]) -> TypeGuard[ImageFormat]:
158
205
  return ext in allowed_types
159
206
 
160
207
  @classmethod
@@ -226,7 +273,7 @@ class Base64Image(BaseModel):
226
273
  # 포맷 제한
227
274
  # PIL이 인식한 포맷이 대문자(JPEG)일 수 있으므로 소문자로
228
275
  pil_format: str = (im.format or "").lower()
229
- allowed_formats: Sequence[ImageType] = config.get("formats", [])
276
+ allowed_formats: Sequence[ImageFormat] = config.get("formats", [])
230
277
  if not cls._verify_ext(pil_format, allowed_formats):
231
278
  logger.error(f"Invalid format: {pil_format} not in {allowed_formats}")
232
279
  return None
@@ -265,12 +312,22 @@ class Base64Image(BaseModel):
265
312
  return cls(ext=ext, data=b64encode(path.read_bytes()).decode("ascii"))
266
313
 
267
314
 
315
+ def _to_image_format(ext: str) -> ImageFormat:
316
+ lowered = ext.lower()
317
+ if lowered in ALLOWED_IMAGE_FORMATS:
318
+ return lowered
319
+ elif lowered == "jpg":
320
+ return "jpeg" # jpg -> jpeg
321
+ else:
322
+ raise ValueError(f"Invalid image format: {ext}")
323
+
324
+
268
325
  def is_remote_url(path: str) -> bool:
269
326
  parsed = urlparse(path)
270
327
  return bool(parsed.scheme and parsed.netloc)
271
328
 
272
329
 
273
- def detect_image_type(image_data: bytes) -> Optional[ImageType]:
330
+ def detect_image_type(image_data: bytes) -> Optional[ImageFormat]:
274
331
  """
275
332
  Detect the image format based on the image binary signature (header).
276
333
  Only JPEG, PNG, GIF, WEBP, and BMP are handled as examples.
chatterer/utils/imghdr.py CHANGED
@@ -27,14 +27,11 @@ def decode_prefix(b64_data: str, prefix_bytes: int = 32) -> bytes:
27
27
  return base64.b64decode(b64_data)
28
28
 
29
29
 
30
- def what(b64_data: str) -> Optional[ImageType]:
31
- """
32
- base64 인코딩된 문자열에 포함된 이미지의 타입을 반환한다.
33
-
34
- :param b64_data: 이미지 데이터를 담은 base64 문자열.
35
- :return: 이미지 포맷 문자열 (예: "jpeg", "png", "gif", 등) 또는 인식되지 않으면 None.
36
- """
37
- h: bytes = decode_prefix(b64_data, prefix_bytes=32)
30
+ def what(b64_or_bytes: str | bytes, prefix_bytes: int = 32) -> Optional[ImageType]:
31
+ if isinstance(b64_or_bytes, str):
32
+ h: bytes = decode_prefix(b64_or_bytes, prefix_bytes=prefix_bytes)
33
+ else:
34
+ h = b64_or_bytes
38
35
 
39
36
  for tf in tests:
40
37
  res = tf(h)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: chatterer
3
- Version: 0.1.25
3
+ Version: 0.1.26
4
4
  Summary: The highest-level interface for various LLM APIs.
5
5
  Requires-Python: >=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
- chatterer/__init__.py,sha256=hpbs8EXfz0OyOA1h9o2ZBR_556pyqcJfzJlQEf7Tl7E,2221
1
+ chatterer/__init__.py,sha256=D2kh79a3yC7LLTP_4kI_JiKqOFlti_FEkajAJHfMSek,2047
2
2
  chatterer/interactive.py,sha256=bw4iZSPv57x0WPmasnObLM6f_tIgAJD-KbiXjN7NvYw,16702
3
- chatterer/language_model.py,sha256=I7oAsD_qhQVxTdVWxEX9_Yt6py6sj8wddVueHED2E0U,20179
3
+ chatterer/language_model.py,sha256=y3nhKiQKTQiFX6M0DWL75OelJwkv3xjSigzjf1icve4,20308
4
4
  chatterer/messages.py,sha256=SIvG9hMHaPG4AFadeRbj5yPnMq2J06fHA4D8jrkz4kQ,458
5
5
  chatterer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  chatterer/common_types/__init__.py,sha256=Ais0kqzQJj4sED24xdv5azIxHkkj0vXWb4LC41dFkBQ,355
@@ -16,9 +16,6 @@ chatterer/examples/snippet.py,sha256=1RrqSr4ZZqTb_jZQjEed0G87oYGdsMSINu5qzrxH8Ps
16
16
  chatterer/examples/transcribe.py,sha256=cZoIn8PymlQJcoGKI3PYDlkrYC0juQ6HW5c1GlPk3KY,6650
17
17
  chatterer/examples/upstage.py,sha256=UYehJC13askeJgZS1-aAJniv3KMGt8tQjFBkhfGSoBQ,3130
18
18
  chatterer/examples/web2md.py,sha256=oOSyjdYn4nuv7uMBuNKmRl3PCA7J0rszR4bF9ZUjJRg,3127
19
- chatterer/strategies/__init__.py,sha256=oroDpp5ppWralCER5wnf8fgX77dqLAPF0ogmRtRzQfU,208
20
- chatterer/strategies/atom_of_thoughts.py,sha256=coXh8ODw7_ig6qcxhdJ5TAiyk70PJbrhElN0J5JD12w,40203
21
- chatterer/strategies/base.py,sha256=rqOzTo6S2eQ3A_F9aBXCmVoLM1eT6F2VzZ2Dof330Tk,413
22
19
  chatterer/tools/__init__.py,sha256=cOFo-Aj2xXK_7IvWYRdg6uNomaT9xuSa_mwk2Y0l_AM,1428
23
20
  chatterer/tools/caption_markdown_images.py,sha256=PfvHvr7x0XRLKlujALvOfEB3pQmjzlYFbcJqw2NHgZs,15008
24
21
  chatterer/tools/convert_pdf_to_markdown.py,sha256=hD4JloVdeQ4ZdAmSK1rQO2q-_rXhj3Zo7Hr3sKcGaoI,28264
@@ -34,12 +31,12 @@ chatterer/tools/citation_chunking/prompt.py,sha256=so-8uFQ5b2Zq2V5Brfxd76bEnKYkH
34
31
  chatterer/tools/citation_chunking/reference.py,sha256=m47XYaB5uFff_x_k7US9hNr-SpZjKnl-GuzsGaQzcZo,893
35
32
  chatterer/tools/citation_chunking/utils.py,sha256=Xytm9lMrS783Po1qWAdEJ8q7Q3l2UMzwHd9EkYTRiwk,6210
36
33
  chatterer/utils/__init__.py,sha256=of2NeLOjsAI79TgA4bL7UggCnAc7xT9eu3eeUBt9K8k,326
37
- chatterer/utils/base64_image.py,sha256=bS25MvOrD5i_ofamxyAy7L2dHjsqEgAhwwIRkT36Qq0,11088
34
+ chatterer/utils/base64_image.py,sha256=vmkA2LROTUK2SSCjWTT0sBumo7dpNZiRZzgDnAap_D4,13671
38
35
  chatterer/utils/bytesio.py,sha256=QabdJCZsabPaiYVfJcdXzdiHjuhqDlz1vJuLJ60P7TY,2559
39
36
  chatterer/utils/code_agent.py,sha256=z3GYWZbiKByeMQKXKpUo5JVbkt2hkKwWULyzdkgak1c,10258
40
- chatterer/utils/imghdr.py,sha256=aZ1_AsRzyTsbV7uoeAZMVaC-hj73kvnFHMdHtFKskdE,3694
41
- chatterer-0.1.25.dist-info/METADATA,sha256=ADRPMFWQAwtfXt4tZEChB_yOXq1qJFPhUr2vdyqm3Lk,11273
42
- chatterer-0.1.25.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
43
- chatterer-0.1.25.dist-info/entry_points.txt,sha256=IzGKhTnZ7G5V23SRmulmSsyt9HcaFH4lU4r3wR1zMsc,63
44
- chatterer-0.1.25.dist-info/top_level.txt,sha256=7nSQKP0bHxPRc7HyzdbKsJdkvPgYD0214o6slRizv9s,10
45
- chatterer-0.1.25.dist-info/RECORD,,
37
+ chatterer/utils/imghdr.py,sha256=xIqodIDEHwSvKJ7OEym9aibPkc0AX6kuasoI9tzMIBk,3542
38
+ chatterer-0.1.26.dist-info/METADATA,sha256=8V-6-yM6wp2QyHBtTAj5LVp8ygGeqH9O8mnPH4k0QvQ,11273
39
+ chatterer-0.1.26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ chatterer-0.1.26.dist-info/entry_points.txt,sha256=IzGKhTnZ7G5V23SRmulmSsyt9HcaFH4lU4r3wR1zMsc,63
41
+ chatterer-0.1.26.dist-info/top_level.txt,sha256=7nSQKP0bHxPRc7HyzdbKsJdkvPgYD0214o6slRizv9s,10
42
+ chatterer-0.1.26.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- from .atom_of_thoughts import (
2
- AoTPipeline,
3
- AoTStrategy,
4
- AoTPrompter,
5
- )
6
- from .base import BaseStrategy
7
-
8
- __all__ = [
9
- "BaseStrategy",
10
- "AoTPipeline",
11
- "AoTPrompter",
12
- "AoTStrategy",
13
- ]
@@ -1,975 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import logging
5
- from dataclasses import dataclass, field
6
- from enum import StrEnum
7
- from typing import Optional, Type, TypeVar
8
-
9
- from pydantic import BaseModel, Field, ValidationError
10
-
11
- from ..language_model import Chatterer, LanguageModelInput
12
- from ..messages import AIMessage, BaseMessage, HumanMessage
13
- from .base import BaseStrategy
14
-
15
- # ---------------------------------------------------------------------------------
16
- # 0) Enums and Basic Models
17
- # ---------------------------------------------------------------------------------
18
-
19
- QA_TEMPLATE = "Q: {question}\nA: {answer}"
20
- MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
21
- UNKNOWN = "Unknown"
22
-
23
-
24
- class SubQuestionNode(BaseModel):
25
- """A single sub-question node in a decomposition tree."""
26
-
27
- question: str = Field(description="A sub-question string that arises from decomposition.")
28
- answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
29
- depend: list[int] = Field(description="Indices of sub-questions that this node depends on.")
30
-
31
-
32
- class RecursiveDecomposeResponse(BaseModel):
33
- """The result of a recursive decomposition step."""
34
-
35
- thought: str = Field(description="Reasoning about decomposition.")
36
- final_answer: str = Field(description="Best answer to the main question.")
37
- sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
38
-
39
-
40
- class ContractQuestionResponse(BaseModel):
41
- """The result of contracting (simplifying) a question."""
42
-
43
- thought: str = Field(description="Reasoning on how the question was compressed.")
44
- question: str = Field(description="New, simplified, self-contained question.")
45
-
46
-
47
- class EnsembleResponse(BaseModel):
48
- """The ensemble process result."""
49
-
50
- thought: str = Field(description="Explanation for choosing the final answer.")
51
- answer: str = Field(description="Best final answer after ensemble.")
52
- confidence: float = Field(description="Confidence score in [0, 1].")
53
-
54
- def model_post_init(self, __context: object) -> None:
55
- self.confidence = max(0.0, min(1.0, self.confidence))
56
-
57
-
58
- class LabelResponse(BaseModel):
59
- """Used to refine and reorder the sub-questions with corrected dependencies."""
60
-
61
- thought: str = Field(description="Explanation or reasoning about labeling.")
62
- sub_questions: list[SubQuestionNode] = Field(
63
- description="Refined list of sub-questions with corrected dependencies."
64
- )
65
-
66
-
67
- class CritiqueResponse(BaseModel):
68
- """A response used for LLM to self-critique or question its own correctness."""
69
-
70
- thought: str = Field(description="Critical reflection on correctness.")
71
- self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
72
-
73
-
74
- # ---------------------------------------------------------------------------------
75
- # [NEW] Additional classes to incorporate a separate sub-question devil's advocate
76
- # ---------------------------------------------------------------------------------
77
-
78
-
79
- class DevilsAdvocateResponse(BaseModel):
80
- """
81
- A response for a 'devil's advocate' pass.
82
- We consider an alternative viewpoint or contradictory answer.
83
- """
84
-
85
- thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
86
- final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
87
- sub_questions: list[SubQuestionNode] = Field(
88
- description="Any additional sub-questions from the contrarian perspective."
89
- )
90
-
91
-
92
- # ---------------------------------------------------------------------------------
93
- # 1) Prompter Classes with Multi-Hop + Devil's Advocate
94
- # ---------------------------------------------------------------------------------
95
-
96
-
97
- class AoTPrompter:
98
- """Generic base prompter that defines the required prompt methods."""
99
-
100
- def recursive_decompose_prompt(
101
- self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
102
- ) -> list[BaseMessage]:
103
- """
104
- Prompt for main decomposition.
105
- Encourages step-by-step reasoning and listing sub-questions as JSON.
106
- """
107
- decompose_instructions = (
108
- "First, restate the main question.\n"
109
- "Decide if sub-questions are needed. If so, list them.\n"
110
- "In the 'thought' field, show your chain-of-thought.\n"
111
- "Return valid JSON:\n"
112
- "{\n"
113
- ' "thought": "...",\n'
114
- ' "final_answer": "...",\n'
115
- ' "sub_questions": [\n'
116
- ' {"question": "...", "answer": null, "depend": []},\n'
117
- " ...\n"
118
- " ]\n"
119
- "}\n"
120
- )
121
-
122
- content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
123
- return messages + [
124
- HumanMessage(content=f"Main question:\n{question}"),
125
- AIMessage(content=content_sub_answers),
126
- AIMessage(content=decompose_instructions),
127
- ]
128
-
129
- def label_prompt(
130
- self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
131
- ) -> list[BaseMessage]:
132
- """
133
- Prompt for refining the sub-questions and dependencies.
134
- """
135
- label_instructions = (
136
- "Review each sub-question to ensure correctness and proper ordering.\n"
137
- "Return valid JSON in the form:\n"
138
- "{\n"
139
- ' "thought": "...",\n'
140
- ' "sub_questions": [\n'
141
- ' {"question": "...", "answer": "...", "depend": [...]},\n'
142
- " ...\n"
143
- " ]\n"
144
- "}\n"
145
- )
146
- return messages + [
147
- AIMessage(content=f"Question: {question}"),
148
- AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
149
- AIMessage(content=label_instructions),
150
- ]
151
-
152
- def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
153
- """
154
- Prompt for merging sub-answers into one self-contained question.
155
- """
156
- contract_instructions = (
157
- "Please merge sub-answers into a single short question that is fully self-contained.\n"
158
- "In 'thought', show how you unify the information.\n"
159
- "Then produce JSON:\n"
160
- "{\n"
161
- ' "thought": "...",\n'
162
- ' "question": "a short but self-contained question"\n'
163
- "}\n"
164
- )
165
- sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
166
- return messages + [
167
- AIMessage(content="We have these sub-questions and answers:"),
168
- AIMessage(content=sub_q_content),
169
- AIMessage(content=contract_instructions),
170
- ]
171
-
172
- def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
173
- """
174
- Prompt for directly answering the contracted question thoroughly.
175
- """
176
- direct_instructions = (
177
- "Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
178
- "Return JSON:\n"
179
- "{\n"
180
- ' "thought": "...",\n'
181
- ' "final_answer": "..."\n'
182
- "}\n"
183
- )
184
- return messages + [
185
- HumanMessage(content=f"Simplified question: {contracted_question}"),
186
- AIMessage(content=direct_instructions),
187
- ]
188
-
189
- def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
190
- """
191
- Prompt for self-critique.
192
- """
193
- critique_instructions = (
194
- "Critique your own approach. Identify possible errors or leaps.\n"
195
- "Return JSON:\n"
196
- "{\n"
197
- ' "thought": "...",\n'
198
- ' "self_assessment": <float in [0,1]>\n'
199
- "}\n"
200
- )
201
- return messages + [
202
- AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
203
- AIMessage(content=f"Your previous ANSWER:\n{answer}"),
204
- AIMessage(content=critique_instructions),
205
- ]
206
-
207
- def ensemble_prompt(
208
- self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
209
- ) -> list[BaseMessage]:
210
- """
211
- Show multiple candidate solutions and pick the best final answer with confidence.
212
- """
213
- instructions = (
214
- "You have multiple candidate solutions. Compare carefully and pick the best.\n"
215
- "Return JSON:\n"
216
- "{\n"
217
- ' "thought": "why you chose this final answer",\n'
218
- ' "answer": "the best consolidated answer",\n'
219
- ' "confidence": 0.0 ~ 1.0\n'
220
- "}\n"
221
- )
222
- reasonings: list[BaseMessage] = []
223
- for idx, (thought, ans) in enumerate(possible_thought_and_answers):
224
- reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
225
- return messages + reasonings + [AIMessage(content=instructions)]
226
-
227
- def devils_advocate_prompt(
228
- self, messages: list[BaseMessage], question: str, existing_answer: str
229
- ) -> list[BaseMessage]:
230
- """
231
- Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
232
- """
233
- instructions = (
234
- "Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
235
- "Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
236
- "Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
237
- "But here, let's keep it in a new dedicated structure:\n"
238
- "{\n"
239
- ' "thought": "...",\n'
240
- ' "final_answer": "...",\n'
241
- ' "sub_questions": [\n'
242
- ' {"question": "...", "answer": null, "depend": []},\n'
243
- " ...\n"
244
- " ]\n"
245
- "}\n"
246
- )
247
- return messages + [
248
- AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
249
- AIMessage(content=instructions),
250
- ]
251
-
252
-
253
- # ---------------------------------------------------------------------------------
254
- # 2) Strict Typed Steps for Pipeline
255
- # ---------------------------------------------------------------------------------
256
-
257
-
258
- class StepName(StrEnum):
259
- """Enum for step names in the pipeline."""
260
-
261
- DOMAIN_DETECTION = "DomainDetection"
262
- DECOMPOSITION = "Decomposition"
263
- DECOMPOSITION_CRITIQUE = "DecompositionCritique"
264
- CONTRACTED_QUESTION = "ContractedQuestion"
265
- CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
266
- CONTRACT_CRITIQUE = "ContractCritique"
267
- BEST_APPROACH_DECISION = "BestApproachDecision"
268
- ENSEMBLE = "Ensemble"
269
- FINAL_ANSWER = "FinalAnswer"
270
-
271
- DEVILS_ADVOCATE = "DevilsAdvocate"
272
- DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
273
-
274
-
275
- class StepRelation(StrEnum):
276
- """Enum for relationship types in the reasoning graph."""
277
-
278
- CRITIQUES = "CRITIQUES"
279
- SELECTS = "SELECTS"
280
- RESULT_OF = "RESULT_OF"
281
- SPLIT_INTO = "SPLIT_INTO"
282
- DEPEND_ON = "DEPEND_ON"
283
- PRECEDES = "PRECEDES"
284
- DECOMPOSED_BY = "DECOMPOSED_BY"
285
-
286
-
287
- class StepRecord(BaseModel):
288
- """A typed record for each pipeline step."""
289
-
290
- step_name: StepName
291
- domain: Optional[str] = None
292
- score: Optional[float] = None
293
- used: Optional[StepName] = None
294
- sub_questions: Optional[list[SubQuestionNode]] = None
295
- parent_decomp_step_idx: Optional[int] = None
296
- parent_subq_idx: Optional[int] = None
297
- question: Optional[str] = None
298
- thought: Optional[str] = None
299
- answer: Optional[str] = None
300
-
301
- def as_properties(self) -> dict[str, str | float | int | None]:
302
- """Converts the StepRecord to a dictionary of properties."""
303
- result: dict[str, str | float | int | None] = {}
304
- if self.score is not None:
305
- result["score"] = self.score
306
- if self.domain:
307
- result["domain"] = self.domain
308
- if self.question:
309
- result["question"] = self.question
310
- if self.thought:
311
- result["thought"] = self.thought
312
- if self.answer:
313
- result["answer"] = self.answer
314
- return result
315
-
316
-
317
- # ---------------------------------------------------------------------------------
318
- # 3) Logging Setup
319
- # ---------------------------------------------------------------------------------
320
-
321
-
322
- class SimpleColorFormatter(logging.Formatter):
323
- """Simple color-coded logging formatter for console output using ANSI escape codes."""
324
-
325
- BLUE = "\033[94m"
326
- GREEN = "\033[92m"
327
- YELLOW = "\033[93m"
328
- RED = "\033[91m"
329
- RESET = "\033[0m"
330
- LEVEL_COLORS = {
331
- logging.DEBUG: BLUE,
332
- logging.INFO: GREEN,
333
- logging.WARNING: YELLOW,
334
- logging.ERROR: RED,
335
- logging.CRITICAL: RED,
336
- }
337
-
338
- def format(self, record: logging.LogRecord) -> str:
339
- log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
340
- message = super().format(record)
341
- return f"{log_color}{message}{self.RESET}"
342
-
343
-
344
- logger = logging.getLogger("AoT")
345
- logger.setLevel(logging.INFO)
346
- handler = logging.StreamHandler()
347
- handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
348
- logger.handlers = [handler]
349
- logger.propagate = False
350
-
351
-
352
- # ---------------------------------------------------------------------------------
353
- # 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
354
- # ---------------------------------------------------------------------------------
355
-
356
- T = TypeVar(
357
- "T",
358
- bound=EnsembleResponse
359
- | ContractQuestionResponse
360
- | LabelResponse
361
- | CritiqueResponse
362
- | RecursiveDecomposeResponse
363
- | DevilsAdvocateResponse,
364
- )
365
-
366
-
367
- @dataclass
368
- class AoTPipeline:
369
- """
370
- The pipeline orchestrates:
371
- 1) Recursive decomposition
372
- 2) For each sub-question, it tries a main approach + a devil's advocate approach
373
- 3) Merges sub-answers using an ensemble
374
- 4) Contracts the question
375
- 5) Possibly does a direct approach on the contracted question
376
- 6) Ensembling the final answers
377
- """
378
-
379
- chatterer: Chatterer
380
- max_depth: int = 2
381
- max_retries: int = 2
382
- steps_history: list[StepRecord] = field(default_factory=list[StepRecord])
383
- prompter: AoTPrompter = field(default_factory=AoTPrompter)
384
-
385
- # 4.1) Utility for calling the LLM with Pydantic parsing
386
- async def _ainvoke_pydantic(
387
- self,
388
- messages: list[BaseMessage],
389
- model_cls: Type[T],
390
- fallback: str = "<None>",
391
- ) -> T:
392
- """
393
- Attempts up to max_retries to parse the model_cls from LLM output as JSON.
394
- """
395
- for attempt in range(1, self.max_retries + 1):
396
- try:
397
- return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
398
- except ValidationError as e:
399
- logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
400
- if attempt == self.max_retries:
401
- # Return a fallback version
402
- if issubclass(model_cls, EnsembleResponse):
403
- return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
404
- elif issubclass(model_cls, ContractQuestionResponse):
405
- return model_cls(thought=fallback, question=fallback) # type: ignore
406
- elif issubclass(model_cls, LabelResponse):
407
- return model_cls(thought=fallback, sub_questions=[]) # type: ignore
408
- elif issubclass(model_cls, CritiqueResponse):
409
- return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
410
- elif issubclass(model_cls, DevilsAdvocateResponse):
411
- return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
412
- else:
413
- return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
414
- # theoretically unreachable
415
- raise RuntimeError("Unexpected error in _ainvoke_pydantic")
416
-
417
- # 4.2) Helper method for self-critique
418
- async def _ainvoke_critique(
419
- self,
420
- messages: list[BaseMessage],
421
- thought: str,
422
- answer: str,
423
- ) -> CritiqueResponse:
424
- """
425
- Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
426
- """
427
- return await self._ainvoke_pydantic(
428
- messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
429
- model_cls=CritiqueResponse,
430
- )
431
-
432
- # 4.3) Helper method for devil's advocate approach
433
- async def _ainvoke_devils_advocate(
434
- self,
435
- messages: list[BaseMessage],
436
- question: str,
437
- existing_answer: str,
438
- ) -> DevilsAdvocateResponse:
439
- """
440
- Instructs the LLM to challenge an existing answer with a devil's advocate approach.
441
- """
442
- return await self._ainvoke_pydantic(
443
- messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
444
- model_cls=DevilsAdvocateResponse,
445
- )
446
-
447
- # 4.4) The main function that recursively decomposes a question and calls sub-steps
448
- async def _arecursive_decompose_question(
449
- self,
450
- messages: list[BaseMessage],
451
- question: str,
452
- depth: int,
453
- parent_decomp_step_idx: Optional[int] = None,
454
- parent_subq_idx: Optional[int] = None,
455
- ) -> RecursiveDecomposeResponse:
456
- """
457
- Recursively decompose the given question. For each sub-question:
458
- 1) Recursively decompose that sub-question if we still have depth left
459
- 2) After getting a main sub-answer, do a devil's advocate pass
460
- 3) Combine main sub-answer + devil's advocate alternative via an ensemble
461
- """
462
- if depth < 0:
463
- logger.info("Max depth reached, returning unknown.")
464
- return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
465
-
466
- # Step 1: Perform the decomposition
467
- decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
468
- messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
469
- model_cls=RecursiveDecomposeResponse,
470
- )
471
-
472
- # Step 2: Label / refine sub-questions (dependencies, ordering)
473
- if decompose_resp.sub_questions:
474
- label_resp: LabelResponse = await self._ainvoke_pydantic(
475
- messages=self.prompter.label_prompt(messages, question, decompose_resp),
476
- model_cls=LabelResponse,
477
- )
478
- decompose_resp.sub_questions = label_resp.sub_questions
479
-
480
- # Save a pipeline record for this decomposition step
481
- current_decomp_step_idx = self._record_decomposition_step(
482
- question=question,
483
- final_answer=decompose_resp.final_answer,
484
- sub_questions=decompose_resp.sub_questions,
485
- parent_decomp_step_idx=parent_decomp_step_idx,
486
- parent_subq_idx=parent_subq_idx,
487
- )
488
-
489
- # Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
490
- if depth > 0 and decompose_resp.sub_questions:
491
- solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
492
- messages=messages,
493
- sub_questions=decompose_resp.sub_questions,
494
- depth=depth,
495
- parent_decomp_step_idx=current_decomp_step_idx,
496
- )
497
- # Then we can refine the "final_answer" from those sub-answers
498
- # or we do a secondary pass to refine the final answer
499
- refined_prompt = self.prompter.recursive_decompose_prompt(
500
- messages=messages,
501
- question=question,
502
- sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
503
- )
504
- refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
505
- refined_prompt, RecursiveDecomposeResponse
506
- )
507
- decompose_resp.final_answer = refined_resp.final_answer
508
- decompose_resp.sub_questions = solved_subs
509
-
510
- # Update pipeline record
511
- self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
512
- self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
513
-
514
- return decompose_resp
515
-
516
- def _record_decomposition_step(
517
- self,
518
- question: str,
519
- final_answer: str,
520
- sub_questions: list[SubQuestionNode],
521
- parent_decomp_step_idx: Optional[int],
522
- parent_subq_idx: Optional[int],
523
- ) -> int:
524
- """
525
- Save the decomposition step in steps_history, returning the index.
526
- """
527
- step_record = StepRecord(
528
- step_name=StepName.DECOMPOSITION,
529
- question=question,
530
- answer=final_answer,
531
- sub_questions=sub_questions,
532
- parent_decomp_step_idx=parent_decomp_step_idx,
533
- parent_subq_idx=parent_subq_idx,
534
- )
535
- self.steps_history.append(step_record)
536
- return len(self.steps_history) - 1
537
-
538
- async def _aresolve_sub_questions(
539
- self,
540
- messages: list[BaseMessage],
541
- sub_questions: list[SubQuestionNode],
542
- depth: int,
543
- parent_decomp_step_idx: Optional[int],
544
- ) -> list[SubQuestionNode]:
545
- """
546
- Resolve sub-questions in topological order.
547
- For each sub-question:
548
- 1) Recursively decompose (main approach).
549
- 2) Acquire a devil's advocate alternative.
550
- 3) Critique or ensemble if needed.
551
- 4) Finalize sub-question answer.
552
- """
553
- n = len(sub_questions)
554
- in_degree = [0] * n
555
- graph: list[list[int]] = [[] for _ in range(n)]
556
- for i, sq in enumerate(sub_questions):
557
- for dep in sq.depend:
558
- if 0 <= dep < n:
559
- in_degree[i] += 1
560
- graph[dep].append(i)
561
-
562
- # Kahn's algorithm for topological order
563
- queue = [i for i in range(n) if in_degree[i] == 0]
564
- topo_order: list[int] = []
565
-
566
- while queue:
567
- node = queue.pop(0)
568
- topo_order.append(node)
569
- for nxt in graph[node]:
570
- in_degree[nxt] -= 1
571
- if in_degree[nxt] == 0:
572
- queue.append(nxt)
573
-
574
- # We'll store the resolved sub-questions
575
- final_subs: dict[int, SubQuestionNode] = {}
576
-
577
- async def _resolve_one_subq(idx: int):
578
- sq = sub_questions[idx]
579
- # 1) Main approach
580
- main_resp = await self._arecursive_decompose_question(
581
- messages=messages,
582
- question=sq.question,
583
- depth=depth - 1,
584
- parent_decomp_step_idx=parent_decomp_step_idx,
585
- parent_subq_idx=idx,
586
- )
587
-
588
- main_answer = main_resp.final_answer
589
-
590
- # 2) Devil's Advocate approach
591
- devils_resp = await self._ainvoke_devils_advocate(
592
- messages=messages, question=sq.question, existing_answer=main_answer
593
- )
594
- # 3) Ensemble to combine main_answer + devils_alternative
595
- ensemble_sub = await self._ainvoke_pydantic(
596
- self.prompter.ensemble_prompt(
597
- messages=messages,
598
- possible_thought_and_answers=[
599
- (main_resp.thought, main_answer),
600
- (devils_resp.thought, devils_resp.final_answer),
601
- ],
602
- ),
603
- EnsembleResponse,
604
- )
605
- sub_best_answer = ensemble_sub.answer
606
-
607
- # Store final subq answer
608
- sq.answer = sub_best_answer
609
- final_subs[idx] = sq
610
-
611
- # Record pipeline steps for devil's advocate
612
- self.steps_history.append(
613
- StepRecord(
614
- step_name=StepName.DEVILS_ADVOCATE,
615
- question=sq.question,
616
- answer=devils_resp.final_answer,
617
- thought=devils_resp.thought,
618
- sub_questions=devils_resp.sub_questions,
619
- )
620
- )
621
- # Possibly critique the devils advocate result
622
- dev_adv_crit = await self._ainvoke_critique(
623
- messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
624
- )
625
- self.steps_history.append(
626
- StepRecord(
627
- step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
628
- thought=dev_adv_crit.thought,
629
- score=dev_adv_crit.self_assessment,
630
- )
631
- )
632
-
633
- # Solve sub-questions in topological order
634
- tasks = [_resolve_one_subq(i) for i in topo_order]
635
- await asyncio.gather(*tasks, return_exceptions=False)
636
-
637
- return [final_subs[i] for i in range(n)]
638
-
639
- # 4.5) The primary pipeline method
640
- async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
641
- """
642
- Execute the pipeline:
643
- 1) Decompose the main question (recursively).
644
- 2) Self-critique.
645
- 3) Provide a devil's advocate approach on the entire main result.
646
- 4) Contract sub-answers (optional).
647
- 5) Directly solve the contracted question.
648
- 6) Self-critique again.
649
- 7) Final ensemble across main vs devil's vs contracted direct answer.
650
- 8) Return final answer.
651
- """
652
- self.steps_history.clear()
653
-
654
- original_question: str = messages[-1].text()
655
- # 1) Recursive decomposition
656
- decomp_resp = await self._arecursive_decompose_question(
657
- messages=messages,
658
- question=original_question,
659
- depth=self.max_depth,
660
- )
661
- logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
662
-
663
- # 2) Self-critique of main decomposition
664
- decomp_critique = await self._ainvoke_critique(
665
- messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
666
- )
667
- self.steps_history.append(
668
- StepRecord(
669
- step_name=StepName.DECOMPOSITION_CRITIQUE,
670
- thought=decomp_critique.thought,
671
- score=decomp_critique.self_assessment,
672
- )
673
- )
674
-
675
- # 3) Devil's advocate on the entire main answer
676
- devils_on_main = await self._ainvoke_devils_advocate(
677
- messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
678
- )
679
- self.steps_history.append(
680
- StepRecord(
681
- step_name=StepName.DEVILS_ADVOCATE,
682
- question=original_question,
683
- answer=devils_on_main.final_answer,
684
- thought=devils_on_main.thought,
685
- sub_questions=devils_on_main.sub_questions,
686
- )
687
- )
688
- devils_crit_main = await self._ainvoke_critique(
689
- messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
690
- )
691
- self.steps_history.append(
692
- StepRecord(
693
- step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
694
- thought=devils_crit_main.thought,
695
- score=devils_crit_main.self_assessment,
696
- )
697
- )
698
-
699
- # 4) Contract sub-answers from main decomposition
700
- top_decomp_record: Optional[StepRecord] = next(
701
- (
702
- s
703
- for s in reversed(self.steps_history)
704
- if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
705
- ),
706
- None,
707
- )
708
- if top_decomp_record and top_decomp_record.sub_questions:
709
- sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
710
- else:
711
- sub_answers = []
712
-
713
- contract_resp = await self._ainvoke_pydantic(
714
- messages=self.prompter.contract_prompt(messages, sub_answers),
715
- model_cls=ContractQuestionResponse,
716
- )
717
- contracted_question = contract_resp.question
718
- self.steps_history.append(
719
- StepRecord(
720
- step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
721
- )
722
- )
723
-
724
- # 5) Attempt direct approach on contracted question
725
- contracted_direct = await self._ainvoke_pydantic(
726
- messages=self.prompter.contract_direct_prompt(messages, contracted_question),
727
- model_cls=RecursiveDecomposeResponse,
728
- fallback="No Contracted Direct Answer",
729
- )
730
- self.steps_history.append(
731
- StepRecord(
732
- step_name=StepName.CONTRACTED_DIRECT_ANSWER,
733
- answer=contracted_direct.final_answer,
734
- thought=contracted_direct.thought,
735
- )
736
- )
737
- logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
738
-
739
- # 5.1) Critique the contracted direct approach
740
- contract_critique = await self._ainvoke_critique(
741
- messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
742
- )
743
- self.steps_history.append(
744
- StepRecord(
745
- step_name=StepName.CONTRACT_CRITIQUE,
746
- thought=contract_critique.thought,
747
- score=contract_critique.self_assessment,
748
- )
749
- )
750
-
751
- # 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
752
- ensemble_resp = await self._ainvoke_pydantic(
753
- self.prompter.ensemble_prompt(
754
- messages=messages,
755
- possible_thought_and_answers=[
756
- (decomp_resp.thought, decomp_resp.final_answer),
757
- (devils_on_main.thought, devils_on_main.final_answer),
758
- (contracted_direct.thought, contracted_direct.final_answer),
759
- ],
760
- ),
761
- EnsembleResponse,
762
- )
763
- best_approach_answer = ensemble_resp.answer
764
- approach_used = StepName.ENSEMBLE
765
- self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
766
- logger.info(f"[Best Approach Decision] => {approach_used}")
767
-
768
- # 7) Final answer
769
- self.steps_history.append(
770
- StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
771
- )
772
- logger.info(f"[Final Answer] => {best_approach_answer}")
773
-
774
- return best_approach_answer
775
-
776
- def run_pipeline(self, messages: list[BaseMessage]) -> str:
777
- """Synchronous wrapper around arun_pipeline."""
778
- return asyncio.run(self.arun_pipeline(messages))
779
-
780
- # ---------------------------------------------------------------------------------
781
- # 4.6) Build or export a reasoning graph
782
- # ---------------------------------------------------------------------------------
783
-
784
- # def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
785
- # """
786
- # Constructs a Graph object (from hypothetical `neo4j_extension`)
787
- # capturing the pipeline steps, including devil's advocate steps.
788
- # """
789
- # from neo4j_extension import Graph, Node, Relationship
790
-
791
- # g = Graph()
792
- # step_nodes: dict[int, Node] = {}
793
- # subq_nodes: dict[str, Node] = {}
794
-
795
- # # Step A: Create nodes for each pipeline step
796
- # for i, record in enumerate(self.steps_history):
797
- # # We'll skip nested Decomposition steps only if we want to flatten them.
798
- # # But let's keep them for clarity.
799
- # step_node = Node(
800
- # properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
801
- # )
802
- # g.add_node(step_node)
803
- # step_nodes[i] = step_node
804
-
805
- # # Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
806
- # all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
807
- # for i, record in enumerate(self.steps_history):
808
- # if record.sub_questions:
809
- # for sq_idx, sq in enumerate(record.sub_questions):
810
- # sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
811
- # all_sub_questions[sq_id] = (i, sq_idx, sq)
812
-
813
- # for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
814
- # n_subq = Node(
815
- # properties={
816
- # "question": sq.question,
817
- # "answer": sq.answer or "",
818
- # },
819
- # labels={"SubQuestion"},
820
- # globalId=sq_id,
821
- # )
822
- # g.add_node(n_subq)
823
- # subq_nodes[sq_id] = n_subq
824
-
825
- # # Step C: Add relationships. We do a simple approach:
826
- # # - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
827
- # for i, record in enumerate(self.steps_history):
828
- # if record.sub_questions:
829
- # start_node = step_nodes[i]
830
- # for sq_idx, sq in enumerate(record.sub_questions):
831
- # sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
832
- # end_node = subq_nodes[sq_id]
833
- # rel = Relationship(
834
- # properties={},
835
- # rel_type=StepRelation.SPLIT_INTO,
836
- # start_node=start_node,
837
- # end_node=end_node,
838
- # globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
839
- # )
840
- # g.add_relationship(rel)
841
- # # Also add sub-question dependencies
842
- # for dep in sq.depend:
843
- # # The same record i -> sub-question subq
844
- # if 0 <= dep < len(record.sub_questions):
845
- # dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
846
- # if dep_id in subq_nodes:
847
- # dep_node = subq_nodes[dep_id]
848
- # rel_dep = Relationship(
849
- # properties={},
850
- # rel_type=StepRelation.DEPEND_ON,
851
- # start_node=end_node,
852
- # end_node=dep_node,
853
- # globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
854
- # )
855
- # g.add_relationship(rel_dep)
856
-
857
- # # Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
858
- # for i in range(len(self.steps_history) - 1):
859
- # start_node = step_nodes[i]
860
- # end_node = step_nodes[i + 1]
861
- # rel = Relationship(
862
- # properties={},
863
- # rel_type=StepRelation.PRECEDES,
864
- # start_node=start_node,
865
- # end_node=end_node,
866
- # globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
867
- # )
868
- # g.add_relationship(rel)
869
-
870
- # # Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
871
- # # We'll do a simple pass:
872
- # # If step_name ends with CRITIQUE => it critiques the step before it
873
- # for i, record in enumerate(self.steps_history):
874
- # if "CRITIQUE" in record.step_name:
875
- # # Let it point to the preceding step
876
- # if i > 0:
877
- # start_node = step_nodes[i]
878
- # end_node = step_nodes[i - 1]
879
- # rel = Relationship(
880
- # properties={},
881
- # rel_type=StepRelation.CRITIQUES,
882
- # start_node=start_node,
883
- # end_node=end_node,
884
- # globalId=f"{global_id_prefix}_crit_{i}",
885
- # )
886
- # g.add_relationship(rel)
887
-
888
- # # If there's a BEST_APPROACH_DECISION step, link it to the step it uses
889
- # best_decision_idx = None
890
- # used_step_idx = None
891
- # for i, record in enumerate(self.steps_history):
892
- # if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
893
- # best_decision_idx = i
894
- # # find the step with that name
895
- # used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
896
- # if used_step_idx is not None:
897
- # rel = Relationship(
898
- # properties={},
899
- # rel_type=StepRelation.SELECTS,
900
- # start_node=step_nodes[i],
901
- # end_node=step_nodes[used_step_idx],
902
- # globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
903
- # )
904
- # g.add_relationship(rel)
905
-
906
- # # And link the final answer to the best approach
907
- # final_answer_idx = next(
908
- # (i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
909
- # )
910
- # if final_answer_idx is not None and best_decision_idx is not None:
911
- # rel = Relationship(
912
- # properties={},
913
- # rel_type=StepRelation.RESULT_OF,
914
- # start_node=step_nodes[final_answer_idx],
915
- # end_node=step_nodes[best_decision_idx],
916
- # globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
917
- # )
918
- # g.add_relationship(rel)
919
-
920
- # return g
921
-
922
-
923
- # ---------------------------------------------------------------------------------
924
- # 5) AoTStrategy class that uses the pipeline
925
- # ---------------------------------------------------------------------------------
926
-
927
-
928
- @dataclass
929
- class AoTStrategy(BaseStrategy):
930
- """
931
- Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
932
- """
933
-
934
- pipeline: AoTPipeline
935
-
936
- async def ainvoke(self, messages: LanguageModelInput) -> str:
937
- """Asynchronously run the pipeline with the given messages."""
938
- # Convert your custom input to list[BaseMessage] as needed:
939
- msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
940
- return await self.pipeline.arun_pipeline(msgs)
941
-
942
- def invoke(self, messages: LanguageModelInput) -> str:
943
- """Synchronously run the pipeline with the given messages."""
944
- msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
945
- return self.pipeline.run_pipeline(msgs)
946
-
947
- # def get_reasoning_graph(self):
948
- # """Return the AoT reasoning graph from the pipeline’s steps history."""
949
- # return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
950
-
951
-
952
- # ---------------------------------------------------------------------------------
953
- # Example usage (pseudo-code)
954
- # ---------------------------------------------------------------------------------
955
- # if __name__ == "__main__":
956
- # from neo4j_extension import Neo4jConnection # or your actual DB connector
957
-
958
- # # You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
959
- # chatterer = Chatterer.openai() # pseudo-code
960
- # pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
961
- # strategy = AoTStrategy(pipeline=pipeline)
962
-
963
- # question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
964
- # answer = strategy.invoke(question)
965
- # print("Final Answer:", answer)
966
-
967
- # # Build the reasoning graph
968
- # graph = strategy.get_reasoning_graph()
969
- # print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
970
-
971
- # # Optionally store in Neo4j
972
- # with Neo4jConnection() as conn:
973
- # conn.clear_all()
974
- # conn.upsert_graph(graph)
975
- # print("Graph stored in Neo4j.")
@@ -1,14 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- from ..language_model import LanguageModelInput
4
-
5
-
6
- class BaseStrategy(ABC):
7
- @abstractmethod
8
- def invoke(self, messages: LanguageModelInput) -> str:
9
- """
10
- Invoke the strategy with the given messages.
11
-
12
- messages: List of messages to be passed to the strategy.
13
- e.g. [{"role": "user", "content": "What is the meaning of life?"}]
14
- """