pydantic-ai-slim 0.0.14__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (27) hide show
  1. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/.gitignore +1 -1
  2. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/PKG-INFO +1 -2
  3. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/__init__.py +2 -1
  4. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/_griffe.py +1 -2
  5. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/_result.py +2 -2
  6. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/agent.py +130 -65
  7. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/gemini.py +11 -4
  8. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/mistral.py +10 -15
  9. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/ollama.py +4 -1
  10. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/test.py +18 -7
  11. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/result.py +43 -19
  12. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/settings.py +5 -1
  13. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/tools.py +16 -23
  14. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pyproject.toml +2 -2
  15. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/README.md +0 -0
  16. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/_pydantic.py +0 -0
  17. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/_system_prompt.py +0 -0
  18. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/_utils.py +0 -0
  19. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/exceptions.py +0 -0
  20. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/messages.py +0 -0
  21. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/__init__.py +0 -0
  22. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/anthropic.py +0 -0
  23. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/function.py +0 -0
  24. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/groq.py +0 -0
  25. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/openai.py +0 -0
  26. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/models/vertexai.py +0 -0
  27. {pydantic_ai_slim-0.0.14 → pydantic_ai_slim-0.0.16}/pydantic_ai/py.typed +0 -0
@@ -10,6 +10,6 @@ env*/
10
10
  /TODO.md
11
11
  /postgres-data/
12
12
  .DS_Store
13
- /pydantic_ai_examples/.chat_app_messages.sqlite
13
+ examples/pydantic_ai_examples/.chat_app_messages.sqlite
14
14
  .cache/
15
15
  .vscode/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -36,7 +36,6 @@ Requires-Dist: groq>=0.12.0; extra == 'groq'
36
36
  Provides-Extra: logfire
37
37
  Requires-Dist: logfire>=2.3; extra == 'logfire'
38
38
  Provides-Extra: mistral
39
- Requires-Dist: json-repair>=0.30.3; extra == 'mistral'
40
39
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
41
40
  Provides-Extra: openai
42
41
  Requires-Dist: openai>=1.54.3; extra == 'openai'
@@ -1,11 +1,12 @@
1
1
  from importlib.metadata import version
2
2
 
3
- from .agent import Agent
3
+ from .agent import Agent, capture_run_messages
4
4
  from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
5
5
  from .tools import RunContext, Tool
6
6
 
7
7
  __all__ = (
8
8
  'Agent',
9
+ 'capture_run_messages',
9
10
  'RunContext',
10
11
  'Tool',
11
12
  'AgentRunError',
@@ -4,8 +4,7 @@ import re
4
4
  from inspect import Signature
5
5
  from typing import Any, Callable, Literal, cast
6
6
 
7
- from _griffe.enumerations import DocstringSectionKind
8
- from _griffe.models import Docstring, Object as GriffeObject
7
+ from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
9
8
 
10
9
  DocstringStyle = Literal['google', 'numpy', 'sphinx']
11
10
 
@@ -12,8 +12,8 @@ from typing_extensions import Self, TypeAliasType, TypedDict
12
12
 
13
13
  from . import _utils, messages as _messages
14
14
  from .exceptions import ModelRetry
15
- from .result import ResultData
16
- from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
15
+ from .result import ResultData, ResultValidatorFunc
16
+ from .tools import AgentDeps, RunContext, ToolDefinition
17
17
 
18
18
 
19
19
  @dataclass
@@ -5,12 +5,12 @@ import dataclasses
5
5
  import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
- from dataclasses import dataclass, field
8
+ from contextvars import ContextVar
9
9
  from types import FrameType
10
10
  from typing import Any, Callable, Generic, Literal, cast, final, overload
11
11
 
12
12
  import logfire_api
13
- from typing_extensions import assert_never
13
+ from typing_extensions import assert_never, deprecated
14
14
 
15
15
  from . import (
16
16
  _result,
@@ -35,10 +35,20 @@ from .tools import (
35
35
  ToolPrepareFunc,
36
36
  )
37
37
 
38
- __all__ = ('Agent',)
38
+ __all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
39
39
 
40
40
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
41
41
 
42
+ # while waiting for https://github.com/pydantic/logfire/issues/745
43
+ try:
44
+ import logfire._internal.stack_info
45
+ except ImportError:
46
+ pass
47
+ else:
48
+ from pathlib import Path
49
+
50
+ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
51
+
42
52
  NoneType = type(None)
43
53
  EndStrategy = Literal['early', 'exhaustive']
44
54
  """The strategy for handling multiple tool calls when a final result is found.
@@ -49,7 +59,7 @@ EndStrategy = Literal['early', 'exhaustive']
49
59
 
50
60
 
51
61
  @final
52
- @dataclass(init=False)
62
+ @dataclasses.dataclass(init=False)
53
63
  class Agent(Generic[AgentDeps, ResultData]):
54
64
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
55
65
 
@@ -89,23 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]):
89
99
  be merged with this value, with the runtime argument taking priority.
90
100
  """
91
101
 
92
- last_run_messages: list[_messages.ModelMessage] | None
93
- """The messages from the last run, useful when a run raised an exception.
94
-
95
- Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
96
- """
97
-
98
- _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
99
- _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
100
- _allow_text_result: bool = field(repr=False)
101
- _system_prompts: tuple[str, ...] = field(repr=False)
102
- _function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
103
- _default_retries: int = field(repr=False)
104
- _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
105
- _deps_type: type[AgentDeps] = field(repr=False)
106
- _max_result_retries: int = field(repr=False)
107
- _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
108
- _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
102
+ _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
103
+ _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
104
+ _allow_text_result: bool = dataclasses.field(repr=False)
105
+ _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
106
+ _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
107
+ _default_retries: int = dataclasses.field(repr=False)
108
+ _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
109
+ _deps_type: type[AgentDeps] = dataclasses.field(repr=False)
110
+ _max_result_retries: int = dataclasses.field(repr=False)
111
+ _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
112
+ _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
109
113
 
110
114
  def __init__(
111
115
  self,
@@ -161,7 +165,6 @@ class Agent(Generic[AgentDeps, ResultData]):
161
165
  self.end_strategy = end_strategy
162
166
  self.name = name
163
167
  self.model_settings = model_settings
164
- self.last_run_messages = None
165
168
  self._result_schema = _result.ResultSchema[result_type].build(
166
169
  result_type, result_tool_name, result_tool_description
167
170
  )
@@ -190,6 +193,7 @@ class Agent(Generic[AgentDeps, ResultData]):
190
193
  deps: AgentDeps = None,
191
194
  model_settings: ModelSettings | None = None,
192
195
  usage_limits: UsageLimits | None = None,
196
+ usage: result.Usage | None = None,
193
197
  infer_name: bool = True,
194
198
  ) -> result.RunResult[ResultData]:
195
199
  """Run the agent with a user prompt in async mode.
@@ -212,6 +216,7 @@ class Agent(Generic[AgentDeps, ResultData]):
212
216
  deps: Optional dependencies to use for this run.
213
217
  model_settings: Optional settings to use for this model's request.
214
218
  usage_limits: Optional limits on model request count or token usage.
219
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
215
220
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
216
221
 
217
222
  Returns:
@@ -219,7 +224,7 @@ class Agent(Generic[AgentDeps, ResultData]):
219
224
  """
220
225
  if infer_name and self.name is None:
221
226
  self._infer_name(inspect.currentframe())
222
- model_used, mode_selection = await self._get_model(model)
227
+ model_used = await self._get_model(model)
223
228
 
224
229
  deps = self._get_deps(deps)
225
230
  new_message_index = len(message_history) if message_history else 0
@@ -228,40 +233,36 @@ class Agent(Generic[AgentDeps, ResultData]):
228
233
  '{agent_name} run {prompt=}',
229
234
  prompt=user_prompt,
230
235
  agent=self,
231
- mode_selection=mode_selection,
232
236
  model_name=model_used.name(),
233
237
  agent_name=self.name or 'agent',
234
238
  ) as run_span:
235
- run_context = RunContext(deps, 0, [], None, model_used)
239
+ run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
236
240
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
237
- self.last_run_messages = run_context.messages = messages
241
+ run_context.messages = messages
238
242
 
239
243
  for tool in self._function_tools.values():
240
244
  tool.current_retry = 0
241
245
 
242
- usage = result.Usage(requests=0)
243
246
  model_settings = merge_model_settings(self.model_settings, model_settings)
244
247
  usage_limits = usage_limits or UsageLimits()
245
248
 
246
- run_step = 0
247
249
  while True:
248
- usage_limits.check_before_request(usage)
250
+ usage_limits.check_before_request(run_context.usage)
249
251
 
250
- run_step += 1
251
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
252
+ run_context.run_step += 1
253
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
252
254
  agent_model = await self._prepare_model(run_context)
253
255
 
254
- with _logfire.span('model request', run_step=run_step) as model_req_span:
256
+ with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
255
257
  model_response, request_usage = await agent_model.request(messages, model_settings)
256
258
  model_req_span.set_attribute('response', model_response)
257
259
  model_req_span.set_attribute('usage', request_usage)
258
260
 
259
261
  messages.append(model_response)
260
- usage += request_usage
261
- usage.requests += 1
262
- usage_limits.check_tokens(request_usage)
262
+ run_context.usage.incr(request_usage, requests=1)
263
+ usage_limits.check_tokens(run_context.usage)
263
264
 
264
- with _logfire.span('handle model response', run_step=run_step) as handle_span:
265
+ with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
265
266
  final_result, tool_responses = await self._handle_model_response(model_response, run_context)
266
267
 
267
268
  if tool_responses:
@@ -272,10 +273,10 @@ class Agent(Generic[AgentDeps, ResultData]):
272
273
  if final_result is not None:
273
274
  result_data = final_result.data
274
275
  run_span.set_attribute('all_messages', messages)
275
- run_span.set_attribute('usage', usage)
276
+ run_span.set_attribute('usage', run_context.usage)
276
277
  handle_span.set_attribute('result', result_data)
277
278
  handle_span.message = 'handle model response -> final result'
278
- return result.RunResult(messages, new_message_index, result_data, usage)
279
+ return result.RunResult(messages, new_message_index, result_data, run_context.usage)
279
280
  else:
280
281
  # continue the conversation
281
282
  handle_span.set_attribute('tool_responses', tool_responses)
@@ -291,6 +292,7 @@ class Agent(Generic[AgentDeps, ResultData]):
291
292
  deps: AgentDeps = None,
292
293
  model_settings: ModelSettings | None = None,
293
294
  usage_limits: UsageLimits | None = None,
295
+ usage: result.Usage | None = None,
294
296
  infer_name: bool = True,
295
297
  ) -> result.RunResult[ResultData]:
296
298
  """Run the agent with a user prompt synchronously.
@@ -317,6 +319,7 @@ class Agent(Generic[AgentDeps, ResultData]):
317
319
  deps: Optional dependencies to use for this run.
318
320
  model_settings: Optional settings to use for this model's request.
319
321
  usage_limits: Optional limits on model request count or token usage.
322
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
320
323
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
321
324
 
322
325
  Returns:
@@ -332,6 +335,7 @@ class Agent(Generic[AgentDeps, ResultData]):
332
335
  deps=deps,
333
336
  model_settings=model_settings,
334
337
  usage_limits=usage_limits,
338
+ usage=usage,
335
339
  infer_name=False,
336
340
  )
337
341
  )
@@ -346,6 +350,7 @@ class Agent(Generic[AgentDeps, ResultData]):
346
350
  deps: AgentDeps = None,
347
351
  model_settings: ModelSettings | None = None,
348
352
  usage_limits: UsageLimits | None = None,
353
+ usage: result.Usage | None = None,
349
354
  infer_name: bool = True,
350
355
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
351
356
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -369,6 +374,7 @@ class Agent(Generic[AgentDeps, ResultData]):
369
374
  deps: Optional dependencies to use for this run.
370
375
  model_settings: Optional settings to use for this model's request.
371
376
  usage_limits: Optional limits on model request count or token usage.
377
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
372
378
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
373
379
 
374
380
  Returns:
@@ -378,7 +384,7 @@ class Agent(Generic[AgentDeps, ResultData]):
378
384
  # f_back because `asynccontextmanager` adds one frame
379
385
  if frame := inspect.currentframe(): # pragma: no branch
380
386
  self._infer_name(frame.f_back)
381
- model_used, mode_selection = await self._get_model(model)
387
+ model_used = await self._get_model(model)
382
388
 
383
389
  deps = self._get_deps(deps)
384
390
  new_message_index = len(message_history) if message_history else 0
@@ -387,32 +393,29 @@ class Agent(Generic[AgentDeps, ResultData]):
387
393
  '{agent_name} run stream {prompt=}',
388
394
  prompt=user_prompt,
389
395
  agent=self,
390
- mode_selection=mode_selection,
391
396
  model_name=model_used.name(),
392
397
  agent_name=self.name or 'agent',
393
398
  ) as run_span:
394
- run_context = RunContext(deps, 0, [], None, model_used)
399
+ run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
395
400
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
396
- self.last_run_messages = run_context.messages = messages
401
+ run_context.messages = messages
397
402
 
398
403
  for tool in self._function_tools.values():
399
404
  tool.current_retry = 0
400
405
 
401
- usage = result.Usage()
402
406
  model_settings = merge_model_settings(self.model_settings, model_settings)
403
407
  usage_limits = usage_limits or UsageLimits()
404
408
 
405
- run_step = 0
406
409
  while True:
407
- run_step += 1
408
- usage_limits.check_before_request(usage)
410
+ run_context.run_step += 1
411
+ usage_limits.check_before_request(run_context.usage)
409
412
 
410
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
413
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
411
414
  agent_model = await self._prepare_model(run_context)
412
415
 
413
- with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
416
+ with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
414
417
  async with agent_model.request_stream(messages, model_settings) as model_response:
415
- usage.requests += 1
418
+ run_context.usage.requests += 1
416
419
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
417
420
  # We want to end the "model request" span here, but we can't exit the context manager
418
421
  # in the traditional way
@@ -448,7 +451,6 @@ class Agent(Generic[AgentDeps, ResultData]):
448
451
  yield result.StreamedRunResult(
449
452
  messages,
450
453
  new_message_index,
451
- usage,
452
454
  usage_limits,
453
455
  result_stream,
454
456
  self._result_schema,
@@ -472,8 +474,8 @@ class Agent(Generic[AgentDeps, ResultData]):
472
474
  handle_span.message = f'handle model response -> {tool_responses_str}'
473
475
  # the model_response should have been fully streamed by now, we can add its usage
474
476
  model_response_usage = model_response.usage()
475
- usage += model_response_usage
476
- usage_limits.check_tokens(usage)
477
+ run_context.usage.incr(model_response_usage)
478
+ usage_limits.check_tokens(run_context.usage)
477
479
 
478
480
  @contextmanager
479
481
  def override(
@@ -614,7 +616,7 @@ class Agent(Generic[AgentDeps, ResultData]):
614
616
  #> success (no tool calls)
615
617
  ```
616
618
  """
617
- self._result_validators.append(_result.ResultValidator(func))
619
+ self._result_validators.append(_result.ResultValidator[AgentDeps, Any](func))
618
620
  return func
619
621
 
620
622
  @overload
@@ -784,14 +786,14 @@ class Agent(Generic[AgentDeps, ResultData]):
784
786
 
785
787
  self._function_tools[tool.name] = tool
786
788
 
787
- async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
789
+ async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
788
790
  """Create a model configured for this agent.
789
791
 
790
792
  Args:
791
793
  model: model to use for this run, required if `model` was not set when creating the agent.
792
794
 
793
795
  Returns:
794
- a tuple of `(model used, how the model was selected)`
796
+ The model used
795
797
  """
796
798
  model_: models.Model
797
799
  if some_model := self._override_model:
@@ -802,18 +804,15 @@ class Agent(Generic[AgentDeps, ResultData]):
802
804
  '(Even when `override(model=...)` is customizing the model that will actually be called)'
803
805
  )
804
806
  model_ = some_model.value
805
- mode_selection = 'override-model'
806
807
  elif model is not None:
807
808
  model_ = models.infer_model(model)
808
- mode_selection = 'custom'
809
809
  elif self.model is not None:
810
810
  # noinspection PyTypeChecker
811
811
  model_ = self.model = models.infer_model(self.model)
812
- mode_selection = 'from-agent'
813
812
  else:
814
813
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
815
814
 
816
- return model_, mode_selection
815
+ return model_
817
816
 
818
817
  async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
819
818
  """Build tools and create an agent model."""
@@ -835,14 +834,25 @@ class Agent(Generic[AgentDeps, ResultData]):
835
834
  async def _prepare_messages(
836
835
  self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
837
836
  ) -> list[_messages.ModelMessage]:
837
+ try:
838
+ ctx_messages = _messages_ctx_var.get()
839
+ except LookupError:
840
+ messages: list[_messages.ModelMessage] = []
841
+ else:
842
+ if ctx_messages.used:
843
+ messages = []
844
+ else:
845
+ messages = ctx_messages.messages
846
+ ctx_messages.used = True
847
+
838
848
  if message_history:
839
849
  # shallow copy messages
840
- messages = message_history.copy()
850
+ messages.extend(message_history)
841
851
  messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
842
852
  else:
843
853
  parts = await self._sys_parts(run_context)
844
854
  parts.append(_messages.UserPromptPart(user_prompt))
845
- messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
855
+ messages.append(_messages.ModelRequest(parts))
846
856
 
847
857
  return messages
848
858
 
@@ -864,11 +874,15 @@ class Agent(Generic[AgentDeps, ResultData]):
864
874
  else:
865
875
  tool_calls.append(part)
866
876
 
867
- if texts:
877
+ # At the moment, we prioritize at least executing tool calls if they are present.
878
+ # In the future, we'd consider making this configurable at the agent or run level.
879
+ # This accounts for cases like anthropic returns that might contain a text response
880
+ # and a tool call response, where the text response just indicates the tool call will happen.
881
+ if tool_calls:
882
+ return await self._handle_structured_response(tool_calls, run_context)
883
+ elif texts:
868
884
  text = '\n\n'.join(texts)
869
885
  return await self._handle_text_response(text, run_context)
870
- elif tool_calls:
871
- return await self._handle_structured_response(tool_calls, run_context)
872
886
  else:
873
887
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
874
888
 
@@ -1115,8 +1129,59 @@ class Agent(Generic[AgentDeps, ResultData]):
1115
1129
  self.name = name
1116
1130
  return
1117
1131
 
1132
+ @property
1133
+ @deprecated(
1134
+ 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
1135
+ )
1136
+ def last_run_messages(self) -> list[_messages.ModelMessage]:
1137
+ raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1138
+
1139
+
1140
+ @dataclasses.dataclass
1141
+ class _RunMessages:
1142
+ messages: list[_messages.ModelMessage]
1143
+ used: bool = False
1144
+
1145
+
1146
+ _messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
1147
+
1148
+
1149
+ @contextmanager
1150
+ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1151
+ """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
1152
+
1153
+ Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
1154
+
1155
+ Examples:
1156
+ ```python
1157
+ from pydantic_ai import Agent, capture_run_messages
1158
+
1159
+ agent = Agent('test')
1160
+
1161
+ with capture_run_messages() as messages:
1162
+ try:
1163
+ result = agent.run_sync('foobar')
1164
+ except Exception:
1165
+ print(messages)
1166
+ raise
1167
+ ```
1168
+
1169
+ !!! note
1170
+ If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
1171
+ `messages` will represent the messages exchanged during the first call only.
1172
+ """
1173
+ try:
1174
+ yield _messages_ctx_var.get().messages
1175
+ except LookupError:
1176
+ messages: list[_messages.ModelMessage] = []
1177
+ token = _messages_ctx_var.set(_RunMessages(messages))
1178
+ try:
1179
+ yield messages
1180
+ finally:
1181
+ _messages_ctx_var.reset(token)
1182
+
1118
1183
 
1119
- @dataclass
1184
+ @dataclasses.dataclass
1120
1185
  class _MarkFinalResult(Generic[ResultData]):
1121
1186
  """Marker class to indicate that the result is the final result.
1122
1187
 
@@ -444,7 +444,8 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
444
444
  if isinstance(item, ToolCallPart):
445
445
  parts.append(_function_call_part_from_call(item))
446
446
  elif isinstance(item, TextPart):
447
- parts.append(_GeminiTextPart(text=item.content))
447
+ if item.content:
448
+ parts.append(_GeminiTextPart(text=item.content))
448
449
  else:
449
450
  assert_never(item)
450
451
  return _GeminiContent(role='model', parts=parts)
@@ -701,7 +702,7 @@ class _GeminiJsonSchema:
701
702
 
702
703
  def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
703
704
  schema.pop('title', None)
704
- schema.pop('default', None)
705
+ default = schema.pop('default', _utils.UNSET)
705
706
  if ref := schema.pop('$ref', None):
706
707
  # noinspection PyTypeChecker
707
708
  key = re.sub(r'^#/\$defs/', '', ref)
@@ -714,8 +715,14 @@ class _GeminiJsonSchema:
714
715
  return
715
716
 
716
717
  if any_of := schema.get('anyOf'):
717
- for schema in any_of:
718
- self._simplify(schema, refs_stack)
718
+ for item_schema in any_of:
719
+ self._simplify(item_schema, refs_stack)
720
+ if len(any_of) == 2 and {'type': 'null'} in any_of and default is None:
721
+ for item_schema in any_of:
722
+ if item_schema != {'type': 'null'}:
723
+ schema.clear()
724
+ schema.update(item_schema)
725
+ return
719
726
 
720
727
  type_ = schema.get('type')
721
728
 
@@ -8,6 +8,7 @@ from datetime import datetime, timezone
8
8
  from itertools import chain
9
9
  from typing import Any, Callable, Literal, Union
10
10
 
11
+ import pydantic_core
11
12
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
12
13
  from typing_extensions import assert_never
13
14
 
@@ -39,7 +40,6 @@ from . import (
39
40
  )
40
41
 
41
42
  try:
42
- from json_repair import repair_json
43
43
  from mistralai import (
44
44
  UNSET,
45
45
  CompletionChunk as MistralCompletionChunk,
@@ -198,11 +198,10 @@ class MistralAgentModel(AgentModel):
198
198
  """Create a streaming completion request to the Mistral model."""
199
199
  response: MistralEventStreamAsync[MistralCompletionEvent] | None
200
200
  mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
201
-
202
201
  model_settings = model_settings or {}
203
202
 
204
203
  if self.result_tools and self.function_tools or self.function_tools:
205
- # Function Calling Mode
204
+ # Function Calling
206
205
  response = await self.client.chat.stream_async(
207
206
  model=str(self.model_name),
208
207
  messages=mistral_messages,
@@ -218,9 +217,9 @@ class MistralAgentModel(AgentModel):
218
217
  elif self.result_tools:
219
218
  # Json Mode
220
219
  parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
221
-
222
220
  user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
223
221
  mistral_messages.append(user_output_format_message)
222
+
224
223
  response = await self.client.chat.stream_async(
225
224
  model=str(self.model_name),
226
225
  messages=mistral_messages,
@@ -270,12 +269,13 @@ class MistralAgentModel(AgentModel):
270
269
  @staticmethod
271
270
  def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
272
271
  """Process a non-streamed response, and prepare a message to return."""
272
+ assert response.choices, 'Unexpected empty response choice.'
273
+
273
274
  if response.created:
274
275
  timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
275
276
  else:
276
277
  timestamp = _now_utc()
277
278
 
278
- assert response.choices, 'Unexpected empty response choice.'
279
279
  choice = response.choices[0]
280
280
  content = choice.message.content
281
281
  tool_calls = choice.message.tool_calls
@@ -546,20 +546,15 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
546
546
  calls.append(tool)
547
547
 
548
548
  elif self._delta_content and self._result_tools:
549
- # NOTE: Params set for the most efficient and fastest way.
550
- output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True)
551
- assert isinstance(
552
- output_json, dict
553
- ), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
549
+ output_json: dict[str, Any] | None = pydantic_core.from_json(
550
+ self._delta_content, allow_partial='trailing-strings'
551
+ )
554
552
 
555
553
  if output_json:
556
554
  for result_tool in self._result_tools.values():
557
- # NOTE: Additional verification to prevent JSON validation to crash in `result.py`
555
+ # NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
558
556
  # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
559
- # For example, `return_type=list[str]` expects a 'response' key with value type array of str.
560
- # when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
561
- # when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
562
- # This ensures it's corrected to `{"response": {}}` and other required parameters and type.
557
+ # Example with BaseModel and required fields.
563
558
  if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
564
559
  continue
565
560
 
@@ -71,6 +71,7 @@ class OllamaModel(Model):
71
71
  model_name: OllamaModelName,
72
72
  *,
73
73
  base_url: str | None = 'http://localhost:11434/v1/',
74
+ api_key: str = 'ollama',
74
75
  openai_client: AsyncOpenAI | None = None,
75
76
  http_client: AsyncHTTPClient | None = None,
76
77
  ):
@@ -83,6 +84,8 @@ class OllamaModel(Model):
83
84
  model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
84
85
  You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
85
86
  base_url: The base url for the ollama requests. The default value is the ollama default
87
+ api_key: The API key to use for authentication. Defaults to 'ollama' for local instances,
88
+ but can be customized for proxy setups that require authentication
86
89
  openai_client: An existing
87
90
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
88
91
  client to use, if provided, `base_url` and `http_client` must be `None`.
@@ -96,7 +99,7 @@ class OllamaModel(Model):
96
99
  else:
97
100
  # API key is not required for ollama but a value is required to create the client
98
101
  http_client_ = http_client or cached_async_http_client()
99
- oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
102
+ oai_client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client_)
100
103
  self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
101
104
 
102
105
  async def agent_model(
@@ -16,6 +16,7 @@ from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
+ ModelResponsePart,
19
20
  RetryPromptPart,
20
21
  TextPart,
21
22
  ToolCallPart,
@@ -177,13 +178,23 @@ class TestAgentModel(AgentModel):
177
178
  # check if there are any retry prompts, if so retry them
178
179
  new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
179
180
  if new_retry_names:
180
- return ModelResponse(
181
- parts=[
182
- ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
183
- for name, args in self.tool_calls
184
- if name in new_retry_names
185
- ]
186
- )
181
+ # Handle retries for both function tools and result tools
182
+ # Check function tools first
183
+ retry_parts: list[ModelResponsePart] = [
184
+ ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
185
+ for name, args in self.tool_calls
186
+ if name in new_retry_names
187
+ ]
188
+ # Check result tools
189
+ if self.result_tools:
190
+ retry_parts.extend(
191
+ [
192
+ ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
193
+ for tool in self.result_tools
194
+ if tool.name in new_retry_names
195
+ ]
196
+ )
197
+ return ModelResponse(parts=retry_parts)
187
198
 
188
199
  if response_text := self.result.left:
189
200
  if response_text.value is None:
@@ -2,11 +2,13 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import AsyncIterator, Awaitable, Callable
5
+ from copy import copy
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime
7
- from typing import Generic, TypeVar, cast
8
+ from typing import Generic, Union, cast
8
9
 
9
10
  import logfire_api
11
+ from typing_extensions import TypeVar
10
12
 
11
13
  from . import _result, _utils, exceptions, messages as _messages, models
12
14
  from .settings import UsageLimits
@@ -14,21 +16,37 @@ from .tools import AgentDeps, RunContext
14
16
 
15
17
  __all__ = (
16
18
  'ResultData',
19
+ 'ResultValidatorFunc',
17
20
  'Usage',
18
21
  'RunResult',
19
22
  'StreamedRunResult',
20
23
  )
21
24
 
22
25
 
23
- ResultData = TypeVar('ResultData')
26
+ ResultData = TypeVar('ResultData', default=str)
24
27
  """Type variable for the result data of a run."""
25
28
 
29
+ ResultValidatorFunc = Union[
30
+ Callable[[RunContext[AgentDeps], ResultData], ResultData],
31
+ Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
32
+ Callable[[ResultData], ResultData],
33
+ Callable[[ResultData], Awaitable[ResultData]],
34
+ ]
35
+ """
36
+ A function that always takes `ResultData` and returns `ResultData` and:
37
+
38
+ * may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
39
+ * may or may not be async
40
+
41
+ Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
42
+ """
43
+
26
44
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
27
45
 
28
46
 
29
47
  @dataclass
30
48
  class Usage:
31
- """LLM usage associated to a request or run.
49
+ """LLM usage associated with a request or run.
32
50
 
33
51
  Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
34
52
 
@@ -36,7 +54,7 @@ class Usage:
36
54
  """
37
55
 
38
56
  requests: int = 0
39
- """Number of requests made."""
57
+ """Number of requests made to the LLM API."""
40
58
  request_tokens: int | None = None
41
59
  """Tokens used in processing requests."""
42
60
  response_tokens: int | None = None
@@ -46,25 +64,33 @@ class Usage:
46
64
  details: dict[str, int] | None = None
47
65
  """Any extra details returned by the model."""
48
66
 
49
- def __add__(self, other: Usage) -> Usage:
50
- """Add two Usages together.
67
+ def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
68
+ """Increment the usage in place.
51
69
 
52
- This is provided so it's trivial to sum usage information from multiple requests and runs.
70
+ Args:
71
+ incr_usage: The usage to increment by.
72
+ requests: The number of requests to increment by in addition to `incr_usage.requests`.
53
73
  """
54
- counts: dict[str, int] = {}
74
+ self.requests += requests
55
75
  for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
56
76
  self_value = getattr(self, f)
57
- other_value = getattr(other, f)
77
+ other_value = getattr(incr_usage, f)
58
78
  if self_value is not None or other_value is not None:
59
- counts[f] = (self_value or 0) + (other_value or 0)
79
+ setattr(self, f, (self_value or 0) + (other_value or 0))
80
+
81
+ if incr_usage.details:
82
+ self.details = self.details or {}
83
+ for key, value in incr_usage.details.items():
84
+ self.details[key] = self.details.get(key, 0) + value
60
85
 
61
- details = self.details.copy() if self.details is not None else None
62
- if other.details is not None:
63
- details = details or {}
64
- for key, value in other.details.items():
65
- details[key] = details.get(key, 0) + value
86
+ def __add__(self, other: Usage) -> Usage:
87
+ """Add two Usages together.
66
88
 
67
- return Usage(**counts, details=details or None)
89
+ This is provided so it's trivial to sum usage information from multiple requests and runs.
90
+ """
91
+ new_usage = copy(self)
92
+ new_usage.incr(other)
93
+ return new_usage
68
94
 
69
95
 
70
96
  @dataclass
@@ -119,8 +145,6 @@ class RunResult(_BaseRunResult[ResultData]):
119
145
  class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
120
146
  """Result of a streamed run that returns structured data via a tool call."""
121
147
 
122
- usage_so_far: Usage
123
- """Usage of the run up until the last request."""
124
148
  _usage_limits: UsageLimits | None
125
149
  _stream_response: models.EitherStreamedResponse
126
150
  _result_schema: _result.ResultSchema[ResultData] | None
@@ -289,7 +313,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
289
313
  !!! note
290
314
  This won't return the full usage until the stream is finished.
291
315
  """
292
- return self.usage_so_far + self._stream_response.usage()
316
+ return self._run_ctx.usage + self._stream_response.usage()
293
317
 
294
318
  def timestamp(self) -> datetime:
295
319
  """Get the timestamp of the response."""
@@ -22,6 +22,7 @@ class ModelSettings(TypedDict, total=False):
22
22
  """The maximum number of tokens to generate before stopping.
23
23
 
24
24
  Supported by:
25
+
25
26
  * Gemini
26
27
  * Anthropic
27
28
  * OpenAI
@@ -37,6 +38,7 @@ class ModelSettings(TypedDict, total=False):
37
38
  Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
38
39
 
39
40
  Supported by:
41
+
40
42
  * Gemini
41
43
  * Anthropic
42
44
  * OpenAI
@@ -51,6 +53,7 @@ class ModelSettings(TypedDict, total=False):
51
53
  You should either alter `temperature` or `top_p`, but not both.
52
54
 
53
55
  Supported by:
56
+
54
57
  * Gemini
55
58
  * Anthropic
56
59
  * OpenAI
@@ -61,6 +64,7 @@ class ModelSettings(TypedDict, total=False):
61
64
  """Override the client-level default timeout for a request, in seconds.
62
65
 
63
66
  Supported by:
67
+
64
68
  * Gemini
65
69
  * Anthropic
66
70
  * OpenAI
@@ -132,6 +136,6 @@ class UsageLimits:
132
136
  f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
133
137
  )
134
138
 
135
- total_tokens = request_tokens + response_tokens
139
+ total_tokens = usage.total_tokens or 0
136
140
  if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
137
141
  raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
@@ -4,7 +4,7 @@ import dataclasses
4
4
  import inspect
5
5
  from collections.abc import Awaitable
6
6
  from dataclasses import dataclass, field
7
- from typing import Any, Callable, Generic, TypeVar, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
10
  from pydantic_core import SchemaValidator
@@ -13,10 +13,12 @@ from typing_extensions import Concatenate, ParamSpec, TypeAlias
13
13
  from . import _pydantic, _utils, messages as _messages, models
14
14
  from .exceptions import ModelRetry, UnexpectedModelBehavior
15
15
 
16
+ if TYPE_CHECKING:
17
+ from .result import Usage
18
+
16
19
  __all__ = (
17
20
  'AgentDeps',
18
21
  'RunContext',
19
- 'ResultValidatorFunc',
20
22
  'SystemPromptFunc',
21
23
  'ToolFuncContext',
22
24
  'ToolFuncPlain',
@@ -38,14 +40,20 @@ class RunContext(Generic[AgentDeps]):
38
40
 
39
41
  deps: AgentDeps
40
42
  """Dependencies for the agent."""
41
- retry: int
42
- """Number of retries so far."""
43
- messages: list[_messages.ModelMessage]
44
- """Messages exchanged in the conversation so far."""
45
- tool_name: str | None
46
- """Name of the tool being called."""
47
43
  model: models.Model
48
44
  """The model used in this run."""
45
+ usage: Usage
46
+ """LLM usage associated with the run."""
47
+ prompt: str
48
+ """The original user prompt passed to the run."""
49
+ messages: list[_messages.ModelMessage] = field(default_factory=list)
50
+ """Messages exchanged in the conversation so far."""
51
+ tool_name: str | None = None
52
+ """Name of the tool being called."""
53
+ retry: int = 0
54
+ """Number of retries so far."""
55
+ run_step: int = 0
56
+ """The current step in the run."""
49
57
 
50
58
  def replace_with(
51
59
  self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
@@ -73,21 +81,6 @@ SystemPromptFunc = Union[
73
81
  Usage `SystemPromptFunc[AgentDeps]`.
74
82
  """
75
83
 
76
- ResultData = TypeVar('ResultData')
77
-
78
- ResultValidatorFunc = Union[
79
- Callable[[RunContext[AgentDeps], ResultData], ResultData],
80
- Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
81
- Callable[[ResultData], ResultData],
82
- Callable[[ResultData], Awaitable[ResultData]],
83
- ]
84
- """
85
- A function that always takes `ResultData` and returns `ResultData`,
86
- but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
87
-
88
- Usage `ResultValidator[AgentDeps, ResultData]`.
89
- """
90
-
91
84
  ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
92
85
  """A tool function that takes `RunContext` as the first argument.
93
86
 
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai-slim"
7
- version = "0.0.14"
7
+ version = "0.0.16"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
9
9
  authors = [
10
10
  { name = "Samuel Colvin", email = "samuel@pydantic.dev" },
@@ -46,7 +46,7 @@ openai = ["openai>=1.54.3"]
46
46
  vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
47
47
  anthropic = ["anthropic>=0.40.0"]
48
48
  groq = ["groq>=0.12.0"]
49
- mistral = ["mistralai>=1.2.5", "json-repair>=0.30.3"]
49
+ mistral = ["mistralai>=1.2.5"]
50
50
  logfire = ["logfire>=2.3"]
51
51
 
52
52
  [dependency-groups]