palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +343 -209
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +639 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +62 -6
  19. palimpzest/prompts/filter_prompts.py +51 -6
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
  22. palimpzest/prompts/prompt_factory.py +375 -47
  23. palimpzest/prompts/split_proposer_prompts.py +1 -1
  24. palimpzest/prompts/util_phrases.py +5 -0
  25. palimpzest/prompts/validator.py +239 -0
  26. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  27. palimpzest/query/execution/execution_strategy.py +210 -317
  28. palimpzest/query/execution/execution_strategy_type.py +5 -7
  29. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  30. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  31. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  32. palimpzest/query/generators/generators.py +160 -331
  33. palimpzest/query/operators/__init__.py +15 -5
  34. palimpzest/query/operators/aggregate.py +50 -33
  35. palimpzest/query/operators/compute.py +201 -0
  36. palimpzest/query/operators/convert.py +33 -19
  37. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  38. palimpzest/query/operators/distinct.py +62 -0
  39. palimpzest/query/operators/filter.py +26 -16
  40. palimpzest/query/operators/join.py +403 -0
  41. palimpzest/query/operators/limit.py +3 -3
  42. palimpzest/query/operators/logical.py +205 -77
  43. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  44. palimpzest/query/operators/physical.py +27 -21
  45. palimpzest/query/operators/project.py +3 -3
  46. palimpzest/query/operators/rag_convert.py +7 -7
  47. palimpzest/query/operators/retrieve.py +9 -9
  48. palimpzest/query/operators/scan.py +81 -42
  49. palimpzest/query/operators/search.py +524 -0
  50. palimpzest/query/operators/split_convert.py +10 -8
  51. palimpzest/query/optimizer/__init__.py +7 -9
  52. palimpzest/query/optimizer/cost_model.py +108 -441
  53. palimpzest/query/optimizer/optimizer.py +123 -181
  54. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  55. palimpzest/query/optimizer/plan.py +352 -67
  56. palimpzest/query/optimizer/primitives.py +43 -19
  57. palimpzest/query/optimizer/rules.py +484 -646
  58. palimpzest/query/optimizer/tasks.py +127 -58
  59. palimpzest/query/processor/config.py +42 -76
  60. palimpzest/query/processor/query_processor.py +73 -18
  61. palimpzest/query/processor/query_processor_factory.py +46 -38
  62. palimpzest/schemabuilder/schema_builder.py +15 -28
  63. palimpzest/utils/model_helpers.py +32 -77
  64. palimpzest/utils/progress.py +114 -102
  65. palimpzest/validator/__init__.py +0 -0
  66. palimpzest/validator/validator.py +306 -0
  67. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
  68. palimpzest-0.8.1.dist-info/RECORD +95 -0
  69. palimpzest/core/lib/fields.py +0 -141
  70. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  71. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  72. palimpzest/query/generators/api_client_factory.py +0 -30
  73. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  74. palimpzest/query/operators/map.py +0 -130
  75. palimpzest/query/processor/nosentinel_processor.py +0 -33
  76. palimpzest/query/processor/processing_strategy_type.py +0 -28
  77. palimpzest/query/processor/sentinel_processor.py +0 -88
  78. palimpzest/query/processor/streaming_processor.py +0 -149
  79. palimpzest/sets.py +0 -405
  80. palimpzest/utils/datareader_helpers.py +0 -61
  81. palimpzest/utils/demo_helpers.py +0 -75
  82. palimpzest/utils/field_helpers.py +0 -69
  83. palimpzest/utils/generation_helpers.py +0 -69
  84. palimpzest/utils/sandbox.py +0 -183
  85. palimpzest-0.7.21.dist-info/RECORD +0 -95
  86. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
  88. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
  89. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,637 @@
1
+ import json
2
+ import textwrap
3
+ import time
4
+ from collections.abc import Generator
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from rich.console import Group
8
+ from rich.live import Live
9
+ from rich.markdown import Markdown
10
+ from rich.rule import Rule
11
+ from rich.text import Text
12
+
13
+ if TYPE_CHECKING:
14
+ import PIL.Image
15
+
16
+ from smolagents.agent_types import handle_agent_output_types
17
+ from smolagents.agents import (
18
+ ActionOutput,
19
+ CodeAgent,
20
+ FinalAnswerPromptTemplate,
21
+ ManagedAgentPromptTemplate,
22
+ PlanningPromptTemplate,
23
+ PromptTemplates,
24
+ RunResult,
25
+ ToolOutput,
26
+ populate_template,
27
+ )
28
+ from smolagents.local_python_executor import fix_final_answer_code
29
+ from smolagents.memory import (
30
+ ActionStep,
31
+ FinalAnswerStep,
32
+ PlanningStep,
33
+ SystemPromptStep,
34
+ TaskStep,
35
+ Timing,
36
+ TokenUsage,
37
+ ToolCall,
38
+ )
39
+ from smolagents.models import (
40
+ CODEAGENT_RESPONSE_FORMAT,
41
+ ChatMessage,
42
+ ChatMessageStreamDelta,
43
+ MessageRole,
44
+ agglomerate_stream_deltas,
45
+ )
46
+ from smolagents.monitoring import YELLOW_HEX, LogLevel
47
+ from smolagents.utils import (
48
+ AgentError,
49
+ AgentExecutionError,
50
+ AgentGenerationError,
51
+ AgentMaxStepsError,
52
+ AgentParsingError,
53
+ extract_code_from_text,
54
+ parse_code_blobs,
55
+ truncate_content,
56
+ )
57
+
58
+ from palimpzest.prompts import (
59
+ CODE_AGENT_SYSTEM_PROMPT,
60
+ DATA_DISCOVERY_AGENT_INITIAL_PLAN_PROMPT,
61
+ DATA_DISCOVERY_AGENT_REPORT_PROMPT,
62
+ DATA_DISCOVERY_AGENT_TASK_PROMPT,
63
+ DATA_DISCOVERY_AGENT_UPDATE_PLAN_POST_MESSAGES_PROMPT,
64
+ DATA_DISCOVERY_AGENT_UPDATE_PLAN_PRE_MESSAGES_PROMPT,
65
+ FINAL_ANSWER_POST_MESSAGES_PROMPT,
66
+ FINAL_ANSWER_PRE_MESSAGES_PROMPT,
67
+ )
68
+
69
+
70
+ # TODO: make this use memory the way you want
71
+ class PZBaseAgent(CodeAgent):
72
+ def __init__(self, run_id: str, context_description: str, *args, **kwargs):
73
+ # memory_config = {
74
+ # "vector_store": {
75
+ # "provider": "chroma",
76
+ # "config": {
77
+ # "collection_name": f"palimpzest-memory-{self.__class__.__name__}",
78
+ # "path": "./pz-chroma",
79
+ # }
80
+ # }
81
+ # }
82
+ # self.pz_memory = Memory.from_config(memory_config)
83
+ self.run_id = run_id
84
+ self.context_description = context_description
85
+ super().__init__(*args, **kwargs)
86
+
87
+ def write_memory_to_messages(
88
+ self,
89
+ summary_mode: bool = False,
90
+ ) -> list[ChatMessage]:
91
+ """
92
+ Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
93
+ that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help
94
+ the LLM.
95
+ """
96
+ messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
97
+ for memory_step in self.memory.steps:
98
+ messages.extend(memory_step.to_messages(summary_mode=summary_mode))
99
+ return messages
100
+
101
+ def _generate_planning_step(
102
+ self, task, is_first_step: bool, step: int
103
+ ) -> Generator[ChatMessageStreamDelta | PlanningStep]:
104
+ start_time = time.time()
105
+ if is_first_step:
106
+ input_messages = [
107
+ ChatMessage(
108
+ role=MessageRole.USER,
109
+ content=[
110
+ {
111
+ "type": "text",
112
+ "text": populate_template(
113
+ self.prompt_templates["planning"]["initial_plan"],
114
+ variables={"task": task, "tools": self.tools, "managed_agents": self.managed_agents, "context_description": self.context_description},
115
+ ),
116
+ }
117
+ ],
118
+ )
119
+ ]
120
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
121
+ plan_message_content = ""
122
+ output_stream = self.model.generate_stream(input_messages, stop_sequences=["<end_plan>"]) # type: ignore
123
+ input_tokens, output_tokens = 0, 0
124
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
125
+ for event in output_stream:
126
+ if event.content is not None:
127
+ plan_message_content += event.content
128
+ live.update(Markdown(plan_message_content))
129
+ if event.token_usage:
130
+ output_tokens += event.token_usage.output_tokens
131
+ input_tokens = event.token_usage.input_tokens
132
+ yield event
133
+ else:
134
+ plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
135
+ plan_message_content = plan_message.content
136
+ input_tokens, output_tokens = (
137
+ (
138
+ plan_message.token_usage.input_tokens,
139
+ plan_message.token_usage.output_tokens,
140
+ )
141
+ if plan_message.token_usage
142
+ else (None, None)
143
+ )
144
+ plan = textwrap.dedent(
145
+ f"""Here are the facts I know and the plan of action that I will follow to solve the task:\n```\n{plan_message_content}\n```"""
146
+ )
147
+ else:
148
+ # Summary mode removes the system prompt and previous planning messages output by the model.
149
+ # Removing previous planning messages avoids influencing too much the new plan.
150
+ memory_messages = self.write_memory_to_messages(summary_mode=True)
151
+ plan_update_pre = ChatMessage(
152
+ role=MessageRole.SYSTEM,
153
+ content=[
154
+ {
155
+ "type": "text",
156
+ "text": populate_template(
157
+ self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task, "context_description": self.context_description}
158
+ ),
159
+ }
160
+ ],
161
+ )
162
+ plan_update_post = ChatMessage(
163
+ role=MessageRole.USER,
164
+ content=[
165
+ {
166
+ "type": "text",
167
+ "text": populate_template(
168
+ self.prompt_templates["planning"]["update_plan_post_messages"],
169
+ variables={
170
+ "task": task,
171
+ "tools": self.tools,
172
+ "managed_agents": self.managed_agents,
173
+ "remaining_steps": (self.max_steps - step),
174
+ "context_description": self.context_description,
175
+ },
176
+ ),
177
+ }
178
+ ],
179
+ )
180
+ input_messages = [plan_update_pre] + memory_messages + [plan_update_post]
181
+ if self.stream_outputs and hasattr(self.model, "generate_stream"):
182
+ plan_message_content = ""
183
+ input_tokens, output_tokens = 0, 0
184
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
185
+ for event in self.model.generate_stream(
186
+ input_messages,
187
+ stop_sequences=["<end_plan>"],
188
+ ): # type: ignore
189
+ if event.content is not None:
190
+ plan_message_content += event.content
191
+ live.update(Markdown(plan_message_content))
192
+ if event.token_usage:
193
+ output_tokens += event.token_usage.output_tokens
194
+ input_tokens = event.token_usage.input_tokens
195
+ yield event
196
+ else:
197
+ plan_message = self.model.generate(input_messages, stop_sequences=["<end_plan>"])
198
+ plan_message_content = plan_message.content
199
+ if plan_message.token_usage is not None:
200
+ input_tokens, output_tokens = (
201
+ plan_message.token_usage.input_tokens,
202
+ plan_message.token_usage.output_tokens,
203
+ )
204
+ plan = textwrap.dedent(
205
+ f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere are the facts I know and my new/updated plan of action to solve the task:\n```\n{plan_message_content}\n```"""
206
+ )
207
+ log_headline = "Initial plan" if is_first_step else "Updated plan"
208
+ self.logger.log(Rule(f"[bold]{log_headline}", style="orange"), Text(plan), level=LogLevel.INFO)
209
+ yield PlanningStep(
210
+ model_input_messages=input_messages,
211
+ plan=plan,
212
+ model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content=plan_message_content),
213
+ token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens),
214
+ timing=Timing(start_time=start_time, end_time=time.time()),
215
+ )
216
+
217
+ # def _curate_messages(self, input_messages: list[ChatMessage]) -> list[ChatMessage]:
218
+ # """
219
+ # Try returning:
220
+ # - System Prompt + task
221
+ # - Current Plan
222
+ # - Summary of previous conversation
223
+ # """
224
+ # # initialize with the system prompt & original task
225
+ # curated_messages = input_messages[:2]
226
+
227
+ # # find the last planning step message
228
+ # idx = len(self.memory.steps) - 1
229
+ # while idx > -1:
230
+ # step = self.memory.steps[idx]
231
+ # if isinstance(step, PlanningStep):
232
+ # curated_messages.append(step.model_output_message)
233
+ # break
234
+ # idx -= 1
235
+
236
+ # # add summary of chat history
237
+ # history = self.pz_memory.search("A condensed summary of the execution trace of the agent.", run_id=self.run_id)
238
+ # for msg in history["results"]:
239
+ # pass
240
+
241
+ # return curated_messages
242
+
243
+ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput | ToolOutput]:
244
+ """
245
+ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
246
+ Yields ChatMessageStreamDelta during the run if streaming is enabled.
247
+ At the end, yields either None if the step is not final, or the final answer.
248
+ """
249
+ memory_messages = self.write_memory_to_messages()
250
+
251
+ input_messages = memory_messages.copy()
252
+
253
+ ### Generate model output ###
254
+ memory_step.model_input_messages = input_messages
255
+ try:
256
+ additional_args: dict[str, Any] = {}
257
+ if self.grammar:
258
+ additional_args["grammar"] = self.grammar
259
+ if self._use_structured_outputs_internally:
260
+ additional_args["response_format"] = CODEAGENT_RESPONSE_FORMAT
261
+ if self.stream_outputs:
262
+ output_stream = self.model.generate_stream(
263
+ input_messages,
264
+ stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
265
+ **additional_args,
266
+ )
267
+ chat_message_stream_deltas: list[ChatMessageStreamDelta] = []
268
+ with Live("", console=self.logger.console, vertical_overflow="visible") as live:
269
+ for event in output_stream:
270
+ chat_message_stream_deltas.append(event)
271
+ live.update(
272
+ Markdown(agglomerate_stream_deltas(chat_message_stream_deltas).render_as_markdown())
273
+ )
274
+ yield event
275
+ chat_message = agglomerate_stream_deltas(chat_message_stream_deltas)
276
+ memory_step.model_output_message = chat_message
277
+ output_text = chat_message.content
278
+ else:
279
+ chat_message: ChatMessage = self.model.generate(
280
+ input_messages,
281
+ stop_sequences=["<end_code>", "Observation:", "Calling tools:"],
282
+ **additional_args,
283
+ )
284
+ memory_step.model_output_message = chat_message
285
+ output_text = chat_message.content
286
+ self.logger.log_markdown(
287
+ content=output_text,
288
+ title="Output message of the LLM:",
289
+ level=LogLevel.DEBUG,
290
+ )
291
+
292
+ # This adds <end_code> sequence to the history.
293
+ # This will nudge ulterior LLM calls to finish with <end_code>, thus efficiently stopping generation.
294
+ if output_text and output_text.strip().endswith("```"):
295
+ output_text += "<end_code>"
296
+ memory_step.model_output_message.content = output_text
297
+
298
+ memory_step.token_usage = chat_message.token_usage
299
+ memory_step.model_output = output_text
300
+ except Exception as e:
301
+ raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
302
+
303
+ ### Parse output ###
304
+ try:
305
+ if self._use_structured_outputs_internally:
306
+ code_action = json.loads(output_text)["code"]
307
+ code_action = extract_code_from_text(code_action) or code_action
308
+ else:
309
+ code_action = parse_code_blobs(output_text)
310
+ code_action = fix_final_answer_code(code_action)
311
+ memory_step.code_action = code_action
312
+ except Exception as e:
313
+ error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
314
+ raise AgentParsingError(error_msg, self.logger) from e
315
+
316
+ memory_step.tool_calls = [
317
+ ToolCall(
318
+ name="python_interpreter",
319
+ arguments=code_action,
320
+ id=f"call_{len(self.memory.steps)}",
321
+ )
322
+ ]
323
+
324
+ ### Execute action ###
325
+ self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
326
+ is_final_answer = False
327
+ try:
328
+ output, execution_logs, is_final_answer = self.python_executor(code_action)
329
+ execution_outputs_console = []
330
+ if len(execution_logs) > 0:
331
+ execution_outputs_console += [
332
+ Text("Execution logs:", style="bold"),
333
+ Text(execution_logs),
334
+ ]
335
+ observation = "Execution logs:\n" + execution_logs
336
+ except Exception as e:
337
+ if hasattr(self.python_executor, "state") and "_print_outputs" in self.python_executor.state:
338
+ execution_logs = str(self.python_executor.state["_print_outputs"])
339
+ if len(execution_logs) > 0:
340
+ execution_outputs_console = [
341
+ Text("Execution logs:", style="bold"),
342
+ Text(execution_logs),
343
+ ]
344
+ memory_step.observations = "Execution logs:\n" + execution_logs
345
+ self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
346
+ error_msg = str(e)
347
+ if "Import of " in error_msg and " is not allowed" in error_msg:
348
+ self.logger.log(
349
+ "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
350
+ level=LogLevel.INFO,
351
+ )
352
+ raise AgentExecutionError(error_msg, self.logger) from e
353
+
354
+ truncated_output = truncate_content(str(output))
355
+ observation += "Last output from code snippet:\n" + truncated_output
356
+ memory_step.observations = observation
357
+
358
+ # # TODO: add output to self.pz_memory
359
+ # def get_role(msg_role):
360
+ # return str(msg_role).split(".")[-1].lower()
361
+
362
+ # messages = [
363
+ # {"role": get_role(memory_step.model_output_message.role), "content": memory_step.model_output_message.content},
364
+ # {"role": "user", "content": memory_step.observations},
365
+ # ]
366
+ # self.pz_memory.add(messages, run_id=self.run_id, agent_id=self.name)
367
+
368
+ execution_outputs_console += [
369
+ Text(
370
+ f"{('Out - Final answer' if is_final_answer else 'Out')}: {truncated_output}",
371
+ style=(f"bold {YELLOW_HEX}" if is_final_answer else ""),
372
+ ),
373
+ ]
374
+ self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
375
+ memory_step.action_output = output
376
+ yield ActionOutput(output=output, is_final_answer=is_final_answer)
377
+
378
+ def _run_stream(
379
+ self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
380
+ ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
381
+ """
382
+ Execute the agent.
383
+ """
384
+ self.step_number = 1
385
+ returned_final_answer = False
386
+ while not returned_final_answer and self.step_number <= self.max_steps:
387
+ # Run a planning step if scheduled
388
+ if self.planning_interval is not None and (
389
+ self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0
390
+ ):
391
+ planning_start_time = time.time()
392
+ planning_step = None
393
+ for element in self._generate_planning_step(
394
+ self.task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
395
+ ): # Don't use the attribute step_number here, because there can be steps from previous runs
396
+ yield element
397
+ planning_step = element
398
+ assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
399
+ self.memory.steps.append(planning_step)
400
+ planning_end_time = time.time()
401
+ planning_step.timing = Timing(
402
+ start_time=planning_start_time,
403
+ end_time=planning_end_time,
404
+ )
405
+
406
+ # Start action step!
407
+ action_step_start_time = time.time()
408
+ action_step = ActionStep(
409
+ step_number=self.step_number,
410
+ timing=Timing(start_time=action_step_start_time),
411
+ observations_images=images,
412
+ )
413
+ self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
414
+ try:
415
+ for output in self._step_stream(action_step):
416
+ # Yield streaming deltas
417
+ if not isinstance(output, (ActionOutput, ToolOutput)):
418
+ # non-action, non-tool output
419
+ yield output
420
+
421
+ if isinstance(output, (ActionOutput, ToolOutput)) and output.is_final_answer:
422
+ if self.final_answer_checks:
423
+ self._validate_final_answer(output.output)
424
+ returned_final_answer = True
425
+ action_step.is_final_answer = True
426
+ final_answer = output.output
427
+ # handle final step
428
+ except AgentGenerationError as e:
429
+ # Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
430
+ raise e
431
+ except AgentError as e:
432
+ # Other AgentError types are caused by the Model, so we should log them and iterate.
433
+ action_step.error = e
434
+ finally:
435
+ self._finalize_step(action_step)
436
+ self.memory.steps.append(action_step)
437
+ yield action_step
438
+ self.step_number += 1
439
+
440
+ if not returned_final_answer and self.step_number == self.max_steps + 1:
441
+ final_answer = self._handle_max_steps_reached(self.task, images)
442
+ yield action_step
443
+ yield FinalAnswerStep(handle_agent_output_types(final_answer))
444
+
445
+ def run(
446
+ self,
447
+ task: str,
448
+ stream: bool = False,
449
+ reset: bool = True,
450
+ images: list["PIL.Image.Image"] | None = None,
451
+ additional_args: dict | None = None,
452
+ max_steps: int | None = None,
453
+ ):
454
+ """
455
+ Run the agent for the given task.
456
+
457
+ Args:
458
+ task (`str`): Task to perform.
459
+ stream (`bool`): Whether to run in streaming mode.
460
+ If `True`, returns a generator that yields each step as it is executed. You must iterate over this generator to process the individual steps (e.g., using a for loop or `next()`).
461
+ If `False`, executes all steps internally and returns only the final answer after completion.
462
+ reset (`bool`): Whether to reset the conversation or keep it going from previous run.
463
+ images (`list[PIL.Image.Image]`, *optional*): Image(s) objects.
464
+ additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
465
+ max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value.
466
+
467
+ Example:
468
+ ```py
469
+ from smolagents import CodeAgent
470
+ agent = CodeAgent(tools=[])
471
+ agent.run("What is the result of 2 power 3.7384?")
472
+ ```
473
+ """
474
+ max_steps = max_steps or self.max_steps
475
+ self.task = task
476
+ self.interrupt_switch = False
477
+ if additional_args is not None:
478
+ self.state.update(additional_args)
479
+ self.task += f"""
480
+ You have been provided with these additional arguments, that you can access using the keys as variables in your python code:
481
+ {str(additional_args)}."""
482
+
483
+ self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt)
484
+ if reset:
485
+ self.memory.reset()
486
+ self.monitor.reset()
487
+
488
+ self.logger.log_task(
489
+ content=self.task.strip(),
490
+ subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
491
+ level=LogLevel.INFO,
492
+ title=self.name if hasattr(self, "name") else None,
493
+ )
494
+ self.memory.steps.append(TaskStep(task=self.task, task_images=images))
495
+
496
+ if getattr(self, "python_executor", None):
497
+ self.python_executor.send_variables(variables=self.state)
498
+ self.python_executor.send_tools({**self.tools, **self.managed_agents})
499
+
500
+ if stream:
501
+ # The steps are returned as they are executed through a generator to iterate on.
502
+ return self._run_stream(task=self.task, max_steps=max_steps, images=images)
503
+ run_start_time = time.time()
504
+ # Outputs are returned only at the end. We only look at the last step.
505
+
506
+ steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images))
507
+ assert isinstance(steps[-1], FinalAnswerStep)
508
+ output = steps[-1].output
509
+
510
+ if self.return_full_result:
511
+ total_input_tokens = 0
512
+ total_output_tokens = 0
513
+ correct_token_usage = True
514
+ for step in self.memory.steps:
515
+ if isinstance(step, (ActionStep, PlanningStep)):
516
+ if step.token_usage is None:
517
+ correct_token_usage = False
518
+ break
519
+ else:
520
+ total_input_tokens += step.token_usage.input_tokens
521
+ total_output_tokens += step.token_usage.output_tokens
522
+ if correct_token_usage:
523
+ token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens)
524
+ else:
525
+ token_usage = None
526
+
527
+ if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError):
528
+ state = "max_steps_error"
529
+ else:
530
+ state = "success"
531
+
532
+ messages = self.memory.get_full_steps()
533
+
534
+ return RunResult(
535
+ output=output,
536
+ token_usage=token_usage,
537
+ messages=messages,
538
+ timing=Timing(start_time=run_start_time, end_time=time.time()),
539
+ state=state,
540
+ )
541
+
542
+ return output
543
+
544
+
545
+ class PZBaseManagedAgent(PZBaseAgent):
546
+
547
+ def __call__(self, task: str, **kwargs):
548
+ """Adds additional prompting for the managed agent, runs it, and wraps the output.
549
+ This method is called only by a managed agent.
550
+ """
551
+ full_task = populate_template(
552
+ self.prompt_templates["managed_agent"]["task"],
553
+ variables=dict(name=self.name, task=task, context_description=self.context_description),
554
+ )
555
+ result = self.run(full_task, **kwargs)
556
+ report = result.output if isinstance(result, RunResult) else result
557
+ answer = populate_template(
558
+ self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report)
559
+ )
560
+ if self.provide_run_summary:
561
+ answer += "\n\nFor more detail, find below a summary of this agent's work:\n<summary_of_work>\n"
562
+ for message in self.write_memory_to_messages(summary_mode=True):
563
+ content = message.content
564
+ answer += "\n" + truncate_content(str(content)) + "\n---"
565
+ answer += "\n</summary_of_work>"
566
+ return answer
567
+
568
+
569
+ class DataDiscoveryAgent(PZBaseManagedAgent):
570
+ def __init__(self, run_id: str, context_description: str, *args, **kwargs):
571
+ self.description = """A team member that will search a data repository to find files which help to answer your question.
572
+ Ask him for all your questions that require searching a repository of relevant data.
573
+ Provide him as much context as possible, in particular if you need to search on a specific timeframe!
574
+ And don't hesitate to provide him with a complex search task, like finding a difference between two files.
575
+ Your request must be a real sentence, not a keyword search! Like "Find me this information (...)" rather than a few keywords.
576
+ """
577
+ prompt_templates = PromptTemplates(
578
+ system_prompt=CODE_AGENT_SYSTEM_PROMPT,
579
+ planning=PlanningPromptTemplate(
580
+ initial_plan=DATA_DISCOVERY_AGENT_INITIAL_PLAN_PROMPT,
581
+ update_plan_pre_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_PRE_MESSAGES_PROMPT,
582
+ update_plan_post_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_POST_MESSAGES_PROMPT,
583
+ ),
584
+ managed_agent=ManagedAgentPromptTemplate(task=DATA_DISCOVERY_AGENT_TASK_PROMPT, report=DATA_DISCOVERY_AGENT_REPORT_PROMPT),
585
+ final_answer=FinalAnswerPromptTemplate(pre_messages=FINAL_ANSWER_PRE_MESSAGES_PROMPT, post_messages=FINAL_ANSWER_POST_MESSAGES_PROMPT),
586
+ )
587
+
588
+ super().__init__(
589
+ *args,
590
+ run_id=run_id,
591
+ context_description=context_description,
592
+ prompt_templates=prompt_templates,
593
+ max_steps=20,
594
+ verbosity_level=2,
595
+ planning_interval=4,
596
+ name="data_discovery_agent",
597
+ description=self.description,
598
+ provide_run_summary=True,
599
+ **kwargs,
600
+ )
601
+ self.prompt_templates["managed_agent"]["task"] += """Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information."""
602
+
603
+
604
+ class SearchManagerAgent(PZBaseAgent):
605
+ def __init__(self, run_id: str, context_description: str, *args, **kwargs):
606
+ prompt_templates = PromptTemplates(
607
+ system_prompt=CODE_AGENT_SYSTEM_PROMPT,
608
+ planning=PlanningPromptTemplate(
609
+ initial_plan=DATA_DISCOVERY_AGENT_INITIAL_PLAN_PROMPT,
610
+ update_plan_pre_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_PRE_MESSAGES_PROMPT,
611
+ update_plan_post_messages=DATA_DISCOVERY_AGENT_UPDATE_PLAN_POST_MESSAGES_PROMPT,
612
+ ),
613
+ managed_agent=ManagedAgentPromptTemplate(task=DATA_DISCOVERY_AGENT_TASK_PROMPT, report=DATA_DISCOVERY_AGENT_REPORT_PROMPT),
614
+ final_answer=FinalAnswerPromptTemplate(pre_messages=FINAL_ANSWER_PRE_MESSAGES_PROMPT, post_messages=FINAL_ANSWER_POST_MESSAGES_PROMPT),
615
+ )
616
+ super().__init__(
617
+ *args,
618
+ run_id=run_id,
619
+ context_description=context_description,
620
+ prompt_templates=prompt_templates,
621
+ max_steps=12,
622
+ verbosity_level=2,
623
+ additional_authorized_imports=["*"],
624
+ planning_interval=4,
625
+ return_full_result=True,
626
+ **kwargs,
627
+ )
628
+
629
+ # class ManagerAgent(CodeAgent):
630
+
631
+ # def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput | ToolOutput]:
632
+ # """
633
+ # Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
634
+ # Yields ChatMessageStreamDelta during the run if streaming is enabled.
635
+ # At the end, yields either None if the step is not final, or the final answer.
636
+ # """
637
+ # raise NotImplementedError("This method should be implemented in child classes")