chatterer 0.1.25__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatterer/__init__.py +0 -10
- chatterer/language_model.py +3 -3
- chatterer/utils/base64_image.py +75 -18
- chatterer/utils/imghdr.py +5 -8
- {chatterer-0.1.25.dist-info → chatterer-0.1.26.dist-info}/METADATA +1 -1
- {chatterer-0.1.25.dist-info → chatterer-0.1.26.dist-info}/RECORD +9 -12
- chatterer/strategies/__init__.py +0 -13
- chatterer/strategies/atom_of_thoughts.py +0 -975
- chatterer/strategies/base.py +0 -14
- {chatterer-0.1.25.dist-info → chatterer-0.1.26.dist-info}/WHEEL +0 -0
- {chatterer-0.1.25.dist-info → chatterer-0.1.26.dist-info}/entry_points.txt +0 -0
- {chatterer-0.1.25.dist-info → chatterer-0.1.26.dist-info}/top_level.txt +0 -0
chatterer/__init__.py
CHANGED
@@ -12,12 +12,6 @@ from .messages import (
|
|
12
12
|
SystemMessage,
|
13
13
|
UsageMetadata,
|
14
14
|
)
|
15
|
-
from .strategies import (
|
16
|
-
AoTPipeline,
|
17
|
-
AoTPrompter,
|
18
|
-
AoTStrategy,
|
19
|
-
BaseStrategy,
|
20
|
-
)
|
21
15
|
from .tools import (
|
22
16
|
CodeSnippets,
|
23
17
|
MarkdownLink,
|
@@ -53,11 +47,7 @@ from .utils import (
|
|
53
47
|
load_dotenv()
|
54
48
|
|
55
49
|
__all__ = [
|
56
|
-
"BaseStrategy",
|
57
50
|
"Chatterer",
|
58
|
-
"AoTStrategy",
|
59
|
-
"AoTPipeline",
|
60
|
-
"AoTPrompter",
|
61
51
|
"html_to_markdown",
|
62
52
|
"anything_to_markdown",
|
63
53
|
"pdf_to_text",
|
chatterer/language_model.py
CHANGED
@@ -27,7 +27,7 @@ from .messages import AIMessage, BaseMessage, HumanMessage, UsageMetadata
|
|
27
27
|
from .utils.code_agent import CodeExecutionResult, FunctionSignature, augment_prompt_for_toolcall
|
28
28
|
|
29
29
|
if TYPE_CHECKING:
|
30
|
-
from instructor import Partial
|
30
|
+
from instructor import Partial # pyright: ignore[reportMissingTypeStubs]
|
31
31
|
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
32
32
|
|
33
33
|
PydanticModelT = TypeVar("PydanticModelT", bound=BaseModel)
|
@@ -339,7 +339,7 @@ class Chatterer(BaseModel):
|
|
339
339
|
**kwargs: Any,
|
340
340
|
) -> Iterator[PydanticModelT]:
|
341
341
|
try:
|
342
|
-
import instructor
|
342
|
+
import instructor # pyright: ignore[reportMissingTypeStubs]
|
343
343
|
except ImportError:
|
344
344
|
raise ImportError("Please install `instructor` with `pip install instructor` to use this feature.")
|
345
345
|
|
@@ -360,7 +360,7 @@ class Chatterer(BaseModel):
|
|
360
360
|
**kwargs: Any,
|
361
361
|
) -> AsyncIterator[PydanticModelT]:
|
362
362
|
try:
|
363
|
-
import instructor
|
363
|
+
import instructor # pyright: ignore[reportMissingTypeStubs]
|
364
364
|
except ImportError:
|
365
365
|
raise ImportError("Please install `instructor` with `pip install instructor` to use this feature.")
|
366
366
|
|
chatterer/utils/base64_image.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
import re
|
4
2
|
from base64 import b64encode
|
5
3
|
from io import BytesIO
|
@@ -18,7 +16,6 @@ from typing import (
|
|
18
16
|
TypeAlias,
|
19
17
|
TypedDict,
|
20
18
|
TypeGuard,
|
21
|
-
cast,
|
22
19
|
get_args,
|
23
20
|
)
|
24
21
|
from urllib.parse import urlparse
|
@@ -29,11 +26,16 @@ from PIL.Image import Resampling
|
|
29
26
|
from PIL.Image import open as image_open
|
30
27
|
from pydantic import BaseModel
|
31
28
|
|
29
|
+
from .imghdr import what
|
30
|
+
|
32
31
|
if TYPE_CHECKING:
|
33
32
|
from openai.types.chat.chat_completion_content_part_image_param import ChatCompletionContentPartImageParam
|
34
33
|
|
35
34
|
logger = getLogger(__name__)
|
36
|
-
|
35
|
+
ImageFormat: TypeAlias = Literal["jpeg", "png", "gif", "webp", "bmp"]
|
36
|
+
ExtendedImageFormat: TypeAlias = ImageFormat | Literal["jpg", "JPG"] | Literal["JPEG", "PNG", "GIF", "WEBP", "BMP"]
|
37
|
+
|
38
|
+
ALLOWED_IMAGE_FORMATS: tuple[ImageFormat, ...] = get_args(ImageFormat)
|
37
39
|
|
38
40
|
|
39
41
|
class ImageProcessingConfig(TypedDict):
|
@@ -46,7 +48,7 @@ class ImageProcessingConfig(TypedDict):
|
|
46
48
|
- resize_target_for_min_side: (int) 리스케일시, '가장 작은 변'을 이 값으로 줄임(비율 유지는 Lanczos).
|
47
49
|
"""
|
48
50
|
|
49
|
-
formats: Sequence[
|
51
|
+
formats: Sequence[ImageFormat]
|
50
52
|
max_size_mb: NotRequired[float]
|
51
53
|
min_largest_side: NotRequired[int]
|
52
54
|
resize_if_min_side_exceeds: NotRequired[int]
|
@@ -59,16 +61,15 @@ def get_default_image_processing_config() -> ImageProcessingConfig:
|
|
59
61
|
"min_largest_side": 200,
|
60
62
|
"resize_if_min_side_exceeds": 2000,
|
61
63
|
"resize_target_for_min_side": 1000,
|
62
|
-
"formats": ["png", "jpeg", "
|
64
|
+
"formats": ["png", "jpeg", "gif", "bmp", "webp"],
|
63
65
|
}
|
64
66
|
|
65
67
|
|
66
|
-
# image_url: str, headers: dict[str, str]) -> Optional[bytes]:
|
67
68
|
class Base64Image(BaseModel):
|
68
|
-
ext:
|
69
|
+
ext: ImageFormat
|
69
70
|
data: str
|
70
71
|
|
71
|
-
IMAGE_TYPES: ClassVar[tuple[str, ...]] =
|
72
|
+
IMAGE_TYPES: ClassVar[tuple[str, ...]] = ALLOWED_IMAGE_FORMATS
|
72
73
|
IMAGE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(
|
73
74
|
r"data:image/(" + "|".join(IMAGE_TYPES) + r");base64,([A-Za-z0-9+/]+={0,2})"
|
74
75
|
)
|
@@ -76,20 +77,66 @@ class Base64Image(BaseModel):
|
|
76
77
|
def __hash__(self) -> int:
|
77
78
|
return hash((self.ext, self.data))
|
78
79
|
|
79
|
-
|
80
|
-
|
81
|
-
|
80
|
+
@classmethod
|
81
|
+
def new(
|
82
|
+
cls,
|
83
|
+
url_or_path_or_bytes: str | bytes,
|
84
|
+
*,
|
85
|
+
headers: dict[str, str] = {},
|
86
|
+
config: ImageProcessingConfig = get_default_image_processing_config(),
|
87
|
+
img_bytes_fetcher: Optional[Callable[[str, dict[str, str]], bytes]] = None,
|
88
|
+
) -> Self:
|
89
|
+
if isinstance(url_or_path_or_bytes, bytes):
|
90
|
+
ext = what(url_or_path_or_bytes)
|
91
|
+
if ext is None:
|
92
|
+
raise ValueError(f"Invalid image format: {url_or_path_or_bytes[:8]} ...")
|
93
|
+
if not cls._verify_ext(ext, config["formats"]):
|
94
|
+
raise ValueError(f"Invalid image format: {ext} not in {config['formats']}")
|
95
|
+
return cls.from_bytes(url_or_path_or_bytes, ext=ext)
|
96
|
+
elif maybe_base64 := cls.from_string(url_or_path_or_bytes):
|
97
|
+
return maybe_base64
|
98
|
+
elif maybe_url_or_path := cls.from_url_or_path(
|
99
|
+
url_or_path_or_bytes, headers=headers, config=config, img_bytes_fetcher=img_bytes_fetcher
|
100
|
+
):
|
101
|
+
return maybe_url_or_path
|
102
|
+
else:
|
103
|
+
raise ValueError(f"Invalid image format: {url_or_path_or_bytes}")
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
async def anew(
|
107
|
+
cls,
|
108
|
+
url_or_path_or_bytes: str | bytes,
|
109
|
+
*,
|
110
|
+
headers: dict[str, str] = {},
|
111
|
+
config: ImageProcessingConfig = get_default_image_processing_config(),
|
112
|
+
img_bytes_fetcher: Optional[Callable[[str, dict[str, str]], Awaitable[bytes]]] = None,
|
113
|
+
) -> Self:
|
114
|
+
if isinstance(url_or_path_or_bytes, bytes):
|
115
|
+
ext = what(url_or_path_or_bytes)
|
116
|
+
if ext is None:
|
117
|
+
raise ValueError(f"Invalid image format: {url_or_path_or_bytes[:8]} ...")
|
118
|
+
if not cls._verify_ext(ext, config["formats"]):
|
119
|
+
raise ValueError(f"Invalid image format: {ext} not in {config['formats']}")
|
120
|
+
return cls.from_bytes(url_or_path_or_bytes, ext=ext)
|
121
|
+
elif maybe_base64 := cls.from_string(url_or_path_or_bytes):
|
122
|
+
return maybe_base64
|
123
|
+
elif maybe_url_or_path := await cls.afrom_url_or_path(
|
124
|
+
url_or_path_or_bytes, headers=headers, config=config, img_bytes_fetcher=img_bytes_fetcher
|
125
|
+
):
|
126
|
+
return maybe_url_or_path
|
127
|
+
else:
|
128
|
+
raise ValueError(f"Invalid image format: {url_or_path_or_bytes}")
|
82
129
|
|
83
130
|
@classmethod
|
84
131
|
def from_string(cls, data: str) -> Optional[Self]:
|
85
132
|
match = cls.IMAGE_PATTERN.fullmatch(data)
|
86
133
|
if not match:
|
87
134
|
return None
|
88
|
-
return cls(ext=
|
135
|
+
return cls(ext=_to_image_format(match.group(1)), data=match.group(2))
|
89
136
|
|
90
137
|
@classmethod
|
91
|
-
def from_bytes(cls, data: bytes, ext:
|
92
|
-
return cls(ext=ext, data=b64encode(data).decode("utf-8"))
|
138
|
+
def from_bytes(cls, data: bytes, ext: ExtendedImageFormat) -> Self:
|
139
|
+
return cls(ext=_to_image_format(ext), data=b64encode(data).decode("utf-8"))
|
93
140
|
|
94
141
|
@classmethod
|
95
142
|
def from_url_or_path(
|
@@ -154,7 +201,7 @@ class Base64Image(BaseModel):
|
|
154
201
|
return {"type": "image_url", "image_url": {"url": self.data_uri}}
|
155
202
|
|
156
203
|
@staticmethod
|
157
|
-
def _verify_ext(ext: str, allowed_types: Sequence[
|
204
|
+
def _verify_ext(ext: str, allowed_types: Sequence[ImageFormat]) -> TypeGuard[ImageFormat]:
|
158
205
|
return ext in allowed_types
|
159
206
|
|
160
207
|
@classmethod
|
@@ -226,7 +273,7 @@ class Base64Image(BaseModel):
|
|
226
273
|
# 포맷 제한
|
227
274
|
# PIL이 인식한 포맷이 대문자(JPEG)일 수 있으므로 소문자로
|
228
275
|
pil_format: str = (im.format or "").lower()
|
229
|
-
allowed_formats: Sequence[
|
276
|
+
allowed_formats: Sequence[ImageFormat] = config.get("formats", [])
|
230
277
|
if not cls._verify_ext(pil_format, allowed_formats):
|
231
278
|
logger.error(f"Invalid format: {pil_format} not in {allowed_formats}")
|
232
279
|
return None
|
@@ -265,12 +312,22 @@ class Base64Image(BaseModel):
|
|
265
312
|
return cls(ext=ext, data=b64encode(path.read_bytes()).decode("ascii"))
|
266
313
|
|
267
314
|
|
315
|
+
def _to_image_format(ext: str) -> ImageFormat:
|
316
|
+
lowered = ext.lower()
|
317
|
+
if lowered in ALLOWED_IMAGE_FORMATS:
|
318
|
+
return lowered
|
319
|
+
elif lowered == "jpg":
|
320
|
+
return "jpeg" # jpg -> jpeg
|
321
|
+
else:
|
322
|
+
raise ValueError(f"Invalid image format: {ext}")
|
323
|
+
|
324
|
+
|
268
325
|
def is_remote_url(path: str) -> bool:
|
269
326
|
parsed = urlparse(path)
|
270
327
|
return bool(parsed.scheme and parsed.netloc)
|
271
328
|
|
272
329
|
|
273
|
-
def detect_image_type(image_data: bytes) -> Optional[
|
330
|
+
def detect_image_type(image_data: bytes) -> Optional[ImageFormat]:
|
274
331
|
"""
|
275
332
|
Detect the image format based on the image binary signature (header).
|
276
333
|
Only JPEG, PNG, GIF, WEBP, and BMP are handled as examples.
|
chatterer/utils/imghdr.py
CHANGED
@@ -27,14 +27,11 @@ def decode_prefix(b64_data: str, prefix_bytes: int = 32) -> bytes:
|
|
27
27
|
return base64.b64decode(b64_data)
|
28
28
|
|
29
29
|
|
30
|
-
def what(
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
:return: 이미지 포맷 문자열 (예: "jpeg", "png", "gif", 등) 또는 인식되지 않으면 None.
|
36
|
-
"""
|
37
|
-
h: bytes = decode_prefix(b64_data, prefix_bytes=32)
|
30
|
+
def what(b64_or_bytes: str | bytes, prefix_bytes: int = 32) -> Optional[ImageType]:
|
31
|
+
if isinstance(b64_or_bytes, str):
|
32
|
+
h: bytes = decode_prefix(b64_or_bytes, prefix_bytes=prefix_bytes)
|
33
|
+
else:
|
34
|
+
h = b64_or_bytes
|
38
35
|
|
39
36
|
for tf in tests:
|
40
37
|
res = tf(h)
|
@@ -1,6 +1,6 @@
|
|
1
|
-
chatterer/__init__.py,sha256=
|
1
|
+
chatterer/__init__.py,sha256=D2kh79a3yC7LLTP_4kI_JiKqOFlti_FEkajAJHfMSek,2047
|
2
2
|
chatterer/interactive.py,sha256=bw4iZSPv57x0WPmasnObLM6f_tIgAJD-KbiXjN7NvYw,16702
|
3
|
-
chatterer/language_model.py,sha256=
|
3
|
+
chatterer/language_model.py,sha256=y3nhKiQKTQiFX6M0DWL75OelJwkv3xjSigzjf1icve4,20308
|
4
4
|
chatterer/messages.py,sha256=SIvG9hMHaPG4AFadeRbj5yPnMq2J06fHA4D8jrkz4kQ,458
|
5
5
|
chatterer/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
chatterer/common_types/__init__.py,sha256=Ais0kqzQJj4sED24xdv5azIxHkkj0vXWb4LC41dFkBQ,355
|
@@ -16,9 +16,6 @@ chatterer/examples/snippet.py,sha256=1RrqSr4ZZqTb_jZQjEed0G87oYGdsMSINu5qzrxH8Ps
|
|
16
16
|
chatterer/examples/transcribe.py,sha256=cZoIn8PymlQJcoGKI3PYDlkrYC0juQ6HW5c1GlPk3KY,6650
|
17
17
|
chatterer/examples/upstage.py,sha256=UYehJC13askeJgZS1-aAJniv3KMGt8tQjFBkhfGSoBQ,3130
|
18
18
|
chatterer/examples/web2md.py,sha256=oOSyjdYn4nuv7uMBuNKmRl3PCA7J0rszR4bF9ZUjJRg,3127
|
19
|
-
chatterer/strategies/__init__.py,sha256=oroDpp5ppWralCER5wnf8fgX77dqLAPF0ogmRtRzQfU,208
|
20
|
-
chatterer/strategies/atom_of_thoughts.py,sha256=coXh8ODw7_ig6qcxhdJ5TAiyk70PJbrhElN0J5JD12w,40203
|
21
|
-
chatterer/strategies/base.py,sha256=rqOzTo6S2eQ3A_F9aBXCmVoLM1eT6F2VzZ2Dof330Tk,413
|
22
19
|
chatterer/tools/__init__.py,sha256=cOFo-Aj2xXK_7IvWYRdg6uNomaT9xuSa_mwk2Y0l_AM,1428
|
23
20
|
chatterer/tools/caption_markdown_images.py,sha256=PfvHvr7x0XRLKlujALvOfEB3pQmjzlYFbcJqw2NHgZs,15008
|
24
21
|
chatterer/tools/convert_pdf_to_markdown.py,sha256=hD4JloVdeQ4ZdAmSK1rQO2q-_rXhj3Zo7Hr3sKcGaoI,28264
|
@@ -34,12 +31,12 @@ chatterer/tools/citation_chunking/prompt.py,sha256=so-8uFQ5b2Zq2V5Brfxd76bEnKYkH
|
|
34
31
|
chatterer/tools/citation_chunking/reference.py,sha256=m47XYaB5uFff_x_k7US9hNr-SpZjKnl-GuzsGaQzcZo,893
|
35
32
|
chatterer/tools/citation_chunking/utils.py,sha256=Xytm9lMrS783Po1qWAdEJ8q7Q3l2UMzwHd9EkYTRiwk,6210
|
36
33
|
chatterer/utils/__init__.py,sha256=of2NeLOjsAI79TgA4bL7UggCnAc7xT9eu3eeUBt9K8k,326
|
37
|
-
chatterer/utils/base64_image.py,sha256=
|
34
|
+
chatterer/utils/base64_image.py,sha256=vmkA2LROTUK2SSCjWTT0sBumo7dpNZiRZzgDnAap_D4,13671
|
38
35
|
chatterer/utils/bytesio.py,sha256=QabdJCZsabPaiYVfJcdXzdiHjuhqDlz1vJuLJ60P7TY,2559
|
39
36
|
chatterer/utils/code_agent.py,sha256=z3GYWZbiKByeMQKXKpUo5JVbkt2hkKwWULyzdkgak1c,10258
|
40
|
-
chatterer/utils/imghdr.py,sha256=
|
41
|
-
chatterer-0.1.
|
42
|
-
chatterer-0.1.
|
43
|
-
chatterer-0.1.
|
44
|
-
chatterer-0.1.
|
45
|
-
chatterer-0.1.
|
37
|
+
chatterer/utils/imghdr.py,sha256=xIqodIDEHwSvKJ7OEym9aibPkc0AX6kuasoI9tzMIBk,3542
|
38
|
+
chatterer-0.1.26.dist-info/METADATA,sha256=8V-6-yM6wp2QyHBtTAj5LVp8ygGeqH9O8mnPH4k0QvQ,11273
|
39
|
+
chatterer-0.1.26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
40
|
+
chatterer-0.1.26.dist-info/entry_points.txt,sha256=IzGKhTnZ7G5V23SRmulmSsyt9HcaFH4lU4r3wR1zMsc,63
|
41
|
+
chatterer-0.1.26.dist-info/top_level.txt,sha256=7nSQKP0bHxPRc7HyzdbKsJdkvPgYD0214o6slRizv9s,10
|
42
|
+
chatterer-0.1.26.dist-info/RECORD,,
|
chatterer/strategies/__init__.py
DELETED
@@ -1,975 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import asyncio
|
4
|
-
import logging
|
5
|
-
from dataclasses import dataclass, field
|
6
|
-
from enum import StrEnum
|
7
|
-
from typing import Optional, Type, TypeVar
|
8
|
-
|
9
|
-
from pydantic import BaseModel, Field, ValidationError
|
10
|
-
|
11
|
-
from ..language_model import Chatterer, LanguageModelInput
|
12
|
-
from ..messages import AIMessage, BaseMessage, HumanMessage
|
13
|
-
from .base import BaseStrategy
|
14
|
-
|
15
|
-
# ---------------------------------------------------------------------------------
|
16
|
-
# 0) Enums and Basic Models
|
17
|
-
# ---------------------------------------------------------------------------------
|
18
|
-
|
19
|
-
QA_TEMPLATE = "Q: {question}\nA: {answer}"
|
20
|
-
MAX_DEPTH_REACHED = "Max depth reached in recursive decomposition."
|
21
|
-
UNKNOWN = "Unknown"
|
22
|
-
|
23
|
-
|
24
|
-
class SubQuestionNode(BaseModel):
|
25
|
-
"""A single sub-question node in a decomposition tree."""
|
26
|
-
|
27
|
-
question: str = Field(description="A sub-question string that arises from decomposition.")
|
28
|
-
answer: Optional[str] = Field(description="Answer for this sub-question, if resolved.")
|
29
|
-
depend: list[int] = Field(description="Indices of sub-questions that this node depends on.")
|
30
|
-
|
31
|
-
|
32
|
-
class RecursiveDecomposeResponse(BaseModel):
|
33
|
-
"""The result of a recursive decomposition step."""
|
34
|
-
|
35
|
-
thought: str = Field(description="Reasoning about decomposition.")
|
36
|
-
final_answer: str = Field(description="Best answer to the main question.")
|
37
|
-
sub_questions: list[SubQuestionNode] = Field(description="Root-level sub-questions.")
|
38
|
-
|
39
|
-
|
40
|
-
class ContractQuestionResponse(BaseModel):
|
41
|
-
"""The result of contracting (simplifying) a question."""
|
42
|
-
|
43
|
-
thought: str = Field(description="Reasoning on how the question was compressed.")
|
44
|
-
question: str = Field(description="New, simplified, self-contained question.")
|
45
|
-
|
46
|
-
|
47
|
-
class EnsembleResponse(BaseModel):
|
48
|
-
"""The ensemble process result."""
|
49
|
-
|
50
|
-
thought: str = Field(description="Explanation for choosing the final answer.")
|
51
|
-
answer: str = Field(description="Best final answer after ensemble.")
|
52
|
-
confidence: float = Field(description="Confidence score in [0, 1].")
|
53
|
-
|
54
|
-
def model_post_init(self, __context: object) -> None:
|
55
|
-
self.confidence = max(0.0, min(1.0, self.confidence))
|
56
|
-
|
57
|
-
|
58
|
-
class LabelResponse(BaseModel):
|
59
|
-
"""Used to refine and reorder the sub-questions with corrected dependencies."""
|
60
|
-
|
61
|
-
thought: str = Field(description="Explanation or reasoning about labeling.")
|
62
|
-
sub_questions: list[SubQuestionNode] = Field(
|
63
|
-
description="Refined list of sub-questions with corrected dependencies."
|
64
|
-
)
|
65
|
-
|
66
|
-
|
67
|
-
class CritiqueResponse(BaseModel):
|
68
|
-
"""A response used for LLM to self-critique or question its own correctness."""
|
69
|
-
|
70
|
-
thought: str = Field(description="Critical reflection on correctness.")
|
71
|
-
self_assessment: float = Field(description="Self-assessed confidence in the approach/answer. A float in [0,1].")
|
72
|
-
|
73
|
-
|
74
|
-
# ---------------------------------------------------------------------------------
|
75
|
-
# [NEW] Additional classes to incorporate a separate sub-question devil's advocate
|
76
|
-
# ---------------------------------------------------------------------------------
|
77
|
-
|
78
|
-
|
79
|
-
class DevilsAdvocateResponse(BaseModel):
|
80
|
-
"""
|
81
|
-
A response for a 'devil's advocate' pass.
|
82
|
-
We consider an alternative viewpoint or contradictory answer.
|
83
|
-
"""
|
84
|
-
|
85
|
-
thought: str = Field(description="Reasoning behind the contradictory viewpoint.")
|
86
|
-
final_answer: str = Field(description="Alternative or conflicting answer to challenge the main one.")
|
87
|
-
sub_questions: list[SubQuestionNode] = Field(
|
88
|
-
description="Any additional sub-questions from the contrarian perspective."
|
89
|
-
)
|
90
|
-
|
91
|
-
|
92
|
-
# ---------------------------------------------------------------------------------
|
93
|
-
# 1) Prompter Classes with Multi-Hop + Devil's Advocate
|
94
|
-
# ---------------------------------------------------------------------------------
|
95
|
-
|
96
|
-
|
97
|
-
class AoTPrompter:
|
98
|
-
"""Generic base prompter that defines the required prompt methods."""
|
99
|
-
|
100
|
-
def recursive_decompose_prompt(
|
101
|
-
self, messages: list[BaseMessage], question: str, sub_answers: list[tuple[str, str]]
|
102
|
-
) -> list[BaseMessage]:
|
103
|
-
"""
|
104
|
-
Prompt for main decomposition.
|
105
|
-
Encourages step-by-step reasoning and listing sub-questions as JSON.
|
106
|
-
"""
|
107
|
-
decompose_instructions = (
|
108
|
-
"First, restate the main question.\n"
|
109
|
-
"Decide if sub-questions are needed. If so, list them.\n"
|
110
|
-
"In the 'thought' field, show your chain-of-thought.\n"
|
111
|
-
"Return valid JSON:\n"
|
112
|
-
"{\n"
|
113
|
-
' "thought": "...",\n'
|
114
|
-
' "final_answer": "...",\n'
|
115
|
-
' "sub_questions": [\n'
|
116
|
-
' {"question": "...", "answer": null, "depend": []},\n'
|
117
|
-
" ...\n"
|
118
|
-
" ]\n"
|
119
|
-
"}\n"
|
120
|
-
)
|
121
|
-
|
122
|
-
content_sub_answers = "\n".join(f"Sub-answer so far: Q={q}, A={a}" for q, a in sub_answers)
|
123
|
-
return messages + [
|
124
|
-
HumanMessage(content=f"Main question:\n{question}"),
|
125
|
-
AIMessage(content=content_sub_answers),
|
126
|
-
AIMessage(content=decompose_instructions),
|
127
|
-
]
|
128
|
-
|
129
|
-
def label_prompt(
|
130
|
-
self, messages: list[BaseMessage], question: str, decompose_response: RecursiveDecomposeResponse
|
131
|
-
) -> list[BaseMessage]:
|
132
|
-
"""
|
133
|
-
Prompt for refining the sub-questions and dependencies.
|
134
|
-
"""
|
135
|
-
label_instructions = (
|
136
|
-
"Review each sub-question to ensure correctness and proper ordering.\n"
|
137
|
-
"Return valid JSON in the form:\n"
|
138
|
-
"{\n"
|
139
|
-
' "thought": "...",\n'
|
140
|
-
' "sub_questions": [\n'
|
141
|
-
' {"question": "...", "answer": "...", "depend": [...]},\n'
|
142
|
-
" ...\n"
|
143
|
-
" ]\n"
|
144
|
-
"}\n"
|
145
|
-
)
|
146
|
-
return messages + [
|
147
|
-
AIMessage(content=f"Question: {question}"),
|
148
|
-
AIMessage(content=f"Current sub-questions:\n{decompose_response.sub_questions}"),
|
149
|
-
AIMessage(content=label_instructions),
|
150
|
-
]
|
151
|
-
|
152
|
-
def contract_prompt(self, messages: list[BaseMessage], sub_answers: list[tuple[str, str]]) -> list[BaseMessage]:
|
153
|
-
"""
|
154
|
-
Prompt for merging sub-answers into one self-contained question.
|
155
|
-
"""
|
156
|
-
contract_instructions = (
|
157
|
-
"Please merge sub-answers into a single short question that is fully self-contained.\n"
|
158
|
-
"In 'thought', show how you unify the information.\n"
|
159
|
-
"Then produce JSON:\n"
|
160
|
-
"{\n"
|
161
|
-
' "thought": "...",\n'
|
162
|
-
' "question": "a short but self-contained question"\n'
|
163
|
-
"}\n"
|
164
|
-
)
|
165
|
-
sub_q_content = "\n".join(f"Q: {q}\nA: {a}" for q, a in sub_answers)
|
166
|
-
return messages + [
|
167
|
-
AIMessage(content="We have these sub-questions and answers:"),
|
168
|
-
AIMessage(content=sub_q_content),
|
169
|
-
AIMessage(content=contract_instructions),
|
170
|
-
]
|
171
|
-
|
172
|
-
def contract_direct_prompt(self, messages: list[BaseMessage], contracted_question: str) -> list[BaseMessage]:
|
173
|
-
"""
|
174
|
-
Prompt for directly answering the contracted question thoroughly.
|
175
|
-
"""
|
176
|
-
direct_instructions = (
|
177
|
-
"Answer the simplified question thoroughly. Show your chain-of-thought in 'thought'.\n"
|
178
|
-
"Return JSON:\n"
|
179
|
-
"{\n"
|
180
|
-
' "thought": "...",\n'
|
181
|
-
' "final_answer": "..."\n'
|
182
|
-
"}\n"
|
183
|
-
)
|
184
|
-
return messages + [
|
185
|
-
HumanMessage(content=f"Simplified question: {contracted_question}"),
|
186
|
-
AIMessage(content=direct_instructions),
|
187
|
-
]
|
188
|
-
|
189
|
-
def critique_prompt(self, messages: list[BaseMessage], thought: str, answer: str) -> list[BaseMessage]:
|
190
|
-
"""
|
191
|
-
Prompt for self-critique.
|
192
|
-
"""
|
193
|
-
critique_instructions = (
|
194
|
-
"Critique your own approach. Identify possible errors or leaps.\n"
|
195
|
-
"Return JSON:\n"
|
196
|
-
"{\n"
|
197
|
-
' "thought": "...",\n'
|
198
|
-
' "self_assessment": <float in [0,1]>\n'
|
199
|
-
"}\n"
|
200
|
-
)
|
201
|
-
return messages + [
|
202
|
-
AIMessage(content=f"Your previous THOUGHT:\n{thought}"),
|
203
|
-
AIMessage(content=f"Your previous ANSWER:\n{answer}"),
|
204
|
-
AIMessage(content=critique_instructions),
|
205
|
-
]
|
206
|
-
|
207
|
-
def ensemble_prompt(
|
208
|
-
self, messages: list[BaseMessage], possible_thought_and_answers: list[tuple[str, str]]
|
209
|
-
) -> list[BaseMessage]:
|
210
|
-
"""
|
211
|
-
Show multiple candidate solutions and pick the best final answer with confidence.
|
212
|
-
"""
|
213
|
-
instructions = (
|
214
|
-
"You have multiple candidate solutions. Compare carefully and pick the best.\n"
|
215
|
-
"Return JSON:\n"
|
216
|
-
"{\n"
|
217
|
-
' "thought": "why you chose this final answer",\n'
|
218
|
-
' "answer": "the best consolidated answer",\n'
|
219
|
-
' "confidence": 0.0 ~ 1.0\n'
|
220
|
-
"}\n"
|
221
|
-
)
|
222
|
-
reasonings: list[BaseMessage] = []
|
223
|
-
for idx, (thought, ans) in enumerate(possible_thought_and_answers):
|
224
|
-
reasonings.append(AIMessage(content=f"[Candidate {idx}] Thought:\n{thought}\nAnswer:\n{ans}\n---"))
|
225
|
-
return messages + reasonings + [AIMessage(content=instructions)]
|
226
|
-
|
227
|
-
def devils_advocate_prompt(
|
228
|
-
self, messages: list[BaseMessage], question: str, existing_answer: str
|
229
|
-
) -> list[BaseMessage]:
|
230
|
-
"""
|
231
|
-
Prompt for a devil's advocate approach to contradict or provide an alternative viewpoint.
|
232
|
-
"""
|
233
|
-
instructions = (
|
234
|
-
"Act as a devil's advocate. Suppose the existing answer is incomplete or incorrect.\n"
|
235
|
-
"Challenge it, find alternative ways or details. Provide a new 'final_answer' (even if contradictory).\n"
|
236
|
-
"Return JSON in the same shape as RecursiveDecomposeResponse OR a dedicated structure.\n"
|
237
|
-
"But here, let's keep it in a new dedicated structure:\n"
|
238
|
-
"{\n"
|
239
|
-
' "thought": "...",\n'
|
240
|
-
' "final_answer": "...",\n'
|
241
|
-
' "sub_questions": [\n'
|
242
|
-
' {"question": "...", "answer": null, "depend": []},\n'
|
243
|
-
" ...\n"
|
244
|
-
" ]\n"
|
245
|
-
"}\n"
|
246
|
-
)
|
247
|
-
return messages + [
|
248
|
-
AIMessage(content=(f"Current question: {question}\nExisting answer to challenge: {existing_answer}\n")),
|
249
|
-
AIMessage(content=instructions),
|
250
|
-
]
|
251
|
-
|
252
|
-
|
253
|
-
# ---------------------------------------------------------------------------------
|
254
|
-
# 2) Strict Typed Steps for Pipeline
|
255
|
-
# ---------------------------------------------------------------------------------
|
256
|
-
|
257
|
-
|
258
|
-
class StepName(StrEnum):
|
259
|
-
"""Enum for step names in the pipeline."""
|
260
|
-
|
261
|
-
DOMAIN_DETECTION = "DomainDetection"
|
262
|
-
DECOMPOSITION = "Decomposition"
|
263
|
-
DECOMPOSITION_CRITIQUE = "DecompositionCritique"
|
264
|
-
CONTRACTED_QUESTION = "ContractedQuestion"
|
265
|
-
CONTRACTED_DIRECT_ANSWER = "ContractedDirectAnswer"
|
266
|
-
CONTRACT_CRITIQUE = "ContractCritique"
|
267
|
-
BEST_APPROACH_DECISION = "BestApproachDecision"
|
268
|
-
ENSEMBLE = "Ensemble"
|
269
|
-
FINAL_ANSWER = "FinalAnswer"
|
270
|
-
|
271
|
-
DEVILS_ADVOCATE = "DevilsAdvocate"
|
272
|
-
DEVILS_ADVOCATE_CRITIQUE = "DevilsAdvocateCritique"
|
273
|
-
|
274
|
-
|
275
|
-
class StepRelation(StrEnum):
|
276
|
-
"""Enum for relationship types in the reasoning graph."""
|
277
|
-
|
278
|
-
CRITIQUES = "CRITIQUES"
|
279
|
-
SELECTS = "SELECTS"
|
280
|
-
RESULT_OF = "RESULT_OF"
|
281
|
-
SPLIT_INTO = "SPLIT_INTO"
|
282
|
-
DEPEND_ON = "DEPEND_ON"
|
283
|
-
PRECEDES = "PRECEDES"
|
284
|
-
DECOMPOSED_BY = "DECOMPOSED_BY"
|
285
|
-
|
286
|
-
|
287
|
-
class StepRecord(BaseModel):
|
288
|
-
"""A typed record for each pipeline step."""
|
289
|
-
|
290
|
-
step_name: StepName
|
291
|
-
domain: Optional[str] = None
|
292
|
-
score: Optional[float] = None
|
293
|
-
used: Optional[StepName] = None
|
294
|
-
sub_questions: Optional[list[SubQuestionNode]] = None
|
295
|
-
parent_decomp_step_idx: Optional[int] = None
|
296
|
-
parent_subq_idx: Optional[int] = None
|
297
|
-
question: Optional[str] = None
|
298
|
-
thought: Optional[str] = None
|
299
|
-
answer: Optional[str] = None
|
300
|
-
|
301
|
-
def as_properties(self) -> dict[str, str | float | int | None]:
|
302
|
-
"""Converts the StepRecord to a dictionary of properties."""
|
303
|
-
result: dict[str, str | float | int | None] = {}
|
304
|
-
if self.score is not None:
|
305
|
-
result["score"] = self.score
|
306
|
-
if self.domain:
|
307
|
-
result["domain"] = self.domain
|
308
|
-
if self.question:
|
309
|
-
result["question"] = self.question
|
310
|
-
if self.thought:
|
311
|
-
result["thought"] = self.thought
|
312
|
-
if self.answer:
|
313
|
-
result["answer"] = self.answer
|
314
|
-
return result
|
315
|
-
|
316
|
-
|
317
|
-
# ---------------------------------------------------------------------------------
|
318
|
-
# 3) Logging Setup
|
319
|
-
# ---------------------------------------------------------------------------------
|
320
|
-
|
321
|
-
|
322
|
-
class SimpleColorFormatter(logging.Formatter):
|
323
|
-
"""Simple color-coded logging formatter for console output using ANSI escape codes."""
|
324
|
-
|
325
|
-
BLUE = "\033[94m"
|
326
|
-
GREEN = "\033[92m"
|
327
|
-
YELLOW = "\033[93m"
|
328
|
-
RED = "\033[91m"
|
329
|
-
RESET = "\033[0m"
|
330
|
-
LEVEL_COLORS = {
|
331
|
-
logging.DEBUG: BLUE,
|
332
|
-
logging.INFO: GREEN,
|
333
|
-
logging.WARNING: YELLOW,
|
334
|
-
logging.ERROR: RED,
|
335
|
-
logging.CRITICAL: RED,
|
336
|
-
}
|
337
|
-
|
338
|
-
def format(self, record: logging.LogRecord) -> str:
|
339
|
-
log_color = self.LEVEL_COLORS.get(record.levelno, self.RESET)
|
340
|
-
message = super().format(record)
|
341
|
-
return f"{log_color}{message}{self.RESET}"
|
342
|
-
|
343
|
-
|
344
|
-
logger = logging.getLogger("AoT")
|
345
|
-
logger.setLevel(logging.INFO)
|
346
|
-
handler = logging.StreamHandler()
|
347
|
-
handler.setFormatter(SimpleColorFormatter("%(levelname)s: %(message)s"))
|
348
|
-
logger.handlers = [handler]
|
349
|
-
logger.propagate = False
|
350
|
-
|
351
|
-
|
352
|
-
# ---------------------------------------------------------------------------------
|
353
|
-
# 4) The AoTPipeline Class (now with recursive devil's advocate at each sub-question)
|
354
|
-
# ---------------------------------------------------------------------------------
|
355
|
-
|
356
|
-
T = TypeVar(
|
357
|
-
"T",
|
358
|
-
bound=EnsembleResponse
|
359
|
-
| ContractQuestionResponse
|
360
|
-
| LabelResponse
|
361
|
-
| CritiqueResponse
|
362
|
-
| RecursiveDecomposeResponse
|
363
|
-
| DevilsAdvocateResponse,
|
364
|
-
)
|
365
|
-
|
366
|
-
|
367
|
-
@dataclass
|
368
|
-
class AoTPipeline:
|
369
|
-
"""
|
370
|
-
The pipeline orchestrates:
|
371
|
-
1) Recursive decomposition
|
372
|
-
2) For each sub-question, it tries a main approach + a devil's advocate approach
|
373
|
-
3) Merges sub-answers using an ensemble
|
374
|
-
4) Contracts the question
|
375
|
-
5) Possibly does a direct approach on the contracted question
|
376
|
-
6) Ensembling the final answers
|
377
|
-
"""
|
378
|
-
|
379
|
-
chatterer: Chatterer
|
380
|
-
max_depth: int = 2
|
381
|
-
max_retries: int = 2
|
382
|
-
steps_history: list[StepRecord] = field(default_factory=list[StepRecord])
|
383
|
-
prompter: AoTPrompter = field(default_factory=AoTPrompter)
|
384
|
-
|
385
|
-
# 4.1) Utility for calling the LLM with Pydantic parsing
|
386
|
-
async def _ainvoke_pydantic(
|
387
|
-
self,
|
388
|
-
messages: list[BaseMessage],
|
389
|
-
model_cls: Type[T],
|
390
|
-
fallback: str = "<None>",
|
391
|
-
) -> T:
|
392
|
-
"""
|
393
|
-
Attempts up to max_retries to parse the model_cls from LLM output as JSON.
|
394
|
-
"""
|
395
|
-
for attempt in range(1, self.max_retries + 1):
|
396
|
-
try:
|
397
|
-
return await self.chatterer.agenerate_pydantic(response_model=model_cls, messages=messages)
|
398
|
-
except ValidationError as e:
|
399
|
-
logger.warning(f"ValidationError on attempt {attempt} for {model_cls.__name__}: {e}")
|
400
|
-
if attempt == self.max_retries:
|
401
|
-
# Return a fallback version
|
402
|
-
if issubclass(model_cls, EnsembleResponse):
|
403
|
-
return model_cls(thought=fallback, answer=fallback, confidence=0.0) # type: ignore
|
404
|
-
elif issubclass(model_cls, ContractQuestionResponse):
|
405
|
-
return model_cls(thought=fallback, question=fallback) # type: ignore
|
406
|
-
elif issubclass(model_cls, LabelResponse):
|
407
|
-
return model_cls(thought=fallback, sub_questions=[]) # type: ignore
|
408
|
-
elif issubclass(model_cls, CritiqueResponse):
|
409
|
-
return model_cls(thought=fallback, self_assessment=0.0) # type: ignore
|
410
|
-
elif issubclass(model_cls, DevilsAdvocateResponse):
|
411
|
-
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
412
|
-
else:
|
413
|
-
return model_cls(thought=fallback, final_answer=fallback, sub_questions=[]) # type: ignore
|
414
|
-
# theoretically unreachable
|
415
|
-
raise RuntimeError("Unexpected error in _ainvoke_pydantic")
|
416
|
-
|
417
|
-
# 4.2) Helper method for self-critique
|
418
|
-
async def _ainvoke_critique(
|
419
|
-
self,
|
420
|
-
messages: list[BaseMessage],
|
421
|
-
thought: str,
|
422
|
-
answer: str,
|
423
|
-
) -> CritiqueResponse:
|
424
|
-
"""
|
425
|
-
Instructs the LLM to critique the given thought & answer, returning CritiqueResponse.
|
426
|
-
"""
|
427
|
-
return await self._ainvoke_pydantic(
|
428
|
-
messages=self.prompter.critique_prompt(messages=messages, thought=thought, answer=answer),
|
429
|
-
model_cls=CritiqueResponse,
|
430
|
-
)
|
431
|
-
|
432
|
-
# 4.3) Helper method for devil's advocate approach
|
433
|
-
async def _ainvoke_devils_advocate(
|
434
|
-
self,
|
435
|
-
messages: list[BaseMessage],
|
436
|
-
question: str,
|
437
|
-
existing_answer: str,
|
438
|
-
) -> DevilsAdvocateResponse:
|
439
|
-
"""
|
440
|
-
Instructs the LLM to challenge an existing answer with a devil's advocate approach.
|
441
|
-
"""
|
442
|
-
return await self._ainvoke_pydantic(
|
443
|
-
messages=self.prompter.devils_advocate_prompt(messages, question=question, existing_answer=existing_answer),
|
444
|
-
model_cls=DevilsAdvocateResponse,
|
445
|
-
)
|
446
|
-
|
447
|
-
# 4.4) The main function that recursively decomposes a question and calls sub-steps
|
448
|
-
async def _arecursive_decompose_question(
|
449
|
-
self,
|
450
|
-
messages: list[BaseMessage],
|
451
|
-
question: str,
|
452
|
-
depth: int,
|
453
|
-
parent_decomp_step_idx: Optional[int] = None,
|
454
|
-
parent_subq_idx: Optional[int] = None,
|
455
|
-
) -> RecursiveDecomposeResponse:
|
456
|
-
"""
|
457
|
-
Recursively decompose the given question. For each sub-question:
|
458
|
-
1) Recursively decompose that sub-question if we still have depth left
|
459
|
-
2) After getting a main sub-answer, do a devil's advocate pass
|
460
|
-
3) Combine main sub-answer + devil's advocate alternative via an ensemble
|
461
|
-
"""
|
462
|
-
if depth < 0:
|
463
|
-
logger.info("Max depth reached, returning unknown.")
|
464
|
-
return RecursiveDecomposeResponse(thought=MAX_DEPTH_REACHED, final_answer=UNKNOWN, sub_questions=[])
|
465
|
-
|
466
|
-
# Step 1: Perform the decomposition
|
467
|
-
decompose_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
468
|
-
messages=self.prompter.recursive_decompose_prompt(messages=messages, question=question, sub_answers=[]),
|
469
|
-
model_cls=RecursiveDecomposeResponse,
|
470
|
-
)
|
471
|
-
|
472
|
-
# Step 2: Label / refine sub-questions (dependencies, ordering)
|
473
|
-
if decompose_resp.sub_questions:
|
474
|
-
label_resp: LabelResponse = await self._ainvoke_pydantic(
|
475
|
-
messages=self.prompter.label_prompt(messages, question, decompose_resp),
|
476
|
-
model_cls=LabelResponse,
|
477
|
-
)
|
478
|
-
decompose_resp.sub_questions = label_resp.sub_questions
|
479
|
-
|
480
|
-
# Save a pipeline record for this decomposition step
|
481
|
-
current_decomp_step_idx = self._record_decomposition_step(
|
482
|
-
question=question,
|
483
|
-
final_answer=decompose_resp.final_answer,
|
484
|
-
sub_questions=decompose_resp.sub_questions,
|
485
|
-
parent_decomp_step_idx=parent_decomp_step_idx,
|
486
|
-
parent_subq_idx=parent_subq_idx,
|
487
|
-
)
|
488
|
-
|
489
|
-
# Step 3: If sub-questions exist and depth remains, solve them + do devil's advocate
|
490
|
-
if depth > 0 and decompose_resp.sub_questions:
|
491
|
-
solved_subs: list[SubQuestionNode] = await self._aresolve_sub_questions(
|
492
|
-
messages=messages,
|
493
|
-
sub_questions=decompose_resp.sub_questions,
|
494
|
-
depth=depth,
|
495
|
-
parent_decomp_step_idx=current_decomp_step_idx,
|
496
|
-
)
|
497
|
-
# Then we can refine the "final_answer" from those sub-answers
|
498
|
-
# or we do a secondary pass to refine the final answer
|
499
|
-
refined_prompt = self.prompter.recursive_decompose_prompt(
|
500
|
-
messages=messages,
|
501
|
-
question=question,
|
502
|
-
sub_answers=[(sq.question, sq.answer or UNKNOWN) for sq in solved_subs],
|
503
|
-
)
|
504
|
-
refined_resp: RecursiveDecomposeResponse = await self._ainvoke_pydantic(
|
505
|
-
refined_prompt, RecursiveDecomposeResponse
|
506
|
-
)
|
507
|
-
decompose_resp.final_answer = refined_resp.final_answer
|
508
|
-
decompose_resp.sub_questions = solved_subs
|
509
|
-
|
510
|
-
# Update pipeline record
|
511
|
-
self.steps_history[current_decomp_step_idx].answer = refined_resp.final_answer
|
512
|
-
self.steps_history[current_decomp_step_idx].sub_questions = solved_subs
|
513
|
-
|
514
|
-
return decompose_resp
|
515
|
-
|
516
|
-
def _record_decomposition_step(
|
517
|
-
self,
|
518
|
-
question: str,
|
519
|
-
final_answer: str,
|
520
|
-
sub_questions: list[SubQuestionNode],
|
521
|
-
parent_decomp_step_idx: Optional[int],
|
522
|
-
parent_subq_idx: Optional[int],
|
523
|
-
) -> int:
|
524
|
-
"""
|
525
|
-
Save the decomposition step in steps_history, returning the index.
|
526
|
-
"""
|
527
|
-
step_record = StepRecord(
|
528
|
-
step_name=StepName.DECOMPOSITION,
|
529
|
-
question=question,
|
530
|
-
answer=final_answer,
|
531
|
-
sub_questions=sub_questions,
|
532
|
-
parent_decomp_step_idx=parent_decomp_step_idx,
|
533
|
-
parent_subq_idx=parent_subq_idx,
|
534
|
-
)
|
535
|
-
self.steps_history.append(step_record)
|
536
|
-
return len(self.steps_history) - 1
|
537
|
-
|
538
|
-
async def _aresolve_sub_questions(
|
539
|
-
self,
|
540
|
-
messages: list[BaseMessage],
|
541
|
-
sub_questions: list[SubQuestionNode],
|
542
|
-
depth: int,
|
543
|
-
parent_decomp_step_idx: Optional[int],
|
544
|
-
) -> list[SubQuestionNode]:
|
545
|
-
"""
|
546
|
-
Resolve sub-questions in topological order.
|
547
|
-
For each sub-question:
|
548
|
-
1) Recursively decompose (main approach).
|
549
|
-
2) Acquire a devil's advocate alternative.
|
550
|
-
3) Critique or ensemble if needed.
|
551
|
-
4) Finalize sub-question answer.
|
552
|
-
"""
|
553
|
-
n = len(sub_questions)
|
554
|
-
in_degree = [0] * n
|
555
|
-
graph: list[list[int]] = [[] for _ in range(n)]
|
556
|
-
for i, sq in enumerate(sub_questions):
|
557
|
-
for dep in sq.depend:
|
558
|
-
if 0 <= dep < n:
|
559
|
-
in_degree[i] += 1
|
560
|
-
graph[dep].append(i)
|
561
|
-
|
562
|
-
# Kahn's algorithm for topological order
|
563
|
-
queue = [i for i in range(n) if in_degree[i] == 0]
|
564
|
-
topo_order: list[int] = []
|
565
|
-
|
566
|
-
while queue:
|
567
|
-
node = queue.pop(0)
|
568
|
-
topo_order.append(node)
|
569
|
-
for nxt in graph[node]:
|
570
|
-
in_degree[nxt] -= 1
|
571
|
-
if in_degree[nxt] == 0:
|
572
|
-
queue.append(nxt)
|
573
|
-
|
574
|
-
# We'll store the resolved sub-questions
|
575
|
-
final_subs: dict[int, SubQuestionNode] = {}
|
576
|
-
|
577
|
-
async def _resolve_one_subq(idx: int):
|
578
|
-
sq = sub_questions[idx]
|
579
|
-
# 1) Main approach
|
580
|
-
main_resp = await self._arecursive_decompose_question(
|
581
|
-
messages=messages,
|
582
|
-
question=sq.question,
|
583
|
-
depth=depth - 1,
|
584
|
-
parent_decomp_step_idx=parent_decomp_step_idx,
|
585
|
-
parent_subq_idx=idx,
|
586
|
-
)
|
587
|
-
|
588
|
-
main_answer = main_resp.final_answer
|
589
|
-
|
590
|
-
# 2) Devil's Advocate approach
|
591
|
-
devils_resp = await self._ainvoke_devils_advocate(
|
592
|
-
messages=messages, question=sq.question, existing_answer=main_answer
|
593
|
-
)
|
594
|
-
# 3) Ensemble to combine main_answer + devils_alternative
|
595
|
-
ensemble_sub = await self._ainvoke_pydantic(
|
596
|
-
self.prompter.ensemble_prompt(
|
597
|
-
messages=messages,
|
598
|
-
possible_thought_and_answers=[
|
599
|
-
(main_resp.thought, main_answer),
|
600
|
-
(devils_resp.thought, devils_resp.final_answer),
|
601
|
-
],
|
602
|
-
),
|
603
|
-
EnsembleResponse,
|
604
|
-
)
|
605
|
-
sub_best_answer = ensemble_sub.answer
|
606
|
-
|
607
|
-
# Store final subq answer
|
608
|
-
sq.answer = sub_best_answer
|
609
|
-
final_subs[idx] = sq
|
610
|
-
|
611
|
-
# Record pipeline steps for devil's advocate
|
612
|
-
self.steps_history.append(
|
613
|
-
StepRecord(
|
614
|
-
step_name=StepName.DEVILS_ADVOCATE,
|
615
|
-
question=sq.question,
|
616
|
-
answer=devils_resp.final_answer,
|
617
|
-
thought=devils_resp.thought,
|
618
|
-
sub_questions=devils_resp.sub_questions,
|
619
|
-
)
|
620
|
-
)
|
621
|
-
# Possibly critique the devils advocate result
|
622
|
-
dev_adv_crit = await self._ainvoke_critique(
|
623
|
-
messages=messages, thought=devils_resp.thought, answer=devils_resp.final_answer
|
624
|
-
)
|
625
|
-
self.steps_history.append(
|
626
|
-
StepRecord(
|
627
|
-
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
628
|
-
thought=dev_adv_crit.thought,
|
629
|
-
score=dev_adv_crit.self_assessment,
|
630
|
-
)
|
631
|
-
)
|
632
|
-
|
633
|
-
# Solve sub-questions in topological order
|
634
|
-
tasks = [_resolve_one_subq(i) for i in topo_order]
|
635
|
-
await asyncio.gather(*tasks, return_exceptions=False)
|
636
|
-
|
637
|
-
return [final_subs[i] for i in range(n)]
|
638
|
-
|
639
|
-
# 4.5) The primary pipeline method
|
640
|
-
async def arun_pipeline(self, messages: list[BaseMessage]) -> str:
|
641
|
-
"""
|
642
|
-
Execute the pipeline:
|
643
|
-
1) Decompose the main question (recursively).
|
644
|
-
2) Self-critique.
|
645
|
-
3) Provide a devil's advocate approach on the entire main result.
|
646
|
-
4) Contract sub-answers (optional).
|
647
|
-
5) Directly solve the contracted question.
|
648
|
-
6) Self-critique again.
|
649
|
-
7) Final ensemble across main vs devil's vs contracted direct answer.
|
650
|
-
8) Return final answer.
|
651
|
-
"""
|
652
|
-
self.steps_history.clear()
|
653
|
-
|
654
|
-
original_question: str = messages[-1].text()
|
655
|
-
# 1) Recursive decomposition
|
656
|
-
decomp_resp = await self._arecursive_decompose_question(
|
657
|
-
messages=messages,
|
658
|
-
question=original_question,
|
659
|
-
depth=self.max_depth,
|
660
|
-
)
|
661
|
-
logger.info(f"[Main Decomposition] final_answer={decomp_resp.final_answer}")
|
662
|
-
|
663
|
-
# 2) Self-critique of main decomposition
|
664
|
-
decomp_critique = await self._ainvoke_critique(
|
665
|
-
messages=messages, thought=decomp_resp.thought, answer=decomp_resp.final_answer
|
666
|
-
)
|
667
|
-
self.steps_history.append(
|
668
|
-
StepRecord(
|
669
|
-
step_name=StepName.DECOMPOSITION_CRITIQUE,
|
670
|
-
thought=decomp_critique.thought,
|
671
|
-
score=decomp_critique.self_assessment,
|
672
|
-
)
|
673
|
-
)
|
674
|
-
|
675
|
-
# 3) Devil's advocate on the entire main answer
|
676
|
-
devils_on_main = await self._ainvoke_devils_advocate(
|
677
|
-
messages=messages, question=original_question, existing_answer=decomp_resp.final_answer
|
678
|
-
)
|
679
|
-
self.steps_history.append(
|
680
|
-
StepRecord(
|
681
|
-
step_name=StepName.DEVILS_ADVOCATE,
|
682
|
-
question=original_question,
|
683
|
-
answer=devils_on_main.final_answer,
|
684
|
-
thought=devils_on_main.thought,
|
685
|
-
sub_questions=devils_on_main.sub_questions,
|
686
|
-
)
|
687
|
-
)
|
688
|
-
devils_crit_main = await self._ainvoke_critique(
|
689
|
-
messages=messages, thought=devils_on_main.thought, answer=devils_on_main.final_answer
|
690
|
-
)
|
691
|
-
self.steps_history.append(
|
692
|
-
StepRecord(
|
693
|
-
step_name=StepName.DEVILS_ADVOCATE_CRITIQUE,
|
694
|
-
thought=devils_crit_main.thought,
|
695
|
-
score=devils_crit_main.self_assessment,
|
696
|
-
)
|
697
|
-
)
|
698
|
-
|
699
|
-
# 4) Contract sub-answers from main decomposition
|
700
|
-
top_decomp_record: Optional[StepRecord] = next(
|
701
|
-
(
|
702
|
-
s
|
703
|
-
for s in reversed(self.steps_history)
|
704
|
-
if s.step_name == StepName.DECOMPOSITION and s.parent_decomp_step_idx is None
|
705
|
-
),
|
706
|
-
None,
|
707
|
-
)
|
708
|
-
if top_decomp_record and top_decomp_record.sub_questions:
|
709
|
-
sub_answers = [(sq.question, sq.answer or UNKNOWN) for sq in top_decomp_record.sub_questions]
|
710
|
-
else:
|
711
|
-
sub_answers = []
|
712
|
-
|
713
|
-
contract_resp = await self._ainvoke_pydantic(
|
714
|
-
messages=self.prompter.contract_prompt(messages, sub_answers),
|
715
|
-
model_cls=ContractQuestionResponse,
|
716
|
-
)
|
717
|
-
contracted_question = contract_resp.question
|
718
|
-
self.steps_history.append(
|
719
|
-
StepRecord(
|
720
|
-
step_name=StepName.CONTRACTED_QUESTION, question=contracted_question, thought=contract_resp.thought
|
721
|
-
)
|
722
|
-
)
|
723
|
-
|
724
|
-
# 5) Attempt direct approach on contracted question
|
725
|
-
contracted_direct = await self._ainvoke_pydantic(
|
726
|
-
messages=self.prompter.contract_direct_prompt(messages, contracted_question),
|
727
|
-
model_cls=RecursiveDecomposeResponse,
|
728
|
-
fallback="No Contracted Direct Answer",
|
729
|
-
)
|
730
|
-
self.steps_history.append(
|
731
|
-
StepRecord(
|
732
|
-
step_name=StepName.CONTRACTED_DIRECT_ANSWER,
|
733
|
-
answer=contracted_direct.final_answer,
|
734
|
-
thought=contracted_direct.thought,
|
735
|
-
)
|
736
|
-
)
|
737
|
-
logger.info(f"[Contracted Direct] final_answer={contracted_direct.final_answer}")
|
738
|
-
|
739
|
-
# 5.1) Critique the contracted direct approach
|
740
|
-
contract_critique = await self._ainvoke_critique(
|
741
|
-
messages=messages, thought=contracted_direct.thought, answer=contracted_direct.final_answer
|
742
|
-
)
|
743
|
-
self.steps_history.append(
|
744
|
-
StepRecord(
|
745
|
-
step_name=StepName.CONTRACT_CRITIQUE,
|
746
|
-
thought=contract_critique.thought,
|
747
|
-
score=contract_critique.self_assessment,
|
748
|
-
)
|
749
|
-
)
|
750
|
-
|
751
|
-
# 6) Ensemble of (Main decomposition, Devil's advocate on main, Contracted direct)
|
752
|
-
ensemble_resp = await self._ainvoke_pydantic(
|
753
|
-
self.prompter.ensemble_prompt(
|
754
|
-
messages=messages,
|
755
|
-
possible_thought_and_answers=[
|
756
|
-
(decomp_resp.thought, decomp_resp.final_answer),
|
757
|
-
(devils_on_main.thought, devils_on_main.final_answer),
|
758
|
-
(contracted_direct.thought, contracted_direct.final_answer),
|
759
|
-
],
|
760
|
-
),
|
761
|
-
EnsembleResponse,
|
762
|
-
)
|
763
|
-
best_approach_answer = ensemble_resp.answer
|
764
|
-
approach_used = StepName.ENSEMBLE
|
765
|
-
self.steps_history.append(StepRecord(step_name=StepName.BEST_APPROACH_DECISION, used=approach_used))
|
766
|
-
logger.info(f"[Best Approach Decision] => {approach_used}")
|
767
|
-
|
768
|
-
# 7) Final answer
|
769
|
-
self.steps_history.append(
|
770
|
-
StepRecord(step_name=StepName.FINAL_ANSWER, answer=best_approach_answer, score=ensemble_resp.confidence)
|
771
|
-
)
|
772
|
-
logger.info(f"[Final Answer] => {best_approach_answer}")
|
773
|
-
|
774
|
-
return best_approach_answer
|
775
|
-
|
776
|
-
def run_pipeline(self, messages: list[BaseMessage]) -> str:
|
777
|
-
"""Synchronous wrapper around arun_pipeline."""
|
778
|
-
return asyncio.run(self.arun_pipeline(messages))
|
779
|
-
|
780
|
-
# ---------------------------------------------------------------------------------
|
781
|
-
# 4.6) Build or export a reasoning graph
|
782
|
-
# ---------------------------------------------------------------------------------
|
783
|
-
|
784
|
-
# def get_reasoning_graph(self, global_id_prefix: str = "AoT"):
|
785
|
-
# """
|
786
|
-
# Constructs a Graph object (from hypothetical `neo4j_extension`)
|
787
|
-
# capturing the pipeline steps, including devil's advocate steps.
|
788
|
-
# """
|
789
|
-
# from neo4j_extension import Graph, Node, Relationship
|
790
|
-
|
791
|
-
# g = Graph()
|
792
|
-
# step_nodes: dict[int, Node] = {}
|
793
|
-
# subq_nodes: dict[str, Node] = {}
|
794
|
-
|
795
|
-
# # Step A: Create nodes for each pipeline step
|
796
|
-
# for i, record in enumerate(self.steps_history):
|
797
|
-
# # We'll skip nested Decomposition steps only if we want to flatten them.
|
798
|
-
# # But let's keep them for clarity.
|
799
|
-
# step_node = Node(
|
800
|
-
# properties=record.as_properties(), labels={record.step_name}, globalId=f"{global_id_prefix}_step_{i}"
|
801
|
-
# )
|
802
|
-
# g.add_node(step_node)
|
803
|
-
# step_nodes[i] = step_node
|
804
|
-
|
805
|
-
# # Step B: Collect sub-questions from each DECOMPOSITION or DEVILS_ADVOCATE
|
806
|
-
# all_sub_questions: dict[str, tuple[int, int, SubQuestionNode]] = {}
|
807
|
-
# for i, record in enumerate(self.steps_history):
|
808
|
-
# if record.sub_questions:
|
809
|
-
# for sq_idx, sq in enumerate(record.sub_questions):
|
810
|
-
# sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
811
|
-
# all_sub_questions[sq_id] = (i, sq_idx, sq)
|
812
|
-
|
813
|
-
# for sq_id, (i, sq_idx, sq) in all_sub_questions.items():
|
814
|
-
# n_subq = Node(
|
815
|
-
# properties={
|
816
|
-
# "question": sq.question,
|
817
|
-
# "answer": sq.answer or "",
|
818
|
-
# },
|
819
|
-
# labels={"SubQuestion"},
|
820
|
-
# globalId=sq_id,
|
821
|
-
# )
|
822
|
-
# g.add_node(n_subq)
|
823
|
-
# subq_nodes[sq_id] = n_subq
|
824
|
-
|
825
|
-
# # Step C: Add relationships. We do a simple approach:
|
826
|
-
# # - If StepRecord is DECOMPOSITION or DEVILS_ADVOCATE with sub_questions, link them via SPLIT_INTO.
|
827
|
-
# for i, record in enumerate(self.steps_history):
|
828
|
-
# if record.sub_questions:
|
829
|
-
# start_node = step_nodes[i]
|
830
|
-
# for sq_idx, sq in enumerate(record.sub_questions):
|
831
|
-
# sq_id = f"{global_id_prefix}_decomp_{i}_sub_{sq_idx}"
|
832
|
-
# end_node = subq_nodes[sq_id]
|
833
|
-
# rel = Relationship(
|
834
|
-
# properties={},
|
835
|
-
# rel_type=StepRelation.SPLIT_INTO,
|
836
|
-
# start_node=start_node,
|
837
|
-
# end_node=end_node,
|
838
|
-
# globalId=f"{global_id_prefix}_split_{i}_{sq_idx}",
|
839
|
-
# )
|
840
|
-
# g.add_relationship(rel)
|
841
|
-
# # Also add sub-question dependencies
|
842
|
-
# for dep in sq.depend:
|
843
|
-
# # The same record i -> sub-question subq
|
844
|
-
# if 0 <= dep < len(record.sub_questions):
|
845
|
-
# dep_id = f"{global_id_prefix}_decomp_{i}_sub_{dep}"
|
846
|
-
# if dep_id in subq_nodes:
|
847
|
-
# dep_node = subq_nodes[dep_id]
|
848
|
-
# rel_dep = Relationship(
|
849
|
-
# properties={},
|
850
|
-
# rel_type=StepRelation.DEPEND_ON,
|
851
|
-
# start_node=end_node,
|
852
|
-
# end_node=dep_node,
|
853
|
-
# globalId=f"{global_id_prefix}_dep_{i}_q_{sq_idx}_on_{dep}",
|
854
|
-
# )
|
855
|
-
# g.add_relationship(rel_dep)
|
856
|
-
|
857
|
-
# # Step D: We add PRECEDES relationships in a linear chain for the pipeline steps
|
858
|
-
# for i in range(len(self.steps_history) - 1):
|
859
|
-
# start_node = step_nodes[i]
|
860
|
-
# end_node = step_nodes[i + 1]
|
861
|
-
# rel = Relationship(
|
862
|
-
# properties={},
|
863
|
-
# rel_type=StepRelation.PRECEDES,
|
864
|
-
# start_node=start_node,
|
865
|
-
# end_node=end_node,
|
866
|
-
# globalId=f"{global_id_prefix}_precede_{i}_to_{i + 1}",
|
867
|
-
# )
|
868
|
-
# g.add_relationship(rel)
|
869
|
-
|
870
|
-
# # Step E: CRITIQUES, SELECTS, RESULT_OF can be similarly added:
|
871
|
-
# # We'll do a simple pass:
|
872
|
-
# # If step_name ends with CRITIQUE => it critiques the step before it
|
873
|
-
# for i, record in enumerate(self.steps_history):
|
874
|
-
# if "CRITIQUE" in record.step_name:
|
875
|
-
# # Let it point to the preceding step
|
876
|
-
# if i > 0:
|
877
|
-
# start_node = step_nodes[i]
|
878
|
-
# end_node = step_nodes[i - 1]
|
879
|
-
# rel = Relationship(
|
880
|
-
# properties={},
|
881
|
-
# rel_type=StepRelation.CRITIQUES,
|
882
|
-
# start_node=start_node,
|
883
|
-
# end_node=end_node,
|
884
|
-
# globalId=f"{global_id_prefix}_crit_{i}",
|
885
|
-
# )
|
886
|
-
# g.add_relationship(rel)
|
887
|
-
|
888
|
-
# # If there's a BEST_APPROACH_DECISION step, link it to the step it uses
|
889
|
-
# best_decision_idx = None
|
890
|
-
# used_step_idx = None
|
891
|
-
# for i, record in enumerate(self.steps_history):
|
892
|
-
# if record.step_name == StepName.BEST_APPROACH_DECISION and record.used:
|
893
|
-
# best_decision_idx = i
|
894
|
-
# # find the step with that name
|
895
|
-
# used_step_idx = next((j for j in step_nodes if self.steps_history[j].step_name == record.used), None)
|
896
|
-
# if used_step_idx is not None:
|
897
|
-
# rel = Relationship(
|
898
|
-
# properties={},
|
899
|
-
# rel_type=StepRelation.SELECTS,
|
900
|
-
# start_node=step_nodes[i],
|
901
|
-
# end_node=step_nodes[used_step_idx],
|
902
|
-
# globalId=f"{global_id_prefix}_select_{i}_use_{used_step_idx}",
|
903
|
-
# )
|
904
|
-
# g.add_relationship(rel)
|
905
|
-
|
906
|
-
# # And link the final answer to the best approach
|
907
|
-
# final_answer_idx = next(
|
908
|
-
# (i for i, r in enumerate(self.steps_history) if r.step_name == StepName.FINAL_ANSWER), None
|
909
|
-
# )
|
910
|
-
# if final_answer_idx is not None and best_decision_idx is not None:
|
911
|
-
# rel = Relationship(
|
912
|
-
# properties={},
|
913
|
-
# rel_type=StepRelation.RESULT_OF,
|
914
|
-
# start_node=step_nodes[final_answer_idx],
|
915
|
-
# end_node=step_nodes[best_decision_idx],
|
916
|
-
# globalId=f"{global_id_prefix}_final_{final_answer_idx}_resultof_{best_decision_idx}",
|
917
|
-
# )
|
918
|
-
# g.add_relationship(rel)
|
919
|
-
|
920
|
-
# return g
|
921
|
-
|
922
|
-
|
923
|
-
# ---------------------------------------------------------------------------------
|
924
|
-
# 5) AoTStrategy class that uses the pipeline
|
925
|
-
# ---------------------------------------------------------------------------------
|
926
|
-
|
927
|
-
|
928
|
-
@dataclass
|
929
|
-
class AoTStrategy(BaseStrategy):
|
930
|
-
"""
|
931
|
-
Strategy using AoTPipeline with a reasoning graph and deeper devil's advocate.
|
932
|
-
"""
|
933
|
-
|
934
|
-
pipeline: AoTPipeline
|
935
|
-
|
936
|
-
async def ainvoke(self, messages: LanguageModelInput) -> str:
|
937
|
-
"""Asynchronously run the pipeline with the given messages."""
|
938
|
-
# Convert your custom input to list[BaseMessage] as needed:
|
939
|
-
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
940
|
-
return await self.pipeline.arun_pipeline(msgs)
|
941
|
-
|
942
|
-
def invoke(self, messages: LanguageModelInput) -> str:
|
943
|
-
"""Synchronously run the pipeline with the given messages."""
|
944
|
-
msgs = self.pipeline.chatterer.client._convert_input(messages).to_messages() # type: ignore
|
945
|
-
return self.pipeline.run_pipeline(msgs)
|
946
|
-
|
947
|
-
# def get_reasoning_graph(self):
|
948
|
-
# """Return the AoT reasoning graph from the pipeline’s steps history."""
|
949
|
-
# return self.pipeline.get_reasoning_graph(global_id_prefix="AoT")
|
950
|
-
|
951
|
-
|
952
|
-
# ---------------------------------------------------------------------------------
|
953
|
-
# Example usage (pseudo-code)
|
954
|
-
# ---------------------------------------------------------------------------------
|
955
|
-
# if __name__ == "__main__":
|
956
|
-
# from neo4j_extension import Neo4jConnection # or your actual DB connector
|
957
|
-
|
958
|
-
# # You would create a Chatterer with your chosen LLM backend (OpenAI, etc.)
|
959
|
-
# chatterer = Chatterer.openai() # pseudo-code
|
960
|
-
# pipeline = AoTPipeline(chatterer=chatterer, max_depth=3)
|
961
|
-
# strategy = AoTStrategy(pipeline=pipeline)
|
962
|
-
|
963
|
-
# question = "Solve 5.9 = 5.11 - x. Also compare 9.11 and 9.9."
|
964
|
-
# answer = strategy.invoke(question)
|
965
|
-
# print("Final Answer:", answer)
|
966
|
-
|
967
|
-
# # Build the reasoning graph
|
968
|
-
# graph = strategy.get_reasoning_graph()
|
969
|
-
# print(f"\nGraph has {len(graph.nodes)} nodes and {len(graph.relationships)} relationships.")
|
970
|
-
|
971
|
-
# # Optionally store in Neo4j
|
972
|
-
# with Neo4jConnection() as conn:
|
973
|
-
# conn.clear_all()
|
974
|
-
# conn.upsert_graph(graph)
|
975
|
-
# print("Graph stored in Neo4j.")
|
chatterer/strategies/base.py
DELETED
@@ -1,14 +0,0 @@
|
|
1
|
-
from abc import ABC, abstractmethod
|
2
|
-
|
3
|
-
from ..language_model import LanguageModelInput
|
4
|
-
|
5
|
-
|
6
|
-
class BaseStrategy(ABC):
|
7
|
-
@abstractmethod
|
8
|
-
def invoke(self, messages: LanguageModelInput) -> str:
|
9
|
-
"""
|
10
|
-
Invoke the strategy with the given messages.
|
11
|
-
|
12
|
-
messages: List of messages to be passed to the strategy.
|
13
|
-
e.g. [{"role": "user", "content": "What is the meaning of life?"}]
|
14
|
-
"""
|
File without changes
|
File without changes
|
File without changes
|