pydantic-ai-slim 0.1.11__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (53) hide show
  1. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/PKG-INFO +3 -3
  2. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_agent_graph.py +6 -8
  3. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_parts_manager.py +3 -1
  4. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/agent.py +21 -0
  5. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/messages.py +7 -0
  6. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/__init__.py +6 -7
  7. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/_json_schema.py +8 -2
  8. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/anthropic.py +23 -26
  9. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/bedrock.py +36 -12
  10. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/cohere.py +5 -3
  11. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/fallback.py +3 -4
  12. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/function.py +9 -4
  13. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/gemini.py +13 -5
  14. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/groq.py +5 -3
  15. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/instrumented.py +8 -9
  16. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/mistral.py +5 -3
  17. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/openai.py +9 -6
  18. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/test.py +4 -3
  19. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/models/wrapper.py +1 -2
  20. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/usage.py +5 -3
  21. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/.gitignore +0 -0
  22. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/README.md +0 -0
  23. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/__init__.py +0 -0
  24. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/__main__.py +0 -0
  25. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_cli.py +0 -0
  26. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_griffe.py +0 -0
  27. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_output.py +0 -0
  28. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_pydantic.py +0 -0
  29. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_system_prompt.py +0 -0
  30. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/_utils.py +0 -0
  31. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/common_tools/__init__.py +0 -0
  32. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  33. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/common_tools/tavily.py +0 -0
  34. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/exceptions.py +0 -0
  35. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/format_as_xml.py +0 -0
  36. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/format_prompt.py +0 -0
  37. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/mcp.py +0 -0
  38. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/__init__.py +0 -0
  39. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/anthropic.py +0 -0
  40. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/azure.py +0 -0
  41. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/bedrock.py +0 -0
  42. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/cohere.py +0 -0
  43. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/deepseek.py +0 -0
  44. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/google_gla.py +0 -0
  45. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/google_vertex.py +0 -0
  46. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/groq.py +0 -0
  47. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/mistral.py +0 -0
  48. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/providers/openai.py +0 -0
  49. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/py.typed +0 -0
  50. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/result.py +0 -0
  51. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/settings.py +0 -0
  52. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pydantic_ai/tools.py +0 -0
  53. {pydantic_ai_slim-0.1.11 → pydantic_ai_slim-0.2.0}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.1.11
3
+ Version: 0.2.0
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.1.11
32
+ Requires-Dist: pydantic-graph==0.2.0
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
47
  Provides-Extra: evals
48
- Requires-Dist: pydantic-evals==0.1.11; extra == 'evals'
48
+ Requires-Dist: pydantic-evals==0.2.0; extra == 'evals'
49
49
  Provides-Extra: groq
50
50
  Requires-Dist: groq>=0.15.0; extra == 'groq'
51
51
  Provides-Extra: logfire
@@ -301,16 +301,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
301
301
  ctx.state.message_history, model_settings, model_request_parameters
302
302
  ) as streamed_response:
303
303
  self._did_stream = True
304
- ctx.state.usage.incr(_usage.Usage(), requests=1)
304
+ ctx.state.usage.requests += 1
305
305
  yield streamed_response
306
306
  # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
307
307
  # otherwise usage won't be properly counted:
308
308
  async for _ in streamed_response:
309
309
  pass
310
310
  model_response = streamed_response.get()
311
- request_usage = streamed_response.usage()
312
311
 
313
- self._finish_handling(ctx, model_response, request_usage)
312
+ self._finish_handling(ctx, model_response)
314
313
  assert self._result is not None # this should be set by the previous line
315
314
 
316
315
  async def _make_request(
@@ -321,12 +320,12 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
321
320
 
322
321
  model_settings, model_request_parameters = await self._prepare_request(ctx)
323
322
  model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
324
- model_response, request_usage = await ctx.deps.model.request(
323
+ model_response = await ctx.deps.model.request(
325
324
  ctx.state.message_history, model_settings, model_request_parameters
326
325
  )
327
- ctx.state.usage.incr(_usage.Usage(), requests=1)
326
+ ctx.state.usage.incr(_usage.Usage())
328
327
 
329
- return self._finish_handling(ctx, model_response, request_usage)
328
+ return self._finish_handling(ctx, model_response)
330
329
 
331
330
  async def _prepare_request(
332
331
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -348,10 +347,9 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
348
347
  self,
349
348
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
350
349
  response: _messages.ModelResponse,
351
- usage: _usage.Usage,
352
350
  ) -> CallToolsNode[DepsT, NodeRunEndT]:
353
351
  # Update usage
354
- ctx.state.usage.incr(usage, requests=0)
352
+ ctx.state.usage.incr(response.usage)
355
353
  if ctx.deps.usage_limits:
356
354
  ctx.deps.usage_limits.check_tokens(ctx.state.usage)
357
355
 
@@ -14,7 +14,7 @@ event-emitting logic.
14
14
  from __future__ import annotations as _annotations
15
15
 
16
16
  from collections.abc import Hashable
17
- from dataclasses import dataclass, field
17
+ from dataclasses import dataclass, field, replace
18
18
  from typing import Any, Union
19
19
 
20
20
  from pydantic_ai.exceptions import UnexpectedModelBehavior
@@ -198,6 +198,8 @@ class ModelResponsePartsManager:
198
198
  return PartStartEvent(index=part_index, part=updated_part)
199
199
  else:
200
200
  # We updated an existing part, so emit a PartDeltaEvent
201
+ if updated_part.tool_call_id and not delta.tool_call_id:
202
+ delta = replace(delta, tool_call_id=updated_part.tool_call_id)
201
203
  return PartDeltaEvent(index=part_index, delta=delta)
202
204
 
203
205
  def handle_tool_call_part(
@@ -551,6 +551,13 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
551
551
  CallToolsNode(
552
552
  model_response=ModelResponse(
553
553
  parts=[TextPart(content='Paris', part_kind='text')],
554
+ usage=Usage(
555
+ requests=1,
556
+ request_tokens=56,
557
+ response_tokens=1,
558
+ total_tokens=57,
559
+ details=None,
560
+ ),
554
561
  model_name='gpt-4o',
555
562
  timestamp=datetime.datetime(...),
556
563
  kind='response',
@@ -1715,6 +1722,13 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
1715
1722
  CallToolsNode(
1716
1723
  model_response=ModelResponse(
1717
1724
  parts=[TextPart(content='Paris', part_kind='text')],
1725
+ usage=Usage(
1726
+ requests=1,
1727
+ request_tokens=56,
1728
+ response_tokens=1,
1729
+ total_tokens=57,
1730
+ details=None,
1731
+ ),
1718
1732
  model_name='gpt-4o',
1719
1733
  timestamp=datetime.datetime(...),
1720
1734
  kind='response',
@@ -1853,6 +1867,13 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
1853
1867
  CallToolsNode(
1854
1868
  model_response=ModelResponse(
1855
1869
  parts=[TextPart(content='Paris', part_kind='text')],
1870
+ usage=Usage(
1871
+ requests=1,
1872
+ request_tokens=56,
1873
+ response_tokens=1,
1874
+ total_tokens=57,
1875
+ details=None,
1876
+ ),
1856
1877
  model_name='gpt-4o',
1857
1878
  timestamp=datetime.datetime(...),
1858
1879
  kind='response',
@@ -14,6 +14,7 @@ from typing_extensions import TypeAlias
14
14
 
15
15
  from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
16
16
  from .exceptions import UnexpectedModelBehavior
17
+ from .usage import Usage
17
18
 
18
19
  AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
19
20
  ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
@@ -554,6 +555,12 @@ class ModelResponse:
554
555
  parts: list[ModelResponsePart]
555
556
  """The parts of the model message."""
556
557
 
558
+ usage: Usage = field(default_factory=Usage)
559
+ """Usage information for the request.
560
+
561
+ This has a default to make tests easier, and to support loading old messages where usage will be missing.
562
+ """
563
+
557
564
  model_name: str | None = None
558
565
  """The name of the model that generated the response."""
559
566
 
@@ -12,7 +12,6 @@ from contextlib import asynccontextmanager, contextmanager
12
12
  from dataclasses import dataclass, field
13
13
  from datetime import datetime
14
14
  from functools import cache
15
- from typing import TYPE_CHECKING
16
15
 
17
16
  import httpx
18
17
  from typing_extensions import Literal, TypeAliasType
@@ -21,12 +20,9 @@ from .._parts_manager import ModelResponsePartsManager
21
20
  from ..exceptions import UserError
22
21
  from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
23
22
  from ..settings import ModelSettings
23
+ from ..tools import ToolDefinition
24
24
  from ..usage import Usage
25
25
 
26
- if TYPE_CHECKING:
27
- from ..tools import ToolDefinition
28
-
29
-
30
26
  KnownModelName = TypeAliasType(
31
27
  'KnownModelName',
32
28
  Literal[
@@ -278,7 +274,7 @@ class Model(ABC):
278
274
  messages: list[ModelMessage],
279
275
  model_settings: ModelSettings | None,
280
276
  model_request_parameters: ModelRequestParameters,
281
- ) -> tuple[ModelResponse, Usage]:
277
+ ) -> ModelResponse:
282
278
  """Make a request to the model."""
283
279
  raise NotImplementedError()
284
280
 
@@ -365,7 +361,10 @@ class StreamedResponse(ABC):
365
361
  def get(self) -> ModelResponse:
366
362
  """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
367
363
  return ModelResponse(
368
- parts=self._parts_manager.get_parts(), model_name=self.model_name, timestamp=self.timestamp
364
+ parts=self._parts_manager.get_parts(),
365
+ model_name=self.model_name,
366
+ timestamp=self.timestamp,
367
+ usage=self.usage(),
369
368
  )
370
369
 
371
370
  def usage(self) -> Usage:
@@ -25,7 +25,7 @@ class WalkJsonSchema(ABC):
25
25
  self.simplify_nullable_unions = simplify_nullable_unions
26
26
 
27
27
  self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
28
- self.refs_stack = tuple[str, ...]()
28
+ self.refs_stack: list[str] = []
29
29
  self.recursive_refs = set[str]()
30
30
 
31
31
  @abstractmethod
@@ -62,13 +62,16 @@ class WalkJsonSchema(ABC):
62
62
  return handled
63
63
 
64
64
  def _handle(self, schema: JsonSchema) -> JsonSchema:
65
+ nested_refs = 0
65
66
  if self.prefer_inlined_defs:
66
67
  while ref := schema.get('$ref'):
67
68
  key = re.sub(r'^#/\$defs/', '', ref)
68
69
  if key in self.refs_stack:
69
70
  self.recursive_refs.add(key)
70
71
  break # recursive ref can't be unpacked
71
- self.refs_stack += (key,)
72
+ self.refs_stack.append(key)
73
+ nested_refs += 1
74
+
72
75
  def_schema = self.defs.get(key)
73
76
  if def_schema is None: # pragma: no cover
74
77
  raise UserError(f'Could not find $ref definition for {key}')
@@ -87,6 +90,9 @@ class WalkJsonSchema(ABC):
87
90
  # Apply the base transform
88
91
  schema = self.transform(schema)
89
92
 
93
+ if nested_refs > 0:
94
+ self.refs_stack = self.refs_stack[:-nested_refs]
95
+
90
96
  return schema
91
97
 
92
98
  def _handle_object(self, schema: JsonSchema) -> JsonSchema:
@@ -145,12 +145,14 @@ class AnthropicModel(Model):
145
145
  messages: list[ModelMessage],
146
146
  model_settings: ModelSettings | None,
147
147
  model_request_parameters: ModelRequestParameters,
148
- ) -> tuple[ModelResponse, usage.Usage]:
148
+ ) -> ModelResponse:
149
149
  check_allow_model_requests()
150
150
  response = await self._messages_create(
151
151
  messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
152
152
  )
153
- return self._process_response(response), _map_usage(response)
153
+ model_response = self._process_response(response)
154
+ model_response.usage.requests = 1
155
+ return model_response
154
156
 
155
157
  @asynccontextmanager
156
158
  async def request_stream(
@@ -260,7 +262,7 @@ class AnthropicModel(Model):
260
262
  )
261
263
  )
262
264
 
263
- return ModelResponse(items, model_name=response.model)
265
+ return ModelResponse(items, usage=_map_usage(response), model_name=response.model)
264
266
 
265
267
  async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
266
268
  peekable_response = _utils.PeekableAsyncStream(response)
@@ -391,36 +393,31 @@ class AnthropicModel(Model):
391
393
  def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
392
394
  if isinstance(message, AnthropicMessage):
393
395
  response_usage = message.usage
396
+ elif isinstance(message, RawMessageStartEvent):
397
+ response_usage = message.message.usage
398
+ elif isinstance(message, RawMessageDeltaEvent):
399
+ response_usage = message.usage
394
400
  else:
395
- if isinstance(message, RawMessageStartEvent):
396
- response_usage = message.message.usage
397
- elif isinstance(message, RawMessageDeltaEvent):
398
- response_usage = message.usage
399
- else:
400
- # No usage information provided in:
401
- # - RawMessageStopEvent
402
- # - RawContentBlockStartEvent
403
- # - RawContentBlockDeltaEvent
404
- # - RawContentBlockStopEvent
405
- response_usage = None
406
-
407
- if response_usage is None:
401
+ # No usage information provided in:
402
+ # - RawMessageStopEvent
403
+ # - RawContentBlockStartEvent
404
+ # - RawContentBlockDeltaEvent
405
+ # - RawContentBlockStopEvent
408
406
  return usage.Usage()
409
407
 
410
- # Store all integer-typed usage values in the details dict
411
- response_usage_dict = response_usage.model_dump()
412
- details: dict[str, int] = {}
413
- for key, value in response_usage_dict.items():
414
- if isinstance(value, int):
415
- details[key] = value
408
+ # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
409
+ # `response_tokens`
410
+ details: dict[str, int] = {
411
+ key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
412
+ }
416
413
 
417
- # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence the getattr call
414
+ # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
418
415
  # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
419
416
  # This approach maintains request_tokens as the count of all input tokens, with cached counts as details
420
417
  request_tokens = (
421
- getattr(response_usage, 'input_tokens', 0)
422
- + (getattr(response_usage, 'cache_creation_input_tokens', 0) or 0) # These can be missing, None, or int
423
- + (getattr(response_usage, 'cache_read_input_tokens', 0) or 0)
418
+ details.get('input_tokens', 0)
419
+ + details.get('cache_creation_input_tokens', 0)
420
+ + details.get('cache_read_input_tokens', 0)
424
421
  )
425
422
 
426
423
  return usage.Usage(
@@ -232,10 +232,12 @@ class BedrockConverseModel(Model):
232
232
  messages: list[ModelMessage],
233
233
  model_settings: ModelSettings | None,
234
234
  model_request_parameters: ModelRequestParameters,
235
- ) -> tuple[ModelResponse, usage.Usage]:
235
+ ) -> ModelResponse:
236
236
  settings = cast(BedrockModelSettings, model_settings or {})
237
237
  response = await self._messages_create(messages, False, settings, model_request_parameters)
238
- return await self._process_response(response)
238
+ model_response = await self._process_response(response)
239
+ model_response.usage.requests = 1
240
+ return model_response
239
241
 
240
242
  @asynccontextmanager
241
243
  async def request_stream(
@@ -248,7 +250,7 @@ class BedrockConverseModel(Model):
248
250
  response = await self._messages_create(messages, True, settings, model_request_parameters)
249
251
  yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response)
250
252
 
251
- async def _process_response(self, response: ConverseResponseTypeDef) -> tuple[ModelResponse, usage.Usage]:
253
+ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
252
254
  items: list[ModelResponsePart] = []
253
255
  if message := response['output'].get('message'):
254
256
  for item in message['content']:
@@ -269,7 +271,7 @@ class BedrockConverseModel(Model):
269
271
  response_tokens=response['usage']['outputTokens'],
270
272
  total_tokens=response['usage']['totalTokens'],
271
273
  )
272
- return ModelResponse(items, model_name=self.model_name), u
274
+ return ModelResponse(items, usage=u, model_name=self.model_name)
273
275
 
274
276
  @overload
275
277
  async def _messages_create(
@@ -367,13 +369,16 @@ class BedrockConverseModel(Model):
367
369
  async def _map_messages(
368
370
  self, messages: list[ModelMessage]
369
371
  ) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
370
- """Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
372
+ """Maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`.
373
+
374
+ Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models.
375
+ """
371
376
  system_prompt: list[SystemContentBlockTypeDef] = []
372
377
  bedrock_messages: list[MessageUnionTypeDef] = []
373
378
  document_count: Iterator[int] = count(1)
374
- for m in messages:
375
- if isinstance(m, ModelRequest):
376
- for part in m.parts:
379
+ for message in messages:
380
+ if isinstance(message, ModelRequest):
381
+ for part in message.parts:
377
382
  if isinstance(part, SystemPromptPart):
378
383
  system_prompt.append({'text': part.content})
379
384
  elif isinstance(part, UserPromptPart):
@@ -414,9 +419,9 @@ class BedrockConverseModel(Model):
414
419
  ],
415
420
  }
416
421
  )
417
- elif isinstance(m, ModelResponse):
422
+ elif isinstance(message, ModelResponse):
418
423
  content: list[ContentBlockOutputTypeDef] = []
419
- for item in m.parts:
424
+ for item in message.parts:
420
425
  if isinstance(item, TextPart):
421
426
  content.append({'text': item.content})
422
427
  else:
@@ -424,12 +429,31 @@ class BedrockConverseModel(Model):
424
429
  content.append(self._map_tool_call(item))
425
430
  bedrock_messages.append({'role': 'assistant', 'content': content})
426
431
  else:
427
- assert_never(m)
432
+ assert_never(message)
433
+
434
+ # Merge together sequential user messages.
435
+ processed_messages: list[MessageUnionTypeDef] = []
436
+ last_message: dict[str, Any] | None = None
437
+ for current_message in bedrock_messages:
438
+ if (
439
+ last_message is not None
440
+ and current_message['role'] == last_message['role']
441
+ and current_message['role'] == 'user'
442
+ ):
443
+ # Add the new user content onto the existing user message.
444
+ last_content = list(last_message['content'])
445
+ last_content.extend(current_message['content'])
446
+ last_message['content'] = last_content
447
+ continue
448
+
449
+ # Add the entire message to the list of messages.
450
+ processed_messages.append(current_message)
451
+ last_message = cast(dict[str, Any], current_message)
428
452
 
429
453
  if instructions := self._get_instructions(messages):
430
454
  system_prompt.insert(0, {'text': instructions})
431
455
 
432
- return system_prompt, bedrock_messages
456
+ return system_prompt, processed_messages
433
457
 
434
458
  @staticmethod
435
459
  async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]:
@@ -133,10 +133,12 @@ class CohereModel(Model):
133
133
  messages: list[ModelMessage],
134
134
  model_settings: ModelSettings | None,
135
135
  model_request_parameters: ModelRequestParameters,
136
- ) -> tuple[ModelResponse, usage.Usage]:
136
+ ) -> ModelResponse:
137
137
  check_allow_model_requests()
138
138
  response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
139
- return self._process_response(response), _map_usage(response)
139
+ model_response = self._process_response(response)
140
+ model_response.usage.requests = 1
141
+ return model_response
140
142
 
141
143
  @property
142
144
  def model_name(self) -> CohereModelName:
@@ -191,7 +193,7 @@ class CohereModel(Model):
191
193
  tool_call_id=c.id or _generate_tool_call_id(),
192
194
  )
193
195
  )
194
- return ModelResponse(parts=parts, model_name=self._model_name)
196
+ return ModelResponse(parts=parts, usage=_map_usage(response), model_name=self._model_name)
195
197
 
196
198
  def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
197
199
  """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
@@ -15,7 +15,6 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, i
15
15
  if TYPE_CHECKING:
16
16
  from ..messages import ModelMessage, ModelResponse
17
17
  from ..settings import ModelSettings
18
- from ..usage import Usage
19
18
 
20
19
 
21
20
  @dataclass(init=False)
@@ -55,7 +54,7 @@ class FallbackModel(Model):
55
54
  messages: list[ModelMessage],
56
55
  model_settings: ModelSettings | None,
57
56
  model_request_parameters: ModelRequestParameters,
58
- ) -> tuple[ModelResponse, Usage]:
57
+ ) -> ModelResponse:
59
58
  """Try each model in sequence until one succeeds.
60
59
 
61
60
  In case of failure, raise a FallbackExceptionGroup with all exceptions.
@@ -65,7 +64,7 @@ class FallbackModel(Model):
65
64
  for model in self.models:
66
65
  customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
67
66
  try:
68
- response, usage = await model.request(messages, model_settings, customized_model_request_parameters)
67
+ response = await model.request(messages, model_settings, customized_model_request_parameters)
69
68
  except Exception as exc:
70
69
  if self._fallback_on(exc):
71
70
  exceptions.append(exc)
@@ -73,7 +72,7 @@ class FallbackModel(Model):
73
72
  raise exc
74
73
 
75
74
  self._set_span_attributes(model)
76
- return response, usage
75
+ return response
77
76
 
78
77
  raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
79
78
 
@@ -88,7 +88,7 @@ class FunctionModel(Model):
88
88
  messages: list[ModelMessage],
89
89
  model_settings: ModelSettings | None,
90
90
  model_request_parameters: ModelRequestParameters,
91
- ) -> tuple[ModelResponse, usage.Usage]:
91
+ ) -> ModelResponse:
92
92
  agent_info = AgentInfo(
93
93
  model_request_parameters.function_tools,
94
94
  model_request_parameters.allow_text_output,
@@ -105,8 +105,11 @@ class FunctionModel(Model):
105
105
  assert isinstance(response_, ModelResponse), response_
106
106
  response = response_
107
107
  response.model_name = self._model_name
108
- # TODO is `messages` right here? Should it just be new messages?
109
- return response, _estimate_usage(chain(messages, [response]))
108
+ # Add usage data if not already present
109
+ if not response.usage.has_values():
110
+ response.usage = _estimate_usage(chain(messages, [response]))
111
+ response.usage.requests = 1
112
+ return response
110
113
 
111
114
  @asynccontextmanager
112
115
  async def request_stream(
@@ -273,7 +276,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
273
276
  else:
274
277
  assert_never(message)
275
278
  return usage.Usage(
276
- request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
279
+ request_tokens=request_tokens,
280
+ response_tokens=response_tokens,
281
+ total_tokens=request_tokens + response_tokens,
277
282
  )
278
283
 
279
284
 
@@ -145,14 +145,14 @@ class GeminiModel(Model):
145
145
  messages: list[ModelMessage],
146
146
  model_settings: ModelSettings | None,
147
147
  model_request_parameters: ModelRequestParameters,
148
- ) -> tuple[ModelResponse, usage.Usage]:
148
+ ) -> ModelResponse:
149
149
  check_allow_model_requests()
150
150
  async with self._make_request(
151
151
  messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
152
152
  ) as http_response:
153
153
  data = await http_response.aread()
154
154
  response = _gemini_response_ta.validate_json(data)
155
- return self._process_response(response), _metadata_as_usage(response)
155
+ return self._process_response(response)
156
156
 
157
157
  @asynccontextmanager
158
158
  async def request_stream(
@@ -269,7 +269,9 @@ class GeminiModel(Model):
269
269
  else:
270
270
  raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
271
271
  parts = response['candidates'][0]['content']['parts']
272
- return _process_response_from_parts(parts, model_name=response.get('model_version', self._model_name))
272
+ usage = _metadata_as_usage(response)
273
+ usage.requests = 1
274
+ return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
273
275
 
274
276
  async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
275
277
  """Process a streamed response, and prepare a streaming response to return."""
@@ -591,7 +593,7 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
591
593
 
592
594
 
593
595
  def _process_response_from_parts(
594
- parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, timestamp: datetime | None = None
596
+ parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
595
597
  ) -> ModelResponse:
596
598
  items: list[ModelResponsePart] = []
597
599
  for part in parts:
@@ -603,7 +605,7 @@ def _process_response_from_parts(
603
605
  raise UnexpectedModelBehavior(
604
606
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
605
607
  )
606
- return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
608
+ return ModelResponse(parts=items, usage=usage, model_name=model_name)
607
609
 
608
610
 
609
611
  class _GeminiFunctionCall(TypedDict):
@@ -831,6 +833,12 @@ class _GeminiJsonSchema(WalkJsonSchema):
831
833
  schema.pop('exclusiveMaximum', None)
832
834
  schema.pop('exclusiveMinimum', None)
833
835
 
836
+ # Gemini only supports string enums, so we need to convert any enum values to strings.
837
+ # Pydantic will take care of transforming the transformed string values to the correct type.
838
+ if enum := schema.get('enum'):
839
+ schema['type'] = 'string'
840
+ schema['enum'] = [str(val) for val in enum]
841
+
834
842
  type_ = schema.get('type')
835
843
  if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
836
844
  # This gets hit when we have a discriminated union
@@ -130,12 +130,14 @@ class GroqModel(Model):
130
130
  messages: list[ModelMessage],
131
131
  model_settings: ModelSettings | None,
132
132
  model_request_parameters: ModelRequestParameters,
133
- ) -> tuple[ModelResponse, usage.Usage]:
133
+ ) -> ModelResponse:
134
134
  check_allow_model_requests()
135
135
  response = await self._completions_create(
136
136
  messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
137
137
  )
138
- return self._process_response(response), _map_usage(response)
138
+ model_response = self._process_response(response)
139
+ model_response.usage.requests = 1
140
+ return model_response
139
141
 
140
142
  @asynccontextmanager
141
143
  async def request_stream(
@@ -237,7 +239,7 @@ class GroqModel(Model):
237
239
  if choice.message.tool_calls is not None:
238
240
  for c in choice.message.tool_calls:
239
241
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
240
- return ModelResponse(items, model_name=response.model, timestamp=timestamp)
242
+ return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
241
243
 
242
244
  async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
243
245
  """Process a streamed response, and prepare a streaming response to return."""
@@ -23,7 +23,6 @@ from ..messages import (
23
23
  ModelResponse,
24
24
  )
25
25
  from ..settings import ModelSettings
26
- from ..usage import Usage
27
26
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
28
27
  from .wrapper import WrapperModel
29
28
 
@@ -122,11 +121,11 @@ class InstrumentedModel(WrapperModel):
122
121
  messages: list[ModelMessage],
123
122
  model_settings: ModelSettings | None,
124
123
  model_request_parameters: ModelRequestParameters,
125
- ) -> tuple[ModelResponse, Usage]:
124
+ ) -> ModelResponse:
126
125
  with self._instrument(messages, model_settings, model_request_parameters) as finish:
127
- response, usage = await super().request(messages, model_settings, model_request_parameters)
128
- finish(response, usage)
129
- return response, usage
126
+ response = await super().request(messages, model_settings, model_request_parameters)
127
+ finish(response)
128
+ return response
130
129
 
131
130
  @asynccontextmanager
132
131
  async def request_stream(
@@ -144,7 +143,7 @@ class InstrumentedModel(WrapperModel):
144
143
  yield response_stream
145
144
  finally:
146
145
  if response_stream:
147
- finish(response_stream.get(), response_stream.usage())
146
+ finish(response_stream.get())
148
147
 
149
148
  @contextmanager
150
149
  def _instrument(
@@ -152,7 +151,7 @@ class InstrumentedModel(WrapperModel):
152
151
  messages: list[ModelMessage],
153
152
  model_settings: ModelSettings | None,
154
153
  model_request_parameters: ModelRequestParameters,
155
- ) -> Iterator[Callable[[ModelResponse, Usage], None]]:
154
+ ) -> Iterator[Callable[[ModelResponse], None]]:
156
155
  operation = 'chat'
157
156
  span_name = f'{operation} {self.model_name}'
158
157
  # TODO Missing attributes:
@@ -177,7 +176,7 @@ class InstrumentedModel(WrapperModel):
177
176
 
178
177
  with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
179
178
 
180
- def finish(response: ModelResponse, usage: Usage):
179
+ def finish(response: ModelResponse):
181
180
  if not span.is_recording():
182
181
  return
183
182
 
@@ -193,7 +192,7 @@ class InstrumentedModel(WrapperModel):
193
192
  },
194
193
  )
195
194
  )
196
- new_attributes: dict[str, AttributeValue] = usage.opentelemetry_attributes() # type: ignore
195
+ new_attributes: dict[str, AttributeValue] = response.usage.opentelemetry_attributes() # pyright: ignore[reportAssignmentType]
197
196
  attributes.update(getattr(span, 'attributes', {}))
198
197
  request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
199
198
  new_attributes['gen_ai.response.model'] = response.model_name or request_model
@@ -147,13 +147,15 @@ class MistralModel(Model):
147
147
  messages: list[ModelMessage],
148
148
  model_settings: ModelSettings | None,
149
149
  model_request_parameters: ModelRequestParameters,
150
- ) -> tuple[ModelResponse, Usage]:
150
+ ) -> ModelResponse:
151
151
  """Make a non-streaming request to the model from Pydantic AI call."""
152
152
  check_allow_model_requests()
153
153
  response = await self._completions_create(
154
154
  messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
155
155
  )
156
- return self._process_response(response), _map_usage(response)
156
+ model_response = self._process_response(response)
157
+ model_response.usage.requests = 1
158
+ return model_response
157
159
 
158
160
  @asynccontextmanager
159
161
  async def request_stream(
@@ -323,7 +325,7 @@ class MistralModel(Model):
323
325
  tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
324
326
  parts.append(tool)
325
327
 
326
- return ModelResponse(parts, model_name=response.model, timestamp=timestamp)
328
+ return ModelResponse(parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
327
329
 
328
330
  async def _process_streamed_response(
329
331
  self,
@@ -192,12 +192,14 @@ class OpenAIModel(Model):
192
192
  messages: list[ModelMessage],
193
193
  model_settings: ModelSettings | None,
194
194
  model_request_parameters: ModelRequestParameters,
195
- ) -> tuple[ModelResponse, usage.Usage]:
195
+ ) -> ModelResponse:
196
196
  check_allow_model_requests()
197
197
  response = await self._completions_create(
198
198
  messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
199
199
  )
200
- return self._process_response(response), _map_usage(response)
200
+ model_response = self._process_response(response)
201
+ model_response.usage.requests = 1
202
+ return model_response
201
203
 
202
204
  @asynccontextmanager
203
205
  async def request_stream(
@@ -304,7 +306,7 @@ class OpenAIModel(Model):
304
306
  if choice.message.tool_calls is not None:
305
307
  for c in choice.message.tool_calls:
306
308
  items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
307
- return ModelResponse(items, model_name=response.model, timestamp=timestamp)
309
+ return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
308
310
 
309
311
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
310
312
  """Process a streamed response, and prepare a streaming response to return."""
@@ -522,12 +524,12 @@ class OpenAIResponsesModel(Model):
522
524
  messages: list[ModelRequest | ModelResponse],
523
525
  model_settings: ModelSettings | None,
524
526
  model_request_parameters: ModelRequestParameters,
525
- ) -> tuple[ModelResponse, usage.Usage]:
527
+ ) -> ModelResponse:
526
528
  check_allow_model_requests()
527
529
  response = await self._responses_create(
528
530
  messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
529
531
  )
530
- return self._process_response(response), _map_usage(response)
532
+ return self._process_response(response)
531
533
 
532
534
  @asynccontextmanager
533
535
  async def request_stream(
@@ -554,7 +556,7 @@ class OpenAIResponsesModel(Model):
554
556
  for item in response.output:
555
557
  if item.type == 'function_call':
556
558
  items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
557
- return ModelResponse(items, model_name=response.model, timestamp=timestamp)
559
+ return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
558
560
 
559
561
  async def _process_streamed_response(
560
562
  self, response: AsyncStream[responses.ResponseStreamEvent]
@@ -935,6 +937,7 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
935
937
  if response_usage.prompt_tokens_details is not None:
936
938
  details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
937
939
  return usage.Usage(
940
+ requests=1,
938
941
  request_tokens=response_usage.prompt_tokens,
939
942
  response_tokens=response_usage.completion_tokens,
940
943
  total_tokens=response_usage.total_tokens,
@@ -86,11 +86,12 @@ class TestModel(Model):
86
86
  messages: list[ModelMessage],
87
87
  model_settings: ModelSettings | None,
88
88
  model_request_parameters: ModelRequestParameters,
89
- ) -> tuple[ModelResponse, Usage]:
89
+ ) -> ModelResponse:
90
90
  self.last_model_request_parameters = model_request_parameters
91
91
  model_response = self._request(messages, model_settings, model_request_parameters)
92
- usage = _estimate_usage([*messages, model_response])
93
- return model_response, usage
92
+ model_response.usage = _estimate_usage([*messages, model_response])
93
+ model_response.usage.requests = 1
94
+ return model_response
94
95
 
95
96
  @asynccontextmanager
96
97
  async def request_stream(
@@ -7,7 +7,6 @@ from typing import Any
7
7
 
8
8
  from ..messages import ModelMessage, ModelResponse
9
9
  from ..settings import ModelSettings
10
- from ..usage import Usage
11
10
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
12
11
 
13
12
 
@@ -24,7 +23,7 @@ class WrapperModel(Model):
24
23
  def __init__(self, wrapped: Model | KnownModelName):
25
24
  self.wrapped = infer_model(wrapped)
26
25
 
27
- async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
26
+ async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:
28
27
  return await self.wrapped.request(*args, **kwargs)
29
28
 
30
29
  @asynccontextmanager
@@ -28,14 +28,12 @@ class Usage:
28
28
  details: dict[str, int] | None = None
29
29
  """Any extra details returned by the model."""
30
30
 
31
- def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
31
+ def incr(self, incr_usage: Usage) -> None:
32
32
  """Increment the usage in place.
33
33
 
34
34
  Args:
35
35
  incr_usage: The usage to increment by.
36
- requests: The number of requests to increment by in addition to `incr_usage.requests`.
37
36
  """
38
- self.requests += requests
39
37
  for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
40
38
  self_value = getattr(self, f)
41
39
  other_value = getattr(incr_usage, f)
@@ -66,6 +64,10 @@ class Usage:
66
64
  result[f'gen_ai.usage.details.{key}'] = value
67
65
  return {k: v for k, v in result.items() if v}
68
66
 
67
+ def has_values(self) -> bool:
68
+ """Whether any values are set and non-zero."""
69
+ return bool(self.requests or self.request_tokens or self.response_tokens or self.details)
70
+
69
71
 
70
72
  @dataclass
71
73
  class UsageLimits: