pydantic-ai-slim 0.7.2__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. pydantic_ai/_agent_graph.py +2 -2
  2. pydantic_ai/_cli.py +18 -3
  3. pydantic_ai/_run_context.py +2 -2
  4. pydantic_ai/ag_ui.py +4 -4
  5. pydantic_ai/agent/__init__.py +7 -9
  6. pydantic_ai/agent/abstract.py +16 -18
  7. pydantic_ai/agent/wrapper.py +4 -6
  8. pydantic_ai/direct.py +4 -4
  9. pydantic_ai/durable_exec/temporal/_agent.py +13 -15
  10. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  11. pydantic_ai/messages.py +16 -6
  12. pydantic_ai/models/__init__.py +5 -5
  13. pydantic_ai/models/anthropic.py +27 -26
  14. pydantic_ai/models/bedrock.py +24 -26
  15. pydantic_ai/models/cohere.py +20 -25
  16. pydantic_ai/models/fallback.py +15 -15
  17. pydantic_ai/models/function.py +7 -9
  18. pydantic_ai/models/gemini.py +43 -39
  19. pydantic_ai/models/google.py +59 -40
  20. pydantic_ai/models/groq.py +22 -19
  21. pydantic_ai/models/huggingface.py +18 -21
  22. pydantic_ai/models/instrumented.py +4 -4
  23. pydantic_ai/models/mcp_sampling.py +1 -2
  24. pydantic_ai/models/mistral.py +24 -22
  25. pydantic_ai/models/openai.py +98 -44
  26. pydantic_ai/models/test.py +4 -5
  27. pydantic_ai/profiles/openai.py +13 -3
  28. pydantic_ai/providers/openai.py +1 -1
  29. pydantic_ai/result.py +5 -5
  30. pydantic_ai/run.py +4 -11
  31. pydantic_ai/usage.py +229 -67
  32. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/METADATA +10 -4
  33. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/RECORD +36 -36
  34. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/WHEEL +0 -0
  35. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/entry_points.txt +0 -0
  36. {pydantic_ai_slim-0.7.2.dist-info → pydantic_ai_slim-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -144,7 +144,6 @@ class GoogleModel(Model):
144
144
  _model_name: GoogleModelName = field(repr=False)
145
145
  _provider: Provider[Client] = field(repr=False)
146
146
  _url: str | None = field(repr=False)
147
- _system: str = field(default='google', repr=False)
148
147
 
149
148
  def __init__(
150
149
  self,
@@ -168,9 +167,7 @@ class GoogleModel(Model):
168
167
 
169
168
  if isinstance(provider, str):
170
169
  provider = GoogleProvider(vertexai=provider == 'google-vertex')
171
-
172
170
  self._provider = provider
173
- self._system = provider.name
174
171
  self.client = provider.client
175
172
 
176
173
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -179,6 +176,16 @@ class GoogleModel(Model):
179
176
  def base_url(self) -> str:
180
177
  return self._provider.base_url
181
178
 
179
+ @property
180
+ def model_name(self) -> GoogleModelName:
181
+ """The model name."""
182
+ return self._model_name
183
+
184
+ @property
185
+ def system(self) -> str:
186
+ """The model provider."""
187
+ return self._provider.name
188
+
182
189
  async def request(
183
190
  self,
184
191
  messages: list[ModelMessage],
@@ -195,7 +202,7 @@ class GoogleModel(Model):
195
202
  messages: list[ModelMessage],
196
203
  model_settings: ModelSettings | None,
197
204
  model_request_parameters: ModelRequestParameters,
198
- ) -> usage.Usage:
205
+ ) -> usage.RequestUsage:
199
206
  check_allow_model_requests()
200
207
  model_settings = cast(GoogleModelSettings, model_settings or {})
201
208
  contents, generation_config = await self._build_content_and_config(
@@ -209,7 +216,7 @@ class GoogleModel(Model):
209
216
  config = CountTokensConfigDict(
210
217
  http_options=generation_config.get('http_options'),
211
218
  )
212
- if self.system != 'google-gla':
219
+ if self._provider.name != 'google-gla':
213
220
  # The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
214
221
  config.update(
215
222
  system_instruction=generation_config.get('system_instruction'),
@@ -238,9 +245,8 @@ class GoogleModel(Model):
238
245
  raise UnexpectedModelBehavior( # pragma: no cover
239
246
  'Total tokens missing from Gemini response', str(response)
240
247
  )
241
- return usage.Usage(
242
- request_tokens=response.total_tokens,
243
- total_tokens=response.total_tokens,
248
+ return usage.RequestUsage(
249
+ input_tokens=response.total_tokens,
244
250
  )
245
251
 
246
252
  @asynccontextmanager
@@ -256,16 +262,6 @@ class GoogleModel(Model):
256
262
  response = await self._generate_content(messages, True, model_settings, model_request_parameters)
257
263
  yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
258
264
 
259
- @property
260
- def model_name(self) -> GoogleModelName:
261
- """The model name."""
262
- return self._model_name
263
-
264
- @property
265
- def system(self) -> str:
266
- """The system / model provider."""
267
- return self._system
268
-
269
265
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
270
266
  tools: list[ToolDict] = [
271
267
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
@@ -392,9 +388,12 @@ class GoogleModel(Model):
392
388
  if finish_reason: # pragma: no branch
393
389
  vendor_details = {'finish_reason': finish_reason.value}
394
390
  usage = _metadata_as_usage(response)
395
- usage.requests = 1
396
391
  return _process_response_from_parts(
397
- parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
392
+ parts,
393
+ response.model_version or self._model_name,
394
+ usage,
395
+ vendor_id=vendor_id,
396
+ vendor_details=vendor_details,
398
397
  )
399
398
 
400
399
  async def _process_streamed_response(
@@ -590,7 +589,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
590
589
  def _process_response_from_parts(
591
590
  parts: list[Part],
592
591
  model_name: GoogleModelName,
593
- usage: usage.Usage,
592
+ usage: usage.RequestUsage,
594
593
  vendor_id: str | None,
595
594
  vendor_details: dict[str, Any] | None = None,
596
595
  ) -> ModelResponse:
@@ -627,7 +626,7 @@ def _process_response_from_parts(
627
626
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
628
627
  )
629
628
  return ModelResponse(
630
- parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
629
+ parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
631
630
  )
632
631
 
633
632
 
@@ -647,31 +646,51 @@ def _tool_config(function_names: list[str]) -> ToolConfigDict:
647
646
  return ToolConfigDict(function_calling_config=function_calling_config)
648
647
 
649
648
 
650
- def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
649
+ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
651
650
  metadata = response.usage_metadata
652
651
  if metadata is None:
653
- return usage.Usage() # pragma: no cover
654
- metadata = metadata.model_dump(exclude_defaults=True)
655
-
652
+ return usage.RequestUsage()
656
653
  details: dict[str, int] = {}
657
- if cached_content_token_count := metadata.get('cached_content_token_count'):
658
- details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
654
+ if cached_content_token_count := metadata.cached_content_token_count:
655
+ details['cached_content_tokens'] = cached_content_token_count
659
656
 
660
- if thoughts_token_count := metadata.get('thoughts_token_count'):
657
+ if thoughts_token_count := metadata.thoughts_token_count:
661
658
  details['thoughts_tokens'] = thoughts_token_count
662
659
 
663
- if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
660
+ if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
664
661
  details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
665
662
 
666
- for key, metadata_details in metadata.items():
667
- if key.endswith('_details') and metadata_details:
668
- suffix = key.removesuffix('_details')
669
- for detail in metadata_details:
670
- details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
671
-
672
- return usage.Usage(
673
- request_tokens=metadata.get('prompt_token_count', 0),
674
- response_tokens=metadata.get('candidates_token_count', 0),
675
- total_tokens=metadata.get('total_token_count', 0),
663
+ input_audio_tokens = 0
664
+ output_audio_tokens = 0
665
+ cache_audio_read_tokens = 0
666
+ for prefix, metadata_details in [
667
+ ('prompt', metadata.prompt_tokens_details),
668
+ ('cache', metadata.cache_tokens_details),
669
+ ('candidates', metadata.candidates_tokens_details),
670
+ ('tool_use_prompt', metadata.tool_use_prompt_tokens_details),
671
+ ]:
672
+ assert getattr(metadata, f'{prefix}_tokens_details') is metadata_details
673
+ if not metadata_details:
674
+ continue
675
+ for detail in metadata_details:
676
+ if not detail.modality or not detail.token_count: # pragma: no cover
677
+ continue
678
+ details[f'{detail.modality.lower()}_{prefix}_tokens'] = detail.token_count
679
+ if detail.modality != 'AUDIO':
680
+ continue
681
+ if metadata_details is metadata.prompt_tokens_details:
682
+ input_audio_tokens = detail.token_count
683
+ elif metadata_details is metadata.candidates_tokens_details:
684
+ output_audio_tokens = detail.token_count
685
+ elif metadata_details is metadata.cache_tokens_details: # pragma: no branch
686
+ cache_audio_read_tokens = detail.token_count
687
+
688
+ return usage.RequestUsage(
689
+ input_tokens=metadata.prompt_token_count or 0,
690
+ output_tokens=metadata.candidates_token_count or 0,
691
+ cache_read_tokens=cached_content_token_count or 0,
692
+ input_audio_tokens=input_audio_tokens,
693
+ output_audio_tokens=output_audio_tokens,
694
+ cache_audio_read_tokens=cache_audio_read_tokens,
676
695
  details=details,
677
696
  )
@@ -118,7 +118,7 @@ class GroqModel(Model):
118
118
  client: AsyncGroq = field(repr=False)
119
119
 
120
120
  _model_name: GroqModelName = field(repr=False)
121
- _system: str = field(default='groq', repr=False)
121
+ _provider: Provider[AsyncGroq] = field(repr=False)
122
122
 
123
123
  def __init__(
124
124
  self,
@@ -143,6 +143,7 @@ class GroqModel(Model):
143
143
 
144
144
  if isinstance(provider, str):
145
145
  provider = infer_provider(provider)
146
+ self._provider = provider
146
147
  self.client = provider.client
147
148
 
148
149
  super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -151,6 +152,16 @@ class GroqModel(Model):
151
152
  def base_url(self) -> str:
152
153
  return str(self.client.base_url)
153
154
 
155
+ @property
156
+ def model_name(self) -> GroqModelName:
157
+ """The model name."""
158
+ return self._model_name
159
+
160
+ @property
161
+ def system(self) -> str:
162
+ """The model provider."""
163
+ return self._provider.name
164
+
154
165
  async def request(
155
166
  self,
156
167
  messages: list[ModelMessage],
@@ -162,7 +173,6 @@ class GroqModel(Model):
162
173
  messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
163
174
  )
164
175
  model_response = self._process_response(response)
165
- model_response.usage.requests = 1
166
176
  return model_response
167
177
 
168
178
  @asynccontextmanager
@@ -180,16 +190,6 @@ class GroqModel(Model):
180
190
  async with response:
181
191
  yield await self._process_streamed_response(response, model_request_parameters)
182
192
 
183
- @property
184
- def model_name(self) -> GroqModelName:
185
- """The model name."""
186
- return self._model_name
187
-
188
- @property
189
- def system(self) -> str:
190
- """The system / model provider."""
191
- return self._system
192
-
193
193
  @overload
194
194
  async def _completions_create(
195
195
  self,
@@ -285,7 +285,11 @@ class GroqModel(Model):
285
285
  for c in choice.message.tool_calls:
286
286
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
287
287
  return ModelResponse(
288
- items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
288
+ items,
289
+ usage=_map_usage(response),
290
+ model_name=response.model,
291
+ timestamp=timestamp,
292
+ provider_request_id=response.id,
289
293
  )
290
294
 
291
295
  async def _process_streamed_response(
@@ -484,7 +488,7 @@ class GroqStreamedResponse(StreamedResponse):
484
488
  return self._timestamp
485
489
 
486
490
 
487
- def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
491
+ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.RequestUsage:
488
492
  response_usage = None
489
493
  if isinstance(completion, chat.ChatCompletion):
490
494
  response_usage = completion.usage
@@ -492,10 +496,9 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
492
496
  response_usage = completion.x_groq.usage
493
497
 
494
498
  if response_usage is None:
495
- return usage.Usage()
499
+ return usage.RequestUsage()
496
500
 
497
- return usage.Usage(
498
- request_tokens=response_usage.prompt_tokens,
499
- response_tokens=response_usage.completion_tokens,
500
- total_tokens=response_usage.total_tokens,
501
+ return usage.RequestUsage(
502
+ input_tokens=response_usage.prompt_tokens,
503
+ output_tokens=response_usage.completion_tokens,
501
504
  )
@@ -114,7 +114,7 @@ class HuggingFaceModel(Model):
114
114
  client: AsyncInferenceClient = field(repr=False)
115
115
 
116
116
  _model_name: str = field(repr=False)
117
- _system: str = field(default='huggingface', repr=False)
117
+ _provider: Provider[AsyncInferenceClient] = field(repr=False)
118
118
 
119
119
  def __init__(
120
120
  self,
@@ -134,13 +134,23 @@ class HuggingFaceModel(Model):
134
134
  settings: Model-specific settings that will be used as defaults for this model.
135
135
  """
136
136
  self._model_name = model_name
137
- self._provider = provider
138
137
  if isinstance(provider, str):
139
138
  provider = infer_provider(provider)
139
+ self._provider = provider
140
140
  self.client = provider.client
141
141
 
142
142
  super().__init__(settings=settings, profile=profile or provider.model_profile)
143
143
 
144
+ @property
145
+ def model_name(self) -> HuggingFaceModelName:
146
+ """The model name."""
147
+ return self._model_name
148
+
149
+ @property
150
+ def system(self) -> str:
151
+ """The system / model provider."""
152
+ return self._provider.name
153
+
144
154
  async def request(
145
155
  self,
146
156
  messages: list[ModelMessage],
@@ -152,7 +162,6 @@ class HuggingFaceModel(Model):
152
162
  messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
153
163
  )
154
164
  model_response = self._process_response(response)
155
- model_response.usage.requests = 1
156
165
  return model_response
157
166
 
158
167
  @asynccontextmanager
@@ -169,16 +178,6 @@ class HuggingFaceModel(Model):
169
178
  )
170
179
  yield await self._process_streamed_response(response, model_request_parameters)
171
180
 
172
- @property
173
- def model_name(self) -> HuggingFaceModelName:
174
- """The model name."""
175
- return self._model_name
176
-
177
- @property
178
- def system(self) -> str:
179
- """The system / model provider."""
180
- return self._system
181
-
182
181
  @overload
183
182
  async def _completions_create(
184
183
  self,
@@ -272,7 +271,7 @@ class HuggingFaceModel(Model):
272
271
  usage=_map_usage(response),
273
272
  model_name=response.model,
274
273
  timestamp=timestamp,
275
- vendor_id=response.id,
274
+ provider_request_id=response.id,
276
275
  )
277
276
 
278
277
  async def _process_streamed_response(
@@ -481,14 +480,12 @@ class HuggingFaceStreamedResponse(StreamedResponse):
481
480
  return self._timestamp
482
481
 
483
482
 
484
- def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
483
+ def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.RequestUsage:
485
484
  response_usage = response.usage
486
485
  if response_usage is None:
487
- return usage.Usage()
486
+ return usage.RequestUsage()
488
487
 
489
- return usage.Usage(
490
- request_tokens=response_usage.prompt_tokens,
491
- response_tokens=response_usage.completion_tokens,
492
- total_tokens=response_usage.total_tokens,
493
- details=None,
488
+ return usage.RequestUsage(
489
+ input_tokens=response_usage.prompt_tokens,
490
+ output_tokens=response_usage.completion_tokens,
494
491
  )
@@ -280,14 +280,14 @@ class InstrumentedModel(WrapperModel):
280
280
  'gen_ai.request.model': request_model,
281
281
  'gen_ai.response.model': response_model,
282
282
  }
283
- if response.usage.request_tokens: # pragma: no branch
283
+ if response.usage.input_tokens: # pragma: no branch
284
284
  self.instrumentation_settings.tokens_histogram.record(
285
- response.usage.request_tokens,
285
+ response.usage.input_tokens,
286
286
  {**metric_attributes, 'gen_ai.token.type': 'input'},
287
287
  )
288
- if response.usage.response_tokens: # pragma: no branch
288
+ if response.usage.output_tokens: # pragma: no branch
289
289
  self.instrumentation_settings.tokens_histogram.record(
290
- response.usage.response_tokens,
290
+ response.usage.output_tokens,
291
291
  {**metric_attributes, 'gen_ai.token.type': 'output'},
292
292
  )
293
293
 
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
- from .. import _mcp, exceptions, usage
8
+ from .. import _mcp, exceptions
9
9
  from .._run_context import RunContext
10
10
  from ..messages import ModelMessage, ModelResponse
11
11
  from ..settings import ModelSettings
@@ -63,7 +63,6 @@ class MCPSamplingModel(Model):
63
63
  if result.role == 'assistant':
64
64
  return ModelResponse(
65
65
  parts=[_mcp.map_from_sampling_content(result.content)],
66
- usage=usage.Usage(requests=1),
67
66
  model_name=result.model,
68
67
  )
69
68
  else:
@@ -40,7 +40,7 @@ from ..profiles import ModelProfileSpec
40
40
  from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
- from ..usage import Usage
43
+ from ..usage import RequestUsage
44
44
  from . import (
45
45
  Model,
46
46
  ModelRequestParameters,
@@ -120,7 +120,7 @@ class MistralModel(Model):
120
120
  json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
121
121
 
122
122
  _model_name: MistralModelName = field(repr=False)
123
- _system: str = field(default='mistral_ai', repr=False)
123
+ _provider: Provider[Mistral] = field(repr=False)
124
124
 
125
125
  def __init__(
126
126
  self,
@@ -147,13 +147,24 @@ class MistralModel(Model):
147
147
 
148
148
  if isinstance(provider, str):
149
149
  provider = infer_provider(provider)
150
+ self._provider = provider
150
151
  self.client = provider.client
151
152
 
152
153
  super().__init__(settings=settings, profile=profile or provider.model_profile)
153
154
 
154
155
  @property
155
156
  def base_url(self) -> str:
156
- return self.client.sdk_configuration.get_server_details()[0]
157
+ return self._provider.base_url
158
+
159
+ @property
160
+ def model_name(self) -> MistralModelName:
161
+ """The model name."""
162
+ return self._model_name
163
+
164
+ @property
165
+ def system(self) -> str:
166
+ """The model provider."""
167
+ return self._provider.name
157
168
 
158
169
  async def request(
159
170
  self,
@@ -167,7 +178,6 @@ class MistralModel(Model):
167
178
  messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
168
179
  )
169
180
  model_response = self._process_response(response)
170
- model_response.usage.requests = 1
171
181
  return model_response
172
182
 
173
183
  @asynccontextmanager
@@ -186,16 +196,6 @@ class MistralModel(Model):
186
196
  async with response:
187
197
  yield await self._process_streamed_response(response, model_request_parameters)
188
198
 
189
- @property
190
- def model_name(self) -> MistralModelName:
191
- """The model name."""
192
- return self._model_name
193
-
194
- @property
195
- def system(self) -> str:
196
- """The system / model provider."""
197
- return self._system
198
-
199
199
  async def _completions_create(
200
200
  self,
201
201
  messages: list[ModelMessage],
@@ -348,7 +348,11 @@ class MistralModel(Model):
348
348
  parts.append(tool)
349
349
 
350
350
  return ModelResponse(
351
- parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
351
+ parts,
352
+ usage=_map_usage(response),
353
+ model_name=response.model,
354
+ timestamp=timestamp,
355
+ provider_request_id=response.id,
352
356
  )
353
357
 
354
358
  async def _process_streamed_response(
@@ -699,17 +703,15 @@ SIMPLE_JSON_TYPE_MAPPING = {
699
703
  }
700
704
 
701
705
 
702
- def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
706
+ def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> RequestUsage:
703
707
  """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
704
708
  if response.usage:
705
- return Usage(
706
- request_tokens=response.usage.prompt_tokens,
707
- response_tokens=response.usage.completion_tokens,
708
- total_tokens=response.usage.total_tokens,
709
- details=None,
709
+ return RequestUsage(
710
+ input_tokens=response.usage.prompt_tokens,
711
+ output_tokens=response.usage.completion_tokens,
710
712
  )
711
713
  else:
712
- return Usage() # pragma: no cover
714
+ return RequestUsage() # pragma: no cover
713
715
 
714
716
 
715
717
  def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None: