pydantic-ai-slim 0.0.15__tar.gz → 0.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (29) hide show
  1. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/PKG-INFO +1 -1
  2. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/_griffe.py +1 -2
  3. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/agent.py +81 -66
  4. pydantic_ai_slim-0.0.17/pydantic_ai/format_as_xml.py +115 -0
  5. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/gemini.py +23 -25
  6. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/ollama.py +4 -1
  7. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/test.py +18 -7
  8. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/vertexai.py +1 -1
  9. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/result.py +93 -63
  10. pydantic_ai_slim-0.0.17/pydantic_ai/settings.py +81 -0
  11. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/tools.py +22 -13
  12. pydantic_ai_slim-0.0.15/pydantic_ai/settings.py → pydantic_ai_slim-0.0.17/pydantic_ai/usage.py +47 -74
  13. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pyproject.toml +1 -1
  14. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/.gitignore +0 -0
  15. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/README.md +0 -0
  16. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/__init__.py +0 -0
  17. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/_pydantic.py +0 -0
  18. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/_result.py +0 -0
  19. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/_system_prompt.py +0 -0
  20. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/_utils.py +0 -0
  21. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/exceptions.py +0 -0
  22. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/messages.py +0 -0
  23. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/__init__.py +0 -0
  24. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/anthropic.py +0 -0
  25. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/function.py +0 -0
  26. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/groq.py +0 -0
  27. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/mistral.py +0 -0
  28. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/models/openai.py +0 -0
  29. {pydantic_ai_slim-0.0.15 → pydantic_ai_slim-0.0.17}/pydantic_ai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -4,8 +4,7 @@ import re
4
4
  from inspect import Signature
5
5
  from typing import Any, Callable, Literal, cast
6
6
 
7
- from _griffe.enumerations import DocstringSectionKind
8
- from _griffe.models import Docstring, Object as GriffeObject
7
+ from griffe import Docstring, DocstringSectionKind, Object as GriffeObject
9
8
 
10
9
  DocstringStyle = Literal['google', 'numpy', 'sphinx']
11
10
 
@@ -6,7 +6,6 @@ import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
- from dataclasses import dataclass, field
10
9
  from types import FrameType
11
10
  from typing import Any, Callable, Generic, Literal, cast, final, overload
12
11
 
@@ -21,9 +20,10 @@ from . import (
21
20
  messages as _messages,
22
21
  models,
23
22
  result,
23
+ usage as _usage,
24
24
  )
25
25
  from .result import ResultData
26
- from .settings import ModelSettings, UsageLimits, merge_model_settings
26
+ from .settings import ModelSettings, merge_model_settings
27
27
  from .tools import (
28
28
  AgentDeps,
29
29
  RunContext,
@@ -40,6 +40,16 @@ __all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
40
40
 
41
41
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
42
42
 
43
+ # while waiting for https://github.com/pydantic/logfire/issues/745
44
+ try:
45
+ import logfire._internal.stack_info
46
+ except ImportError:
47
+ pass
48
+ else:
49
+ from pathlib import Path
50
+
51
+ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
52
+
43
53
  NoneType = type(None)
44
54
  EndStrategy = Literal['early', 'exhaustive']
45
55
  """The strategy for handling multiple tool calls when a final result is found.
@@ -50,7 +60,7 @@ EndStrategy = Literal['early', 'exhaustive']
50
60
 
51
61
 
52
62
  @final
53
- @dataclass(init=False)
63
+ @dataclasses.dataclass(init=False)
54
64
  class Agent(Generic[AgentDeps, ResultData]):
55
65
  """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
56
66
 
@@ -90,17 +100,17 @@ class Agent(Generic[AgentDeps, ResultData]):
90
100
  be merged with this value, with the runtime argument taking priority.
91
101
  """
92
102
 
93
- _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
94
- _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
95
- _allow_text_result: bool = field(repr=False)
96
- _system_prompts: tuple[str, ...] = field(repr=False)
97
- _function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
98
- _default_retries: int = field(repr=False)
99
- _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
100
- _deps_type: type[AgentDeps] = field(repr=False)
101
- _max_result_retries: int = field(repr=False)
102
- _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
103
- _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
103
+ _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
104
+ _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
105
+ _allow_text_result: bool = dataclasses.field(repr=False)
106
+ _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
107
+ _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
108
+ _default_retries: int = dataclasses.field(repr=False)
109
+ _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
110
+ _deps_type: type[AgentDeps] = dataclasses.field(repr=False)
111
+ _max_result_retries: int = dataclasses.field(repr=False)
112
+ _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
113
+ _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
104
114
 
105
115
  def __init__(
106
116
  self,
@@ -183,7 +193,8 @@ class Agent(Generic[AgentDeps, ResultData]):
183
193
  model: models.Model | models.KnownModelName | None = None,
184
194
  deps: AgentDeps = None,
185
195
  model_settings: ModelSettings | None = None,
186
- usage_limits: UsageLimits | None = None,
196
+ usage_limits: _usage.UsageLimits | None = None,
197
+ usage: _usage.Usage | None = None,
187
198
  infer_name: bool = True,
188
199
  ) -> result.RunResult[ResultData]:
189
200
  """Run the agent with a user prompt in async mode.
@@ -206,6 +217,7 @@ class Agent(Generic[AgentDeps, ResultData]):
206
217
  deps: Optional dependencies to use for this run.
207
218
  model_settings: Optional settings to use for this model's request.
208
219
  usage_limits: Optional limits on model request count or token usage.
220
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
209
221
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
210
222
 
211
223
  Returns:
@@ -213,7 +225,7 @@ class Agent(Generic[AgentDeps, ResultData]):
213
225
  """
214
226
  if infer_name and self.name is None:
215
227
  self._infer_name(inspect.currentframe())
216
- model_used, mode_selection = await self._get_model(model)
228
+ model_used = await self._get_model(model)
217
229
 
218
230
  deps = self._get_deps(deps)
219
231
  new_message_index = len(message_history) if message_history else 0
@@ -222,40 +234,36 @@ class Agent(Generic[AgentDeps, ResultData]):
222
234
  '{agent_name} run {prompt=}',
223
235
  prompt=user_prompt,
224
236
  agent=self,
225
- mode_selection=mode_selection,
226
237
  model_name=model_used.name(),
227
238
  agent_name=self.name or 'agent',
228
239
  ) as run_span:
229
- run_context = RunContext(deps, 0, [], None, model_used)
240
+ run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
230
241
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
231
242
  run_context.messages = messages
232
243
 
233
244
  for tool in self._function_tools.values():
234
245
  tool.current_retry = 0
235
246
 
236
- usage = result.Usage(requests=0)
237
247
  model_settings = merge_model_settings(self.model_settings, model_settings)
238
- usage_limits = usage_limits or UsageLimits()
248
+ usage_limits = usage_limits or _usage.UsageLimits()
239
249
 
240
- run_step = 0
241
250
  while True:
242
- usage_limits.check_before_request(usage)
251
+ usage_limits.check_before_request(run_context.usage)
243
252
 
244
- run_step += 1
245
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
253
+ run_context.run_step += 1
254
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
246
255
  agent_model = await self._prepare_model(run_context)
247
256
 
248
- with _logfire.span('model request', run_step=run_step) as model_req_span:
257
+ with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
249
258
  model_response, request_usage = await agent_model.request(messages, model_settings)
250
259
  model_req_span.set_attribute('response', model_response)
251
260
  model_req_span.set_attribute('usage', request_usage)
252
261
 
253
262
  messages.append(model_response)
254
- usage += request_usage
255
- usage.requests += 1
256
- usage_limits.check_tokens(request_usage)
263
+ run_context.usage.incr(request_usage, requests=1)
264
+ usage_limits.check_tokens(run_context.usage)
257
265
 
258
- with _logfire.span('handle model response', run_step=run_step) as handle_span:
266
+ with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
259
267
  final_result, tool_responses = await self._handle_model_response(model_response, run_context)
260
268
 
261
269
  if tool_responses:
@@ -265,11 +273,14 @@ class Agent(Generic[AgentDeps, ResultData]):
265
273
  # Check if we got a final result
266
274
  if final_result is not None:
267
275
  result_data = final_result.data
276
+ result_tool_name = final_result.tool_name
268
277
  run_span.set_attribute('all_messages', messages)
269
- run_span.set_attribute('usage', usage)
278
+ run_span.set_attribute('usage', run_context.usage)
270
279
  handle_span.set_attribute('result', result_data)
271
280
  handle_span.message = 'handle model response -> final result'
272
- return result.RunResult(messages, new_message_index, result_data, usage)
281
+ return result.RunResult(
282
+ messages, new_message_index, result_data, result_tool_name, run_context.usage
283
+ )
273
284
  else:
274
285
  # continue the conversation
275
286
  handle_span.set_attribute('tool_responses', tool_responses)
@@ -284,7 +295,8 @@ class Agent(Generic[AgentDeps, ResultData]):
284
295
  model: models.Model | models.KnownModelName | None = None,
285
296
  deps: AgentDeps = None,
286
297
  model_settings: ModelSettings | None = None,
287
- usage_limits: UsageLimits | None = None,
298
+ usage_limits: _usage.UsageLimits | None = None,
299
+ usage: _usage.Usage | None = None,
288
300
  infer_name: bool = True,
289
301
  ) -> result.RunResult[ResultData]:
290
302
  """Run the agent with a user prompt synchronously.
@@ -311,6 +323,7 @@ class Agent(Generic[AgentDeps, ResultData]):
311
323
  deps: Optional dependencies to use for this run.
312
324
  model_settings: Optional settings to use for this model's request.
313
325
  usage_limits: Optional limits on model request count or token usage.
326
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
314
327
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
315
328
 
316
329
  Returns:
@@ -326,6 +339,7 @@ class Agent(Generic[AgentDeps, ResultData]):
326
339
  deps=deps,
327
340
  model_settings=model_settings,
328
341
  usage_limits=usage_limits,
342
+ usage=usage,
329
343
  infer_name=False,
330
344
  )
331
345
  )
@@ -339,7 +353,8 @@ class Agent(Generic[AgentDeps, ResultData]):
339
353
  model: models.Model | models.KnownModelName | None = None,
340
354
  deps: AgentDeps = None,
341
355
  model_settings: ModelSettings | None = None,
342
- usage_limits: UsageLimits | None = None,
356
+ usage_limits: _usage.UsageLimits | None = None,
357
+ usage: _usage.Usage | None = None,
343
358
  infer_name: bool = True,
344
359
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
345
360
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -363,6 +378,7 @@ class Agent(Generic[AgentDeps, ResultData]):
363
378
  deps: Optional dependencies to use for this run.
364
379
  model_settings: Optional settings to use for this model's request.
365
380
  usage_limits: Optional limits on model request count or token usage.
381
+ usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
366
382
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
367
383
 
368
384
  Returns:
@@ -372,7 +388,7 @@ class Agent(Generic[AgentDeps, ResultData]):
372
388
  # f_back because `asynccontextmanager` adds one frame
373
389
  if frame := inspect.currentframe(): # pragma: no branch
374
390
  self._infer_name(frame.f_back)
375
- model_used, mode_selection = await self._get_model(model)
391
+ model_used = await self._get_model(model)
376
392
 
377
393
  deps = self._get_deps(deps)
378
394
  new_message_index = len(message_history) if message_history else 0
@@ -381,32 +397,29 @@ class Agent(Generic[AgentDeps, ResultData]):
381
397
  '{agent_name} run stream {prompt=}',
382
398
  prompt=user_prompt,
383
399
  agent=self,
384
- mode_selection=mode_selection,
385
400
  model_name=model_used.name(),
386
401
  agent_name=self.name or 'agent',
387
402
  ) as run_span:
388
- run_context = RunContext(deps, 0, [], None, model_used)
403
+ run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
389
404
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
390
405
  run_context.messages = messages
391
406
 
392
407
  for tool in self._function_tools.values():
393
408
  tool.current_retry = 0
394
409
 
395
- usage = result.Usage()
396
410
  model_settings = merge_model_settings(self.model_settings, model_settings)
397
- usage_limits = usage_limits or UsageLimits()
411
+ usage_limits = usage_limits or _usage.UsageLimits()
398
412
 
399
- run_step = 0
400
413
  while True:
401
- run_step += 1
402
- usage_limits.check_before_request(usage)
414
+ run_context.run_step += 1
415
+ usage_limits.check_before_request(run_context.usage)
403
416
 
404
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
417
+ with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
405
418
  agent_model = await self._prepare_model(run_context)
406
419
 
407
- with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
420
+ with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
408
421
  async with agent_model.request_stream(messages, model_settings) as model_response:
409
- usage.requests += 1
422
+ run_context.usage.requests += 1
410
423
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
411
424
  # We want to end the "model request" span here, but we can't exit the context manager
412
425
  # in the traditional way
@@ -442,7 +455,6 @@ class Agent(Generic[AgentDeps, ResultData]):
442
455
  yield result.StreamedRunResult(
443
456
  messages,
444
457
  new_message_index,
445
- usage,
446
458
  usage_limits,
447
459
  result_stream,
448
460
  self._result_schema,
@@ -466,8 +478,8 @@ class Agent(Generic[AgentDeps, ResultData]):
466
478
  handle_span.message = f'handle model response -> {tool_responses_str}'
467
479
  # the model_response should have been fully streamed by now, we can add its usage
468
480
  model_response_usage = model_response.usage()
469
- usage += model_response_usage
470
- usage_limits.check_tokens(usage)
481
+ run_context.usage.incr(model_response_usage)
482
+ usage_limits.check_tokens(run_context.usage)
471
483
 
472
484
  @contextmanager
473
485
  def override(
@@ -778,14 +790,14 @@ class Agent(Generic[AgentDeps, ResultData]):
778
790
 
779
791
  self._function_tools[tool.name] = tool
780
792
 
781
- async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
793
+ async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
782
794
  """Create a model configured for this agent.
783
795
 
784
796
  Args:
785
797
  model: model to use for this run, required if `model` was not set when creating the agent.
786
798
 
787
799
  Returns:
788
- a tuple of `(model used, how the model was selected)`
800
+ The model used
789
801
  """
790
802
  model_: models.Model
791
803
  if some_model := self._override_model:
@@ -796,18 +808,15 @@ class Agent(Generic[AgentDeps, ResultData]):
796
808
  '(Even when `override(model=...)` is customizing the model that will actually be called)'
797
809
  )
798
810
  model_ = some_model.value
799
- mode_selection = 'override-model'
800
811
  elif model is not None:
801
812
  model_ = models.infer_model(model)
802
- mode_selection = 'custom'
803
813
  elif self.model is not None:
804
814
  # noinspection PyTypeChecker
805
815
  model_ = self.model = models.infer_model(self.model)
806
- mode_selection = 'from-agent'
807
816
  else:
808
817
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
809
818
 
810
- return model_, mode_selection
819
+ return model_
811
820
 
812
821
  async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
813
822
  """Build tools and create an agent model."""
@@ -830,15 +839,15 @@ class Agent(Generic[AgentDeps, ResultData]):
830
839
  self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
831
840
  ) -> list[_messages.ModelMessage]:
832
841
  try:
833
- messages = _messages_ctx_var.get()
842
+ ctx_messages = _messages_ctx_var.get()
834
843
  except LookupError:
835
- messages = []
844
+ messages: list[_messages.ModelMessage] = []
836
845
  else:
837
- if messages:
838
- raise exceptions.UserError(
839
- 'The capture_run_messages() context manager may only be used to wrap '
840
- 'one call to run(), run_sync(), or run_stream().'
841
- )
846
+ if ctx_messages.used:
847
+ messages = []
848
+ else:
849
+ messages = ctx_messages.messages
850
+ ctx_messages.used = True
842
851
 
843
852
  if message_history:
844
853
  # shallow copy messages
@@ -1132,7 +1141,13 @@ class Agent(Generic[AgentDeps, ResultData]):
1132
1141
  raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1133
1142
 
1134
1143
 
1135
- _messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var')
1144
+ @dataclasses.dataclass
1145
+ class _RunMessages:
1146
+ messages: list[_messages.ModelMessage]
1147
+ used: bool = False
1148
+
1149
+
1150
+ _messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
1136
1151
 
1137
1152
 
1138
1153
  @contextmanager
@@ -1156,21 +1171,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1156
1171
  ```
1157
1172
 
1158
1173
  !!! note
1159
- You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context.
1160
- If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
1174
+ If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
1175
+ `messages` will represent the messages exchanged during the first call only.
1161
1176
  """
1162
1177
  try:
1163
- yield _messages_ctx_var.get()
1178
+ yield _messages_ctx_var.get().messages
1164
1179
  except LookupError:
1165
1180
  messages: list[_messages.ModelMessage] = []
1166
- token = _messages_ctx_var.set(messages)
1181
+ token = _messages_ctx_var.set(_RunMessages(messages))
1167
1182
  try:
1168
1183
  yield messages
1169
1184
  finally:
1170
1185
  _messages_ctx_var.reset(token)
1171
1186
 
1172
1187
 
1173
- @dataclass
1188
+ @dataclasses.dataclass
1174
1189
  class _MarkFinalResult(Generic[ResultData]):
1175
1190
  """Marker class to indicate that the result is the final result.
1176
1191
 
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import Iterable, Iterator, Mapping
4
+ from dataclasses import asdict, dataclass, is_dataclass
5
+ from datetime import date
6
+ from typing import Any
7
+ from xml.etree import ElementTree
8
+
9
+ from pydantic import BaseModel
10
+
11
+ __all__ = ('format_as_xml',)
12
+
13
+
14
+ def format_as_xml(
15
+ obj: Any,
16
+ root_tag: str = 'examples',
17
+ item_tag: str = 'example',
18
+ include_root_tag: bool = True,
19
+ none_str: str = 'null',
20
+ indent: str | None = ' ',
21
+ ) -> str:
22
+ """Format a Python object as XML.
23
+
24
+ This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
25
+ rather than JSON etc.
26
+
27
+ Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
28
+ `Iterable`, `dataclass`, and `BaseModel`.
29
+
30
+ Args:
31
+ obj: Python Object to serialize to XML.
32
+ root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
33
+ item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
34
+ for dataclasses and Pydantic models.
35
+ include_root_tag: Whether to include the root tag in the output
36
+ (The root tag is always included if it includes a body - e.g. when the input is a simple value).
37
+ none_str: String to use for `None` values.
38
+ indent: Indentation string to use for pretty printing.
39
+
40
+ Returns: XML representation of the object.
41
+
42
+ Example:
43
+ ```python {title="format_as_xml_example.py" lint="skip"}
44
+ from pydantic_ai.format_as_xml import format_as_xml
45
+
46
+ print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
47
+ '''
48
+ <user>
49
+ <name>John</name>
50
+ <height>6</height>
51
+ <weight>200</weight>
52
+ </user>
53
+ '''
54
+ ```
55
+ """
56
+ el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
57
+ if not include_root_tag and el.text is None:
58
+ join = '' if indent is None else '\n'
59
+ return join.join(_rootless_xml_elements(el, indent))
60
+ else:
61
+ if indent is not None:
62
+ ElementTree.indent(el, space=indent)
63
+ return ElementTree.tostring(el, encoding='unicode')
64
+
65
+
66
+ @dataclass
67
+ class _ToXml:
68
+ item_tag: str
69
+ none_str: str
70
+
71
+ def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
72
+ element = ElementTree.Element(self.item_tag if tag is None else tag)
73
+ if value is None:
74
+ element.text = self.none_str
75
+ elif isinstance(value, str):
76
+ element.text = value
77
+ elif isinstance(value, (bytes, bytearray)):
78
+ element.text = value.decode(errors='ignore')
79
+ elif isinstance(value, (bool, int, float)):
80
+ element.text = str(value)
81
+ elif isinstance(value, date):
82
+ element.text = value.isoformat()
83
+ elif isinstance(value, Mapping):
84
+ self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
85
+ elif is_dataclass(value) and not isinstance(value, type):
86
+ if tag is None:
87
+ element = ElementTree.Element(value.__class__.__name__)
88
+ dc_dict = asdict(value)
89
+ self._mapping_to_xml(element, dc_dict)
90
+ elif isinstance(value, BaseModel):
91
+ if tag is None:
92
+ element = ElementTree.Element(value.__class__.__name__)
93
+ self._mapping_to_xml(element, value.model_dump(mode='python'))
94
+ elif isinstance(value, Iterable):
95
+ for item in value: # pyright: ignore[reportUnknownVariableType]
96
+ item_el = self.to_xml(item, None)
97
+ element.append(item_el)
98
+ else:
99
+ raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
100
+ return element
101
+
102
+ def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
103
+ for key, value in mapping.items():
104
+ if isinstance(key, int):
105
+ key = str(key)
106
+ elif not isinstance(key, str):
107
+ raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
108
+ element.append(self.to_xml(value, key))
109
+
110
+
111
+ def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
112
+ for sub_element in root:
113
+ if indent is not None:
114
+ ElementTree.indent(sub_element, space=indent)
115
+ yield ElementTree.tostring(sub_element, encoding='unicode')
@@ -273,17 +273,26 @@ class GeminiAgentModel(AgentModel):
273
273
  contents: list[_GeminiContent] = []
274
274
  for m in messages:
275
275
  if isinstance(m, ModelRequest):
276
+ message_parts: list[_GeminiPartUnion] = []
277
+
276
278
  for part in m.parts:
277
279
  if isinstance(part, SystemPromptPart):
278
280
  sys_prompt_parts.append(_GeminiTextPart(text=part.content))
279
281
  elif isinstance(part, UserPromptPart):
280
- contents.append(_content_user_prompt(part))
282
+ message_parts.append(_GeminiTextPart(text=part.content))
281
283
  elif isinstance(part, ToolReturnPart):
282
- contents.append(_content_tool_return(part))
284
+ message_parts.append(_response_part_from_response(part.tool_name, part.model_response_object()))
283
285
  elif isinstance(part, RetryPromptPart):
284
- contents.append(_content_retry_prompt(part))
286
+ if part.tool_name is None:
287
+ message_parts.append(_GeminiTextPart(text=part.model_response()))
288
+ else:
289
+ response = {'call_error': part.model_response()}
290
+ message_parts.append(_response_part_from_response(part.tool_name, response))
285
291
  else:
286
292
  assert_never(part)
293
+
294
+ if message_parts:
295
+ contents.append(_GeminiContent(role='user', parts=message_parts))
287
296
  elif isinstance(m, ModelResponse):
288
297
  contents.append(_content_model_response(m))
289
298
  else:
@@ -420,31 +429,14 @@ class _GeminiContent(TypedDict):
420
429
  parts: list[_GeminiPartUnion]
421
430
 
422
431
 
423
- def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
424
- return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
425
-
426
-
427
- def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
428
- f_response = _response_part_from_response(m.tool_name, m.model_response_object())
429
- return _GeminiContent(role='user', parts=[f_response])
430
-
431
-
432
- def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
433
- if m.tool_name is None:
434
- part = _GeminiTextPart(text=m.model_response())
435
- else:
436
- response = {'call_error': m.model_response()}
437
- part = _response_part_from_response(m.tool_name, response)
438
- return _GeminiContent(role='user', parts=[part])
439
-
440
-
441
432
  def _content_model_response(m: ModelResponse) -> _GeminiContent:
442
433
  parts: list[_GeminiPartUnion] = []
443
434
  for item in m.parts:
444
435
  if isinstance(item, ToolCallPart):
445
436
  parts.append(_function_call_part_from_call(item))
446
437
  elif isinstance(item, TextPart):
447
- parts.append(_GeminiTextPart(text=item.content))
438
+ if item.content:
439
+ parts.append(_GeminiTextPart(text=item.content))
448
440
  else:
449
441
  assert_never(item)
450
442
  return _GeminiContent(role='model', parts=parts)
@@ -701,7 +693,7 @@ class _GeminiJsonSchema:
701
693
 
702
694
  def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
703
695
  schema.pop('title', None)
704
- schema.pop('default', None)
696
+ default = schema.pop('default', _utils.UNSET)
705
697
  if ref := schema.pop('$ref', None):
706
698
  # noinspection PyTypeChecker
707
699
  key = re.sub(r'^#/\$defs/', '', ref)
@@ -714,8 +706,14 @@ class _GeminiJsonSchema:
714
706
  return
715
707
 
716
708
  if any_of := schema.get('anyOf'):
717
- for schema in any_of:
718
- self._simplify(schema, refs_stack)
709
+ for item_schema in any_of:
710
+ self._simplify(item_schema, refs_stack)
711
+ if len(any_of) == 2 and {'type': 'null'} in any_of and default is None:
712
+ for item_schema in any_of:
713
+ if item_schema != {'type': 'null'}:
714
+ schema.clear()
715
+ schema.update(item_schema)
716
+ return
719
717
 
720
718
  type_ = schema.get('type')
721
719
 
@@ -71,6 +71,7 @@ class OllamaModel(Model):
71
71
  model_name: OllamaModelName,
72
72
  *,
73
73
  base_url: str | None = 'http://localhost:11434/v1/',
74
+ api_key: str = 'ollama',
74
75
  openai_client: AsyncOpenAI | None = None,
75
76
  http_client: AsyncHTTPClient | None = None,
76
77
  ):
@@ -83,6 +84,8 @@ class OllamaModel(Model):
83
84
  model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
84
85
  You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
85
86
  base_url: The base url for the ollama requests. The default value is the ollama default
87
+ api_key: The API key to use for authentication. Defaults to 'ollama' for local instances,
88
+ but can be customized for proxy setups that require authentication
86
89
  openai_client: An existing
87
90
  [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
88
91
  client to use, if provided, `base_url` and `http_client` must be `None`.
@@ -96,7 +99,7 @@ class OllamaModel(Model):
96
99
  else:
97
100
  # API key is not required for ollama but a value is required to create the client
98
101
  http_client_ = http_client or cached_async_http_client()
99
- oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
102
+ oai_client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client_)
100
103
  self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
101
104
 
102
105
  async def agent_model(
@@ -16,6 +16,7 @@ from ..messages import (
16
16
  ModelMessage,
17
17
  ModelRequest,
18
18
  ModelResponse,
19
+ ModelResponsePart,
19
20
  RetryPromptPart,
20
21
  TextPart,
21
22
  ToolCallPart,
@@ -177,13 +178,23 @@ class TestAgentModel(AgentModel):
177
178
  # check if there are any retry prompts, if so retry them
178
179
  new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
179
180
  if new_retry_names:
180
- return ModelResponse(
181
- parts=[
182
- ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
183
- for name, args in self.tool_calls
184
- if name in new_retry_names
185
- ]
186
- )
181
+ # Handle retries for both function tools and result tools
182
+ # Check function tools first
183
+ retry_parts: list[ModelResponsePart] = [
184
+ ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
185
+ for name, args in self.tool_calls
186
+ if name in new_retry_names
187
+ ]
188
+ # Check result tools
189
+ if self.result_tools:
190
+ retry_parts.extend(
191
+ [
192
+ ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
193
+ for tool in self.result_tools
194
+ if tool.name in new_retry_names
195
+ ]
196
+ )
197
+ return ModelResponse(parts=retry_parts)
187
198
 
188
199
  if response_text := self.result.left:
189
200
  if response_text.value is None:
@@ -178,7 +178,7 @@ def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredenti
178
178
  # pyright: reportUnknownVariableType=false
179
179
  # pyright: reportUnknownArgumentType=false
180
180
  async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
181
- return await run_in_executor(google.auth.default)
181
+ return await run_in_executor(google.auth.default, scopes=['https://www.googleapis.com/auth/cloud-platform'])
182
182
 
183
183
 
184
184
  # default expiry is 3600 seconds
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import AsyncIterator, Awaitable, Callable
5
+ from copy import deepcopy
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime
7
8
  from typing import Generic, Union, cast
@@ -10,16 +11,10 @@ import logfire_api
10
11
  from typing_extensions import TypeVar
11
12
 
12
13
  from . import _result, _utils, exceptions, messages as _messages, models
13
- from .settings import UsageLimits
14
14
  from .tools import AgentDeps, RunContext
15
+ from .usage import Usage, UsageLimits
15
16
 
16
- __all__ = (
17
- 'ResultData',
18
- 'ResultValidatorFunc',
19
- 'Usage',
20
- 'RunResult',
21
- 'StreamedRunResult',
22
- )
17
+ __all__ = 'ResultData', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult'
23
18
 
24
19
 
25
20
  ResultData = TypeVar('ResultData', default=str)
@@ -43,47 +38,6 @@ Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
43
38
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
44
39
 
45
40
 
46
- @dataclass
47
- class Usage:
48
- """LLM usage associated with a request or run.
49
-
50
- Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
51
-
52
- You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
53
- """
54
-
55
- requests: int = 0
56
- """Number of requests made to the LLM API."""
57
- request_tokens: int | None = None
58
- """Tokens used in processing requests."""
59
- response_tokens: int | None = None
60
- """Tokens used in generating responses."""
61
- total_tokens: int | None = None
62
- """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
63
- details: dict[str, int] | None = None
64
- """Any extra details returned by the model."""
65
-
66
- def __add__(self, other: Usage) -> Usage:
67
- """Add two Usages together.
68
-
69
- This is provided so it's trivial to sum usage information from multiple requests and runs.
70
- """
71
- counts: dict[str, int] = {}
72
- for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
73
- self_value = getattr(self, f)
74
- other_value = getattr(other, f)
75
- if self_value is not None or other_value is not None:
76
- counts[f] = (self_value or 0) + (other_value or 0)
77
-
78
- details = self.details.copy() if self.details is not None else None
79
- if other.details is not None:
80
- details = details or {}
81
- for key, value in other.details.items():
82
- details[key] = details.get(key, 0) + value
83
-
84
- return Usage(**counts, details=details or None)
85
-
86
-
87
41
  @dataclass
88
42
  class _BaseRunResult(ABC, Generic[ResultData]):
89
43
  """Base type for results.
@@ -94,25 +48,70 @@ class _BaseRunResult(ABC, Generic[ResultData]):
94
48
  _all_messages: list[_messages.ModelMessage]
95
49
  _new_message_index: int
96
50
 
97
- def all_messages(self) -> list[_messages.ModelMessage]:
98
- """Return the history of _messages."""
51
+ def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
52
+ """Return the history of _messages.
53
+
54
+ Args:
55
+ result_tool_return_content: The return content of the tool call to set in the last message.
56
+ This provides a convenient way to modify the content of the result tool call if you want to continue
57
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
58
+ not be modified.
59
+
60
+ Returns:
61
+ List of messages.
62
+ """
99
63
  # this is a method to be consistent with the other methods
64
+ if result_tool_return_content is not None:
65
+ raise NotImplementedError('Setting result tool return content is not supported for this result type.')
100
66
  return self._all_messages
101
67
 
102
- def all_messages_json(self) -> bytes:
103
- """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
104
- return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
68
+ def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
69
+ """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.
70
+
71
+ Args:
72
+ result_tool_return_content: The return content of the tool call to set in the last message.
73
+ This provides a convenient way to modify the content of the result tool call if you want to continue
74
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
75
+ not be modified.
76
+
77
+ Returns:
78
+ JSON bytes representing the messages.
79
+ """
80
+ return _messages.ModelMessagesTypeAdapter.dump_json(
81
+ self.all_messages(result_tool_return_content=result_tool_return_content)
82
+ )
105
83
 
106
- def new_messages(self) -> list[_messages.ModelMessage]:
84
+ def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
107
85
  """Return new messages associated with this run.
108
86
 
109
- System prompts and any messages from older runs are excluded.
87
+ Messages from older runs are excluded.
88
+
89
+ Args:
90
+ result_tool_return_content: The return content of the tool call to set in the last message.
91
+ This provides a convenient way to modify the content of the result tool call if you want to continue
92
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
93
+ not be modified.
94
+
95
+ Returns:
96
+ List of new messages.
110
97
  """
111
- return self.all_messages()[self._new_message_index :]
98
+ return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :]
99
+
100
+ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes:
101
+ """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes.
112
102
 
113
- def new_messages_json(self) -> bytes:
114
- """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
115
- return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
103
+ Args:
104
+ result_tool_return_content: The return content of the tool call to set in the last message.
105
+ This provides a convenient way to modify the content of the result tool call if you want to continue
106
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
107
+ not be modified.
108
+
109
+ Returns:
110
+ JSON bytes representing the new messages.
111
+ """
112
+ return _messages.ModelMessagesTypeAdapter.dump_json(
113
+ self.new_messages(result_tool_return_content=result_tool_return_content)
114
+ )
116
115
 
117
116
  @abstractmethod
118
117
  def usage(self) -> Usage:
@@ -125,19 +124,50 @@ class RunResult(_BaseRunResult[ResultData]):
125
124
 
126
125
  data: ResultData
127
126
  """Data from the final response in the run."""
127
+ _result_tool_name: str | None
128
128
  _usage: Usage
129
129
 
130
130
  def usage(self) -> Usage:
131
131
  """Return the usage of the whole run."""
132
132
  return self._usage
133
133
 
134
+ def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
135
+ """Return the history of _messages.
136
+
137
+ Args:
138
+ result_tool_return_content: The return content of the tool call to set in the last message.
139
+ This provides a convenient way to modify the content of the result tool call if you want to continue
140
+ the conversation and want to set the response to the result tool call. If `None`, the last message will
141
+ not be modified.
142
+
143
+ Returns:
144
+ List of messages.
145
+ """
146
+ if result_tool_return_content is not None:
147
+ return self._set_result_tool_return(result_tool_return_content)
148
+ else:
149
+ return self._all_messages
150
+
151
+ def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]:
152
+ """Set return content for the result tool.
153
+
154
+ Useful if you want to continue the conversation and want to set the response to the result tool call.
155
+ """
156
+ if not self._result_tool_name:
157
+ raise ValueError('Cannot set result tool return content when the return type is `str`.')
158
+ messages = deepcopy(self._all_messages)
159
+ last_message = messages[-1]
160
+ for part in last_message.parts:
161
+ if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name:
162
+ part.content = return_content
163
+ return messages
164
+ raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.')
165
+
134
166
 
135
167
  @dataclass
136
168
  class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
137
169
  """Result of a streamed run that returns structured data via a tool call."""
138
170
 
139
- usage_so_far: Usage
140
- """Usage of the run up until the last request."""
141
171
  _usage_limits: UsageLimits | None
142
172
  _stream_response: models.EitherStreamedResponse
143
173
  _result_schema: _result.ResultSchema[ResultData] | None
@@ -306,7 +336,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
306
336
  !!! note
307
337
  This won't return the full usage until the stream is finished.
308
338
  """
309
- return self.usage_so_far + self._stream_response.usage()
339
+ return self._run_ctx.usage + self._stream_response.usage()
310
340
 
311
341
  def timestamp(self) -> datetime:
312
342
  """Get the timestamp of the response."""
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from httpx import Timeout
6
+ from typing_extensions import TypedDict
7
+
8
+ if TYPE_CHECKING:
9
+ pass
10
+
11
+
12
+ class ModelSettings(TypedDict, total=False):
13
+ """Settings to configure an LLM.
14
+
15
+ Here we include only settings which apply to multiple models / model providers.
16
+ """
17
+
18
+ max_tokens: int
19
+ """The maximum number of tokens to generate before stopping.
20
+
21
+ Supported by:
22
+
23
+ * Gemini
24
+ * Anthropic
25
+ * OpenAI
26
+ * Groq
27
+ """
28
+
29
+ temperature: float
30
+ """Amount of randomness injected into the response.
31
+
32
+ Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's
33
+ maximum `temperature` for creative and generative tasks.
34
+
35
+ Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
36
+
37
+ Supported by:
38
+
39
+ * Gemini
40
+ * Anthropic
41
+ * OpenAI
42
+ * Groq
43
+ """
44
+
45
+ top_p: float
46
+ """An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
47
+
48
+ So 0.1 means only the tokens comprising the top 10% probability mass are considered.
49
+
50
+ You should either alter `temperature` or `top_p`, but not both.
51
+
52
+ Supported by:
53
+
54
+ * Gemini
55
+ * Anthropic
56
+ * OpenAI
57
+ * Groq
58
+ """
59
+
60
+ timeout: float | Timeout
61
+ """Override the client-level default timeout for a request, in seconds.
62
+
63
+ Supported by:
64
+
65
+ * Gemini
66
+ * Anthropic
67
+ * OpenAI
68
+ * Groq
69
+ """
70
+
71
+
72
+ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
73
+ """Merge two sets of model settings, preferring the overrides.
74
+
75
+ A common use case is: merge_model_settings(<agent settings>, <run settings>)
76
+ """
77
+ # Note: we may want merge recursively if/when we add non-primitive values
78
+ if base and overrides:
79
+ return base | overrides
80
+ else:
81
+ return base or overrides
@@ -4,15 +4,18 @@ import dataclasses
4
4
  import inspect
5
5
  from collections.abc import Awaitable
6
6
  from dataclasses import dataclass, field
7
- from typing import Any, Callable, Generic, TypeVar, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
10
  from pydantic_core import SchemaValidator
11
- from typing_extensions import Concatenate, ParamSpec, TypeAlias
11
+ from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
12
12
 
13
13
  from . import _pydantic, _utils, messages as _messages, models
14
14
  from .exceptions import ModelRetry, UnexpectedModelBehavior
15
15
 
16
+ if TYPE_CHECKING:
17
+ from .result import Usage
18
+
16
19
  __all__ = (
17
20
  'AgentDeps',
18
21
  'RunContext',
@@ -27,7 +30,7 @@ __all__ = (
27
30
  'ToolDefinition',
28
31
  )
29
32
 
30
- AgentDeps = TypeVar('AgentDeps')
33
+ AgentDeps = TypeVar('AgentDeps', default=None)
31
34
  """Type variable for agent dependencies."""
32
35
 
33
36
 
@@ -37,14 +40,20 @@ class RunContext(Generic[AgentDeps]):
37
40
 
38
41
  deps: AgentDeps
39
42
  """Dependencies for the agent."""
40
- retry: int
41
- """Number of retries so far."""
42
- messages: list[_messages.ModelMessage]
43
- """Messages exchanged in the conversation so far."""
44
- tool_name: str | None
45
- """Name of the tool being called."""
46
43
  model: models.Model
47
44
  """The model used in this run."""
45
+ usage: Usage
46
+ """LLM usage associated with the run."""
47
+ prompt: str
48
+ """The original user prompt passed to the run."""
49
+ messages: list[_messages.ModelMessage] = field(default_factory=list)
50
+ """Messages exchanged in the conversation so far."""
51
+ tool_name: str | None = None
52
+ """Name of the tool being called."""
53
+ retry: int = 0
54
+ """Number of retries so far."""
55
+ run_step: int = 0
56
+ """The current step in the run."""
48
57
 
49
58
  def replace_with(
50
59
  self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
@@ -58,7 +67,7 @@ class RunContext(Generic[AgentDeps]):
58
67
  return dataclasses.replace(self, **kwargs)
59
68
 
60
69
 
61
- ToolParams = ParamSpec('ToolParams')
70
+ ToolParams = ParamSpec('ToolParams', default=...)
62
71
  """Retrieval function param spec."""
63
72
 
64
73
  SystemPromptFunc = Union[
@@ -83,7 +92,7 @@ ToolFuncPlain = Callable[ToolParams, Any]
83
92
  Usage `ToolPlainFunc[ToolParams]`.
84
93
  """
85
94
  ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
86
- """Either part_kind of tool function.
95
+ """Either kind of tool function.
87
96
 
88
97
  This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
89
98
  [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
@@ -125,7 +134,7 @@ A = TypeVar('A')
125
134
  class Tool(Generic[AgentDeps]):
126
135
  """A tool function for an agent."""
127
136
 
128
- function: ToolFuncEither[AgentDeps, ...]
137
+ function: ToolFuncEither[AgentDeps]
129
138
  takes_ctx: bool
130
139
  max_retries: int | None
131
140
  name: str
@@ -141,7 +150,7 @@ class Tool(Generic[AgentDeps]):
141
150
 
142
151
  def __init__(
143
152
  self,
144
- function: ToolFuncEither[AgentDeps, ...],
153
+ function: ToolFuncEither[AgentDeps],
145
154
  *,
146
155
  takes_ctx: bool | None = None,
147
156
  max_retries: int | None = None,
@@ -1,87 +1,60 @@
1
- from __future__ import annotations
1
+ from __future__ import annotations as _annotations
2
2
 
3
+ from copy import copy
3
4
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING
5
-
6
- from httpx import Timeout
7
- from typing_extensions import TypedDict
8
5
 
9
6
  from .exceptions import UsageLimitExceeded
10
7
 
11
- if TYPE_CHECKING:
12
- from .result import Usage
13
-
14
-
15
- class ModelSettings(TypedDict, total=False):
16
- """Settings to configure an LLM.
17
-
18
- Here we include only settings which apply to multiple models / model providers.
19
- """
20
-
21
- max_tokens: int
22
- """The maximum number of tokens to generate before stopping.
23
-
24
- Supported by:
25
-
26
- * Gemini
27
- * Anthropic
28
- * OpenAI
29
- * Groq
30
- """
31
-
32
- temperature: float
33
- """Amount of randomness injected into the response.
34
-
35
- Use `temperature` closer to `0.0` for analytical / multiple choice, and closer to a model's
36
- maximum `temperature` for creative and generative tasks.
37
-
38
- Note that even with `temperature` of `0.0`, the results will not be fully deterministic.
39
-
40
- Supported by:
8
+ __all__ = 'Usage', 'UsageLimits'
41
9
 
42
- * Gemini
43
- * Anthropic
44
- * OpenAI
45
- * Groq
46
- """
47
-
48
- top_p: float
49
- """An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
50
-
51
- So 0.1 means only the tokens comprising the top 10% probability mass are considered.
52
-
53
- You should either alter `temperature` or `top_p`, but not both.
54
-
55
- Supported by:
56
-
57
- * Gemini
58
- * Anthropic
59
- * OpenAI
60
- * Groq
61
- """
62
10
 
63
- timeout: float | Timeout
64
- """Override the client-level default timeout for a request, in seconds.
11
+ @dataclass
12
+ class Usage:
13
+ """LLM usage associated with a request or run.
65
14
 
66
- Supported by:
15
+ Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
67
16
 
68
- * Gemini
69
- * Anthropic
70
- * OpenAI
71
- * Groq
17
+ You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
72
18
  """
73
19
 
74
-
75
- def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
76
- """Merge two sets of model settings, preferring the overrides.
77
-
78
- A common use case is: merge_model_settings(<agent settings>, <run settings>)
79
- """
80
- # Note: we may want merge recursively if/when we add non-primitive values
81
- if base and overrides:
82
- return base | overrides
83
- else:
84
- return base or overrides
20
+ requests: int = 0
21
+ """Number of requests made to the LLM API."""
22
+ request_tokens: int | None = None
23
+ """Tokens used in processing requests."""
24
+ response_tokens: int | None = None
25
+ """Tokens used in generating responses."""
26
+ total_tokens: int | None = None
27
+ """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
28
+ details: dict[str, int] | None = None
29
+ """Any extra details returned by the model."""
30
+
31
+ def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
32
+ """Increment the usage in place.
33
+
34
+ Args:
35
+ incr_usage: The usage to increment by.
36
+ requests: The number of requests to increment by in addition to `incr_usage.requests`.
37
+ """
38
+ self.requests += requests
39
+ for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
40
+ self_value = getattr(self, f)
41
+ other_value = getattr(incr_usage, f)
42
+ if self_value is not None or other_value is not None:
43
+ setattr(self, f, (self_value or 0) + (other_value or 0))
44
+
45
+ if incr_usage.details:
46
+ self.details = self.details or {}
47
+ for key, value in incr_usage.details.items():
48
+ self.details[key] = self.details.get(key, 0) + value
49
+
50
+ def __add__(self, other: Usage) -> Usage:
51
+ """Add two Usages together.
52
+
53
+ This is provided so it's trivial to sum usage information from multiple requests and runs.
54
+ """
55
+ new_usage = copy(self)
56
+ new_usage.incr(other)
57
+ return new_usage
85
58
 
86
59
 
87
60
  @dataclass
@@ -136,6 +109,6 @@ class UsageLimits:
136
109
  f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
137
110
  )
138
111
 
139
- total_tokens = request_tokens + response_tokens
112
+ total_tokens = usage.total_tokens or 0
140
113
  if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
141
114
  raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai-slim"
7
- version = "0.0.15"
7
+ version = "0.0.17"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
9
9
  authors = [
10
10
  { name = "Samuel Colvin", email = "samuel@pydantic.dev" },