planar 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
planar/_version.py CHANGED
@@ -1 +1 @@
1
- VERSION = "0.7.0"
1
+ VERSION = "0.9.0"
planar/ai/agent.py CHANGED
@@ -1,212 +1,49 @@
1
- from __future__ import annotations
2
-
3
- import abc
4
1
  import inspect
5
- from dataclasses import dataclass, field
6
- from typing import (
7
- Any,
8
- Callable,
9
- Coroutine,
10
- Dict,
11
- List,
12
- Type,
13
- Union,
14
- cast,
15
- overload,
16
- )
2
+ from dataclasses import dataclass
3
+ from typing import Any, Type, cast
17
4
 
18
5
  from pydantic import BaseModel
6
+ from pydantic_ai import models
19
7
 
8
+ from planar.ai.agent_base import AgentBase
20
9
  from planar.ai.agent_utils import (
21
- AgentEventEmitter,
22
- AgentEventType,
23
- ToolCallResult,
10
+ ModelSpec,
24
11
  create_tool_definition,
25
12
  extract_files_from_model,
26
13
  get_agent_config,
27
14
  render_template,
28
15
  )
29
16
  from planar.ai.models import (
30
- AgentConfig,
17
+ AgentEventEmitter,
18
+ AgentEventType,
31
19
  AgentRunResult,
32
20
  AssistantMessage,
33
- CompletionResponse,
34
21
  ModelMessage,
35
22
  SystemMessage,
23
+ ToolCallResult,
24
+ ToolDefinition,
25
+ ToolMessage,
36
26
  ToolResponse,
37
27
  UserMessage,
38
28
  )
39
- from planar.ai.providers import Anthropic, Gemini, Model, OpenAI
29
+ from planar.ai.pydantic_ai import ModelRunResponse, model_run
40
30
  from planar.logging import get_logger
41
- from planar.modeling.field_helpers import JsonSchema
42
- from planar.utils import P, R, T, U, utc_now
43
- from planar.workflows import as_step
31
+ from planar.utils import utc_now
44
32
  from planar.workflows.models import StepType
33
+ from planar.workflows.notifications import agent_text, agent_think
45
34
 
46
35
  logger = get_logger(__name__)
47
36
 
48
37
 
49
- def _parse_model_string(model_str: str) -> Model:
50
- """Parse a model string (e.g., 'openai:gpt-4.1') into a Model instance."""
51
- parts = model_str.split(":", 1)
52
- if len(parts) != 2:
53
- raise ValueError(
54
- f"Invalid model format: {model_str}. Expected format: 'provider:model_id'"
55
- )
56
-
57
- provider_id, model_id = parts
58
-
59
- if provider_id.lower() == "openai":
60
- return OpenAI.model(model_id)
61
- elif provider_id.lower() == "anthropic":
62
- return Anthropic.model(model_id)
63
- elif provider_id.lower() == "gemini":
64
- return Gemini.model(model_id)
65
- else:
66
- raise ValueError(f"Unsupported provider: {provider_id}")
67
-
68
-
69
- @dataclass
70
- class AgentBase[
71
- # TODO: add `= str` default when we upgrade to 3.13
72
- TInput: BaseModel | str,
73
- TOutput: BaseModel | str,
74
- ](abc.ABC):
75
- """An LLM-powered agent that can be called directly within workflows."""
76
-
77
- name: str
78
- system_prompt: str
79
- output_type: Type[TOutput] | None = None
80
- input_type: Type[TInput] | None = None
81
- user_prompt: str = ""
82
- tools: List[Callable] = field(default_factory=list)
83
- max_turns: int = 2
84
- model_parameters: Dict[str, Any] = field(default_factory=dict)
85
- event_emitter: AgentEventEmitter | None = None
86
- durable: bool = True
87
-
88
- # TODO: move here to serialize to frontend
89
- #
90
- # built_in_vars: Dict[str, str] = field(default_factory=lambda: {
91
- # "datetime_now": datetime.datetime.now().isoformat(),
92
- # "date_today": datetime.date.today().isoformat(),
93
- # })
94
-
95
- def __post_init__(self):
96
- if self.input_type:
97
- if (
98
- not issubclass(self.input_type, BaseModel)
99
- and self.input_type is not str
100
- ):
101
- raise ValueError(
102
- "input_type must be 'str' or a subclass of a Pydantic model"
103
- )
104
- if self.max_turns < 1:
105
- raise ValueError("Max_turns must be greater than or equal to 1.")
106
- if self.tools and self.max_turns <= 1:
107
- raise ValueError(
108
- "For tool calling to work, max_turns must be greater than 1."
109
- )
110
-
111
- def input_schema(self) -> JsonSchema | None:
112
- if self.input_type is None:
113
- return None
114
- if self.input_type is str:
115
- return None
116
- assert issubclass(self.input_type, BaseModel), (
117
- "input_type must be a subclass of BaseModel or str"
118
- )
119
- return self.input_type.model_json_schema()
120
-
121
- def output_schema(self) -> JsonSchema | None:
122
- if self.output_type is None:
123
- return None
124
- if self.output_type is str:
125
- return None
126
- assert issubclass(self.output_type, BaseModel), (
127
- "output_type must be a subclass of BaseModel or str"
128
- )
129
- return self.output_type.model_json_schema()
130
-
131
- @overload
132
- async def __call__(
133
- self: "AgentBase[TInput, str]",
134
- input_value: TInput,
135
- ) -> AgentRunResult[str]: ...
136
-
137
- @overload
138
- async def __call__(
139
- self: "AgentBase[TInput, TOutput]",
140
- input_value: TInput,
141
- ) -> AgentRunResult[TOutput]: ...
142
-
143
- def as_step_if_durable(
144
- self,
145
- func: Callable[P, Coroutine[T, U, R]],
146
- step_type: StepType,
147
- display_name: str | None = None,
148
- return_type: Type[R] | None = None,
149
- ) -> Callable[P, Coroutine[T, U, R]]:
150
- if not self.durable:
151
- return func
152
- return as_step(
153
- func,
154
- step_type=step_type,
155
- display_name=display_name or self.name,
156
- return_type=return_type,
157
- )
158
-
159
- async def __call__(
160
- self,
161
- input_value: TInput,
162
- ) -> AgentRunResult[Any]:
163
- if self.input_type is not None and not isinstance(input_value, self.input_type):
164
- raise ValueError(
165
- f"Input value must be of type {self.input_type}, but got {type(input_value)}"
166
- )
167
- elif not isinstance(input_value, (str, BaseModel)):
168
- # Should not happen based on type constraints, but just in case
169
- # user does not have type checking enabled
170
- raise ValueError(
171
- "Input value must be a string or a Pydantic model if input_type is not provided"
172
- )
173
-
174
- if self.output_type is None:
175
- run_step = self.as_step_if_durable(
176
- self.run_step,
177
- step_type=StepType.AGENT,
178
- display_name=self.name,
179
- return_type=AgentRunResult[str],
180
- )
181
- else:
182
- run_step = self.as_step_if_durable(
183
- self.run_step,
184
- step_type=StepType.AGENT,
185
- display_name=self.name,
186
- return_type=AgentRunResult[self.output_type],
187
- )
188
-
189
- result = await run_step(input_value=input_value)
190
- # Cast the result to ensure type compatibility
191
- return cast(AgentRunResult[TOutput], result)
192
-
193
- @abc.abstractmethod
194
- async def run_step(
195
- self,
196
- input_value: TInput,
197
- ) -> AgentRunResult[TOutput]: ...
198
-
199
- @abc.abstractmethod
200
- def get_model_str(self) -> str: ...
201
-
202
- def to_config(self) -> AgentConfig:
203
- return AgentConfig(
204
- system_prompt=self.system_prompt,
205
- user_prompt=self.user_prompt,
206
- model=self.get_model_str(),
207
- max_turns=self.max_turns,
208
- model_parameters=self.model_parameters,
209
- )
38
+ class AgentWorkflowNotifier(AgentEventEmitter):
39
+ def emit(self, event_type, data):
40
+ match event_type:
41
+ case AgentEventType.THINK:
42
+ agent_think(str(data))
43
+ case AgentEventType.TEXT:
44
+ agent_text(str(data))
45
+ case _:
46
+ ...
210
47
 
211
48
 
212
49
  @dataclass
@@ -214,7 +51,7 @@ class Agent[
214
51
  TInput: BaseModel | str,
215
52
  TOutput: BaseModel | str,
216
53
  ](AgentBase[TInput, TOutput]):
217
- model: Union[str, Model] = "openai:gpt-4.1"
54
+ model: models.KnownModelName | models.Model = "openai:gpt-4o"
218
55
 
219
56
  async def run_step(
220
57
  self,
@@ -229,12 +66,12 @@ class Agent[
229
66
  Returns:
230
67
  AgentRunResult containing the agent's response
231
68
  """
232
- event_emitter = self.event_emitter
69
+ if self.event_emitter:
70
+ event_emitter = self.event_emitter
71
+ else:
72
+ event_emitter = AgentWorkflowNotifier()
233
73
  logger.debug(
234
- "agent run_step called",
235
- agent_name=self.name,
236
- input_type=type(input_value),
237
- config=self.to_config(),
74
+ "agent run_step called", agent_name=self.name, input_type=type(input_value)
238
75
  )
239
76
  result = None
240
77
 
@@ -297,18 +134,18 @@ class Agent[
297
134
  raise ValueError(f"Missing required parameter for prompt formatting: {e}")
298
135
 
299
136
  # Get the LLM provider and model
300
- model_config = config.model
301
- if isinstance(model_config, str):
302
- model = _parse_model_string(model_config)
137
+ if isinstance(self.model, str):
138
+ model = models.infer_model(self.model)
303
139
  else:
304
- model = model_config
140
+ model = self.model
305
141
 
306
142
  # Apply model parameters if specified
143
+ model_settings = None
307
144
  if config.model_parameters:
308
- model = model.with_parameters(**config.model_parameters)
145
+ model_settings = config.model_parameters
309
146
 
310
147
  # Prepare structured messages
311
- messages: List[ModelMessage] = []
148
+ messages: list[ModelMessage] = []
312
149
  if formatted_system_prompt:
313
150
  messages.append(SystemMessage(content=formatted_system_prompt))
314
151
 
@@ -320,168 +157,182 @@ class Agent[
320
157
  if self.tools:
321
158
  tool_definitions = [create_tool_definition(tool) for tool in self.tools]
322
159
 
323
- # Determine output type for the provider call
160
+ # Determine output type for the agent call
324
161
  # Pass the Pydantic model type if output_type is a subclass of BaseModel,
325
162
  # otherwise pass None (indicating string output is expected).
326
- output_type_for_provider: Type[BaseModel] | None = None
163
+ output_type: Type[BaseModel] | None = None
327
164
  # Use issubclass safely by checking if output_type is a type first
328
165
  if inspect.isclass(self.output_type) and issubclass(
329
166
  self.output_type, BaseModel
330
167
  ):
331
- output_type_for_provider = cast(Type[BaseModel], self.output_type)
168
+ output_type = cast(Type[BaseModel], self.output_type)
332
169
 
333
170
  # Execute the LLM call
334
171
  max_turns = config.max_turns
335
172
 
336
- # Single turn completion (default case)
337
- result = None
338
- if not tool_definitions:
173
+ # We use this inner function to pass "model" and "event_emitter",
174
+ # which are not serializable as step parameters.
175
+ async def agent_run_step(
176
+ model_spec: ModelSpec,
177
+ messages: list[ModelMessage],
178
+ turns_left: int,
179
+ tools: list[ToolDefinition] | None = None,
180
+ output_type: Type[BaseModel] | None = None,
181
+ ):
339
182
  logger.debug(
340
- "agent performing single turn completion",
183
+ "agent running",
341
184
  agent_name=self.name,
342
- model=model.model_spec,
343
- output_type=output_type_for_provider,
185
+ model=model_spec,
186
+ model_settings=model_settings,
187
+ output_type=output_type,
344
188
  )
345
- response = await self.as_step_if_durable(
346
- model.provider_class.complete,
189
+ if output_type is None:
190
+ return await model_run(
191
+ model=model,
192
+ max_extra_turns=turns_left,
193
+ model_settings=model_settings,
194
+ messages=messages,
195
+ tools=tools or [],
196
+ event_handler=cast(Any, event_emitter),
197
+ )
198
+ else:
199
+ return await model_run(
200
+ model=model,
201
+ max_extra_turns=turns_left,
202
+ model_settings=model_settings,
203
+ messages=messages,
204
+ output_type=output_type,
205
+ tools=tools or [],
206
+ event_handler=cast(Any, event_emitter),
207
+ )
208
+
209
+ model_spec = ModelSpec(
210
+ model_id=str(model),
211
+ parameters=config.model_parameters,
212
+ )
213
+ result = None
214
+ logger.debug(
215
+ "agent performing multi-turn completion with tools",
216
+ agent_name=self.name,
217
+ max_turns=max_turns,
218
+ )
219
+ turns_left = max_turns
220
+ while turns_left > 0:
221
+ turns_left -= 1
222
+ logger.debug("agent turn", agent_name=self.name, turns_left=turns_left)
223
+
224
+ # Get model response
225
+ run_response = await self.as_step_if_durable(
226
+ agent_run_step,
347
227
  step_type=StepType.AGENT,
348
- return_type=CompletionResponse[output_type_for_provider or str],
228
+ return_type=ModelRunResponse[output_type or str],
349
229
  )(
350
- model_spec=model.model_spec,
230
+ model_spec=model_spec,
351
231
  messages=messages,
352
- output_type=output_type_for_provider,
232
+ turns_left=turns_left,
233
+ output_type=output_type,
234
+ tools=tool_definitions or [],
353
235
  )
354
- result = response.content
236
+ response = run_response.response
237
+ turns_left -= run_response.extra_turns_used
355
238
 
356
239
  # Emit response event if event_emitter is provided
357
240
  if event_emitter:
358
241
  event_emitter.emit(AgentEventType.RESPONSE, response.content)
359
- else:
242
+
243
+ # If no tool calls or last turn, return content
244
+ if not response.tool_calls or turns_left == 0:
245
+ logger.debug(
246
+ "agent completion: no tool calls or last turn",
247
+ agent_name=self.name,
248
+ has_content=response.content is not None,
249
+ )
250
+ result = response.content
251
+ break
252
+
253
+ # Process tool calls
360
254
  logger.debug(
361
- "agent performing multi-turn completion with tools",
255
+ "agent received tool calls",
362
256
  agent_name=self.name,
363
- max_turns=max_turns,
257
+ num_tool_calls=len(response.tool_calls),
364
258
  )
365
- # Multi-turn with tools
366
- turns_left = max_turns
367
- while turns_left > 0:
368
- turns_left -= 1
369
- logger.debug("agent turn", agent_name=self.name, turns_left=turns_left)
370
-
371
- # Get model response
372
- response = await self.as_step_if_durable(
373
- model.provider_class.complete,
374
- step_type=StepType.AGENT,
375
- return_type=CompletionResponse[output_type_for_provider or str],
376
- )(
377
- model_spec=model.model_spec,
378
- messages=messages,
379
- output_type=output_type_for_provider,
380
- tools=tool_definitions,
381
- )
382
-
383
- # Emit response event if event_emitter is provided
384
- if event_emitter:
385
- event_emitter.emit(AgentEventType.RESPONSE, response.content)
386
-
387
- # If no tool calls or last turn, return content
388
- if not response.tool_calls or turns_left == 0:
389
- logger.debug(
390
- "agent completion: no tool calls or last turn",
391
- agent_name=self.name,
392
- has_content=response.content is not None,
393
- )
394
- result = response.content
395
- break
259
+ assistant_message = AssistantMessage(
260
+ content=None,
261
+ tool_calls=response.tool_calls,
262
+ )
263
+ messages.append(assistant_message)
396
264
 
397
- # Process tool calls
265
+ # Execute each tool and add tool responses to messages
266
+ for tool_call_idx, tool_call in enumerate(response.tool_calls):
398
267
  logger.debug(
399
- "agent received tool calls",
268
+ "agent processing tool call",
400
269
  agent_name=self.name,
401
- num_tool_calls=len(response.tool_calls),
270
+ tool_call_index=tool_call_idx + 1,
271
+ tool_call_id=tool_call.id,
272
+ tool_call_name=tool_call.name,
402
273
  )
403
- assistant_message = AssistantMessage(
404
- content=None,
405
- tool_calls=response.tool_calls,
274
+ # Find the matching tool function
275
+ tool_fn = next(
276
+ (t for t in self.tools if t.__name__ == tool_call.name),
277
+ None,
406
278
  )
407
- messages.append(assistant_message)
408
279
 
409
- # Execute each tool and add tool responses to messages
410
- for tool_call_idx, tool_call in enumerate(response.tool_calls):
411
- logger.debug(
412
- "agent processing tool call",
280
+ if not tool_fn:
281
+ tool_result = f"Error: Tool '{tool_call.name}' not found."
282
+ logger.warning(
283
+ "tool not found for agent",
284
+ tool_name=tool_call.name,
413
285
  agent_name=self.name,
414
- tool_call_index=tool_call_idx + 1,
415
- tool_call_id=tool_call.id,
416
- tool_call_name=tool_call.name,
417
286
  )
418
- # Find the matching tool function
419
- tool_fn = next(
420
- (t for t in self.tools if t.__name__ == tool_call.name),
421
- None,
287
+ else:
288
+ # Execute the tool with the provided arguments
289
+ tool_result = await self.as_step_if_durable(
290
+ tool_fn,
291
+ step_type=StepType.TOOL_CALL,
292
+ )(**tool_call.arguments)
293
+ logger.info(
294
+ "tool executed by agent",
295
+ tool_name=tool_call.name,
296
+ agent_name=self.name,
297
+ result_type=type(tool_result),
422
298
  )
423
299
 
424
- if not tool_fn:
425
- tool_result = f"Error: Tool '{tool_call.name}' not found."
426
- logger.warning(
427
- "tool not found for agent",
428
- tool_name=tool_call.name,
429
- agent_name=self.name,
430
- )
431
- else:
432
- # Execute the tool with the provided arguments
433
- tool_result = await self.as_step_if_durable(
434
- tool_fn,
435
- step_type=StepType.TOOL_CALL,
436
- )(**tool_call.arguments)
437
- logger.info(
438
- "tool executed by agent",
439
- tool_name=tool_call.name,
440
- agent_name=self.name,
441
- result_type=type(tool_result),
442
- )
443
-
444
- # Create a tool response
445
- tool_response = ToolResponse(
446
- tool_call_id=tool_call.id or "call_1", content=str(tool_result)
447
- )
300
+ # Create a tool response
301
+ tool_response = ToolResponse(
302
+ tool_call_id=tool_call.id or "call_1", content=str(tool_result)
303
+ )
448
304
 
449
- # Emit tool response event if event_emitter is provided
450
- if event_emitter:
451
- event_emitter.emit(
452
- AgentEventType.TOOL_RESPONSE,
453
- ToolCallResult(
454
- tool_call_id=tool_call.id or "call_1",
455
- tool_call_name=tool_call.name,
456
- content=tool_result,
457
- ),
458
- )
459
-
460
- # Convert the tool response to a message based on provider
461
- tool_message = model.provider_class.format_tool_response(
462
- tool_response
305
+ # Emit tool response event if event_emitter is provided
306
+ if event_emitter:
307
+ event_emitter.emit(
308
+ AgentEventType.TOOL_RESPONSE,
309
+ ToolCallResult(
310
+ tool_call_id=tool_call.id or "call_1",
311
+ tool_call_name=tool_call.name,
312
+ content=tool_result,
313
+ ),
463
314
  )
464
- messages.append(tool_message)
465
-
466
- # Continue to next turn
467
315
 
468
- if result is None:
469
- logger.warning(
470
- "agent completed tool interactions but result is none",
471
- agent_name=self.name,
472
- expected_type=self.output_type,
473
- )
474
- raise ValueError(
475
- f"Reached max turns without the expected result of type {self.output_type}. "
476
- "You may need to increase the max_turns parameter or update the Agent instructions."
316
+ tool_message = ToolMessage(
317
+ content=tool_response.content,
318
+ tool_call_id=tool_response.tool_call_id or "call_1",
477
319
  )
320
+ messages.append(tool_message)
478
321
 
479
- if event_emitter:
480
- event_emitter.emit(AgentEventType.COMPLETED, result)
322
+ # Continue to next turn
481
323
 
482
324
  if result is None:
483
- logger.warning("agent final result is none", agent_name=self.name)
484
- raise ValueError("No result obtained after tool interactions")
325
+ logger.warning(
326
+ "agent completed tool interactions but result is none",
327
+ agent_name=self.name,
328
+ expected_type=self.output_type,
329
+ )
330
+ raise ValueError(
331
+ f"Expected result of type {self.output_type} but got none after tool interactions."
332
+ )
333
+
334
+ if event_emitter:
335
+ event_emitter.emit(AgentEventType.COMPLETED, result)
485
336
 
486
337
  logger.info(
487
338
  "agent completed",