planar 0.5.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. planar/_version.py +1 -1
  2. planar/ai/agent.py +155 -283
  3. planar/ai/agent_base.py +170 -0
  4. planar/ai/agent_utils.py +7 -0
  5. planar/ai/pydantic_ai.py +638 -0
  6. planar/ai/test_agent_serialization.py +1 -1
  7. planar/app.py +64 -20
  8. planar/cli.py +39 -27
  9. planar/config.py +45 -36
  10. planar/db/db.py +2 -1
  11. planar/files/storage/azure_blob.py +343 -0
  12. planar/files/storage/base.py +7 -0
  13. planar/files/storage/config.py +70 -7
  14. planar/files/storage/s3.py +6 -6
  15. planar/files/storage/test_azure_blob.py +435 -0
  16. planar/logging/formatter.py +17 -4
  17. planar/logging/test_formatter.py +327 -0
  18. planar/registry_items.py +2 -1
  19. planar/routers/agents_router.py +3 -1
  20. planar/routers/files.py +11 -2
  21. planar/routers/models.py +14 -1
  22. planar/routers/test_agents_router.py +1 -1
  23. planar/routers/test_files_router.py +49 -0
  24. planar/routers/test_routes_security.py +5 -7
  25. planar/routers/test_workflow_router.py +270 -3
  26. planar/routers/workflow.py +95 -36
  27. planar/rules/models.py +36 -39
  28. planar/rules/test_data/account_dormancy_management.json +223 -0
  29. planar/rules/test_data/airline_loyalty_points_calculator.json +262 -0
  30. planar/rules/test_data/applicant_risk_assessment.json +435 -0
  31. planar/rules/test_data/booking_fraud_detection.json +407 -0
  32. planar/rules/test_data/cellular_data_rollover_system.json +258 -0
  33. planar/rules/test_data/clinical_trial_eligibility_screener.json +437 -0
  34. planar/rules/test_data/customer_lifetime_value.json +143 -0
  35. planar/rules/test_data/import_duties_calculator.json +289 -0
  36. planar/rules/test_data/insurance_prior_authorization.json +443 -0
  37. planar/rules/test_data/online_check_in_eligibility_system.json +254 -0
  38. planar/rules/test_data/order_consolidation_system.json +375 -0
  39. planar/rules/test_data/portfolio_risk_monitor.json +471 -0
  40. planar/rules/test_data/supply_chain_risk.json +253 -0
  41. planar/rules/test_data/warehouse_cross_docking.json +237 -0
  42. planar/rules/test_rules.py +750 -6
  43. planar/scaffold_templates/planar.dev.yaml.j2 +6 -6
  44. planar/scaffold_templates/planar.prod.yaml.j2 +9 -5
  45. planar/scaffold_templates/pyproject.toml.j2 +1 -1
  46. planar/security/auth_context.py +21 -0
  47. planar/security/{jwt_middleware.py → auth_middleware.py} +70 -17
  48. planar/security/authorization.py +9 -15
  49. planar/security/tests/test_auth_middleware.py +162 -0
  50. planar/sse/proxy.py +4 -9
  51. planar/test_app.py +92 -1
  52. planar/test_cli.py +81 -59
  53. planar/test_config.py +17 -14
  54. planar/testing/fixtures.py +325 -0
  55. planar/testing/planar_test_client.py +5 -2
  56. planar/utils.py +41 -1
  57. planar/workflows/execution.py +1 -1
  58. planar/workflows/orchestrator.py +5 -0
  59. planar/workflows/serialization.py +12 -6
  60. planar/workflows/step_core.py +3 -1
  61. planar/workflows/test_serialization.py +9 -1
  62. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/METADATA +30 -5
  63. planar-0.8.0.dist-info/RECORD +166 -0
  64. planar/.__init__.py.un~ +0 -0
  65. planar/._version.py.un~ +0 -0
  66. planar/.app.py.un~ +0 -0
  67. planar/.cli.py.un~ +0 -0
  68. planar/.config.py.un~ +0 -0
  69. planar/.context.py.un~ +0 -0
  70. planar/.db.py.un~ +0 -0
  71. planar/.di.py.un~ +0 -0
  72. planar/.engine.py.un~ +0 -0
  73. planar/.files.py.un~ +0 -0
  74. planar/.log_context.py.un~ +0 -0
  75. planar/.log_metadata.py.un~ +0 -0
  76. planar/.logging.py.un~ +0 -0
  77. planar/.object_registry.py.un~ +0 -0
  78. planar/.otel.py.un~ +0 -0
  79. planar/.server.py.un~ +0 -0
  80. planar/.session.py.un~ +0 -0
  81. planar/.sqlalchemy.py.un~ +0 -0
  82. planar/.task_local.py.un~ +0 -0
  83. planar/.test_app.py.un~ +0 -0
  84. planar/.test_config.py.un~ +0 -0
  85. planar/.test_object_config.py.un~ +0 -0
  86. planar/.test_sqlalchemy.py.un~ +0 -0
  87. planar/.test_utils.py.un~ +0 -0
  88. planar/.util.py.un~ +0 -0
  89. planar/.utils.py.un~ +0 -0
  90. planar/ai/.__init__.py.un~ +0 -0
  91. planar/ai/._models.py.un~ +0 -0
  92. planar/ai/.agent.py.un~ +0 -0
  93. planar/ai/.agent_utils.py.un~ +0 -0
  94. planar/ai/.events.py.un~ +0 -0
  95. planar/ai/.files.py.un~ +0 -0
  96. planar/ai/.models.py.un~ +0 -0
  97. planar/ai/.providers.py.un~ +0 -0
  98. planar/ai/.pydantic_ai.py.un~ +0 -0
  99. planar/ai/.pydantic_ai_agent.py.un~ +0 -0
  100. planar/ai/.pydantic_ai_provider.py.un~ +0 -0
  101. planar/ai/.step.py.un~ +0 -0
  102. planar/ai/.test_agent.py.un~ +0 -0
  103. planar/ai/.test_agent_serialization.py.un~ +0 -0
  104. planar/ai/.test_providers.py.un~ +0 -0
  105. planar/ai/.utils.py.un~ +0 -0
  106. planar/ai/providers.py +0 -1088
  107. planar/ai/test_agent.py +0 -1298
  108. planar/ai/test_providers.py +0 -463
  109. planar/db/.db.py.un~ +0 -0
  110. planar/files/.config.py.un~ +0 -0
  111. planar/files/.local.py.un~ +0 -0
  112. planar/files/.local_filesystem.py.un~ +0 -0
  113. planar/files/.model.py.un~ +0 -0
  114. planar/files/.models.py.un~ +0 -0
  115. planar/files/.s3.py.un~ +0 -0
  116. planar/files/.storage.py.un~ +0 -0
  117. planar/files/.test_files.py.un~ +0 -0
  118. planar/files/storage/.__init__.py.un~ +0 -0
  119. planar/files/storage/.base.py.un~ +0 -0
  120. planar/files/storage/.config.py.un~ +0 -0
  121. planar/files/storage/.context.py.un~ +0 -0
  122. planar/files/storage/.local_directory.py.un~ +0 -0
  123. planar/files/storage/.test_local_directory.py.un~ +0 -0
  124. planar/files/storage/.test_s3.py.un~ +0 -0
  125. planar/human/.human.py.un~ +0 -0
  126. planar/human/.test_human.py.un~ +0 -0
  127. planar/logging/.__init__.py.un~ +0 -0
  128. planar/logging/.attributes.py.un~ +0 -0
  129. planar/logging/.formatter.py.un~ +0 -0
  130. planar/logging/.logger.py.un~ +0 -0
  131. planar/logging/.otel.py.un~ +0 -0
  132. planar/logging/.tracer.py.un~ +0 -0
  133. planar/modeling/.mixin.py.un~ +0 -0
  134. planar/modeling/.storage.py.un~ +0 -0
  135. planar/modeling/orm/.planar_base_model.py.un~ +0 -0
  136. planar/object_config/.object_config.py.un~ +0 -0
  137. planar/routers/.__init__.py.un~ +0 -0
  138. planar/routers/.agents_router.py.un~ +0 -0
  139. planar/routers/.crud.py.un~ +0 -0
  140. planar/routers/.decision.py.un~ +0 -0
  141. planar/routers/.event.py.un~ +0 -0
  142. planar/routers/.file_attachment.py.un~ +0 -0
  143. planar/routers/.files.py.un~ +0 -0
  144. planar/routers/.files_router.py.un~ +0 -0
  145. planar/routers/.human.py.un~ +0 -0
  146. planar/routers/.info.py.un~ +0 -0
  147. planar/routers/.models.py.un~ +0 -0
  148. planar/routers/.object_config_router.py.un~ +0 -0
  149. planar/routers/.rule.py.un~ +0 -0
  150. planar/routers/.test_object_config_router.py.un~ +0 -0
  151. planar/routers/.test_workflow_router.py.un~ +0 -0
  152. planar/routers/.workflow.py.un~ +0 -0
  153. planar/rules/.decorator.py.un~ +0 -0
  154. planar/rules/.runner.py.un~ +0 -0
  155. planar/rules/.test_rules.py.un~ +0 -0
  156. planar/security/.jwt_middleware.py.un~ +0 -0
  157. planar/sse/.constants.py.un~ +0 -0
  158. planar/sse/.example.html.un~ +0 -0
  159. planar/sse/.hub.py.un~ +0 -0
  160. planar/sse/.model.py.un~ +0 -0
  161. planar/sse/.proxy.py.un~ +0 -0
  162. planar/testing/.client.py.un~ +0 -0
  163. planar/testing/.memory_storage.py.un~ +0 -0
  164. planar/testing/.planar_test_client.py.un~ +0 -0
  165. planar/testing/.predictable_tracer.py.un~ +0 -0
  166. planar/testing/.synchronizable_tracer.py.un~ +0 -0
  167. planar/testing/.test_memory_storage.py.un~ +0 -0
  168. planar/testing/.workflow_observer.py.un~ +0 -0
  169. planar/workflows/.__init__.py.un~ +0 -0
  170. planar/workflows/.builtin_steps.py.un~ +0 -0
  171. planar/workflows/.concurrency_tracing.py.un~ +0 -0
  172. planar/workflows/.context.py.un~ +0 -0
  173. planar/workflows/.contrib.py.un~ +0 -0
  174. planar/workflows/.decorators.py.un~ +0 -0
  175. planar/workflows/.durable_test.py.un~ +0 -0
  176. planar/workflows/.errors.py.un~ +0 -0
  177. planar/workflows/.events.py.un~ +0 -0
  178. planar/workflows/.exceptions.py.un~ +0 -0
  179. planar/workflows/.execution.py.un~ +0 -0
  180. planar/workflows/.human.py.un~ +0 -0
  181. planar/workflows/.lock.py.un~ +0 -0
  182. planar/workflows/.misc.py.un~ +0 -0
  183. planar/workflows/.model.py.un~ +0 -0
  184. planar/workflows/.models.py.un~ +0 -0
  185. planar/workflows/.notifications.py.un~ +0 -0
  186. planar/workflows/.orchestrator.py.un~ +0 -0
  187. planar/workflows/.runtime.py.un~ +0 -0
  188. planar/workflows/.serialization.py.un~ +0 -0
  189. planar/workflows/.step.py.un~ +0 -0
  190. planar/workflows/.step_core.py.un~ +0 -0
  191. planar/workflows/.sub_workflow_runner.py.un~ +0 -0
  192. planar/workflows/.sub_workflow_scheduler.py.un~ +0 -0
  193. planar/workflows/.test_concurrency.py.un~ +0 -0
  194. planar/workflows/.test_concurrency_detection.py.un~ +0 -0
  195. planar/workflows/.test_human.py.un~ +0 -0
  196. planar/workflows/.test_lock_timeout.py.un~ +0 -0
  197. planar/workflows/.test_orchestrator.py.un~ +0 -0
  198. planar/workflows/.test_race_conditions.py.un~ +0 -0
  199. planar/workflows/.test_serialization.py.un~ +0 -0
  200. planar/workflows/.test_suspend_deserialization.py.un~ +0 -0
  201. planar/workflows/.test_workflow.py.un~ +0 -0
  202. planar/workflows/.tracing.py.un~ +0 -0
  203. planar/workflows/.types.py.un~ +0 -0
  204. planar/workflows/.util.py.un~ +0 -0
  205. planar/workflows/.utils.py.un~ +0 -0
  206. planar/workflows/.workflow.py.un~ +0 -0
  207. planar/workflows/.workflow_wrapper.py.un~ +0 -0
  208. planar/workflows/.wrappers.py.un~ +0 -0
  209. planar-0.5.0.dist-info/RECORD +0 -289
  210. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/WHEEL +0 -0
  211. {planar-0.5.0.dist-info → planar-0.8.0.dist-info}/entry_points.txt +0 -0
planar/_version.py CHANGED
@@ -1 +1 @@
1
- VERSION = "0.5.0"
1
+ VERSION = "0.8.0"
planar/ai/agent.py CHANGED
@@ -1,23 +1,14 @@
1
- from __future__ import annotations
2
-
3
1
  import inspect
4
- from dataclasses import dataclass, field
5
- from typing import (
6
- Any,
7
- Callable,
8
- Dict,
9
- List,
10
- Type,
11
- Union,
12
- cast,
13
- overload,
14
- )
2
+ from dataclasses import dataclass
3
+ from typing import Any, Type, cast
15
4
 
16
5
  from pydantic import BaseModel
6
+ from pydantic_ai import models
17
7
 
8
+ from planar.ai.agent_base import AgentBase
18
9
  from planar.ai.agent_utils import (
19
- AgentEventEmitter,
20
10
  AgentEventType,
11
+ ModelSpec,
21
12
  ToolCallResult,
22
13
  create_tool_definition,
23
14
  extract_files_from_model,
@@ -25,171 +16,33 @@ from planar.ai.agent_utils import (
25
16
  render_template,
26
17
  )
27
18
  from planar.ai.models import (
28
- AgentConfig,
29
19
  AgentRunResult,
30
20
  AssistantMessage,
31
- CompletionResponse,
32
21
  ModelMessage,
33
22
  SystemMessage,
23
+ ToolDefinition,
24
+ ToolMessage,
34
25
  ToolResponse,
35
26
  UserMessage,
36
27
  )
37
- from planar.ai.providers import Anthropic, Gemini, Model, OpenAI
28
+ from planar.ai.pydantic_ai import ModelRunResponse, model_run
38
29
  from planar.logging import get_logger
39
- from planar.modeling.field_helpers import JsonSchema
40
30
  from planar.utils import utc_now
41
- from planar.workflows import as_step
42
31
  from planar.workflows.models import StepType
43
32
 
44
33
  logger = get_logger(__name__)
45
34
 
46
35
 
47
- def _parse_model_string(model_str: str) -> Model:
48
- """Parse a model string (e.g., 'openai:gpt-4.1') into a Model instance."""
49
- parts = model_str.split(":", 1)
50
- if len(parts) != 2:
51
- raise ValueError(
52
- f"Invalid model format: {model_str}. Expected format: 'provider:model_id'"
53
- )
54
-
55
- provider_id, model_id = parts
56
-
57
- if provider_id.lower() == "openai":
58
- return OpenAI.model(model_id)
59
- elif provider_id.lower() == "anthropic":
60
- return Anthropic.model(model_id)
61
- elif provider_id.lower() == "gemini":
62
- return Gemini.model(model_id)
63
- else:
64
- raise ValueError(f"Unsupported provider: {provider_id}")
65
-
66
-
67
36
  @dataclass
68
37
  class Agent[
69
- # TODO: add `= str` default when we upgrade to 3.13
70
38
  TInput: BaseModel | str,
71
39
  TOutput: BaseModel | str,
72
- ]:
73
- """An LLM-powered agent that can be called directly within workflows."""
74
-
75
- name: str
76
- system_prompt: str
77
- output_type: Type[TOutput] | None = None
78
- input_type: Type[TInput] | None = None
79
- user_prompt: str = ""
80
- model: Union[str, Model] = "openai:gpt-4.1"
81
- tools: List[Callable] = field(default_factory=list)
82
- max_turns: int = 2
83
- model_parameters: Dict[str, Any] = field(default_factory=dict)
84
-
85
- # TODO: move here to serialize to frontend
86
- #
87
- # built_in_vars: Dict[str, str] = field(default_factory=lambda: {
88
- # "datetime_now": datetime.datetime.now().isoformat(),
89
- # "date_today": datetime.date.today().isoformat(),
90
- # })
91
-
92
- def __post_init__(self):
93
- if self.input_type:
94
- if (
95
- not issubclass(self.input_type, BaseModel)
96
- and self.input_type is not str
97
- ):
98
- raise ValueError(
99
- "input_type must be 'str' or a subclass of a Pydantic model"
100
- )
101
- if self.max_turns < 1:
102
- raise ValueError("Max_turns must be greater than or equal to 1.")
103
- if self.tools and self.max_turns <= 1:
104
- raise ValueError(
105
- "For tool calling to work, max_turns must be greater than 1."
106
- )
107
-
108
- def input_schema(self) -> JsonSchema | None:
109
- if self.input_type is None:
110
- return None
111
- if self.input_type is str:
112
- return None
113
- assert issubclass(self.input_type, BaseModel), (
114
- "input_type must be a subclass of BaseModel or str"
115
- )
116
- return self.input_type.model_json_schema()
117
-
118
- def output_schema(self) -> JsonSchema | None:
119
- if self.output_type is None:
120
- return None
121
- if self.output_type is str:
122
- return None
123
- assert issubclass(self.output_type, BaseModel), (
124
- "output_type must be a subclass of BaseModel or str"
125
- )
126
- return self.output_type.model_json_schema()
127
-
128
- def to_config(self) -> AgentConfig:
129
- return AgentConfig(
130
- system_prompt=self.system_prompt,
131
- user_prompt=self.user_prompt,
132
- model=str(self.model),
133
- max_turns=self.max_turns,
134
- model_parameters=self.model_parameters,
135
- )
136
-
137
- @overload
138
- async def __call__(
139
- self: "Agent[TInput, str]",
140
- input_value: TInput,
141
- event_emitter: AgentEventEmitter | None = None,
142
- ) -> AgentRunResult[str]: ...
143
-
144
- @overload
145
- async def __call__(
146
- self: "Agent[TInput, TOutput]",
147
- input_value: TInput,
148
- event_emitter: AgentEventEmitter | None = None,
149
- ) -> AgentRunResult[TOutput]: ...
150
-
151
- async def __call__(
152
- self,
153
- input_value: TInput,
154
- event_emitter: AgentEventEmitter | None = None,
155
- ) -> AgentRunResult[Any]:
156
- if self.input_type is not None and not isinstance(input_value, self.input_type):
157
- raise ValueError(
158
- f"Input value must be of type {self.input_type}, but got {type(input_value)}"
159
- )
160
- elif not isinstance(input_value, (str, BaseModel)):
161
- # Should not happen based on type constraints, but just in case
162
- # user does not have type checking enabled
163
- raise ValueError(
164
- "Input value must be a string or a Pydantic model if input_type is not provided"
165
- )
166
-
167
- if self.output_type is None:
168
- run_step = as_step(
169
- self.run_step,
170
- step_type=StepType.AGENT,
171
- display_name=self.name,
172
- return_type=AgentRunResult[str],
173
- )
174
- else:
175
- run_step = as_step(
176
- self.run_step,
177
- step_type=StepType.AGENT,
178
- display_name=self.name,
179
- return_type=AgentRunResult[self.output_type],
180
- )
181
-
182
- result = await run_step(
183
- input_value=input_value,
184
- event_emitter=event_emitter,
185
- )
186
- # Cast the result to ensure type compatibility
187
- return cast(AgentRunResult[TOutput], result)
40
+ ](AgentBase[TInput, TOutput]):
41
+ model: models.KnownModelName | models.Model = "openai:gpt-4o"
188
42
 
189
43
  async def run_step(
190
44
  self,
191
45
  input_value: TInput,
192
- event_emitter: AgentEventEmitter | None = None,
193
46
  ) -> AgentRunResult[TOutput]:
194
47
  """Execute the agent with the provided inputs.
195
48
 
@@ -200,6 +53,7 @@ class Agent[
200
53
  Returns:
201
54
  AgentRunResult containing the agent's response
202
55
  """
56
+ event_emitter = self.event_emitter
203
57
  logger.debug(
204
58
  "agent run_step called", agent_name=self.name, input_type=type(input_value)
205
59
  )
@@ -264,18 +118,18 @@ class Agent[
264
118
  raise ValueError(f"Missing required parameter for prompt formatting: {e}")
265
119
 
266
120
  # Get the LLM provider and model
267
- model_config = config.model
268
- if isinstance(model_config, str):
269
- model = _parse_model_string(model_config)
121
+ if isinstance(self.model, str):
122
+ model = models.infer_model(self.model)
270
123
  else:
271
- model = model_config
124
+ model = self.model
272
125
 
273
126
  # Apply model parameters if specified
127
+ model_settings = None
274
128
  if config.model_parameters:
275
- model = model.with_parameters(**config.model_parameters)
129
+ model_settings = config.model_parameters
276
130
 
277
131
  # Prepare structured messages
278
- messages: List[ModelMessage] = []
132
+ messages: list[ModelMessage] = []
279
133
  if formatted_system_prompt:
280
134
  messages.append(SystemMessage(content=formatted_system_prompt))
281
135
 
@@ -287,167 +141,182 @@ class Agent[
287
141
  if self.tools:
288
142
  tool_definitions = [create_tool_definition(tool) for tool in self.tools]
289
143
 
290
- # Determine output type for the provider call
144
+ # Determine output type for the agent call
291
145
  # Pass the Pydantic model type if output_type is a subclass of BaseModel,
292
146
  # otherwise pass None (indicating string output is expected).
293
- output_type_for_provider: Type[BaseModel] | None = None
147
+ output_type: Type[BaseModel] | None = None
294
148
  # Use issubclass safely by checking if output_type is a type first
295
149
  if inspect.isclass(self.output_type) and issubclass(
296
150
  self.output_type, BaseModel
297
151
  ):
298
- output_type_for_provider = cast(Type[BaseModel], self.output_type)
152
+ output_type = cast(Type[BaseModel], self.output_type)
299
153
 
300
154
  # Execute the LLM call
301
155
  max_turns = config.max_turns
302
156
 
303
- # Single turn completion (default case)
304
- result = None
305
- if not tool_definitions:
157
+ # We use this inner function to pass "model" and "event_emitter",
158
+ # which are not serializable as step parameters.
159
+ async def agent_run_step(
160
+ model_spec: ModelSpec,
161
+ messages: list[ModelMessage],
162
+ turns_left: int,
163
+ tools: list[ToolDefinition] | None = None,
164
+ output_type: Type[BaseModel] | None = None,
165
+ ):
306
166
  logger.debug(
307
- "agent performing single turn completion",
167
+ "agent running",
308
168
  agent_name=self.name,
309
- model=model.model_spec,
310
- output_type=output_type_for_provider,
169
+ model=model_spec,
170
+ model_settings=model_settings,
171
+ output_type=output_type,
311
172
  )
312
- response = await as_step(
313
- model.provider_class.complete,
173
+ if output_type is None:
174
+ return await model_run(
175
+ model=model,
176
+ max_extra_turns=turns_left,
177
+ model_settings=model_settings,
178
+ messages=messages,
179
+ tools=tools or [],
180
+ event_handler=cast(Any, event_emitter),
181
+ )
182
+ else:
183
+ return await model_run(
184
+ model=model,
185
+ max_extra_turns=turns_left,
186
+ model_settings=model_settings,
187
+ messages=messages,
188
+ output_type=output_type,
189
+ tools=tools or [],
190
+ event_handler=cast(Any, event_emitter),
191
+ )
192
+
193
+ model_spec = ModelSpec(
194
+ model_id=str(model),
195
+ parameters=config.model_parameters,
196
+ )
197
+ result = None
198
+ logger.debug(
199
+ "agent performing multi-turn completion with tools",
200
+ agent_name=self.name,
201
+ max_turns=max_turns,
202
+ )
203
+ turns_left = max_turns
204
+ while turns_left > 0:
205
+ turns_left -= 1
206
+ logger.debug("agent turn", agent_name=self.name, turns_left=turns_left)
207
+
208
+ # Get model response
209
+ run_response = await self.as_step_if_durable(
210
+ agent_run_step,
314
211
  step_type=StepType.AGENT,
315
- return_type=CompletionResponse[output_type_for_provider or str],
212
+ return_type=ModelRunResponse[output_type or str],
316
213
  )(
317
- model_spec=model.model_spec,
214
+ model_spec=model_spec,
318
215
  messages=messages,
319
- output_type=output_type_for_provider,
216
+ turns_left=turns_left,
217
+ output_type=output_type,
218
+ tools=tool_definitions or [],
320
219
  )
321
- result = response.content
220
+ response = run_response.response
221
+ turns_left -= run_response.extra_turns_used
322
222
 
323
223
  # Emit response event if event_emitter is provided
324
224
  if event_emitter:
325
225
  event_emitter.emit(AgentEventType.RESPONSE, response.content)
326
- else:
226
+
227
+ # If no tool calls or last turn, return content
228
+ if not response.tool_calls or turns_left == 0:
229
+ logger.debug(
230
+ "agent completion: no tool calls or last turn",
231
+ agent_name=self.name,
232
+ has_content=response.content is not None,
233
+ )
234
+ result = response.content
235
+ break
236
+
237
+ # Process tool calls
327
238
  logger.debug(
328
- "agent performing multi-turn completion with tools",
239
+ "agent received tool calls",
329
240
  agent_name=self.name,
330
- max_turns=max_turns,
241
+ num_tool_calls=len(response.tool_calls),
331
242
  )
332
- # Multi-turn with tools
333
- turns_left = max_turns
334
- while turns_left > 0:
335
- turns_left -= 1
336
- logger.debug("agent turn", agent_name=self.name, turns_left=turns_left)
337
-
338
- # Get model response
339
- response = await as_step(
340
- model.provider_class.complete,
341
- step_type=StepType.AGENT,
342
- return_type=CompletionResponse[output_type_for_provider or str],
343
- )(
344
- model_spec=model.model_spec,
345
- messages=messages,
346
- output_type=output_type_for_provider,
347
- tools=tool_definitions,
348
- )
349
-
350
- # Emit response event if event_emitter is provided
351
- if event_emitter:
352
- event_emitter.emit(AgentEventType.RESPONSE, response.content)
353
-
354
- # If no tool calls or last turn, return content
355
- if not response.tool_calls or turns_left == 0:
356
- logger.debug(
357
- "agent completion: no tool calls or last turn",
358
- agent_name=self.name,
359
- has_content=response.content is not None,
360
- )
361
- result = response.content
362
- break
243
+ assistant_message = AssistantMessage(
244
+ content=None,
245
+ tool_calls=response.tool_calls,
246
+ )
247
+ messages.append(assistant_message)
363
248
 
364
- # Process tool calls
249
+ # Execute each tool and add tool responses to messages
250
+ for tool_call_idx, tool_call in enumerate(response.tool_calls):
365
251
  logger.debug(
366
- "agent received tool calls",
252
+ "agent processing tool call",
367
253
  agent_name=self.name,
368
- num_tool_calls=len(response.tool_calls),
254
+ tool_call_index=tool_call_idx + 1,
255
+ tool_call_id=tool_call.id,
256
+ tool_call_name=tool_call.name,
369
257
  )
370
- assistant_message = AssistantMessage(
371
- content=None,
372
- tool_calls=response.tool_calls,
258
+ # Find the matching tool function
259
+ tool_fn = next(
260
+ (t for t in self.tools if t.__name__ == tool_call.name),
261
+ None,
373
262
  )
374
- messages.append(assistant_message)
375
263
 
376
- # Execute each tool and add tool responses to messages
377
- for tool_call_idx, tool_call in enumerate(response.tool_calls):
378
- logger.debug(
379
- "agent processing tool call",
264
+ if not tool_fn:
265
+ tool_result = f"Error: Tool '{tool_call.name}' not found."
266
+ logger.warning(
267
+ "tool not found for agent",
268
+ tool_name=tool_call.name,
380
269
  agent_name=self.name,
381
- tool_call_index=tool_call_idx + 1,
382
- tool_call_id=tool_call.id,
383
- tool_call_name=tool_call.name,
384
270
  )
385
- # Find the matching tool function
386
- tool_fn = next(
387
- (t for t in self.tools if t.__name__ == tool_call.name),
388
- None,
271
+ else:
272
+ # Execute the tool with the provided arguments
273
+ tool_result = await self.as_step_if_durable(
274
+ tool_fn,
275
+ step_type=StepType.TOOL_CALL,
276
+ )(**tool_call.arguments)
277
+ logger.info(
278
+ "tool executed by agent",
279
+ tool_name=tool_call.name,
280
+ agent_name=self.name,
281
+ result_type=type(tool_result),
389
282
  )
390
283
 
391
- if not tool_fn:
392
- tool_result = f"Error: Tool '{tool_call.name}' not found."
393
- logger.warning(
394
- "tool not found for agent",
395
- tool_name=tool_call.name,
396
- agent_name=self.name,
397
- )
398
- else:
399
- # Execute the tool with the provided arguments
400
- tool_result = await as_step(
401
- tool_fn,
402
- step_type=StepType.TOOL_CALL,
403
- )(**tool_call.arguments)
404
- logger.info(
405
- "tool executed by agent",
406
- tool_name=tool_call.name,
407
- agent_name=self.name,
408
- result_type=type(tool_result),
409
- )
410
-
411
- # Create a tool response
412
- tool_response = ToolResponse(
413
- tool_call_id=tool_call.id or "call_1", content=str(tool_result)
414
- )
284
+ # Create a tool response
285
+ tool_response = ToolResponse(
286
+ tool_call_id=tool_call.id or "call_1", content=str(tool_result)
287
+ )
415
288
 
416
- # Emit tool response event if event_emitter is provided
417
- if event_emitter:
418
- event_emitter.emit(
419
- AgentEventType.TOOL_RESPONSE,
420
- ToolCallResult(
421
- tool_call_id=tool_call.id or "call_1",
422
- tool_call_name=tool_call.name,
423
- content=tool_result,
424
- ),
425
- )
426
-
427
- # Convert the tool response to a message based on provider
428
- tool_message = model.provider_class.format_tool_response(
429
- tool_response
289
+ # Emit tool response event if event_emitter is provided
290
+ if event_emitter:
291
+ event_emitter.emit(
292
+ AgentEventType.TOOL_RESPONSE,
293
+ ToolCallResult(
294
+ tool_call_id=tool_call.id or "call_1",
295
+ tool_call_name=tool_call.name,
296
+ content=tool_result,
297
+ ),
430
298
  )
431
- messages.append(tool_message)
432
299
 
433
- # Continue to next turn
434
-
435
- if result is None:
436
- logger.warning(
437
- "agent completed tool interactions but result is none",
438
- agent_name=self.name,
439
- expected_type=self.output_type,
440
- )
441
- raise ValueError(
442
- f"Expected result of type {self.output_type} but got none after tool interactions."
300
+ tool_message = ToolMessage(
301
+ content=tool_response.content,
302
+ tool_call_id=tool_response.tool_call_id or "call_1",
443
303
  )
304
+ messages.append(tool_message)
444
305
 
445
- if event_emitter:
446
- event_emitter.emit(AgentEventType.COMPLETED, result)
306
+ # Continue to next turn
447
307
 
448
308
  if result is None:
449
- logger.warning("agent final result is none", agent_name=self.name)
450
- raise ValueError("No result obtained after tool interactions")
309
+ logger.warning(
310
+ "agent completed tool interactions but result is none",
311
+ agent_name=self.name,
312
+ expected_type=self.output_type,
313
+ )
314
+ raise ValueError(
315
+ f"Expected result of type {self.output_type} but got none after tool interactions."
316
+ )
317
+
318
+ if event_emitter:
319
+ event_emitter.emit(AgentEventType.COMPLETED, result)
451
320
 
452
321
  logger.info(
453
322
  "agent completed",
@@ -455,3 +324,6 @@ class Agent[
455
324
  final_result_type=type(result),
456
325
  )
457
326
  return AgentRunResult[TOutput](output=cast(TOutput, result))
327
+
328
+ def get_model_str(self) -> str:
329
+ return str(self.model)