palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,637 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import textwrap
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from rich.console import Group
|
|
8
|
+
from rich.live import Live
|
|
9
|
+
from rich.markdown import Markdown
|
|
10
|
+
from rich.rule import Rule
|
|
11
|
+
from rich.text import Text
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import PIL.Image
|
|
15
|
+
|
|
16
|
+
from smolagents.agent_types import handle_agent_output_types
|
|
17
|
+
from smolagents.agents import (
|
|
18
|
+
ActionOutput,
|
|
19
|
+
CodeAgent,
|
|
20
|
+
FinalAnswerPromptTemplate,
|
|
21
|
+
ManagedAgentPromptTemplate,
|
|
22
|
+
PlanningPromptTemplate,
|
|
23
|
+
PromptTemplates,
|
|
24
|
+
RunResult,
|
|
25
|
+
ToolOutput,
|
|
26
|
+
populate_template,
|
|
27
|
+
)
|
|
28
|
+
from smolagents.local_python_executor import fix_final_answer_code
|
|
29
|
+
from smolagents.memory import (
|
|
30
|
+
ActionStep,
|
|
31
|
+
FinalAnswerStep,
|
|
32
|
+
PlanningStep,
|
|
33
|
+
SystemPromptStep,
|
|
34
|
+
TaskStep,
|
|
35
|
+
Timing,
|
|
36
|
+
TokenUsage,
|
|
37
|
+
ToolCall,
|
|
38
|
+
)
|
|
39
|
+
from smolagents.models import (
|
|
40
|
+
CODEAGENT_RESPONSE_FORMAT,
|
|
41
|
+
ChatMessage,
|
|
42
|
+
ChatMessageStreamDelta,
|
|
43
|
+
MessageRole,
|
|
44
|
+
agglomerate_stream_deltas,
|
|
45
|
+
)
|
|
46
|
+
from smolagents.monitoring import YELLOW_HEX, LogLevel
|
|
47
|
+
from smolagents.utils import (
|
|
48
|
+
AgentError,
|
|
49
|
+
AgentExecutionError,
|
|
50
|
+
AgentGenerationError,
|
|
51
|
+
AgentMaxStepsError,
|
|
52
|
+
AgentParsingError,
|
|
53
|
+
extract_code_from_text,
|
|
54
|
+
parse_code_blobs,
|
|
55
|
+
truncate_content,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
from palimpzest.prompts import (
|
|
59
|
+
CODE_AGENT_SYSTEM_PROMPT,
|
|
60
|
+
DATA_DISCOVERY_AGENT_INITIAL_PLAN_PROMPT,
|
|
61
|
+
DATA_DISCOVERY_AGENT_REPORT_PROMPT,
|
|
62
|
+
DATA_DISCOVERY_AGENT_TASK_PROMPT,
|
|
63
|
+
DATA_DISCOVERY_AGENT_UPDATE_PLAN_POST_MESSAGES_PROMPT,
|
|
64
|
+
DATA_DISCOVERY_AGENT_UPDATE_PLAN_PRE_MESSAGES_PROMPT,
|
|
65
|
+
FINAL_ANSWER_POST_MESSAGES_PROMPT,
|
|
66
|
+
FINAL_ANSWER_PRE_MESSAGES_PROMPT,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# TODO: make this use memory the way you want
|
|
71
|
+
class PZBaseAgent(CodeAgent):
|
|
72
|
+
def __init__(self, run_id: str, context_description: str, *args, **kwargs):
|
|
73
|
+
# memory_config = {
|
|
74
|
+
# "vector_store": {
|
|
75
|
+
# "provider": "chroma",
|
|
76
|
+
# "config": {
|
|
77
|
+
# "collection_name": f"palimpzest-memory-{self.__class__.__name__}",
|
|
78
|
+
# "path": "./pz-chroma",
|
|
79
|
+
# }
|
|
80
|
+
# }
|
|
81
|
+
# }
|
|
82
|
+
# self.pz_memory = Memory.from_config(memory_config)
|
|
83
|
+
self.run_id = run_id
|
|
84
|
+
self.context_description = context_description
|
|
85
|
+
super().__init__(*args, **kwargs)
|
|
86
|
+
|
|
87
|
+
def write_memory_to_messages(
|
|
88
|
+
self,
|
|
89
|
+
summary_mode: bool = False,
|
|
90
|
+
) -> list[ChatMessage]:
|
|
91
|
+
"""
|
|
92
|
+
Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
|
|
93
|
+
that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help
|
|
94
|
+
the LLM.
|
|
95
|
+
"""
|
|
96
|
+
messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
|
|
97
|
+
for memory_step in self.memory.steps:
|
|
98
|
+
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
|
|
99
|
+
return messages
|
|
100
|
+
|
|
101
|
+
def _generate_planning_step(
|
|
102
|
+
self, task, is_first_step: bool, step: int
|
|
103
|
+
) -> Generator[ChatMessageStreamDelta | PlanningStep]:
|
|
104
|
+
start_time = time.time()
|
|
105
|
+
if is_first_step:
|
|
106
|
+
input_messages = [
|
|
107
|
+
ChatMessage(
|
|
108
|
+
role=MessageRole.USER,
|
|
109
|
+
content=[
|
|
110
|
+
{
|
|
111
|
+
"type": "text",
|
|
112
|
+
"text": populate_template(
|
|
113
|
+
self.prompt_templates["planning"]["initial_plan"],
|
|
114
|
+
variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents, "context_description": self.context_description},
|
|
115
|
+
),
|
|
116
|
+
}
|
|
117
|
+
],
|
|
118
|
+
)
|
|
119
|
+
]
|
|
120
|
+
if self.stream_outputs and hasattr(self.model, "generate_stream"):
|
|
121
|
+
plan_message_content = ""
|
|
122
|
+
output_stream = self.model.generate_stream(input_messages, stop_sequences=["<end_plan>"]) # type: ignore
|
|
123
|
+
input_tokens, output_tokens = 0, 0
|
|
124
|
+
with Live("", console=self.logger.console, vertical_overflow="visible") as live:
|
|
125
|
+
for event in output_stream:
|
|
126
|
+
if event.content is not None:
|
|
127
|
+
plan_message_content += event.content
|
|
128
|
+
live.update(Markdown(plan_message_content))
|
|
129
|
+
if event.token_usage:
|
|
130
|
+
output_tokens += event.token_usage.output_tokens
|
|
131
|
+
input_tokens = event.token_usage.input_tokens
|
|
132
|
+
yield event
|
|
133
|
+
else:
|
|
134
|
+
plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
|
|
135
|
+
plan_message_content = plan_message.content
|
|
136
|
+
input_tokens, output_tokens = (
|
|
137
|
+
(
|
|
138
|
+
plan_message.token_usage.input_tokens,
|
|
139
|
+
plan_message.token_usage.output_tokens,
|
|
140
|
+
)
|
|
141
|
+
if plan_message.token_usage
|
|
142
|
+
else (None, None)
|
|
143
|
+
)
|
|
144
|
+
plan = textwrap.dedent(
|
|
145
|
+
f"""Here are the facts I know and the plan of action that I will follow to solve the task:\n```\n{plan_message_content}\n```"""
|
|
146
|
+
)
|
|
147
|
+
else:
|
|
148
|
+
# Summary mode removes the system prompt and previous planning messages output by the model.
|
|
149
|
+
# Removing previous planning messages avoids influencing too much the new plan.
|
|
150
|
+
memory_messages = self.write_memory_to_messages(summary_mode=True)
|
|
151
|
+
plan_update_pre = ChatMessage(
|
|
152
|
+
role=MessageRole.SYSTEM,
|
|
153
|
+
content=[
|
|
154
|
+
{
|
|
155
|
+
"type": "text",
|
|
156
|
+
"text": populate_template(
|
|
157
|
+
self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task, "context_description": self.context_description}
|
|
158
|
+
),
|
|
159
|
+
}
|
|
160
|
+
],
|
|
161
|
+
)
|
|
162
|
+
plan_update_post = ChatMessage(
|
|
163
|
+
role=MessageRole.USER,
|
|
164
|
+
content=[
|
|
165
|
+
{
|
|
166
|
+
"type": "text",
|
|
167
|
+
"text": populate_template(
|
|
168
|
+
self.prompt_templates["planning"]["update_plan_post_messages"],
|
|
169
|
+
variables={
|
|
170
|
+
"task": task,
|
|
171
|
+
"tools": self.tools,
|
|
172
|
+
"managed_agents": self.managed_agents,
|
|
173
|
+
"remaining_steps": (self.max_steps - step),
|
|
174
|
+
"context_description": self.context_description,
|
|
175
|
+
},
|
|
176
|
+
),
|
|
177
|
+
}
|
|
178
|
+
],
|
|
179
|
+
)
|
|
180
|
+
input_messages = [plan_update_pre] + memory_messages + [plan_update_post]
|
|
181
|
+
if self.stream_outputs and hasattr(self.model, "generate_stream"):
|
|
182
|
+
plan_message_content = ""
|
|
183
|
+
input_tokens, output_tokens = 0, 0
|
|
184
|
+
with Live("", console=self.logger.console, vertical_overflow="visible") as live:
|
|
185
|
+
for event in self.model.generate_stream(
|
|
186
|
+
input_messages,
|
|
187
|
+
stop_sequences=["<end_plan>"],
|
|
188
|
+
): # type: ignore
|
|
189
|
+
if event.content is not None:
|
|
190
|
+
plan_message_content += event.content
|
|
191
|
+
live.update(Markdown(plan_message_content))
|
|
192
|
+
if event.token_usage:
|
|
193
|
+
output_tokens += event.token_usage.output_tokens
|
|
194
|
+
input_tokens = event.token_usage.input_tokens
|
|
195
|
+
yield event
|
|
196
|
+
else:
|
|
197
|
+
plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
|
|
198
|
+
plan_message_content = plan_message.content
|
|
199
|
+
if plan_message.token_usage is not None:
|
|
200
|
+
input_tokens, output_tokens = (
|
|
201
|
+
plan_message.token_usage.input_tokens,
|
|
202
|
+
plan_message.token_usage.output_tokens,
|
|
203
|
+
)
|
|
204
|
+
plan = textwrap.dedent(
|
|
205
|
+
f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere are the facts I know and my new/updated plan of action to solve the task:\n```\n{plan_message_content}\n```"""
|
|
206
|
+
)
|
|
207
|
+
log_headline = "Initial plan" if is_first_step else "Updated plan"
|
|
208
|
+
self.logger.log(Rule(f"[bold]{log_headline}", style="orange"), Text(plan), level=LogLevel.INFO)
|
|
209
|
+
yield PlanningStep(
|
|
210
|
+
model_input_messages=input_messages,
|
|
211
|
+
plan=plan,
|
|
212
|
+
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content=plan_message_content),
|
|
213
|
+
token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
|
214
|
+
timing=Timing(start_time=start_time, end_time=time.time()),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# def _curate_messages(self, input_messages: list[ChatMessage]) -> list[ChatMessage]:
|
|
218
|
+
# """
|
|
219
|
+
# Try returning:
|
|
220
|
+
# - System Prompt + task
|
|
221
|
+
# - Current Plan
|
|
222
|
+
# - Summary of previous conversation
|
|
223
|
+
# """
|
|
224
|
+
# # initialize with the system prompt & original task
|
|
225
|
+
# curated_messages = input_messages[:2]
|
|
226
|
+
|
|
227
|
+
# # find the last planning step message
|
|
228
|
+
# idx = len(self.memory.steps) - 1
|
|
229
|
+
# while idx > -1:
|
|
230
|
+
# step = self.memory.steps[idx]
|
|
231
|
+
# if isinstance(step, PlanningStep):
|
|
232
|
+
# curated_messages.append(step.model_output_message)
|
|
233
|
+
# break
|
|
234
|
+
# idx -= 1
|
|
235
|
+
|
|
236
|
+
# # add summary of chat history
|
|
237
|
+
# history = self.pz_memory.search("A condensed summary of the execution trace of the agent.", run_id=self.run_id)
|
|
238
|
+
# for msg in history["results"]:
|
|
239
|
+
# pass
|
|
240
|
+
|
|
241
|
+
# return curated_messages
|
|
242
|
+
|
|
243
|
+
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput | ToolOutput]:
|
|
244
|
+
"""
|
|
245
|
+
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
|
246
|
+
Yields ChatMessageStreamDelta during the run if streaming is enabled.
|
|
247
|
+
At the end, yields either None if the step is not final, or the final answer.
|
|
248
|
+
"""
|
|
249
|
+
memory_messages = self.write_memory_to_messages()
|
|
250
|
+
|
|
251
|
+
input_messages = memory_messages.copy()
|
|
252
|
+
|
|
253
|
+
### Generate model output ###
|
|
254
|
+
memory_step.model_input_messages = input_messages
|
|
255
|
+
try:
|
|
256
|
+
additional_args: dict[str, Any] = {}
|
|
257
|
+
if self.grammar:
|
|
258
|
+
additional_args["grammar"] = self.grammar
|
|
259
|
+
if self._use_structured_outputs_internally:
|
|
260
|
+
additional_args["response_format"] = CODEAGENT_RESPONSE_FORMAT
|
|
261
|
+
if self.stream_outputs:
|
|
262
|
+
output_stream = self.model.generate_stream(
|
|
263
|
+
input_messages,
|
|
264
|
+
stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
|
|
265
|
+
**additional_args,
|
|
266
|
+
)
|
|
267
|
+
chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
|
|
268
|
+
with Live("", console=self.logger.console, vertical_overflow="visible") as live:
|
|
269
|
+
for event in output_stream:
|
|
270
|
+
chat_message_stream_deltas.append(event)
|
|
271
|
+
live.update(
|
|
272
|
+
Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
|
|
273
|
+
)
|
|
274
|
+
yield event
|
|
275
|
+
chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
|
|
276
|
+
memory_step.model_output_message = chat_message
|
|
277
|
+
output_text = chat_message.content
|
|
278
|
+
else:
|
|
279
|
+
chat_message: ChatMessage = self.model.generate(
|
|
280
|
+
input_messages,
|
|
281
|
+
stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
|
|
282
|
+
**additional_args,
|
|
283
|
+
)
|
|
284
|
+
memory_step.model_output_message = chat_message
|
|
285
|
+
output_text = chat_message.content
|
|
286
|
+
self.logger.log_markdown(
|
|
287
|
+
content=output_text,
|
|
288
|
+
title="Output message of the LLM:",
|
|
289
|
+
level=LogLevel.DEBUG,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# This adds <end_code> sequence to the history.
|
|
293
|
+
# This will nudge ulterior LLM calls to finish with <end_code>, thus efficiently stopping generation.
|
|
294
|
+
if output_text and output_text.strip().endswith("```"):
|
|
295
|
+
output_text += "<end_code>"
|
|
296
|
+
memory_step.model_output_message.content = output_text
|
|
297
|
+
|
|
298
|
+
memory_step.token_usage = chat_message.token_usage
|
|
299
|
+
memory_step.model_output = output_text
|
|
300
|
+
except Exception as e:
|
|
301
|
+
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
|
|
302
|
+
|
|
303
|
+
### Parse output ###
|
|
304
|
+
try:
|
|
305
|
+
if self._use_structured_outputs_internally:
|
|
306
|
+
code_action = json.loads(output_text)["code"]
|
|
307
|
+
code_action = extract_code_from_text(code_action) or code_action
|
|
308
|
+
else:
|
|
309
|
+
code_action = parse_code_blobs(output_text)
|
|
310
|
+
code_action = fix_final_answer_code(code_action)
|
|
311
|
+
memory_step.code_action = code_action
|
|
312
|
+
except Exception as e:
|
|
313
|
+
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
|
314
|
+
raise AgentParsingError(error_msg, self.logger) from e
|
|
315
|
+
|
|
316
|
+
memory_step.tool_calls = [
|
|
317
|
+
ToolCall(
|
|
318
|
+
name="python_interpreter",
|
|
319
|
+
arguments=code_action,
|
|
320
|
+
id=f"call_{len(self.memory.steps)}",
|
|
321
|
+
)
|
|
322
|
+
]
|
|
323
|
+
|
|
324
|
+
### Execute action ###
|
|
325
|
+
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
|
|
326
|
+
is_final_answer = False
|
|
327
|
+
try:
|
|
328
|
+
output, execution_logs, is_final_answer = self.python_executor(code_action)
|
|
329
|
+
execution_outputs_console = []
|
|
330
|
+
if len(execution_logs) > 0:
|
|
331
|
+
execution_outputs_console += [
|
|
332
|
+
Text("Execution logs:", style="bold"),
|
|
333
|
+
Text(execution_logs),
|
|
334
|
+
]
|
|
335
|
+
observation = "Execution logs:\n" + execution_logs
|
|
336
|
+
except Exception as e:
|
|
337
|
+
if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state:
|
|
338
|
+
execution_logs = str(self.python_executor.state["_print_outputs"])
|
|
339
|
+
if len(execution_logs) > 0:
|
|
340
|
+
execution_outputs_console = [
|
|
341
|
+
Text("Execution logs:", style="bold"),
|
|
342
|
+
Text(execution_logs),
|
|
343
|
+
]
|
|
344
|
+
memory_step.observations = "Execution logs:\n" + execution_logs
|
|
345
|
+
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
|
|
346
|
+
error_msg = str(e)
|
|
347
|
+
if "Import of " in error_msg and " is not allowed" in error_msg:
|
|
348
|
+
self.logger.log(
|
|
349
|
+
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
|
|
350
|
+
level=LogLevel.INFO,
|
|
351
|
+
)
|
|
352
|
+
raise AgentExecutionError(error_msg, self.logger) from e
|
|
353
|
+
|
|
354
|
+
truncated_output = truncate_content(str(output))
|
|
355
|
+
observation += "Last output from code snippet:\n" + truncated_output
|
|
356
|
+
memory_step.observations = observation
|
|
357
|
+
|
|
358
|
+
# # TODO: add output to self.pz_memory
|
|
359
|
+
# def get_role(msg_role):
|
|
360
|
+
# return str(msg_role).split(".")[-1].lower()
|
|
361
|
+
|
|
362
|
+
# messages = [
|
|
363
|
+
# {"role": get_role(memory_step.model_output_message.role), "content": memory_step.model_output_message.content},
|
|
364
|
+
# {"role": "user", "content": memory_step.observations},
|
|
365
|
+
# ]
|
|
366
|
+
# self.pz_memory.add(messages, run_id=self.run_id, agent_id=self.name)
|
|
367
|
+
|
|
368
|
+
execution_outputs_console += [
|
|
369
|
+
Text(
|
|
370
|
+
f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",
|
|
371
|
+
style=(f"bold {YELLOW_HEX}" if is_final_answer else ""),
|
|
372
|
+
),
|
|
373
|
+
]
|
|
374
|
+
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
|
|
375
|
+
memory_step.action_output = output
|
|
376
|
+
yield ActionOutput(output=output, is_final_answer=is_final_answer)
|
|
377
|
+
|
|
378
|
+
def _run_stream(
|
|
379
|
+
self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
|
|
380
|
+
) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
|
|
381
|
+
"""
|
|
382
|
+
Execute the agent.
|
|
383
|
+
"""
|
|
384
|
+
self.step_number = 1
|
|
385
|
+
returned_final_answer = False
|
|
386
|
+
while not returned_final_answer and self.step_number <= self.max_steps:
|
|
387
|
+
# Run a planning step if scheduled
|
|
388
|
+
if self.planning_interval is not None and (
|
|
389
|
+
self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0
|
|
390
|
+
):
|
|
391
|
+
planning_start_time = time.time()
|
|
392
|
+
planning_step = None
|
|
393
|
+
for element in self._generate_planning_step(
|
|
394
|
+
self.task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
|
|
395
|
+
): # Don't use the attribute step_number here, because there can be steps from previous runs
|
|
396
|
+
yield element
|
|
397
|
+
planning_step = element
|
|
398
|
+
assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
|
|
399
|
+
self.memory.steps.append(planning_step)
|
|
400
|
+
planning_end_time = time.time()
|
|
401
|
+
planning_step.timing = Timing(
|
|
402
|
+
start_time=planning_start_time,
|
|
403
|
+
end_time=planning_end_time,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Start action step!
|
|
407
|
+
action_step_start_time = time.time()
|
|
408
|
+
action_step = ActionStep(
|
|
409
|
+
step_number=self.step_number,
|
|
410
|
+
timing=Timing(start_time=action_step_start_time),
|
|
411
|
+
observations_images=images,
|
|
412
|
+
)
|
|
413
|
+
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
|
|
414
|
+
try:
|
|
415
|
+
for output in self._step_stream(action_step):
|
|
416
|
+
# Yield streaming deltas
|
|
417
|
+
if not isinstance(output, (ActionOutput, ToolOutput)):
|
|
418
|
+
# non-action, non-tool output
|
|
419
|
+
yield output
|
|
420
|
+
|
|
421
|
+
if isinstance(output, (ActionOutput, ToolOutput)) and output.is_final_answer:
|
|
422
|
+
if self.final_answer_checks:
|
|
423
|
+
self._validate_final_answer(output.output)
|
|
424
|
+
returned_final_answer = True
|
|
425
|
+
action_step.is_final_answer = True
|
|
426
|
+
final_answer = output.output
|
|
427
|
+
# handle final step
|
|
428
|
+
except AgentGenerationError as e:
|
|
429
|
+
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
|
|
430
|
+
raise e
|
|
431
|
+
except AgentError as e:
|
|
432
|
+
# Other AgentError types are caused by the Model, so we should log them and iterate.
|
|
433
|
+
action_step.error = e
|
|
434
|
+
finally:
|
|
435
|
+
self._finalize_step(action_step)
|
|
436
|
+
self.memory.steps.append(action_step)
|
|
437
|
+
yield action_step
|
|
438
|
+
self.step_number += 1
|
|
439
|
+
|
|
440
|
+
if not returned_final_answer and self.step_number == self.max_steps + 1:
|
|
441
|
+
final_answer = self._handle_max_steps_reached(self.task, images)
|
|
442
|
+
yield action_step
|
|
443
|
+
yield FinalAnswerStep(handle_agent_output_types(final_answer))
|
|
444
|
+
|
|
445
|
+
def run(
|
|
446
|
+
self,
|
|
447
|
+
task: str,
|
|
448
|
+
stream: bool = False,
|
|
449
|
+
reset: bool = True,
|
|
450
|
+
images: list["PIL.Image.Image"] | None = None,
|
|
451
|
+
additional_args: dict | None = None,
|
|
452
|
+
max_steps: int | None = None,
|
|
453
|
+
):
|
|
454
|
+
"""
|
|
455
|
+
Run the agent for the given task.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
task (`str`): Task to perform.
|
|
459
|
+
stream (`bool`): Whether to run in streaming mode.
|
|
460
|
+
If `True`, returns a generator that yields each step as it is executed. You must iterate over this generator to process the individual steps (e.g., using a for loop or `next()`).
|
|
461
|
+
If `False`, executes all steps internally and returns only the final answer after completion.
|
|
462
|
+
reset (`bool`): Whether to reset the conversation or keep it going from previous run.
|
|
463
|
+
images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
|
|
464
|
+
additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
|
|
465
|
+
max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value.
|
|
466
|
+
|
|
467
|
+
Example:
|
|
468
|
+
```py
|
|
469
|
+
from smolagents import CodeAgent
|
|
470
|
+
agent = CodeAgent(tools=[])
|
|
471
|
+
agent.run("What is the result of 2 power 3.7384?")
|
|
472
|
+
```
|
|
473
|
+
"""
|
|
474
|
+
max_steps = max_steps or self.max_steps
|
|
475
|
+
self.task = task
|
|
476
|
+
self.interrupt_switch = False
|
|
477
|
+
if additional_args is not None:
|
|
478
|
+
self.state.update(additional_args)
|
|
479
|
+
self.task += f"""
|
|
480
|
+
You have been provided with these additional arguments, that you can access using the keys as variables in your python code:
|
|
481
|
+
{str(additional_args)}."""
|
|
482
|
+
|
|
483
|
+
self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt)
|
|
484
|
+
if reset:
|
|
485
|
+
self.memory.reset()
|
|
486
|
+
self.monitor.reset()
|
|
487
|
+
|
|
488
|
+
self.logger.log_task(
|
|
489
|
+
content=self.task.strip(),
|
|
490
|
+
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
|
|
491
|
+
level=LogLevel.INFO,
|
|
492
|
+
title=self.name if hasattr(self, "name") else None,
|
|
493
|
+
)
|
|
494
|
+
self.memory.steps.append(TaskStep(task=self.task, task_images=images))
|
|
495
|
+
|
|
496
|
+
if getattr(self, "python_executor", None):
|
|
497
|
+
self.python_executor.send_variables(variables=self.state)
|
|
498
|
+
self.python_executor.send_tools({**self.tools, **self.managed_agents})
|
|
499
|
+
|
|
500
|
+
if stream:
|
|
501
|
+
# The steps are returned as they are executed through a generator to iterate on.
|
|
502
|
+
return self._run_stream(task=self.task, max_steps=max_steps, images=images)
|
|
503
|
+
run_start_time = time.time()
|
|
504
|
+
# Outputs are returned only at the end. We only look at the last step.
|
|
505
|
+
|
|
506
|
+
steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images))
|
|
507
|
+
assert isinstance(steps[-1], FinalAnswerStep)
|
|
508
|
+
output = steps[-1].output
|
|
509
|
+
|
|
510
|
+
if self.return_full_result:
|
|
511
|
+
total_input_tokens = 0
|
|
512
|
+
total_output_tokens = 0
|
|
513
|
+
correct_token_usage = True
|
|
514
|
+
for step in self.memory.steps:
|
|
515
|
+
if isinstance(step, (ActionStep, PlanningStep)):
|
|
516
|
+
if step.token_usage is None:
|
|
517
|
+
correct_token_usage = False
|
|
518
|
+
break
|
|
519
|
+
else:
|
|
520
|
+
total_input_tokens += step.token_usage.input_tokens
|
|
521
|
+
total_output_tokens += step.token_usage.output_tokens
|
|
522
|
+
if correct_token_usage:
|
|
523
|
+
token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens)
|
|
524
|
+
else:
|
|
525
|
+
token_usage = None
|
|
526
|
+
|
|
527
|
+
if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError):
|
|
528
|
+
state = "max_steps_error"
|
|
529
|
+
else:
|
|
530
|
+
state = "success"
|
|
531
|
+
|
|
532
|
+
messages = self.memory.get_full_steps()
|
|
533
|
+
|
|
534
|
+
return RunResult(
|
|
535
|
+
output=output,
|
|
536
|
+
token_usage=token_usage,
|
|
537
|
+
messages=messages,
|
|
538
|
+
timing=Timing(start_time=run_start_time, end_time=time.time()),
|
|
539
|
+
state=state,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
return output
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
class PZBaseManagedAgent(PZBaseAgent):
|
|
546
|
+
|
|
547
|
+
def __call__(self, task: str, **kwargs):
|
|
548
|
+
"""Adds additional prompting for the managed agent, runs it, and wraps the output.
|
|
549
|
+
This method is called only by a managed agent.
|
|
550
|
+
"""
|
|
551
|
+
full_task = populate_template(
|
|
552
|
+
self.prompt_templates["managed_agent"]["task"],
|
|
553
|
+
variables=dict(name=self.name, task=task, context_description=self.context_description),
|
|
554
|
+
)
|
|
555
|
+
result = self.run(full_task, **kwargs)
|
|
556
|
+
report = result.output if isinstance(result, RunResult) else result
|
|
557
|
+
answer = populate_template(
|
|
558
|
+
self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report)
|
|
559
|
+
)
|
|
560
|
+
if self.provide_run_summary:
|
|
561
|
+
answer += "\n\nFor more detail, find below a summary of this agent's work:\n<summary_of_work>\n"
|
|
562
|
+
for message in self.write_memory_to_messages(summary_mode=True):
|
|
563
|
+
content = message.content
|
|
564
|
+
answer += "\n" + truncate_content(str(content)) + "\n---"
|
|
565
|
+
answer += "\n</summary_of_work>"
|
|
566
|
+
return answer
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
class DataDiscoveryAgent(PZBaseManagedAgent):
|
|
570
|
+
def __init__(self, run_id: str, context_description: str, *args, **kwargs):
|
|
571
|
+
self.description = """A team member that will search a data repository to find files which help to answer your question.
|
|
572
|
+
Ask him for all your questions that require searching a repository of relevant data.
|
|
573
|
+
Provide him as much context as possible, in particular if you need to search on a specific timeframe!
|
|
574
|
+
And don't hesitate to provide him with a complex search task, like finding a difference between two files.
|
|
575
|
+
Your request must be a real sentence, not a keyword search! Like "Find me this information (...)" rather than a few keywords.
|
|
576
|
+
"""
|
|
577
|
+
prompt_templates = PromptTemplates(
|
|
578
|
+
system_prompt=CODE_AGENT_SYSTEM_PROMPT,
|
|
579
|
+
planning=PlanningPromptTemplate(
|
|
580
|
+
initial_plan=DATA_DISCOVERY_AGENT_INITIAL_PLAN_PROMPT,
|
|
581
|
+
update_plan_pre_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_PRE_MESSAGES_PROMPT,
|
|
582
|
+
update_plan_post_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_POST_MESSAGES_PROMPT,
|
|
583
|
+
),
|
|
584
|
+
managed_agent=ManagedAgentPromptTemplate(task=DATA_DISCOVERY_AGENT_TASK_PROMPT, report=DATA_DISCOVERY_AGENT_REPORT_PROMPT),
|
|
585
|
+
final_answer=FinalAnswerPromptTemplate(pre_messages=FINAL_ANSWER_PRE_MESSAGES_PROMPT, post_messages=FINAL_ANSWER_POST_MESSAGES_PROMPT),
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
super().__init__(
|
|
589
|
+
*args,
|
|
590
|
+
run_id=run_id,
|
|
591
|
+
context_description=context_description,
|
|
592
|
+
prompt_templates=prompt_templates,
|
|
593
|
+
max_steps=20,
|
|
594
|
+
verbosity_level=2,
|
|
595
|
+
planning_interval=4,
|
|
596
|
+
name="data_discovery_agent",
|
|
597
|
+
description=self.description,
|
|
598
|
+
provide_run_summary=True,
|
|
599
|
+
**kwargs,
|
|
600
|
+
)
|
|
601
|
+
self.prompt_templates["managed_agent"]["task"] += """Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information."""
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
class SearchManagerAgent(PZBaseAgent):
|
|
605
|
+
def __init__(self, run_id: str, context_description: str, *args, **kwargs):
|
|
606
|
+
prompt_templates = PromptTemplates(
|
|
607
|
+
system_prompt=CODE_AGENT_SYSTEM_PROMPT,
|
|
608
|
+
planning=PlanningPromptTemplate(
|
|
609
|
+
initial_plan=DATA_DISCOVERY_AGENT_INITIAL_PLAN_PROMPT,
|
|
610
|
+
update_plan_pre_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_PRE_MESSAGES_PROMPT,
|
|
611
|
+
update_plan_post_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_POST_MESSAGES_PROMPT,
|
|
612
|
+
),
|
|
613
|
+
managed_agent=ManagedAgentPromptTemplate(task=DATA_DISCOVERY_AGENT_TASK_PROMPT, report=DATA_DISCOVERY_AGENT_REPORT_PROMPT),
|
|
614
|
+
final_answer=FinalAnswerPromptTemplate(pre_messages=FINAL_ANSWER_PRE_MESSAGES_PROMPT, post_messages=FINAL_ANSWER_POST_MESSAGES_PROMPT),
|
|
615
|
+
)
|
|
616
|
+
super().__init__(
|
|
617
|
+
*args,
|
|
618
|
+
run_id=run_id,
|
|
619
|
+
context_description=context_description,
|
|
620
|
+
prompt_templates=prompt_templates,
|
|
621
|
+
max_steps=12,
|
|
622
|
+
verbosity_level=2,
|
|
623
|
+
additional_authorized_imports=["*"],
|
|
624
|
+
planning_interval=4,
|
|
625
|
+
return_full_result=True,
|
|
626
|
+
**kwargs,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# class ManagerAgent(CodeAgent):
|
|
630
|
+
|
|
631
|
+
# def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput | ToolOutput]:
|
|
632
|
+
# """
|
|
633
|
+
# Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
|
634
|
+
# Yields ChatMessageStreamDelta during the run if streaming is enabled.
|
|
635
|
+
# At the end, yields either None if the step is not final, or the final answer.
|
|
636
|
+
# """
|
|
637
|
+
# raise NotImplementedError("This method should be implemented in child classes")
|