pydantic-ai-slim 0.7.3__py3-none-any.whl → 0.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -13,7 +13,7 @@ from typing_extensions import assert_never
13
13
  from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._output import OutputObjectDefinition
15
15
  from .._run_context import RunContext
16
- from ..builtin_tools import CodeExecutionTool, WebSearchTool
16
+ from ..builtin_tools import CodeExecutionTool, UrlContextTool, WebSearchTool
17
17
  from ..exceptions import UserError
18
18
  from ..messages import (
19
19
  BinaryContent,
@@ -72,6 +72,7 @@ try:
72
72
  ToolConfigDict,
73
73
  ToolDict,
74
74
  ToolListUnionDict,
75
+ UrlContextDict,
75
76
  )
76
77
 
77
78
  from ..providers.google import GoogleProvider
@@ -218,7 +219,7 @@ class GoogleModel(Model):
218
219
  )
219
220
  if self._provider.name != 'google-gla':
220
221
  # The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
221
- config.update(
222
+ config.update( # pragma: lax no cover
222
223
  system_instruction=generation_config.get('system_instruction'),
223
224
  tools=cast(list[ToolDict], generation_config.get('tools')),
224
225
  # Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
@@ -270,6 +271,8 @@ class GoogleModel(Model):
270
271
  for tool in model_request_parameters.builtin_tools:
271
272
  if isinstance(tool, WebSearchTool):
272
273
  tools.append(ToolDict(google_search=GoogleSearchDict()))
274
+ elif isinstance(tool, UrlContextTool):
275
+ tools.append(ToolDict(url_context=UrlContextDict()))
273
276
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
274
277
  tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
275
278
  else: # pragma: no cover
@@ -374,23 +377,25 @@ class GoogleModel(Model):
374
377
  def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
375
378
  if not response.candidates or len(response.candidates) != 1:
376
379
  raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
377
- if response.candidates[0].content is None or response.candidates[0].content.parts is None:
378
- if response.candidates[0].finish_reason == 'SAFETY':
380
+ candidate = response.candidates[0]
381
+ if candidate.content is None or candidate.content.parts is None:
382
+ if candidate.finish_reason == 'SAFETY':
379
383
  raise UnexpectedModelBehavior('Safety settings triggered', str(response))
380
384
  else:
381
385
  raise UnexpectedModelBehavior(
382
386
  'Content field missing from Gemini response', str(response)
383
387
  ) # pragma: no cover
384
- parts = response.candidates[0].content.parts or []
388
+ parts = candidate.content.parts or []
385
389
  vendor_id = response.response_id or None
386
390
  vendor_details: dict[str, Any] | None = None
387
- finish_reason = response.candidates[0].finish_reason
391
+ finish_reason = candidate.finish_reason
388
392
  if finish_reason: # pragma: no branch
389
393
  vendor_details = {'finish_reason': finish_reason.value}
390
394
  usage = _metadata_as_usage(response)
391
395
  return _process_response_from_parts(
392
396
  parts,
393
397
  response.model_version or self._model_name,
398
+ self._provider.name,
394
399
  usage,
395
400
  vendor_id=vendor_id,
396
401
  vendor_details=vendor_details,
@@ -410,6 +415,7 @@ class GoogleModel(Model):
410
415
  _model_name=self._model_name,
411
416
  _response=peekable_response,
412
417
  _timestamp=first_chunk.create_time or _utils.now_utc(),
418
+ _provider_name=self._provider.name,
413
419
  )
414
420
 
415
421
  async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
@@ -519,6 +525,7 @@ class GeminiStreamedResponse(StreamedResponse):
519
525
  _model_name: GoogleModelName
520
526
  _response: AsyncIterator[GenerateContentResponse]
521
527
  _timestamp: datetime
528
+ _provider_name: str
522
529
 
523
530
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
524
531
  async for chunk in self._response:
@@ -526,10 +533,16 @@ class GeminiStreamedResponse(StreamedResponse):
526
533
 
527
534
  assert chunk.candidates is not None
528
535
  candidate = chunk.candidates[0]
529
- if candidate.content is None:
530
- raise UnexpectedModelBehavior('Streamed response has no content field') # pragma: no cover
531
- assert candidate.content.parts is not None
532
- for part in candidate.content.parts:
536
+ if candidate.content is None or candidate.content.parts is None:
537
+ if candidate.finish_reason == 'STOP': # pragma: no cover
538
+ # Normal completion - skip this chunk
539
+ continue
540
+ elif candidate.finish_reason == 'SAFETY': # pragma: no cover
541
+ raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
542
+ else: # pragma: no cover
543
+ raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
544
+ parts = candidate.content.parts or []
545
+ for part in parts:
533
546
  if part.text is not None:
534
547
  if part.thought:
535
548
  yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
@@ -554,6 +567,11 @@ class GeminiStreamedResponse(StreamedResponse):
554
567
  """Get the model name of the response."""
555
568
  return self._model_name
556
569
 
570
+ @property
571
+ def provider_name(self) -> str:
572
+ """Get the provider name."""
573
+ return self._provider_name
574
+
557
575
  @property
558
576
  def timestamp(self) -> datetime:
559
577
  """Get the timestamp of the response."""
@@ -589,6 +607,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
589
607
  def _process_response_from_parts(
590
608
  parts: list[Part],
591
609
  model_name: GoogleModelName,
610
+ provider_name: str,
592
611
  usage: usage.RequestUsage,
593
612
  vendor_id: str | None,
594
613
  vendor_details: dict[str, Any] | None = None,
@@ -626,7 +645,12 @@ def _process_response_from_parts(
626
645
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
627
646
  )
628
647
  return ModelResponse(
629
- parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
648
+ parts=items,
649
+ model_name=model_name,
650
+ usage=usage,
651
+ provider_request_id=vendor_id,
652
+ provider_details=vendor_details,
653
+ provider_name=provider_name,
630
654
  )
631
655
 
632
656
 
@@ -654,7 +678,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
654
678
  if cached_content_token_count := metadata.cached_content_token_count:
655
679
  details['cached_content_tokens'] = cached_content_token_count
656
680
 
657
- if thoughts_token_count := metadata.thoughts_token_count:
681
+ if thoughts_token_count := (metadata.thoughts_token_count or 0):
658
682
  details['thoughts_tokens'] = thoughts_token_count
659
683
 
660
684
  if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
@@ -687,7 +711,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
687
711
 
688
712
  return usage.RequestUsage(
689
713
  input_tokens=metadata.prompt_token_count or 0,
690
- output_tokens=metadata.candidates_token_count or 0,
714
+ output_tokens=(metadata.candidates_token_count or 0) + thoughts_token_count,
691
715
  cache_read_tokens=cached_content_token_count or 0,
692
716
  input_audio_tokens=input_audio_tokens,
693
717
  output_audio_tokens=output_audio_tokens,
@@ -290,6 +290,7 @@ class GroqModel(Model):
290
290
  model_name=response.model,
291
291
  timestamp=timestamp,
292
292
  provider_request_id=response.id,
293
+ provider_name=self._provider.name,
293
294
  )
294
295
 
295
296
  async def _process_streamed_response(
@@ -309,6 +310,7 @@ class GroqModel(Model):
309
310
  _model_name=self._model_name,
310
311
  _model_profile=self.profile,
311
312
  _timestamp=number_to_datetime(first_chunk.created),
313
+ _provider_name=self._provider.name,
312
314
  )
313
315
 
314
316
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -444,6 +446,7 @@ class GroqStreamedResponse(StreamedResponse):
444
446
  _model_profile: ModelProfile
445
447
  _response: AsyncIterable[chat.ChatCompletionChunk]
446
448
  _timestamp: datetime
449
+ _provider_name: str
447
450
 
448
451
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
449
452
  async for chunk in self._response:
@@ -482,6 +485,11 @@ class GroqStreamedResponse(StreamedResponse):
482
485
  """Get the model name of the response."""
483
486
  return self._model_name
484
487
 
488
+ @property
489
+ def provider_name(self) -> str:
490
+ """Get the provider name."""
491
+ return self._provider_name
492
+
485
493
  @property
486
494
  def timestamp(self) -> datetime:
487
495
  """Get the timestamp of the response."""
@@ -272,6 +272,7 @@ class HuggingFaceModel(Model):
272
272
  model_name=response.model,
273
273
  timestamp=timestamp,
274
274
  provider_request_id=response.id,
275
+ provider_name=self._provider.name,
275
276
  )
276
277
 
277
278
  async def _process_streamed_response(
@@ -291,6 +292,7 @@ class HuggingFaceModel(Model):
291
292
  _model_profile=self.profile,
292
293
  _response=peekable_response,
293
294
  _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
295
+ _provider_name=self._provider.name,
294
296
  )
295
297
 
296
298
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
@@ -437,6 +439,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
437
439
  _model_profile: ModelProfile
438
440
  _response: AsyncIterable[ChatCompletionStreamOutput]
439
441
  _timestamp: datetime
442
+ _provider_name: str
440
443
 
441
444
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
442
445
  async for chunk in self._response:
@@ -474,6 +477,11 @@ class HuggingFaceStreamedResponse(StreamedResponse):
474
477
  """Get the model name of the response."""
475
478
  return self._model_name
476
479
 
480
+ @property
481
+ def provider_name(self) -> str:
482
+ """Get the provider name."""
483
+ return self._provider_name
484
+
477
485
  @property
478
486
  def timestamp(self) -> datetime:
479
487
  """Get the timestamp of the response."""
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import itertools
3
4
  import json
4
5
  from collections.abc import AsyncIterator, Iterator, Mapping
5
6
  from contextlib import asynccontextmanager, contextmanager
6
7
  from dataclasses import dataclass, field
7
- from typing import Any, Callable, Literal
8
+ from typing import Any, Callable, Literal, cast
8
9
  from urllib.parse import urlparse
9
10
 
10
11
  from opentelemetry._events import (
@@ -18,8 +19,14 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
18
19
  from opentelemetry.util.types import AttributeValue
19
20
  from pydantic import TypeAdapter
20
21
 
22
+ from .. import _otel_messages
21
23
  from .._run_context import RunContext
22
- from ..messages import ModelMessage, ModelRequest, ModelResponse
24
+ from ..messages import (
25
+ ModelMessage,
26
+ ModelRequest,
27
+ ModelResponse,
28
+ SystemPromptPart,
29
+ )
23
30
  from ..settings import ModelSettings
24
31
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
25
32
  from .wrapper import WrapperModel
@@ -80,6 +87,8 @@ class InstrumentationSettings:
80
87
  event_logger: EventLogger = field(repr=False)
81
88
  event_mode: Literal['attributes', 'logs'] = 'attributes'
82
89
  include_binary_content: bool = True
90
+ include_content: bool = True
91
+ version: Literal[1, 2] = 1
83
92
 
84
93
  def __init__(
85
94
  self,
@@ -90,6 +99,7 @@ class InstrumentationSettings:
90
99
  event_logger_provider: EventLoggerProvider | None = None,
91
100
  include_binary_content: bool = True,
92
101
  include_content: bool = True,
102
+ version: Literal[1, 2] = 1,
93
103
  ):
94
104
  """Create instrumentation options.
95
105
 
@@ -109,6 +119,10 @@ class InstrumentationSettings:
109
119
  include_binary_content: Whether to include binary content in the instrumentation events.
110
120
  include_content: Whether to include prompts, completions, and tool call arguments and responses
111
121
  in the instrumentation events.
122
+ version: Version of the data format.
123
+ Version 1 is based on the legacy event-based OpenTelemetry GenAI spec.
124
+ Version 2 stores messages in the attributes `gen_ai.input.messages` and `gen_ai.output.messages`.
125
+ Version 2 is still WIP and experimental, but will become the default in Pydantic AI v1.
112
126
  """
113
127
  from pydantic_ai import __version__
114
128
 
@@ -122,6 +136,7 @@ class InstrumentationSettings:
122
136
  self.event_mode = event_mode
123
137
  self.include_binary_content = include_binary_content
124
138
  self.include_content = include_content
139
+ self.version = version
125
140
 
126
141
  # As specified in the OpenTelemetry GenAI metrics spec:
127
142
  # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
@@ -179,6 +194,90 @@ class InstrumentationSettings:
179
194
  event.body = InstrumentedModel.serialize_any(event.body)
180
195
  return events
181
196
 
197
+ def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_messages.ChatMessage]:
198
+ result: list[_otel_messages.ChatMessage] = []
199
+ for message in messages:
200
+ if isinstance(message, ModelRequest):
201
+ for is_system, group in itertools.groupby(message.parts, key=lambda p: isinstance(p, SystemPromptPart)):
202
+ message_parts: list[_otel_messages.MessagePart] = []
203
+ for part in group:
204
+ if hasattr(part, 'otel_message_parts'):
205
+ message_parts.extend(part.otel_message_parts(self))
206
+ result.append(
207
+ _otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
208
+ )
209
+ elif isinstance(message, ModelResponse): # pragma: no branch
210
+ result.append(_otel_messages.ChatMessage(role='assistant', parts=message.otel_message_parts(self)))
211
+ return result
212
+
213
+ def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
214
+ if self.version == 1:
215
+ events = self.messages_to_otel_events(input_messages)
216
+ for event in self.messages_to_otel_events([response]):
217
+ events.append(
218
+ Event(
219
+ 'gen_ai.choice',
220
+ body={
221
+ 'index': 0,
222
+ 'message': event.body,
223
+ },
224
+ )
225
+ )
226
+ for event in events:
227
+ event.attributes = {
228
+ GEN_AI_SYSTEM_ATTRIBUTE: system,
229
+ **(event.attributes or {}),
230
+ }
231
+ self._emit_events(span, events)
232
+ else:
233
+ output_messages = self.messages_to_otel_messages([response])
234
+ assert len(output_messages) == 1
235
+ output_message = cast(_otel_messages.OutputMessage, output_messages[0])
236
+ if response.provider_details and 'finish_reason' in response.provider_details:
237
+ output_message['finish_reason'] = response.provider_details['finish_reason']
238
+ instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
239
+ attributes = {
240
+ 'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
241
+ 'gen_ai.output.messages': json.dumps([output_message]),
242
+ 'logfire.json_schema': json.dumps(
243
+ {
244
+ 'type': 'object',
245
+ 'properties': {
246
+ 'gen_ai.input.messages': {'type': 'array'},
247
+ 'gen_ai.output.messages': {'type': 'array'},
248
+ **({'gen_ai.system_instructions': {'type': 'array'}} if instructions else {}),
249
+ 'model_request_parameters': {'type': 'object'},
250
+ },
251
+ }
252
+ ),
253
+ }
254
+ if instructions is not None:
255
+ attributes['gen_ai.system_instructions'] = json.dumps(
256
+ [_otel_messages.TextPart(type='text', content=instructions)]
257
+ )
258
+ span.set_attributes(attributes)
259
+
260
+ def _emit_events(self, span: Span, events: list[Event]) -> None:
261
+ if self.event_mode == 'logs':
262
+ for event in events:
263
+ self.event_logger.emit(event)
264
+ else:
265
+ attr_name = 'events'
266
+ span.set_attributes(
267
+ {
268
+ attr_name: json.dumps([InstrumentedModel.event_to_dict(event) for event in events]),
269
+ 'logfire.json_schema': json.dumps(
270
+ {
271
+ 'type': 'object',
272
+ 'properties': {
273
+ attr_name: {'type': 'array'},
274
+ 'model_request_parameters': {'type': 'object'},
275
+ },
276
+ }
277
+ ),
278
+ }
279
+ )
280
+
182
281
 
183
282
  GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
184
283
  GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
@@ -269,7 +368,7 @@ class InstrumentedModel(WrapperModel):
269
368
  # FallbackModel updates these span attributes.
270
369
  attributes.update(getattr(span, 'attributes', {}))
271
370
  request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
272
- system = attributes[GEN_AI_SYSTEM_ATTRIBUTE]
371
+ system = cast(str, attributes[GEN_AI_SYSTEM_ATTRIBUTE])
273
372
 
274
373
  response_model = response.model_name or request_model
275
374
 
@@ -297,18 +396,7 @@ class InstrumentedModel(WrapperModel):
297
396
  if not span.is_recording():
298
397
  return
299
398
 
300
- events = self.instrumentation_settings.messages_to_otel_events(messages)
301
- for event in self.instrumentation_settings.messages_to_otel_events([response]):
302
- events.append(
303
- Event(
304
- 'gen_ai.choice',
305
- body={
306
- # TODO finish_reason
307
- 'index': 0,
308
- 'message': event.body,
309
- },
310
- )
311
- )
399
+ self.instrumentation_settings.handle_messages(messages, response, system, span)
312
400
  span.set_attributes(
313
401
  {
314
402
  **response.usage.opentelemetry_attributes(),
@@ -316,12 +404,6 @@ class InstrumentedModel(WrapperModel):
316
404
  }
317
405
  )
318
406
  span.update_name(f'{operation} {request_model}')
319
- for event in events:
320
- event.attributes = {
321
- GEN_AI_SYSTEM_ATTRIBUTE: system,
322
- **(event.attributes or {}),
323
- }
324
- self._emit_events(span, events)
325
407
 
326
408
  yield finish
327
409
  finally:
@@ -330,27 +412,6 @@ class InstrumentedModel(WrapperModel):
330
412
  # to prevent them from being redundantly recorded in the span itself by logfire.
331
413
  record_metrics()
332
414
 
333
- def _emit_events(self, span: Span, events: list[Event]) -> None:
334
- if self.instrumentation_settings.event_mode == 'logs':
335
- for event in events:
336
- self.instrumentation_settings.event_logger.emit(event)
337
- else:
338
- attr_name = 'events'
339
- span.set_attributes(
340
- {
341
- attr_name: json.dumps([self.event_to_dict(event) for event in events]),
342
- 'logfire.json_schema': json.dumps(
343
- {
344
- 'type': 'object',
345
- 'properties': {
346
- attr_name: {'type': 'array'},
347
- 'model_request_parameters': {'type': 'object'},
348
- },
349
- }
350
- ),
351
- }
352
- )
353
-
354
415
  @staticmethod
355
416
  def model_attributes(model: Model):
356
417
  attributes: dict[str, AttributeValue] = {
@@ -353,6 +353,7 @@ class MistralModel(Model):
353
353
  model_name=response.model,
354
354
  timestamp=timestamp,
355
355
  provider_request_id=response.id,
356
+ provider_name=self._provider.name,
356
357
  )
357
358
 
358
359
  async def _process_streamed_response(
@@ -378,6 +379,7 @@ class MistralModel(Model):
378
379
  _response=peekable_response,
379
380
  _model_name=self._model_name,
380
381
  _timestamp=timestamp,
382
+ _provider_name=self._provider.name,
381
383
  )
382
384
 
383
385
  @staticmethod
@@ -584,6 +586,7 @@ class MistralStreamedResponse(StreamedResponse):
584
586
  _model_name: MistralModelName
585
587
  _response: AsyncIterable[MistralCompletionEvent]
586
588
  _timestamp: datetime
589
+ _provider_name: str
587
590
 
588
591
  _delta_content: str = field(default='', init=False)
589
592
 
@@ -631,6 +634,11 @@ class MistralStreamedResponse(StreamedResponse):
631
634
  """Get the model name of the response."""
632
635
  return self._model_name
633
636
 
637
+ @property
638
+ def provider_name(self) -> str:
639
+ """Get the provider name."""
640
+ return self._provider_name
641
+
634
642
  @property
635
643
  def timestamp(self) -> datetime:
636
644
  """Get the timestamp of the response."""
@@ -500,6 +500,7 @@ class OpenAIModel(Model):
500
500
  timestamp=timestamp,
501
501
  provider_details=vendor_details,
502
502
  provider_request_id=response.id,
503
+ provider_name=self._provider.name,
503
504
  )
504
505
 
505
506
  async def _process_streamed_response(
@@ -519,6 +520,7 @@ class OpenAIModel(Model):
519
520
  _model_profile=self.profile,
520
521
  _response=peekable_response,
521
522
  _timestamp=number_to_datetime(first_chunk.created),
523
+ _provider_name=self._provider.name,
522
524
  )
523
525
 
524
526
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -571,6 +573,8 @@ class OpenAIModel(Model):
571
573
  # Note: model responses from this model should only have one text item, so the following
572
574
  # shouldn't merge multiple texts into one unless you switch models between runs:
573
575
  message_param['content'] = '\n\n'.join(texts)
576
+ else:
577
+ message_param['content'] = None
574
578
  if tool_calls:
575
579
  message_param['tool_calls'] = tool_calls
576
580
  openai_messages.append(message_param)
@@ -803,6 +807,7 @@ class OpenAIResponsesModel(Model):
803
807
  model_name=response.model,
804
808
  provider_request_id=response.id,
805
809
  timestamp=timestamp,
810
+ provider_name=self._provider.name,
806
811
  )
807
812
 
808
813
  async def _process_streamed_response(
@@ -822,6 +827,7 @@ class OpenAIResponsesModel(Model):
822
827
  _model_name=self._model_name,
823
828
  _response=peekable_response,
824
829
  _timestamp=number_to_datetime(first_chunk.response.created_at),
830
+ _provider_name=self._provider.name,
825
831
  )
826
832
 
827
833
  @overload
@@ -1137,6 +1143,7 @@ class OpenAIStreamedResponse(StreamedResponse):
1137
1143
  _model_profile: ModelProfile
1138
1144
  _response: AsyncIterable[ChatCompletionChunk]
1139
1145
  _timestamp: datetime
1146
+ _provider_name: str
1140
1147
 
1141
1148
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1142
1149
  async for chunk in self._response:
@@ -1180,6 +1187,11 @@ class OpenAIStreamedResponse(StreamedResponse):
1180
1187
  """Get the model name of the response."""
1181
1188
  return self._model_name
1182
1189
 
1190
+ @property
1191
+ def provider_name(self) -> str:
1192
+ """Get the provider name."""
1193
+ return self._provider_name
1194
+
1183
1195
  @property
1184
1196
  def timestamp(self) -> datetime:
1185
1197
  """Get the timestamp of the response."""
@@ -1193,6 +1205,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1193
1205
  _model_name: OpenAIModelName
1194
1206
  _response: AsyncIterable[responses.ResponseStreamEvent]
1195
1207
  _timestamp: datetime
1208
+ _provider_name: str
1196
1209
 
1197
1210
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
1198
1211
  async for chunk in self._response:
@@ -1313,6 +1326,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1313
1326
  """Get the model name of the response."""
1314
1327
  return self._model_name
1315
1328
 
1329
+ @property
1330
+ def provider_name(self) -> str:
1331
+ """Get the provider name."""
1332
+ return self._provider_name
1333
+
1316
1334
  @property
1317
1335
  def timestamp(self) -> datetime:
1318
1336
  """Get the timestamp of the response."""
@@ -131,6 +131,7 @@ class TestModel(Model):
131
131
  _model_name=self._model_name,
132
132
  _structured_response=model_response,
133
133
  _messages=messages,
134
+ _provider_name=self._system,
134
135
  )
135
136
 
136
137
  @property
@@ -263,6 +264,7 @@ class TestStreamedResponse(StreamedResponse):
263
264
  _model_name: str
264
265
  _structured_response: ModelResponse
265
266
  _messages: InitVar[Iterable[ModelMessage]]
267
+ _provider_name: str
266
268
  _timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
267
269
 
268
270
  def __post_init__(self, _messages: Iterable[ModelMessage]):
@@ -305,6 +307,11 @@ class TestStreamedResponse(StreamedResponse):
305
307
  """Get the model name of the response."""
306
308
  return self._model_name
307
309
 
310
+ @property
311
+ def provider_name(self) -> str:
312
+ """Get the provider name."""
313
+ return self._provider_name
314
+
308
315
  @property
309
316
  def timestamp(self) -> datetime:
310
317
  """Get the timestamp of the response."""
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import os
4
- from typing import overload
4
+ from typing import Union, overload
5
5
 
6
6
  import httpx
7
+ from typing_extensions import TypeAlias
7
8
 
8
9
  from pydantic_ai.exceptions import UserError
9
10
  from pydantic_ai.models import cached_async_http_client
@@ -12,15 +13,18 @@ from pydantic_ai.profiles.anthropic import anthropic_model_profile
12
13
  from pydantic_ai.providers import Provider
13
14
 
14
15
  try:
15
- from anthropic import AsyncAnthropic
16
- except ImportError as _import_error: # pragma: no cover
16
+ from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
17
+ except ImportError as _import_error:
17
18
  raise ImportError(
18
19
  'Please install the `anthropic` package to use the Anthropic provider, '
19
20
  'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
20
21
  ) from _import_error
21
22
 
22
23
 
23
- class AnthropicProvider(Provider[AsyncAnthropic]):
24
+ AsyncAnthropicClient: TypeAlias = Union[AsyncAnthropic, AsyncAnthropicBedrock]
25
+
26
+
27
+ class AnthropicProvider(Provider[AsyncAnthropicClient]):
24
28
  """Provider for Anthropic API."""
25
29
 
26
30
  @property
@@ -32,14 +36,14 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
32
36
  return str(self._client.base_url)
33
37
 
34
38
  @property
35
- def client(self) -> AsyncAnthropic:
39
+ def client(self) -> AsyncAnthropicClient:
36
40
  return self._client
37
41
 
38
42
  def model_profile(self, model_name: str) -> ModelProfile | None:
39
43
  return anthropic_model_profile(model_name)
40
44
 
41
45
  @overload
42
- def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
46
+ def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...
43
47
 
44
48
  @overload
45
49
  def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
@@ -48,7 +52,7 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
48
52
  self,
49
53
  *,
50
54
  api_key: str | None = None,
51
- anthropic_client: AsyncAnthropic | None = None,
55
+ anthropic_client: AsyncAnthropicClient | None = None,
52
56
  http_client: httpx.AsyncClient | None = None,
53
57
  ) -> None:
54
58
  """Create a new Anthropic provider.
@@ -71,7 +75,6 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
71
75
  'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
72
76
  'to use the Anthropic provider.'
73
77
  )
74
-
75
78
  if http_client is not None:
76
79
  self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
77
80
  else: