contextagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentz/agent/base.py +262 -0
- agentz/artifacts/__init__.py +5 -0
- agentz/artifacts/artifact_writer.py +538 -0
- agentz/artifacts/reporter.py +235 -0
- agentz/artifacts/terminal_writer.py +100 -0
- agentz/context/__init__.py +6 -0
- agentz/context/context.py +91 -0
- agentz/context/conversation.py +205 -0
- agentz/context/data_store.py +208 -0
- agentz/llm/llm_setup.py +156 -0
- agentz/mcp/manager.py +142 -0
- agentz/mcp/patches.py +88 -0
- agentz/mcp/servers/chrome_devtools/server.py +14 -0
- agentz/profiles/base.py +108 -0
- agentz/profiles/data/data_analysis.py +38 -0
- agentz/profiles/data/data_loader.py +35 -0
- agentz/profiles/data/evaluation.py +43 -0
- agentz/profiles/data/model_training.py +47 -0
- agentz/profiles/data/preprocessing.py +47 -0
- agentz/profiles/data/visualization.py +47 -0
- agentz/profiles/manager/evaluate.py +51 -0
- agentz/profiles/manager/memory.py +62 -0
- agentz/profiles/manager/observe.py +48 -0
- agentz/profiles/manager/routing.py +66 -0
- agentz/profiles/manager/writer.py +51 -0
- agentz/profiles/mcp/browser.py +21 -0
- agentz/profiles/mcp/chrome.py +21 -0
- agentz/profiles/mcp/notion.py +21 -0
- agentz/runner/__init__.py +74 -0
- agentz/runner/base.py +28 -0
- agentz/runner/executor.py +320 -0
- agentz/runner/hooks.py +110 -0
- agentz/runner/iteration.py +142 -0
- agentz/runner/patterns.py +215 -0
- agentz/runner/tracker.py +188 -0
- agentz/runner/utils.py +45 -0
- agentz/runner/workflow.py +250 -0
- agentz/tools/__init__.py +20 -0
- agentz/tools/data_tools/__init__.py +17 -0
- agentz/tools/data_tools/data_analysis.py +152 -0
- agentz/tools/data_tools/data_loading.py +92 -0
- agentz/tools/data_tools/evaluation.py +175 -0
- agentz/tools/data_tools/helpers.py +120 -0
- agentz/tools/data_tools/model_training.py +192 -0
- agentz/tools/data_tools/preprocessing.py +229 -0
- agentz/tools/data_tools/visualization.py +281 -0
- agentz/utils/__init__.py +69 -0
- agentz/utils/config.py +708 -0
- agentz/utils/helpers.py +10 -0
- agentz/utils/parsers.py +142 -0
- agentz/utils/printer.py +539 -0
- contextagent-0.1.0.dist-info/METADATA +269 -0
- contextagent-0.1.0.dist-info/RECORD +66 -0
- contextagent-0.1.0.dist-info/WHEEL +5 -0
- contextagent-0.1.0.dist-info/licenses/LICENSE +21 -0
- contextagent-0.1.0.dist-info/top_level.txt +2 -0
- pipelines/base.py +972 -0
- pipelines/data_scientist.py +97 -0
- pipelines/data_scientist_memory.py +151 -0
- pipelines/experience_learner.py +0 -0
- pipelines/prompt_generator.py +0 -0
- pipelines/simple.py +78 -0
- pipelines/simple_browser.py +145 -0
- pipelines/simple_chrome.py +75 -0
- pipelines/simple_notion.py +103 -0
- pipelines/tool_builder.py +0 -0
@@ -0,0 +1,235 @@
|
|
1
|
+
"""Shared data models and RunReporter facade for pipeline runs."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import threading
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Optional
|
9
|
+
|
10
|
+
from rich.console import Console
|
11
|
+
|
12
|
+
from agentz.artifacts.artifact_writer import ArtifactWriter
|
13
|
+
from agentz.artifacts.terminal_writer import TerminalWriter
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class PanelRecord:
|
18
|
+
"""Representation of a panel rendered during the run."""
|
19
|
+
|
20
|
+
title: Optional[str]
|
21
|
+
content: str
|
22
|
+
border_style: Optional[str]
|
23
|
+
iteration: Optional[int]
|
24
|
+
group_id: Optional[str]
|
25
|
+
recorded_at: str
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class AgentStepRecord:
|
30
|
+
"""Runtime information captured per agent execution."""
|
31
|
+
|
32
|
+
agent_name: str
|
33
|
+
span_name: str
|
34
|
+
iteration: Optional[int]
|
35
|
+
group_id: Optional[str]
|
36
|
+
started_at: str
|
37
|
+
finished_at: Optional[str] = None
|
38
|
+
duration_seconds: Optional[float] = None
|
39
|
+
status: str = "running"
|
40
|
+
error: Optional[str] = None
|
41
|
+
|
42
|
+
|
43
|
+
class RunReporter:
|
44
|
+
"""Facade combining terminal display and artifact persistence."""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
*,
|
49
|
+
base_dir: Path,
|
50
|
+
pipeline_slug: str,
|
51
|
+
workflow_name: str,
|
52
|
+
experiment_id: str,
|
53
|
+
console: Optional[Console] = None,
|
54
|
+
) -> None:
|
55
|
+
self.base_dir = base_dir
|
56
|
+
self.pipeline_slug = pipeline_slug
|
57
|
+
self.workflow_name = workflow_name
|
58
|
+
self.experiment_id = experiment_id
|
59
|
+
self.console = console
|
60
|
+
|
61
|
+
self.run_dir = base_dir / pipeline_slug / experiment_id
|
62
|
+
self.terminal_md_path = self.run_dir / "terminal_log.md"
|
63
|
+
self.terminal_html_path = self.run_dir / "terminal_log.html"
|
64
|
+
self.final_report_md_path = self.run_dir / "final_report.md"
|
65
|
+
self.final_report_html_path = self.run_dir / "final_report.html"
|
66
|
+
|
67
|
+
self._lock = threading.RLock()
|
68
|
+
|
69
|
+
# Delegate to specialized reporters
|
70
|
+
self._artifact_writer = ArtifactWriter(
|
71
|
+
base_dir=base_dir,
|
72
|
+
pipeline_slug=pipeline_slug,
|
73
|
+
workflow_name=workflow_name,
|
74
|
+
experiment_id=experiment_id,
|
75
|
+
)
|
76
|
+
self._terminal_writer = TerminalWriter(
|
77
|
+
run_dir=self.run_dir,
|
78
|
+
console=console,
|
79
|
+
)
|
80
|
+
|
81
|
+
# ------------------------------------------------------------------ basics
|
82
|
+
|
83
|
+
def start(self, config: Any) -> None:
|
84
|
+
"""Prepare filesystem layout and capture start metadata."""
|
85
|
+
with self._lock:
|
86
|
+
self._artifact_writer.start(config)
|
87
|
+
|
88
|
+
def set_final_result(self, result: Any) -> None:
|
89
|
+
"""Store pipeline result for later persistence."""
|
90
|
+
with self._lock:
|
91
|
+
self._artifact_writer.set_final_result(result)
|
92
|
+
|
93
|
+
# ----------------------------------------------------------------- logging
|
94
|
+
|
95
|
+
def record_status_update(
|
96
|
+
self,
|
97
|
+
*,
|
98
|
+
item_id: str,
|
99
|
+
content: str,
|
100
|
+
is_done: bool,
|
101
|
+
title: Optional[str],
|
102
|
+
border_style: Optional[str],
|
103
|
+
group_id: Optional[str],
|
104
|
+
) -> None:
|
105
|
+
"""Currently unused; maintained for interface compatibility."""
|
106
|
+
with self._lock:
|
107
|
+
self._artifact_writer.record_status_update(
|
108
|
+
item_id=item_id,
|
109
|
+
content=content,
|
110
|
+
is_done=is_done,
|
111
|
+
title=title,
|
112
|
+
border_style=border_style,
|
113
|
+
group_id=group_id,
|
114
|
+
)
|
115
|
+
|
116
|
+
def record_group_start(
|
117
|
+
self,
|
118
|
+
*,
|
119
|
+
group_id: str,
|
120
|
+
title: Optional[str],
|
121
|
+
border_style: Optional[str],
|
122
|
+
iteration: Optional[int] = None,
|
123
|
+
) -> None:
|
124
|
+
"""Record the start of an iteration/group."""
|
125
|
+
with self._lock:
|
126
|
+
self._artifact_writer.record_group_start(
|
127
|
+
group_id=group_id,
|
128
|
+
title=title,
|
129
|
+
border_style=border_style,
|
130
|
+
iteration=iteration,
|
131
|
+
)
|
132
|
+
|
133
|
+
def record_group_end(
|
134
|
+
self,
|
135
|
+
*,
|
136
|
+
group_id: str,
|
137
|
+
is_done: bool = True,
|
138
|
+
title: Optional[str] = None,
|
139
|
+
) -> None:
|
140
|
+
"""Record the end of an iteration/group."""
|
141
|
+
with self._lock:
|
142
|
+
self._artifact_writer.record_group_end(
|
143
|
+
group_id=group_id,
|
144
|
+
is_done=is_done,
|
145
|
+
title=title,
|
146
|
+
)
|
147
|
+
|
148
|
+
def record_agent_step_start(
|
149
|
+
self,
|
150
|
+
*,
|
151
|
+
step_id: str,
|
152
|
+
agent_name: str,
|
153
|
+
span_name: str,
|
154
|
+
iteration: Optional[int],
|
155
|
+
group_id: Optional[str],
|
156
|
+
printer_title: Optional[str],
|
157
|
+
) -> None:
|
158
|
+
"""Capture metadata when an agent step begins."""
|
159
|
+
with self._lock:
|
160
|
+
self._artifact_writer.record_agent_step_start(
|
161
|
+
step_id=step_id,
|
162
|
+
agent_name=agent_name,
|
163
|
+
span_name=span_name,
|
164
|
+
iteration=iteration,
|
165
|
+
group_id=group_id,
|
166
|
+
printer_title=printer_title,
|
167
|
+
)
|
168
|
+
|
169
|
+
def record_agent_step_end(
|
170
|
+
self,
|
171
|
+
*,
|
172
|
+
step_id: str,
|
173
|
+
status: str,
|
174
|
+
duration_seconds: float,
|
175
|
+
error: Optional[str] = None,
|
176
|
+
) -> None:
|
177
|
+
"""Update agent step telemetry on completion."""
|
178
|
+
with self._lock:
|
179
|
+
self._artifact_writer.record_agent_step_end(
|
180
|
+
step_id=step_id,
|
181
|
+
status=status,
|
182
|
+
duration_seconds=duration_seconds,
|
183
|
+
error=error,
|
184
|
+
)
|
185
|
+
|
186
|
+
def record_panel(
|
187
|
+
self,
|
188
|
+
*,
|
189
|
+
title: str,
|
190
|
+
content: str,
|
191
|
+
border_style: Optional[str],
|
192
|
+
iteration: Optional[int],
|
193
|
+
group_id: Optional[str],
|
194
|
+
) -> None:
|
195
|
+
"""Persist panel meta for terminal & HTML artefacts."""
|
196
|
+
with self._lock:
|
197
|
+
# Create panel record
|
198
|
+
from agentz.artifacts.artifact_writer import _utc_timestamp
|
199
|
+
record = PanelRecord(
|
200
|
+
title=title,
|
201
|
+
content=content,
|
202
|
+
border_style=border_style,
|
203
|
+
iteration=iteration,
|
204
|
+
group_id=group_id,
|
205
|
+
recorded_at=_utc_timestamp(),
|
206
|
+
)
|
207
|
+
# Record in both reporters
|
208
|
+
self._artifact_writer.record_panel(
|
209
|
+
title=title,
|
210
|
+
content=content,
|
211
|
+
border_style=border_style,
|
212
|
+
iteration=iteration,
|
213
|
+
group_id=group_id,
|
214
|
+
)
|
215
|
+
self._terminal_writer.record_panel(record)
|
216
|
+
|
217
|
+
# ------------------------------------------------------------- finalisation
|
218
|
+
|
219
|
+
def finalize(self) -> None:
|
220
|
+
"""Persist markdown + HTML artefacts."""
|
221
|
+
with self._lock:
|
222
|
+
self._artifact_writer.finalize()
|
223
|
+
|
224
|
+
# ---------------------------------------------------------- terminal flush
|
225
|
+
|
226
|
+
def print_terminal_report(self) -> None:
|
227
|
+
"""Stream captured panel content back to the console."""
|
228
|
+
self._terminal_writer.print_terminal_report()
|
229
|
+
|
230
|
+
# ----------------------------------------------------------------- helpers
|
231
|
+
|
232
|
+
def ensure_started(self) -> None:
|
233
|
+
"""Raise if reporter not initialised."""
|
234
|
+
self._artifact_writer.ensure_started()
|
235
|
+
|
@@ -0,0 +1,100 @@
|
|
1
|
+
"""TerminalWriter handles real-time console output and panel display."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import re
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import TYPE_CHECKING, List, Optional
|
8
|
+
|
9
|
+
from rich.console import Console
|
10
|
+
from rich.markdown import Markdown
|
11
|
+
from rich.panel import Panel
|
12
|
+
from rich.text import Text
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from agentz.artifacts.reporter import PanelRecord
|
16
|
+
|
17
|
+
|
18
|
+
class TerminalWriter:
|
19
|
+
"""Handles real-time terminal display of panels and run information."""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
*,
|
24
|
+
run_dir: Path,
|
25
|
+
console: Optional[Console] = None,
|
26
|
+
) -> None:
|
27
|
+
self.run_dir = run_dir
|
28
|
+
self.console = console
|
29
|
+
self._panels: List[PanelRecord] = []
|
30
|
+
|
31
|
+
def record_panel(self, record: PanelRecord) -> None:
|
32
|
+
"""Store panel for later terminal display."""
|
33
|
+
self._panels.append(record)
|
34
|
+
|
35
|
+
def print_terminal_report(self) -> None:
|
36
|
+
"""Stream captured panel content back to the console."""
|
37
|
+
if not self.console or not self._panels:
|
38
|
+
return
|
39
|
+
|
40
|
+
panels = self._select_terminal_panels()
|
41
|
+
if not panels:
|
42
|
+
return
|
43
|
+
|
44
|
+
self.console.print(
|
45
|
+
Text(f"Run artefacts saved to {self.run_dir}", style="bold cyan")
|
46
|
+
)
|
47
|
+
for record in panels:
|
48
|
+
renderable = self._panel_renderable(record.content)
|
49
|
+
panel = Panel(
|
50
|
+
renderable,
|
51
|
+
title=record.title,
|
52
|
+
border_style=record.border_style or "cyan",
|
53
|
+
padding=(1, 2),
|
54
|
+
)
|
55
|
+
self.console.print(panel)
|
56
|
+
|
57
|
+
# ----------------------------------------------------------------- helpers
|
58
|
+
|
59
|
+
def _select_terminal_panels(self) -> List[PanelRecord]:
|
60
|
+
"""Return only final panels for terminal replay."""
|
61
|
+
final_panels = [
|
62
|
+
record for record in self._panels if self._is_final_panel(record)
|
63
|
+
]
|
64
|
+
if final_panels:
|
65
|
+
return final_panels
|
66
|
+
# Fallback: display only the most recent panel
|
67
|
+
return self._panels[-1:]
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def _is_final_panel(record: PanelRecord) -> bool:
|
71
|
+
"""Heuristic for identifying final report panels."""
|
72
|
+
if record.group_id and "final" in record.group_id.lower():
|
73
|
+
return True
|
74
|
+
if record.title:
|
75
|
+
title = record.title.lower()
|
76
|
+
if "final" in title or "writer" in title:
|
77
|
+
return True
|
78
|
+
return False
|
79
|
+
|
80
|
+
def _panel_renderable(self, content: str):
|
81
|
+
"""Render Markdown panels using rich, otherwise plain text."""
|
82
|
+
if self._looks_like_markdown(content):
|
83
|
+
return Markdown(content)
|
84
|
+
return Text(content)
|
85
|
+
|
86
|
+
@staticmethod
|
87
|
+
def _looks_like_markdown(content: str) -> bool:
|
88
|
+
"""Rudimentary detection of Markdown content."""
|
89
|
+
if not content:
|
90
|
+
return False
|
91
|
+
markdown_patterns = (
|
92
|
+
r"^#{1,6}\s", # headings
|
93
|
+
r"^\s*[-*+]\s+\S", # bullet lists
|
94
|
+
r"^\s*\d+\.\s+\S", # numbered lists
|
95
|
+
r"`{1,3}.+?`{1,3}", # inline or fenced code
|
96
|
+
r"\*\*.+\*\*", # bold text
|
97
|
+
r"_{1,2}.+_{1,2}", # italic/underline emphasis
|
98
|
+
)
|
99
|
+
return any(re.search(pattern, content, re.MULTILINE) for pattern in markdown_patterns)
|
100
|
+
|
@@ -0,0 +1,91 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Dict, List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
from agentz.context.conversation import BaseIterationRecord, ConversationState, create_conversation_state
|
6
|
+
from agentz.profiles.base import Profile, load_all_profiles
|
7
|
+
|
8
|
+
|
9
|
+
class Context:
|
10
|
+
"""Central coordinator for conversation state and iteration management."""
|
11
|
+
|
12
|
+
# Constants for iteration group IDs
|
13
|
+
ITERATION_GROUP_PREFIX = "iter"
|
14
|
+
FINAL_GROUP_ID = "iter-final"
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
components: Union[ConversationState, List[str]]
|
19
|
+
) -> None:
|
20
|
+
"""Initialize context engine with conversation state.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
components: Either a ConversationState object (for backward compatibility)
|
24
|
+
or a list of component names to automatically initialize:
|
25
|
+
- "profiles": loads all profiles via load_all_profiles()
|
26
|
+
- "states": creates conversation state via create_conversation_state()
|
27
|
+
|
28
|
+
Examples:
|
29
|
+
# Automatic initialization
|
30
|
+
context = Context(["profiles", "states"])
|
31
|
+
|
32
|
+
# Manual initialization (backward compatible)
|
33
|
+
state = create_conversation_state(profiles)
|
34
|
+
context = Context(state)
|
35
|
+
"""
|
36
|
+
self.profiles: Optional[Dict[str, Profile]] = None
|
37
|
+
|
38
|
+
if isinstance(components, ConversationState):
|
39
|
+
# Backward compatible: direct state initialization
|
40
|
+
self._state = components
|
41
|
+
elif isinstance(components, list):
|
42
|
+
# Automatic initialization from component list
|
43
|
+
if "profiles" in components:
|
44
|
+
self.profiles = load_all_profiles()
|
45
|
+
|
46
|
+
if "states" in components:
|
47
|
+
if self.profiles is None:
|
48
|
+
raise ValueError("'states' requires 'profiles' to be initialized first. Include 'profiles' in the component list.")
|
49
|
+
self._state = create_conversation_state(self.profiles)
|
50
|
+
elif not hasattr(self, '_state'):
|
51
|
+
# If no state requested, create empty state
|
52
|
+
self._state = ConversationState()
|
53
|
+
else:
|
54
|
+
raise TypeError(f"components must be ConversationState or list, got {type(components)}")
|
55
|
+
|
56
|
+
@property
|
57
|
+
def state(self) -> ConversationState:
|
58
|
+
return self._state
|
59
|
+
|
60
|
+
def begin_iteration(self) -> Tuple[BaseIterationRecord, str]:
|
61
|
+
"""Start a new iteration and return its record with group_id.
|
62
|
+
|
63
|
+
Automatically starts the conversation state timer on first iteration.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Tuple of (iteration_record, group_id) where group_id follows the pattern "iter-{index}"
|
67
|
+
"""
|
68
|
+
# Lazy timer start: start on first iteration if not already started
|
69
|
+
if self._state.started_at is None:
|
70
|
+
self._state.start_timer()
|
71
|
+
|
72
|
+
iteration = self._state.begin_iteration()
|
73
|
+
group_id = f"{self.ITERATION_GROUP_PREFIX}-{iteration.index}"
|
74
|
+
|
75
|
+
return iteration, group_id
|
76
|
+
|
77
|
+
def mark_iteration_complete(self) -> None:
|
78
|
+
"""Mark the current iteration as complete."""
|
79
|
+
self._state.mark_iteration_complete()
|
80
|
+
|
81
|
+
def begin_final_report(self) -> Tuple[None, str]:
|
82
|
+
"""Begin final report phase and return group_id.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Tuple of (None, group_id) where group_id is the final report group ID
|
86
|
+
"""
|
87
|
+
return None, self.FINAL_GROUP_ID
|
88
|
+
|
89
|
+
def mark_final_complete(self) -> None:
|
90
|
+
"""Mark final report as complete."""
|
91
|
+
pass # No state change needed for final report
|
@@ -0,0 +1,205 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
from typing import Any, ClassVar, Dict, List, Optional, Set, Tuple, Type
|
5
|
+
from pydantic import BaseModel, Field, PrivateAttr, ValidationError, create_model
|
6
|
+
from agentz.profiles.base import Profile, ToolAgentOutput
|
7
|
+
|
8
|
+
|
9
|
+
class BaseIterationRecord(BaseModel):
|
10
|
+
"""State captured for a single iteration of the research loop."""
|
11
|
+
|
12
|
+
index: int
|
13
|
+
observation: Optional[str] = None
|
14
|
+
tools: List[ToolAgentOutput] = Field(default_factory=list)
|
15
|
+
payloads: List[Any] = Field(default_factory=list)
|
16
|
+
status: str = Field(default="pending", description="Iteration status: pending or complete")
|
17
|
+
summarized: bool = Field(default=False, description="Whether this iteration has been summarised")
|
18
|
+
_output_union: ClassVar[Optional[Type[BaseModel]]] = None # type: ignore[var-annotated]
|
19
|
+
|
20
|
+
def mark_complete(self) -> None:
|
21
|
+
self.status = "complete"
|
22
|
+
|
23
|
+
def is_complete(self) -> bool:
|
24
|
+
return self.status == "complete"
|
25
|
+
|
26
|
+
def mark_summarized(self) -> None:
|
27
|
+
self.summarized = True
|
28
|
+
|
29
|
+
def history_block(self) -> str:
|
30
|
+
"""Render this iteration as a formatted history block for prompts."""
|
31
|
+
lines: List[str] = [f"[ITERATION {self.index}]"]
|
32
|
+
|
33
|
+
if self.observation:
|
34
|
+
lines.append(f"<thought>\n{self.observation}\n</thought>")
|
35
|
+
|
36
|
+
# Render structured payloads generically
|
37
|
+
if self.payloads:
|
38
|
+
payload_lines = []
|
39
|
+
for payload in self.payloads:
|
40
|
+
if isinstance(payload, BaseModel):
|
41
|
+
payload_lines.append(payload.model_dump_json(indent=2))
|
42
|
+
else:
|
43
|
+
payload_lines.append(str(payload))
|
44
|
+
if payload_lines:
|
45
|
+
lines.append(f"<payloads>\n{chr(10).join(payload_lines)}\n</payloads>")
|
46
|
+
|
47
|
+
# Render tool execution results
|
48
|
+
if self.tools:
|
49
|
+
tool_lines = [tool.output for tool in self.tools]
|
50
|
+
lines.append(f"<findings>\n{chr(10).join(tool_lines)}\n</findings>")
|
51
|
+
|
52
|
+
return "\n\n".join(lines).strip()
|
53
|
+
|
54
|
+
def add_payload(self, value: Any) -> BaseModel:
|
55
|
+
expected_union = getattr(self.__class__, "_output_union", None)
|
56
|
+
union_args: Tuple[Type[BaseModel], ...] = ()
|
57
|
+
if expected_union is not None:
|
58
|
+
union_args = getattr(expected_union, "__args__", ()) or ()
|
59
|
+
|
60
|
+
if isinstance(value, BaseModel):
|
61
|
+
payload = value
|
62
|
+
if union_args and not isinstance(payload, union_args):
|
63
|
+
data = payload.model_dump()
|
64
|
+
else:
|
65
|
+
self.payloads.append(payload)
|
66
|
+
return payload
|
67
|
+
elif isinstance(value, dict):
|
68
|
+
data = value
|
69
|
+
else:
|
70
|
+
if union_args:
|
71
|
+
raise TypeError(
|
72
|
+
f"Payload type {type(value)!r} is incompatible with expected schemas {union_args}"
|
73
|
+
)
|
74
|
+
raise TypeError(f"Payload type {type(value)!r} is not supported")
|
75
|
+
|
76
|
+
if not union_args:
|
77
|
+
raise TypeError("No output schemas are registered for payload coercion")
|
78
|
+
|
79
|
+
errors: List[ValidationError] = []
|
80
|
+
for candidate in union_args:
|
81
|
+
try:
|
82
|
+
payload = candidate.model_validate(data)
|
83
|
+
self.payloads.append(payload)
|
84
|
+
return payload
|
85
|
+
except ValidationError as exc:
|
86
|
+
errors.append(exc)
|
87
|
+
|
88
|
+
raise ValidationError.from_exception_data(
|
89
|
+
title="Iteration payload validation failed",
|
90
|
+
line_errors=[err for exc in errors for err in exc.errors()],
|
91
|
+
) from (errors[-1] if errors else None)
|
92
|
+
|
93
|
+
|
94
|
+
|
95
|
+
class ConversationState(BaseModel):
|
96
|
+
iterations: List[BaseIterationRecord] = Field(default_factory=list)
|
97
|
+
final_report: Optional[str] = None
|
98
|
+
started_at: Optional[float] = None
|
99
|
+
complete: bool = False
|
100
|
+
summary: Optional[str] = None
|
101
|
+
query: Optional[str] = None
|
102
|
+
|
103
|
+
_iteration_model: Type[BaseIterationRecord] = PrivateAttr()
|
104
|
+
|
105
|
+
def start_timer(self) -> None:
|
106
|
+
self.started_at = time.time()
|
107
|
+
|
108
|
+
def elapsed_minutes(self) -> float:
|
109
|
+
if self.started_at is None:
|
110
|
+
return 0.0
|
111
|
+
return (time.time() - self.started_at) / 60
|
112
|
+
|
113
|
+
def begin_iteration(self) -> BaseIterationRecord:
|
114
|
+
iteration = self._iteration_model(index=len(self.iterations) + 1)
|
115
|
+
self.iterations.append(iteration)
|
116
|
+
return iteration
|
117
|
+
|
118
|
+
@property
|
119
|
+
def current_iteration(self) -> BaseIterationRecord:
|
120
|
+
if not self.iterations:
|
121
|
+
raise ValueError("No iteration has been started yet.")
|
122
|
+
return self.iterations[-1]
|
123
|
+
|
124
|
+
def mark_iteration_complete(self) -> None:
|
125
|
+
self.current_iteration.mark_complete()
|
126
|
+
|
127
|
+
def mark_research_complete(self) -> None:
|
128
|
+
self.complete = True
|
129
|
+
self.current_iteration.mark_complete()
|
130
|
+
|
131
|
+
def get_history_blocks(self, include_current: bool, only_unsummarized: bool = False) -> str:
|
132
|
+
relevant = [
|
133
|
+
iteration
|
134
|
+
for iteration in self.iterations
|
135
|
+
if (iteration.is_complete() or include_current and iteration is self.current_iteration)
|
136
|
+
and (not only_unsummarized or not iteration.summarized)
|
137
|
+
]
|
138
|
+
blocks = [iteration.history_block() for iteration in relevant if iteration.history_block()]
|
139
|
+
return "\n\n".join(blocks).strip()
|
140
|
+
|
141
|
+
def iteration_history(self, include_current: bool = False) -> str:
|
142
|
+
return self.get_history_blocks(include_current, only_unsummarized=False)
|
143
|
+
|
144
|
+
def unsummarized_history(self, include_current: bool = True) -> str:
|
145
|
+
return self.get_history_blocks(include_current, only_unsummarized=True)
|
146
|
+
|
147
|
+
def set_query(self, query: str) -> None:
|
148
|
+
self.query = query
|
149
|
+
|
150
|
+
def record_payload(self, payload: Any) -> BaseModel:
|
151
|
+
"""Attach a structured payload to the current iteration."""
|
152
|
+
iteration = self.current_iteration if self.iterations else self.begin_iteration()
|
153
|
+
return iteration.add_payload(payload)
|
154
|
+
|
155
|
+
def all_findings(self) -> List[str]:
|
156
|
+
findings: List[str] = []
|
157
|
+
for iteration in self.iterations:
|
158
|
+
findings.extend(tool.output for tool in iteration.tools)
|
159
|
+
return findings
|
160
|
+
|
161
|
+
def findings_text(self) -> str:
|
162
|
+
findings = self.all_findings()
|
163
|
+
return "\n\n".join(findings).strip() if findings else ""
|
164
|
+
|
165
|
+
def update_summary(self, summary: str) -> None:
|
166
|
+
self.summary = summary
|
167
|
+
for iteration in self.iterations:
|
168
|
+
iteration.mark_summarized()
|
169
|
+
|
170
|
+
|
171
|
+
def create_conversation_state(profiles: Dict[str, Profile]) -> "ConversationState":
|
172
|
+
models: List[Type[BaseModel]] = []
|
173
|
+
seen: Set[str] = set()
|
174
|
+
|
175
|
+
for profile in profiles.values():
|
176
|
+
model = getattr(profile, "output_schema", None)
|
177
|
+
if model is not None and isinstance(model, type) and issubclass(model, BaseModel):
|
178
|
+
key = f"{model.__module__}.{model.__qualname__}"
|
179
|
+
if key not in seen:
|
180
|
+
seen.add(key)
|
181
|
+
models.append(model)
|
182
|
+
|
183
|
+
if not models:
|
184
|
+
models = [ToolAgentOutput]
|
185
|
+
|
186
|
+
iterator = iter(models)
|
187
|
+
union_type: Type[BaseModel] = next(iterator)
|
188
|
+
for model in iterator:
|
189
|
+
union_type = union_type | model # type: ignore[operator]
|
190
|
+
|
191
|
+
field_definitions = {
|
192
|
+
"payloads": (List[union_type], Field(default_factory=list)),
|
193
|
+
}
|
194
|
+
|
195
|
+
iteration_model: Type[BaseIterationRecord] = create_model(
|
196
|
+
"IterationRecord",
|
197
|
+
__base__=BaseIterationRecord,
|
198
|
+
__module__=BaseIterationRecord.__module__,
|
199
|
+
**field_definitions,
|
200
|
+
)
|
201
|
+
iteration_model._output_union = union_type # type: ignore[attr-defined]
|
202
|
+
|
203
|
+
state = ConversationState()
|
204
|
+
object.__setattr__(state, "_iteration_model", iteration_model)
|
205
|
+
return state
|