pydantic-ai-slim 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -22,7 +22,17 @@ from . import (
22
22
  result,
23
23
  )
24
24
  from .result import ResultData
25
- from .tools import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams
25
+ from .tools import (
26
+ AgentDeps,
27
+ RunContext,
28
+ Tool,
29
+ ToolDefinition,
30
+ ToolFuncContext,
31
+ ToolFuncEither,
32
+ ToolFuncPlain,
33
+ ToolParams,
34
+ ToolPrepareFunc,
35
+ )
26
36
 
27
37
  __all__ = ('Agent',)
28
38
 
@@ -136,7 +146,10 @@ class Agent(Generic[AgentDeps, ResultData]):
136
146
  self._function_tools = {}
137
147
  self._default_retries = retries
138
148
  for tool in tools:
139
- self._register_tool(Tool.infer(tool))
149
+ if isinstance(tool, Tool):
150
+ self._register_tool(tool)
151
+ else:
152
+ self._register_tool(Tool(tool))
140
153
  self._deps_type = deps_type
141
154
  self._system_prompt_functions = []
142
155
  self._max_result_retries = result_retries if result_retries is not None else retries
@@ -166,29 +179,33 @@ class Agent(Generic[AgentDeps, ResultData]):
166
179
  """
167
180
  if infer_name and self.name is None:
168
181
  self._infer_name(inspect.currentframe())
169
- model_used, custom_model, agent_model = await self._get_agent_model(model)
182
+ model_used, mode_selection = await self._get_model(model)
170
183
 
171
184
  deps = self._get_deps(deps)
172
185
 
173
186
  with _logfire.span(
174
- '{agent.name} run {prompt=}',
187
+ '{agent_name} run {prompt=}',
175
188
  prompt=user_prompt,
176
189
  agent=self,
177
- custom_model=custom_model,
190
+ mode_selection=mode_selection,
178
191
  model_name=model_used.name(),
192
+ agent_name=self.name or 'agent',
179
193
  ) as run_span:
180
194
  new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
181
195
  self.last_run_messages = messages
182
196
 
183
197
  for tool in self._function_tools.values():
184
- tool.reset()
198
+ tool.current_retry = 0
185
199
 
186
200
  cost = result.Cost()
187
201
 
188
202
  run_step = 0
189
203
  while True:
190
204
  run_step += 1
191
- with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
205
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
206
+ agent_model = await self._prepare_model(model_used, deps)
207
+
208
+ with _logfire.span('model request', run_step=run_step) as model_req_span:
192
209
  model_response, request_cost = await agent_model.request(messages)
193
210
  model_req_span.set_attribute('response', model_response)
194
211
  model_req_span.set_attribute('cost', request_cost)
@@ -197,12 +214,15 @@ class Agent(Generic[AgentDeps, ResultData]):
197
214
  messages.append(model_response)
198
215
  cost += request_cost
199
216
 
200
- with _logfire.span('handle model response') as handle_span:
201
- either = await self._handle_model_response(model_response, deps)
217
+ with _logfire.span('handle model response', run_step=run_step) as handle_span:
218
+ final_result, response_messages = await self._handle_model_response(model_response, deps)
202
219
 
203
- if isinstance(either, _MarkFinalResult):
204
- # we have a final result, end the conversation
205
- result_data = either.data
220
+ # Add all messages to the conversation
221
+ messages.extend(response_messages)
222
+
223
+ # Check if we got a final result
224
+ if final_result is not None:
225
+ result_data = final_result.data
206
226
  run_span.set_attribute('all_messages', messages)
207
227
  run_span.set_attribute('cost', cost)
208
228
  handle_span.set_attribute('result', result_data)
@@ -210,11 +230,9 @@ class Agent(Generic[AgentDeps, ResultData]):
210
230
  return result.RunResult(messages, new_message_index, result_data, cost)
211
231
  else:
212
232
  # continue the conversation
213
- tool_responses = either
214
- handle_span.set_attribute('tool_responses', tool_responses)
215
- response_msgs = ' '.join(m.role for m in tool_responses)
233
+ handle_span.set_attribute('tool_responses', response_messages)
234
+ response_msgs = ' '.join(r.role for r in response_messages)
216
235
  handle_span.message = f'handle model response -> {response_msgs}'
217
- messages.extend(tool_responses)
218
236
 
219
237
  def run_sync(
220
238
  self,
@@ -272,28 +290,33 @@ class Agent(Generic[AgentDeps, ResultData]):
272
290
  # f_back because `asynccontextmanager` adds one frame
273
291
  if frame := inspect.currentframe(): # pragma: no branch
274
292
  self._infer_name(frame.f_back)
275
- model_used, custom_model, agent_model = await self._get_agent_model(model)
293
+ model_used, mode_selection = await self._get_model(model)
276
294
 
277
295
  deps = self._get_deps(deps)
278
296
 
279
297
  with _logfire.span(
280
- '{agent.name} run stream {prompt=}',
298
+ '{agent_name} run stream {prompt=}',
281
299
  prompt=user_prompt,
282
300
  agent=self,
283
- custom_model=custom_model,
301
+ mode_selection=mode_selection,
284
302
  model_name=model_used.name(),
303
+ agent_name=self.name or 'agent',
285
304
  ) as run_span:
286
305
  new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
287
306
  self.last_run_messages = messages
288
307
 
289
308
  for tool in self._function_tools.values():
290
- tool.reset()
309
+ tool.current_retry = 0
291
310
 
292
311
  cost = result.Cost()
293
312
 
294
313
  run_step = 0
295
314
  while True:
296
315
  run_step += 1
316
+
317
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
318
+ agent_model = await self._prepare_model(model_used, deps)
319
+
297
320
  with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
298
321
  async with agent_model.request_stream(messages) as model_response:
299
322
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
@@ -302,10 +325,16 @@ class Agent(Generic[AgentDeps, ResultData]):
302
325
  model_req_span.__exit__(None, None, None)
303
326
 
304
327
  with _logfire.span('handle model response') as handle_span:
305
- either = await self._handle_streamed_model_response(model_response, deps)
328
+ final_result, response_messages = await self._handle_streamed_model_response(
329
+ model_response, deps
330
+ )
331
+
332
+ # Add all messages to the conversation
333
+ messages.extend(response_messages)
306
334
 
307
- if isinstance(either, _MarkFinalResult):
308
- result_stream = either.data
335
+ # Check if we got a final result
336
+ if final_result is not None:
337
+ result_stream = final_result.data
309
338
  run_span.set_attribute('all_messages', messages)
310
339
  handle_span.set_attribute('result_type', result_stream.__class__.__name__)
311
340
  handle_span.message = 'handle model response -> final result'
@@ -321,11 +350,10 @@ class Agent(Generic[AgentDeps, ResultData]):
321
350
  )
322
351
  return
323
352
  else:
324
- tool_responses = either
325
- handle_span.set_attribute('tool_responses', tool_responses)
326
- response_msgs = ' '.join(m.role for m in tool_responses)
353
+ # continue the conversation
354
+ handle_span.set_attribute('tool_responses', response_messages)
355
+ response_msgs = ' '.join(r.role for r in response_messages)
327
356
  handle_span.message = f'handle model response -> {response_msgs}'
328
- messages.extend(tool_responses)
329
357
  # the model_response should have been fully streamed by now, we can add it's cost
330
358
  cost += model_response.cost()
331
359
 
@@ -475,7 +503,11 @@ class Agent(Generic[AgentDeps, ResultData]):
475
503
 
476
504
  @overload
477
505
  def tool(
478
- self, /, *, retries: int | None = None
506
+ self,
507
+ /,
508
+ *,
509
+ retries: int | None = None,
510
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
479
511
  ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
480
512
 
481
513
  def tool(
@@ -484,9 +516,9 @@ class Agent(Generic[AgentDeps, ResultData]):
484
516
  /,
485
517
  *,
486
518
  retries: int | None = None,
519
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
487
520
  ) -> Any:
488
- """Decorator to register a tool function which takes
489
- [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
521
+ """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
490
522
 
491
523
  Can decorate a sync or async functions.
492
524
 
@@ -519,20 +551,23 @@ class Agent(Generic[AgentDeps, ResultData]):
519
551
  func: The tool function to register.
520
552
  retries: The number of retries to allow for this tool, defaults to the agent's default retries,
521
553
  which defaults to 1.
522
- """ # noqa: D205
554
+ prepare: custom method to prepare the tool definition for each step, return `None` to omit this
555
+ tool from a given step. This is useful if you want to customise a tool at call time,
556
+ or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
557
+ """
523
558
  if func is None:
524
559
 
525
560
  def tool_decorator(
526
561
  func_: ToolFuncContext[AgentDeps, ToolParams],
527
562
  ) -> ToolFuncContext[AgentDeps, ToolParams]:
528
563
  # noinspection PyTypeChecker
529
- self._register_function(func_, True, retries)
564
+ self._register_function(func_, True, retries, prepare)
530
565
  return func_
531
566
 
532
567
  return tool_decorator
533
568
  else:
534
569
  # noinspection PyTypeChecker
535
- self._register_function(func, True, retries)
570
+ self._register_function(func, True, retries, prepare)
536
571
  return func
537
572
 
538
573
  @overload
@@ -540,10 +575,21 @@ class Agent(Generic[AgentDeps, ResultData]):
540
575
 
541
576
  @overload
542
577
  def tool_plain(
543
- self, /, *, retries: int | None = None
578
+ self,
579
+ /,
580
+ *,
581
+ retries: int | None = None,
582
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
544
583
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
545
584
 
546
- def tool_plain(self, func: ToolFuncPlain[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
585
+ def tool_plain(
586
+ self,
587
+ func: ToolFuncPlain[ToolParams] | None = None,
588
+ /,
589
+ *,
590
+ retries: int | None = None,
591
+ prepare: ToolPrepareFunc[AgentDeps] | None = None,
592
+ ) -> Any:
547
593
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
548
594
 
549
595
  Can decorate a sync or async functions.
@@ -577,30 +623,38 @@ class Agent(Generic[AgentDeps, ResultData]):
577
623
  func: The tool function to register.
578
624
  retries: The number of retries to allow for this tool, defaults to the agent's default retries,
579
625
  which defaults to 1.
626
+ prepare: custom method to prepare the tool definition for each step, return `None` to omit this
627
+ tool from a given step. This is useful if you want to customise a tool at call time,
628
+ or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
580
629
  """
581
630
  if func is None:
582
631
 
583
632
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
584
633
  # noinspection PyTypeChecker
585
- self._register_function(func_, False, retries)
634
+ self._register_function(func_, False, retries, prepare)
586
635
  return func_
587
636
 
588
637
  return tool_decorator
589
638
  else:
590
- self._register_function(func, False, retries)
639
+ self._register_function(func, False, retries, prepare)
591
640
  return func
592
641
 
593
642
  def _register_function(
594
- self, func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None
643
+ self,
644
+ func: ToolFuncEither[AgentDeps, ToolParams],
645
+ takes_ctx: bool,
646
+ retries: int | None,
647
+ prepare: ToolPrepareFunc[AgentDeps] | None,
595
648
  ) -> None:
596
649
  """Private utility to register a function as a tool."""
597
650
  retries_ = retries if retries is not None else self._default_retries
598
- tool = Tool(func, takes_ctx, max_retries=retries_)
651
+ tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
599
652
  self._register_tool(tool)
600
653
 
601
654
  def _register_tool(self, tool: Tool[AgentDeps]) -> None:
602
655
  """Private utility to register a tool instance."""
603
656
  if tool.max_retries is None:
657
+ # noinspection PyTypeChecker
604
658
  tool = dataclasses.replace(tool, max_retries=self._default_retries)
605
659
 
606
660
  if tool.name in self._function_tools:
@@ -611,16 +665,14 @@ class Agent(Generic[AgentDeps, ResultData]):
611
665
 
612
666
  self._function_tools[tool.name] = tool
613
667
 
614
- async def _get_agent_model(
615
- self, model: models.Model | models.KnownModelName | None
616
- ) -> tuple[models.Model, models.Model | None, models.AgentModel]:
668
+ async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
617
669
  """Create a model configured for this agent.
618
670
 
619
671
  Args:
620
672
  model: model to use for this run, required if `model` was not set when creating the agent.
621
673
 
622
674
  Returns:
623
- a tuple of `(model used, custom_model if any, agent_model)`
675
+ a tuple of `(model used, how the model was selected)`
624
676
  """
625
677
  model_: models.Model
626
678
  if some_model := self._override_model:
@@ -631,19 +683,35 @@ class Agent(Generic[AgentDeps, ResultData]):
631
683
  '(Even when `override(model=...)` is customizing the model that will actually be called)'
632
684
  )
633
685
  model_ = some_model.value
634
- custom_model = None
686
+ mode_selection = 'override-model'
635
687
  elif model is not None:
636
- custom_model = model_ = models.infer_model(model)
688
+ model_ = models.infer_model(model)
689
+ mode_selection = 'custom'
637
690
  elif self.model is not None:
638
691
  # noinspection PyTypeChecker
639
692
  model_ = self.model = models.infer_model(self.model)
640
- custom_model = None
693
+ mode_selection = 'from-agent'
641
694
  else:
642
695
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
643
696
 
644
- result_tools = list(self._result_schema.tools.values()) if self._result_schema else None
645
- agent_model = await model_.agent_model(self._function_tools, self._allow_text_result, result_tools)
646
- return model_, custom_model, agent_model
697
+ return model_, mode_selection
698
+
699
+ async def _prepare_model(self, model: models.Model, deps: AgentDeps) -> models.AgentModel:
700
+ """Create building tools and create an agent model."""
701
+ function_tools: list[ToolDefinition] = []
702
+
703
+ async def add_tool(tool: Tool[AgentDeps]) -> None:
704
+ ctx = RunContext(deps, tool.current_retry, tool.name)
705
+ if tool_def := await tool.prepare_tool_def(ctx):
706
+ function_tools.append(tool_def)
707
+
708
+ await asyncio.gather(*map(add_tool, self._function_tools.values()))
709
+
710
+ return await model.agent_model(
711
+ function_tools=function_tools,
712
+ allow_text_result=self._allow_text_result,
713
+ result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
714
+ )
647
715
 
648
716
  async def _prepare_messages(
649
717
  self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
@@ -663,11 +731,11 @@ class Agent(Generic[AgentDeps, ResultData]):
663
731
 
664
732
  async def _handle_model_response(
665
733
  self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
666
- ) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
734
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
667
735
  """Process a non-streamed response from the model.
668
736
 
669
737
  Returns:
670
- Return `Either` left: final result data, right: list of messages to send back to the model.
738
+ A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
671
739
  """
672
740
  if model_response.role == 'model-text-response':
673
741
  # plain string response
@@ -677,15 +745,15 @@ class Agent(Generic[AgentDeps, ResultData]):
677
745
  result_data = await self._validate_result(result_data_input, deps, None)
678
746
  except _result.ToolRetryError as e:
679
747
  self._incr_result_retry()
680
- return [e.tool_retry]
748
+ return None, [e.tool_retry]
681
749
  else:
682
- return _MarkFinalResult(result_data)
750
+ return _MarkFinalResult(result_data), []
683
751
  else:
684
752
  self._incr_result_retry()
685
753
  response = _messages.RetryPrompt(
686
754
  content='Plain text responses are not permitted, please call one of the functions instead.',
687
755
  )
688
- return [response]
756
+ return None, [response]
689
757
  elif model_response.role == 'model-structured-response':
690
758
  if self._result_schema is not None:
691
759
  # if there's a result schema, and any of the calls match one of its tools, return the result
@@ -697,9 +765,15 @@ class Agent(Generic[AgentDeps, ResultData]):
697
765
  result_data = await self._validate_result(result_data, deps, call)
698
766
  except _result.ToolRetryError as e:
699
767
  self._incr_result_retry()
700
- return [e.tool_retry]
768
+ return None, [e.tool_retry]
701
769
  else:
702
- return _MarkFinalResult(result_data)
770
+ # Add a ToolReturn message for the schema tool call
771
+ tool_return = _messages.ToolReturn(
772
+ tool_name=call.tool_name,
773
+ content='Final result processed.',
774
+ tool_id=call.tool_id,
775
+ )
776
+ return _MarkFinalResult(result_data), [tool_return]
703
777
 
704
778
  if not model_response.calls:
705
779
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
@@ -714,26 +788,24 @@ class Agent(Generic[AgentDeps, ResultData]):
714
788
  messages.append(self._unknown_tool(call.tool_name))
715
789
 
716
790
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
717
- messages += await asyncio.gather(*tasks)
718
- return messages
791
+ task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
792
+ messages.extend(task_results)
793
+ return None, messages
719
794
  else:
720
795
  assert_never(model_response)
721
796
 
722
797
  async def _handle_streamed_model_response(
723
798
  self, model_response: models.EitherStreamedResponse, deps: AgentDeps
724
- ) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
799
+ ) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
725
800
  """Process a streamed response from the model.
726
801
 
727
- TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
728
- (with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
729
-
730
802
  Returns:
731
- Return `Either` left: final result data, right: list of messages to send back to the model.
803
+ A tuple of (final_result, messages). If final_result is not None, the conversation should end.
732
804
  """
733
805
  if isinstance(model_response, models.StreamTextResponse):
734
806
  # plain string response
735
807
  if self._allow_text_result:
736
- return _MarkFinalResult(model_response)
808
+ return _MarkFinalResult(model_response), []
737
809
  else:
738
810
  self._incr_result_retry()
739
811
  response = _messages.RetryPrompt(
@@ -743,7 +815,7 @@ class Agent(Generic[AgentDeps, ResultData]):
743
815
  async for _ in model_response:
744
816
  pass
745
817
 
746
- return [response]
818
+ return None, [response]
747
819
  else:
748
820
  assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
749
821
  if self._result_schema is not None:
@@ -757,8 +829,14 @@ class Agent(Generic[AgentDeps, ResultData]):
757
829
  break
758
830
  structured_msg = model_response.get()
759
831
 
760
- if self._result_schema.find_tool(structured_msg):
761
- return _MarkFinalResult(model_response)
832
+ if match := self._result_schema.find_tool(structured_msg):
833
+ call, _ = match
834
+ tool_return = _messages.ToolReturn(
835
+ tool_name=call.tool_name,
836
+ content='Final result processed.',
837
+ tool_id=call.tool_id,
838
+ )
839
+ return _MarkFinalResult(model_response), [tool_return]
762
840
 
763
841
  # the model is calling a tool function, consume the response to get the next message
764
842
  async for _ in model_response:
@@ -777,8 +855,9 @@ class Agent(Generic[AgentDeps, ResultData]):
777
855
  messages.append(self._unknown_tool(call.tool_name))
778
856
 
779
857
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
780
- messages += await asyncio.gather(*tasks)
781
- return messages
858
+ task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
859
+ messages.extend(task_results)
860
+ return None, messages
782
861
 
783
862
  async def _validate_result(
784
863
  self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
@@ -837,6 +916,12 @@ class Agent(Generic[AgentDeps, ResultData]):
837
916
  if item is self:
838
917
  self.name = name
839
918
  return
919
+ if parent_frame.f_locals != parent_frame.f_globals:
920
+ # if we couldn't find the agent in locals and globals are a different dict, try globals
921
+ for name, item in parent_frame.f_globals.items():
922
+ if item is self:
923
+ self.name = name
924
+ return
840
925
 
841
926
 
842
927
  @dataclass
@@ -844,6 +929,8 @@ class _MarkFinalResult(Generic[ResultData]):
844
929
  """Marker class to indicate that the result is the final result.
845
930
 
846
931
  This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
932
+
933
+ It also avoids problems in the case where the result type is itself `None`, but is set.
847
934
  """
848
935
 
849
936
  data: ResultData
pydantic_ai/messages.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import json
4
3
  from dataclasses import dataclass, field
5
4
  from datetime import datetime
6
5
  from typing import Annotated, Any, Literal, Union
@@ -74,6 +73,9 @@ class ToolReturn:
74
73
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
75
74
 
76
75
 
76
+ ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
77
+
78
+
77
79
  @dataclass
78
80
  class RetryPrompt:
79
81
  """A message back to a model asking it to try again.
@@ -109,7 +111,8 @@ class RetryPrompt:
109
111
  if isinstance(self.content, str):
110
112
  description = self.content
111
113
  else:
112
- description = f'{len(self.content)} validation errors: {json.dumps(self.content, indent=2)}'
114
+ json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
115
+ description = f'{len(self.content)} validation errors: {json_errors.decode()}'
113
116
  return f'{description}\n\nFix the errors and try again.'
114
117
 
115
118
 
@@ -7,11 +7,11 @@ specific LLM being used.
7
7
  from __future__ import annotations as _annotations
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
10
+ from collections.abc import AsyncIterator, Iterable, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
12
  from datetime import datetime
13
13
  from functools import cache
14
- from typing import TYPE_CHECKING, Literal, Protocol, Union
14
+ from typing import TYPE_CHECKING, Literal, Union
15
15
 
16
16
  import httpx
17
17
 
@@ -19,8 +19,8 @@ from ..exceptions import UserError
19
19
  from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
20
20
 
21
21
  if TYPE_CHECKING:
22
- from .._utils import ObjectJsonSchema
23
22
  from ..result import Cost
23
+ from ..tools import ToolDefinition
24
24
 
25
25
 
26
26
  KnownModelName = Literal[
@@ -49,6 +49,23 @@ KnownModelName = Literal[
49
49
  'gemini-1.5-pro',
50
50
  'vertexai:gemini-1.5-flash',
51
51
  'vertexai:gemini-1.5-pro',
52
+ 'ollama:codellama',
53
+ 'ollama:gemma',
54
+ 'ollama:gemma2',
55
+ 'ollama:llama3',
56
+ 'ollama:llama3.1',
57
+ 'ollama:llama3.2',
58
+ 'ollama:llama3.2-vision',
59
+ 'ollama:llama3.3',
60
+ 'ollama:mistral',
61
+ 'ollama:mistral-nemo',
62
+ 'ollama:mixtral',
63
+ 'ollama:phi3',
64
+ 'ollama:qwq',
65
+ 'ollama:qwen',
66
+ 'ollama:qwen2',
67
+ 'ollama:qwen2.5',
68
+ 'ollama:starcoder2',
52
69
  'test',
53
70
  ]
54
71
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -63,11 +80,12 @@ class Model(ABC):
63
80
  @abstractmethod
64
81
  async def agent_model(
65
82
  self,
66
- function_tools: Mapping[str, AbstractToolDefinition],
83
+ *,
84
+ function_tools: list[ToolDefinition],
67
85
  allow_text_result: bool,
68
- result_tools: Sequence[AbstractToolDefinition] | None,
86
+ result_tools: list[ToolDefinition],
69
87
  ) -> AgentModel:
70
- """Create an agent model.
88
+ """Create an agent model, this is called for each step of an agent run.
71
89
 
72
90
  This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
73
91
 
@@ -87,7 +105,7 @@ class Model(ABC):
87
105
 
88
106
 
89
107
  class AgentModel(ABC):
90
- """Model configured for a specific agent."""
108
+ """Model configured for each step of an Agent run."""
91
109
 
92
110
  @abstractmethod
93
111
  async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
@@ -238,7 +256,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
238
256
  elif model.startswith('openai:'):
239
257
  from .openai import OpenAIModel
240
258
 
241
- return OpenAIModel(model[7:]) # pyright: ignore[reportArgumentType]
259
+ return OpenAIModel(model[7:])
242
260
  elif model.startswith('gemini'):
243
261
  from .gemini import GeminiModel
244
262
 
@@ -252,39 +270,14 @@ def infer_model(model: Model | KnownModelName) -> Model:
252
270
  from .vertexai import VertexAIModel
253
271
 
254
272
  return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
273
+ elif model.startswith('ollama:'):
274
+ from .ollama import OllamaModel
275
+
276
+ return OllamaModel(model[7:])
255
277
  else:
256
278
  raise UserError(f'Unknown model: {model}')
257
279
 
258
280
 
259
- class AbstractToolDefinition(Protocol):
260
- """Abstract definition of a function/tool.
261
-
262
- This is used for both function tools and result tools.
263
- """
264
-
265
- @property
266
- def name(self) -> str:
267
- """The name of the tool."""
268
- ...
269
-
270
- @property
271
- def description(self) -> str:
272
- """The description of the tool."""
273
- ...
274
-
275
- @property
276
- def json_schema(self) -> ObjectJsonSchema:
277
- """The JSON schema for the tool's arguments."""
278
- ...
279
-
280
- @property
281
- def outer_typed_dict_key(self) -> str | None:
282
- """The key in the outer [TypedDict] that wraps a result tool.
283
-
284
- This will only be set for result tools which don't have an `object` JSON schema.
285
- """
286
-
287
-
288
281
  @cache
289
282
  def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.AsyncClient:
290
283
  """Cached HTTPX async client so multiple agents and calls can share the same client.