planar 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
planar/_version.py CHANGED
@@ -1 +1 @@
1
- VERSION = "0.7.0"
1
+ VERSION = "0.8.0"
planar/ai/agent.py CHANGED
@@ -1,25 +1,14 @@
1
- from __future__ import annotations
2
-
3
- import abc
4
1
  import inspect
5
- from dataclasses import dataclass, field
6
- from typing import (
7
- Any,
8
- Callable,
9
- Coroutine,
10
- Dict,
11
- List,
12
- Type,
13
- Union,
14
- cast,
15
- overload,
16
- )
2
+ from dataclasses import dataclass
3
+ from typing import Any, Type, cast
17
4
 
18
5
  from pydantic import BaseModel
6
+ from pydantic_ai import models
19
7
 
8
+ from planar.ai.agent_base import AgentBase
20
9
  from planar.ai.agent_utils import (
21
- AgentEventEmitter,
22
10
  AgentEventType,
11
+ ModelSpec,
23
12
  ToolCallResult,
24
13
  create_tool_definition,
25
14
  extract_files_from_model,
@@ -27,194 +16,29 @@ from planar.ai.agent_utils import (
27
16
  render_template,
28
17
  )
29
18
  from planar.ai.models import (
30
- AgentConfig,
31
19
  AgentRunResult,
32
20
  AssistantMessage,
33
- CompletionResponse,
34
21
  ModelMessage,
35
22
  SystemMessage,
23
+ ToolDefinition,
24
+ ToolMessage,
36
25
  ToolResponse,
37
26
  UserMessage,
38
27
  )
39
- from planar.ai.providers import Anthropic, Gemini, Model, OpenAI
28
+ from planar.ai.pydantic_ai import ModelRunResponse, model_run
40
29
  from planar.logging import get_logger
41
- from planar.modeling.field_helpers import JsonSchema
42
- from planar.utils import P, R, T, U, utc_now
43
- from planar.workflows import as_step
30
+ from planar.utils import utc_now
44
31
  from planar.workflows.models import StepType
45
32
 
46
33
  logger = get_logger(__name__)
47
34
 
48
35
 
49
- def _parse_model_string(model_str: str) -> Model:
50
- """Parse a model string (e.g., 'openai:gpt-4.1') into a Model instance."""
51
- parts = model_str.split(":", 1)
52
- if len(parts) != 2:
53
- raise ValueError(
54
- f"Invalid model format: {model_str}. Expected format: 'provider:model_id'"
55
- )
56
-
57
- provider_id, model_id = parts
58
-
59
- if provider_id.lower() == "openai":
60
- return OpenAI.model(model_id)
61
- elif provider_id.lower() == "anthropic":
62
- return Anthropic.model(model_id)
63
- elif provider_id.lower() == "gemini":
64
- return Gemini.model(model_id)
65
- else:
66
- raise ValueError(f"Unsupported provider: {provider_id}")
67
-
68
-
69
- @dataclass
70
- class AgentBase[
71
- # TODO: add `= str` default when we upgrade to 3.13
72
- TInput: BaseModel | str,
73
- TOutput: BaseModel | str,
74
- ](abc.ABC):
75
- """An LLM-powered agent that can be called directly within workflows."""
76
-
77
- name: str
78
- system_prompt: str
79
- output_type: Type[TOutput] | None = None
80
- input_type: Type[TInput] | None = None
81
- user_prompt: str = ""
82
- tools: List[Callable] = field(default_factory=list)
83
- max_turns: int = 2
84
- model_parameters: Dict[str, Any] = field(default_factory=dict)
85
- event_emitter: AgentEventEmitter | None = None
86
- durable: bool = True
87
-
88
- # TODO: move here to serialize to frontend
89
- #
90
- # built_in_vars: Dict[str, str] = field(default_factory=lambda: {
91
- # "datetime_now": datetime.datetime.now().isoformat(),
92
- # "date_today": datetime.date.today().isoformat(),
93
- # })
94
-
95
- def __post_init__(self):
96
- if self.input_type:
97
- if (
98
- not issubclass(self.input_type, BaseModel)
99
- and self.input_type is not str
100
- ):
101
- raise ValueError(
102
- "input_type must be 'str' or a subclass of a Pydantic model"
103
- )
104
- if self.max_turns < 1:
105
- raise ValueError("Max_turns must be greater than or equal to 1.")
106
- if self.tools and self.max_turns <= 1:
107
- raise ValueError(
108
- "For tool calling to work, max_turns must be greater than 1."
109
- )
110
-
111
- def input_schema(self) -> JsonSchema | None:
112
- if self.input_type is None:
113
- return None
114
- if self.input_type is str:
115
- return None
116
- assert issubclass(self.input_type, BaseModel), (
117
- "input_type must be a subclass of BaseModel or str"
118
- )
119
- return self.input_type.model_json_schema()
120
-
121
- def output_schema(self) -> JsonSchema | None:
122
- if self.output_type is None:
123
- return None
124
- if self.output_type is str:
125
- return None
126
- assert issubclass(self.output_type, BaseModel), (
127
- "output_type must be a subclass of BaseModel or str"
128
- )
129
- return self.output_type.model_json_schema()
130
-
131
- @overload
132
- async def __call__(
133
- self: "AgentBase[TInput, str]",
134
- input_value: TInput,
135
- ) -> AgentRunResult[str]: ...
136
-
137
- @overload
138
- async def __call__(
139
- self: "AgentBase[TInput, TOutput]",
140
- input_value: TInput,
141
- ) -> AgentRunResult[TOutput]: ...
142
-
143
- def as_step_if_durable(
144
- self,
145
- func: Callable[P, Coroutine[T, U, R]],
146
- step_type: StepType,
147
- display_name: str | None = None,
148
- return_type: Type[R] | None = None,
149
- ) -> Callable[P, Coroutine[T, U, R]]:
150
- if not self.durable:
151
- return func
152
- return as_step(
153
- func,
154
- step_type=step_type,
155
- display_name=display_name or self.name,
156
- return_type=return_type,
157
- )
158
-
159
- async def __call__(
160
- self,
161
- input_value: TInput,
162
- ) -> AgentRunResult[Any]:
163
- if self.input_type is not None and not isinstance(input_value, self.input_type):
164
- raise ValueError(
165
- f"Input value must be of type {self.input_type}, but got {type(input_value)}"
166
- )
167
- elif not isinstance(input_value, (str, BaseModel)):
168
- # Should not happen based on type constraints, but just in case
169
- # user does not have type checking enabled
170
- raise ValueError(
171
- "Input value must be a string or a Pydantic model if input_type is not provided"
172
- )
173
-
174
- if self.output_type is None:
175
- run_step = self.as_step_if_durable(
176
- self.run_step,
177
- step_type=StepType.AGENT,
178
- display_name=self.name,
179
- return_type=AgentRunResult[str],
180
- )
181
- else:
182
- run_step = self.as_step_if_durable(
183
- self.run_step,
184
- step_type=StepType.AGENT,
185
- display_name=self.name,
186
- return_type=AgentRunResult[self.output_type],
187
- )
188
-
189
- result = await run_step(input_value=input_value)
190
- # Cast the result to ensure type compatibility
191
- return cast(AgentRunResult[TOutput], result)
192
-
193
- @abc.abstractmethod
194
- async def run_step(
195
- self,
196
- input_value: TInput,
197
- ) -> AgentRunResult[TOutput]: ...
198
-
199
- @abc.abstractmethod
200
- def get_model_str(self) -> str: ...
201
-
202
- def to_config(self) -> AgentConfig:
203
- return AgentConfig(
204
- system_prompt=self.system_prompt,
205
- user_prompt=self.user_prompt,
206
- model=self.get_model_str(),
207
- max_turns=self.max_turns,
208
- model_parameters=self.model_parameters,
209
- )
210
-
211
-
212
36
  @dataclass
213
37
  class Agent[
214
38
  TInput: BaseModel | str,
215
39
  TOutput: BaseModel | str,
216
40
  ](AgentBase[TInput, TOutput]):
217
- model: Union[str, Model] = "openai:gpt-4.1"
41
+ model: models.KnownModelName | models.Model = "openai:gpt-4o"
218
42
 
219
43
  async def run_step(
220
44
  self,
@@ -231,10 +55,7 @@ class Agent[
231
55
  """
232
56
  event_emitter = self.event_emitter
233
57
  logger.debug(
234
- "agent run_step called",
235
- agent_name=self.name,
236
- input_type=type(input_value),
237
- config=self.to_config(),
58
+ "agent run_step called", agent_name=self.name, input_type=type(input_value)
238
59
  )
239
60
  result = None
240
61
 
@@ -297,18 +118,18 @@ class Agent[
297
118
  raise ValueError(f"Missing required parameter for prompt formatting: {e}")
298
119
 
299
120
  # Get the LLM provider and model
300
- model_config = config.model
301
- if isinstance(model_config, str):
302
- model = _parse_model_string(model_config)
121
+ if isinstance(self.model, str):
122
+ model = models.infer_model(self.model)
303
123
  else:
304
- model = model_config
124
+ model = self.model
305
125
 
306
126
  # Apply model parameters if specified
127
+ model_settings = None
307
128
  if config.model_parameters:
308
- model = model.with_parameters(**config.model_parameters)
129
+ model_settings = config.model_parameters
309
130
 
310
131
  # Prepare structured messages
311
- messages: List[ModelMessage] = []
132
+ messages: list[ModelMessage] = []
312
133
  if formatted_system_prompt:
313
134
  messages.append(SystemMessage(content=formatted_system_prompt))
314
135
 
@@ -320,168 +141,182 @@ class Agent[
320
141
  if self.tools:
321
142
  tool_definitions = [create_tool_definition(tool) for tool in self.tools]
322
143
 
323
- # Determine output type for the provider call
144
+ # Determine output type for the agent call
324
145
  # Pass the Pydantic model type if output_type is a subclass of BaseModel,
325
146
  # otherwise pass None (indicating string output is expected).
326
- output_type_for_provider: Type[BaseModel] | None = None
147
+ output_type: Type[BaseModel] | None = None
327
148
  # Use issubclass safely by checking if output_type is a type first
328
149
  if inspect.isclass(self.output_type) and issubclass(
329
150
  self.output_type, BaseModel
330
151
  ):
331
- output_type_for_provider = cast(Type[BaseModel], self.output_type)
152
+ output_type = cast(Type[BaseModel], self.output_type)
332
153
 
333
154
  # Execute the LLM call
334
155
  max_turns = config.max_turns
335
156
 
336
- # Single turn completion (default case)
337
- result = None
338
- if not tool_definitions:
157
+ # We use this inner function to pass "model" and "event_emitter",
158
+ # which are not serializable as step parameters.
159
+ async def agent_run_step(
160
+ model_spec: ModelSpec,
161
+ messages: list[ModelMessage],
162
+ turns_left: int,
163
+ tools: list[ToolDefinition] | None = None,
164
+ output_type: Type[BaseModel] | None = None,
165
+ ):
339
166
  logger.debug(
340
- "agent performing single turn completion",
167
+ "agent running",
341
168
  agent_name=self.name,
342
- model=model.model_spec,
343
- output_type=output_type_for_provider,
169
+ model=model_spec,
170
+ model_settings=model_settings,
171
+ output_type=output_type,
344
172
  )
345
- response = await self.as_step_if_durable(
346
- model.provider_class.complete,
173
+ if output_type is None:
174
+ return await model_run(
175
+ model=model,
176
+ max_extra_turns=turns_left,
177
+ model_settings=model_settings,
178
+ messages=messages,
179
+ tools=tools or [],
180
+ event_handler=cast(Any, event_emitter),
181
+ )
182
+ else:
183
+ return await model_run(
184
+ model=model,
185
+ max_extra_turns=turns_left,
186
+ model_settings=model_settings,
187
+ messages=messages,
188
+ output_type=output_type,
189
+ tools=tools or [],
190
+ event_handler=cast(Any, event_emitter),
191
+ )
192
+
193
+ model_spec = ModelSpec(
194
+ model_id=str(model),
195
+ parameters=config.model_parameters,
196
+ )
197
+ result = None
198
+ logger.debug(
199
+ "agent performing multi-turn completion with tools",
200
+ agent_name=self.name,
201
+ max_turns=max_turns,
202
+ )
203
+ turns_left = max_turns
204
+ while turns_left > 0:
205
+ turns_left -= 1
206
+ logger.debug("agent turn", agent_name=self.name, turns_left=turns_left)
207
+
208
+ # Get model response
209
+ run_response = await self.as_step_if_durable(
210
+ agent_run_step,
347
211
  step_type=StepType.AGENT,
348
- return_type=CompletionResponse[output_type_for_provider or str],
212
+ return_type=ModelRunResponse[output_type or str],
349
213
  )(
350
- model_spec=model.model_spec,
214
+ model_spec=model_spec,
351
215
  messages=messages,
352
- output_type=output_type_for_provider,
216
+ turns_left=turns_left,
217
+ output_type=output_type,
218
+ tools=tool_definitions or [],
353
219
  )
354
- result = response.content
220
+ response = run_response.response
221
+ turns_left -= run_response.extra_turns_used
355
222
 
356
223
  # Emit response event if event_emitter is provided
357
224
  if event_emitter:
358
225
  event_emitter.emit(AgentEventType.RESPONSE, response.content)
359
- else:
226
+
227
+ # If no tool calls or last turn, return content
228
+ if not response.tool_calls or turns_left == 0:
229
+ logger.debug(
230
+ "agent completion: no tool calls or last turn",
231
+ agent_name=self.name,
232
+ has_content=response.content is not None,
233
+ )
234
+ result = response.content
235
+ break
236
+
237
+ # Process tool calls
360
238
  logger.debug(
361
- "agent performing multi-turn completion with tools",
239
+ "agent received tool calls",
362
240
  agent_name=self.name,
363
- max_turns=max_turns,
241
+ num_tool_calls=len(response.tool_calls),
364
242
  )
365
- # Multi-turn with tools
366
- turns_left = max_turns
367
- while turns_left > 0:
368
- turns_left -= 1
369
- logger.debug("agent turn", agent_name=self.name, turns_left=turns_left)
370
-
371
- # Get model response
372
- response = await self.as_step_if_durable(
373
- model.provider_class.complete,
374
- step_type=StepType.AGENT,
375
- return_type=CompletionResponse[output_type_for_provider or str],
376
- )(
377
- model_spec=model.model_spec,
378
- messages=messages,
379
- output_type=output_type_for_provider,
380
- tools=tool_definitions,
381
- )
382
-
383
- # Emit response event if event_emitter is provided
384
- if event_emitter:
385
- event_emitter.emit(AgentEventType.RESPONSE, response.content)
386
-
387
- # If no tool calls or last turn, return content
388
- if not response.tool_calls or turns_left == 0:
389
- logger.debug(
390
- "agent completion: no tool calls or last turn",
391
- agent_name=self.name,
392
- has_content=response.content is not None,
393
- )
394
- result = response.content
395
- break
243
+ assistant_message = AssistantMessage(
244
+ content=None,
245
+ tool_calls=response.tool_calls,
246
+ )
247
+ messages.append(assistant_message)
396
248
 
397
- # Process tool calls
249
+ # Execute each tool and add tool responses to messages
250
+ for tool_call_idx, tool_call in enumerate(response.tool_calls):
398
251
  logger.debug(
399
- "agent received tool calls",
252
+ "agent processing tool call",
400
253
  agent_name=self.name,
401
- num_tool_calls=len(response.tool_calls),
254
+ tool_call_index=tool_call_idx + 1,
255
+ tool_call_id=tool_call.id,
256
+ tool_call_name=tool_call.name,
402
257
  )
403
- assistant_message = AssistantMessage(
404
- content=None,
405
- tool_calls=response.tool_calls,
258
+ # Find the matching tool function
259
+ tool_fn = next(
260
+ (t for t in self.tools if t.__name__ == tool_call.name),
261
+ None,
406
262
  )
407
- messages.append(assistant_message)
408
263
 
409
- # Execute each tool and add tool responses to messages
410
- for tool_call_idx, tool_call in enumerate(response.tool_calls):
411
- logger.debug(
412
- "agent processing tool call",
264
+ if not tool_fn:
265
+ tool_result = f"Error: Tool '{tool_call.name}' not found."
266
+ logger.warning(
267
+ "tool not found for agent",
268
+ tool_name=tool_call.name,
413
269
  agent_name=self.name,
414
- tool_call_index=tool_call_idx + 1,
415
- tool_call_id=tool_call.id,
416
- tool_call_name=tool_call.name,
417
270
  )
418
- # Find the matching tool function
419
- tool_fn = next(
420
- (t for t in self.tools if t.__name__ == tool_call.name),
421
- None,
271
+ else:
272
+ # Execute the tool with the provided arguments
273
+ tool_result = await self.as_step_if_durable(
274
+ tool_fn,
275
+ step_type=StepType.TOOL_CALL,
276
+ )(**tool_call.arguments)
277
+ logger.info(
278
+ "tool executed by agent",
279
+ tool_name=tool_call.name,
280
+ agent_name=self.name,
281
+ result_type=type(tool_result),
422
282
  )
423
283
 
424
- if not tool_fn:
425
- tool_result = f"Error: Tool '{tool_call.name}' not found."
426
- logger.warning(
427
- "tool not found for agent",
428
- tool_name=tool_call.name,
429
- agent_name=self.name,
430
- )
431
- else:
432
- # Execute the tool with the provided arguments
433
- tool_result = await self.as_step_if_durable(
434
- tool_fn,
435
- step_type=StepType.TOOL_CALL,
436
- )(**tool_call.arguments)
437
- logger.info(
438
- "tool executed by agent",
439
- tool_name=tool_call.name,
440
- agent_name=self.name,
441
- result_type=type(tool_result),
442
- )
443
-
444
- # Create a tool response
445
- tool_response = ToolResponse(
446
- tool_call_id=tool_call.id or "call_1", content=str(tool_result)
447
- )
284
+ # Create a tool response
285
+ tool_response = ToolResponse(
286
+ tool_call_id=tool_call.id or "call_1", content=str(tool_result)
287
+ )
448
288
 
449
- # Emit tool response event if event_emitter is provided
450
- if event_emitter:
451
- event_emitter.emit(
452
- AgentEventType.TOOL_RESPONSE,
453
- ToolCallResult(
454
- tool_call_id=tool_call.id or "call_1",
455
- tool_call_name=tool_call.name,
456
- content=tool_result,
457
- ),
458
- )
459
-
460
- # Convert the tool response to a message based on provider
461
- tool_message = model.provider_class.format_tool_response(
462
- tool_response
289
+ # Emit tool response event if event_emitter is provided
290
+ if event_emitter:
291
+ event_emitter.emit(
292
+ AgentEventType.TOOL_RESPONSE,
293
+ ToolCallResult(
294
+ tool_call_id=tool_call.id or "call_1",
295
+ tool_call_name=tool_call.name,
296
+ content=tool_result,
297
+ ),
463
298
  )
464
- messages.append(tool_message)
465
299
 
466
- # Continue to next turn
467
-
468
- if result is None:
469
- logger.warning(
470
- "agent completed tool interactions but result is none",
471
- agent_name=self.name,
472
- expected_type=self.output_type,
473
- )
474
- raise ValueError(
475
- f"Reached max turns without the expected result of type {self.output_type}. "
476
- "You may need to increase the max_turns parameter or update the Agent instructions."
300
+ tool_message = ToolMessage(
301
+ content=tool_response.content,
302
+ tool_call_id=tool_response.tool_call_id or "call_1",
477
303
  )
304
+ messages.append(tool_message)
478
305
 
479
- if event_emitter:
480
- event_emitter.emit(AgentEventType.COMPLETED, result)
306
+ # Continue to next turn
481
307
 
482
308
  if result is None:
483
- logger.warning("agent final result is none", agent_name=self.name)
484
- raise ValueError("No result obtained after tool interactions")
309
+ logger.warning(
310
+ "agent completed tool interactions but result is none",
311
+ agent_name=self.name,
312
+ expected_type=self.output_type,
313
+ )
314
+ raise ValueError(
315
+ f"Expected result of type {self.output_type} but got none after tool interactions."
316
+ )
317
+
318
+ if event_emitter:
319
+ event_emitter.emit(AgentEventType.COMPLETED, result)
485
320
 
486
321
  logger.info(
487
322
  "agent completed",