pydantic-ai-slim 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. pydantic_ai/_otel_messages.py +67 -0
  2. pydantic_ai/agent/__init__.py +11 -4
  3. pydantic_ai/builtin_tools.py +1 -0
  4. pydantic_ai/durable_exec/temporal/_model.py +4 -0
  5. pydantic_ai/messages.py +109 -18
  6. pydantic_ai/models/__init__.py +27 -9
  7. pydantic_ai/models/anthropic.py +20 -8
  8. pydantic_ai/models/bedrock.py +16 -10
  9. pydantic_ai/models/cohere.py +3 -1
  10. pydantic_ai/models/function.py +5 -0
  11. pydantic_ai/models/gemini.py +8 -1
  12. pydantic_ai/models/google.py +21 -4
  13. pydantic_ai/models/groq.py +8 -0
  14. pydantic_ai/models/huggingface.py +8 -0
  15. pydantic_ai/models/instrumented.py +103 -42
  16. pydantic_ai/models/mistral.py +8 -0
  17. pydantic_ai/models/openai.py +80 -36
  18. pydantic_ai/models/test.py +7 -0
  19. pydantic_ai/profiles/__init__.py +1 -1
  20. pydantic_ai/profiles/harmony.py +13 -0
  21. pydantic_ai/profiles/openai.py +6 -1
  22. pydantic_ai/profiles/qwen.py +8 -0
  23. pydantic_ai/providers/__init__.py +5 -1
  24. pydantic_ai/providers/anthropic.py +11 -8
  25. pydantic_ai/providers/azure.py +1 -1
  26. pydantic_ai/providers/cerebras.py +96 -0
  27. pydantic_ai/providers/cohere.py +2 -2
  28. pydantic_ai/providers/deepseek.py +4 -4
  29. pydantic_ai/providers/fireworks.py +3 -3
  30. pydantic_ai/providers/github.py +4 -4
  31. pydantic_ai/providers/grok.py +3 -3
  32. pydantic_ai/providers/groq.py +3 -3
  33. pydantic_ai/providers/heroku.py +3 -3
  34. pydantic_ai/providers/mistral.py +3 -3
  35. pydantic_ai/providers/moonshotai.py +3 -6
  36. pydantic_ai/providers/ollama.py +1 -1
  37. pydantic_ai/providers/openrouter.py +4 -4
  38. pydantic_ai/providers/together.py +3 -3
  39. pydantic_ai/providers/vercel.py +4 -4
  40. pydantic_ai/retries.py +154 -42
  41. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/METADATA +4 -4
  42. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/RECORD +45 -42
  43. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/WHEEL +0 -0
  44. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/entry_points.txt +0 -0
  45. {pydantic_ai_slim-0.7.4.dist-info → pydantic_ai_slim-0.7.6.dist-info}/licenses/LICENSE +0 -0
@@ -395,6 +395,7 @@ class GoogleModel(Model):
395
395
  return _process_response_from_parts(
396
396
  parts,
397
397
  response.model_version or self._model_name,
398
+ self._provider.name,
398
399
  usage,
399
400
  vendor_id=vendor_id,
400
401
  vendor_details=vendor_details,
@@ -414,6 +415,7 @@ class GoogleModel(Model):
414
415
  _model_name=self._model_name,
415
416
  _response=peekable_response,
416
417
  _timestamp=first_chunk.create_time or _utils.now_utc(),
418
+ _provider_name=self._provider.name,
417
419
  )
418
420
 
419
421
  async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
@@ -523,6 +525,7 @@ class GeminiStreamedResponse(StreamedResponse):
523
525
  _model_name: GoogleModelName
524
526
  _response: AsyncIterator[GenerateContentResponse]
525
527
  _timestamp: datetime
528
+ _provider_name: str
526
529
 
527
530
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
528
531
  async for chunk in self._response:
@@ -531,7 +534,10 @@ class GeminiStreamedResponse(StreamedResponse):
531
534
  assert chunk.candidates is not None
532
535
  candidate = chunk.candidates[0]
533
536
  if candidate.content is None or candidate.content.parts is None:
534
- if candidate.finish_reason == 'SAFETY': # pragma: no cover
537
+ if candidate.finish_reason == 'STOP': # pragma: no cover
538
+ # Normal completion - skip this chunk
539
+ continue
540
+ elif candidate.finish_reason == 'SAFETY': # pragma: no cover
535
541
  raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
536
542
  else: # pragma: no cover
537
543
  raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
@@ -561,6 +567,11 @@ class GeminiStreamedResponse(StreamedResponse):
561
567
  """Get the model name of the response."""
562
568
  return self._model_name
563
569
 
570
+ @property
571
+ def provider_name(self) -> str:
572
+ """Get the provider name."""
573
+ return self._provider_name
574
+
564
575
  @property
565
576
  def timestamp(self) -> datetime:
566
577
  """Get the timestamp of the response."""
@@ -596,6 +607,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
596
607
  def _process_response_from_parts(
597
608
  parts: list[Part],
598
609
  model_name: GoogleModelName,
610
+ provider_name: str,
599
611
  usage: usage.RequestUsage,
600
612
  vendor_id: str | None,
601
613
  vendor_details: dict[str, Any] | None = None,
@@ -633,7 +645,12 @@ def _process_response_from_parts(
633
645
  f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
634
646
  )
635
647
  return ModelResponse(
636
- parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
648
+ parts=items,
649
+ model_name=model_name,
650
+ usage=usage,
651
+ provider_request_id=vendor_id,
652
+ provider_details=vendor_details,
653
+ provider_name=provider_name,
637
654
  )
638
655
 
639
656
 
@@ -661,7 +678,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
661
678
  if cached_content_token_count := metadata.cached_content_token_count:
662
679
  details['cached_content_tokens'] = cached_content_token_count
663
680
 
664
- if thoughts_token_count := metadata.thoughts_token_count:
681
+ if thoughts_token_count := (metadata.thoughts_token_count or 0):
665
682
  details['thoughts_tokens'] = thoughts_token_count
666
683
 
667
684
  if tool_use_prompt_token_count := metadata.tool_use_prompt_token_count:
@@ -694,7 +711,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.RequestUsage:
694
711
 
695
712
  return usage.RequestUsage(
696
713
  input_tokens=metadata.prompt_token_count or 0,
697
- output_tokens=metadata.candidates_token_count or 0,
714
+ output_tokens=(metadata.candidates_token_count or 0) + thoughts_token_count,
698
715
  cache_read_tokens=cached_content_token_count or 0,
699
716
  input_audio_tokens=input_audio_tokens,
700
717
  output_audio_tokens=output_audio_tokens,
@@ -290,6 +290,7 @@ class GroqModel(Model):
290
290
  model_name=response.model,
291
291
  timestamp=timestamp,
292
292
  provider_request_id=response.id,
293
+ provider_name=self._provider.name,
293
294
  )
294
295
 
295
296
  async def _process_streamed_response(
@@ -309,6 +310,7 @@ class GroqModel(Model):
309
310
  _model_name=self._model_name,
310
311
  _model_profile=self.profile,
311
312
  _timestamp=number_to_datetime(first_chunk.created),
313
+ _provider_name=self._provider.name,
312
314
  )
313
315
 
314
316
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -444,6 +446,7 @@ class GroqStreamedResponse(StreamedResponse):
444
446
  _model_profile: ModelProfile
445
447
  _response: AsyncIterable[chat.ChatCompletionChunk]
446
448
  _timestamp: datetime
449
+ _provider_name: str
447
450
 
448
451
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
449
452
  async for chunk in self._response:
@@ -482,6 +485,11 @@ class GroqStreamedResponse(StreamedResponse):
482
485
  """Get the model name of the response."""
483
486
  return self._model_name
484
487
 
488
+ @property
489
+ def provider_name(self) -> str:
490
+ """Get the provider name."""
491
+ return self._provider_name
492
+
485
493
  @property
486
494
  def timestamp(self) -> datetime:
487
495
  """Get the timestamp of the response."""
@@ -272,6 +272,7 @@ class HuggingFaceModel(Model):
272
272
  model_name=response.model,
273
273
  timestamp=timestamp,
274
274
  provider_request_id=response.id,
275
+ provider_name=self._provider.name,
275
276
  )
276
277
 
277
278
  async def _process_streamed_response(
@@ -291,6 +292,7 @@ class HuggingFaceModel(Model):
291
292
  _model_profile=self.profile,
292
293
  _response=peekable_response,
293
294
  _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
295
+ _provider_name=self._provider.name,
294
296
  )
295
297
 
296
298
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
@@ -437,6 +439,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
437
439
  _model_profile: ModelProfile
438
440
  _response: AsyncIterable[ChatCompletionStreamOutput]
439
441
  _timestamp: datetime
442
+ _provider_name: str
440
443
 
441
444
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
442
445
  async for chunk in self._response:
@@ -474,6 +477,11 @@ class HuggingFaceStreamedResponse(StreamedResponse):
474
477
  """Get the model name of the response."""
475
478
  return self._model_name
476
479
 
480
+ @property
481
+ def provider_name(self) -> str:
482
+ """Get the provider name."""
483
+ return self._provider_name
484
+
477
485
  @property
478
486
  def timestamp(self) -> datetime:
479
487
  """Get the timestamp of the response."""
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import itertools
3
4
  import json
4
5
  from collections.abc import AsyncIterator, Iterator, Mapping
5
6
  from contextlib import asynccontextmanager, contextmanager
6
7
  from dataclasses import dataclass, field
7
- from typing import Any, Callable, Literal
8
+ from typing import Any, Callable, Literal, cast
8
9
  from urllib.parse import urlparse
9
10
 
10
11
  from opentelemetry._events import (
@@ -18,8 +19,14 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
18
19
  from opentelemetry.util.types import AttributeValue
19
20
  from pydantic import TypeAdapter
20
21
 
22
+ from .. import _otel_messages
21
23
  from .._run_context import RunContext
22
- from ..messages import ModelMessage, ModelRequest, ModelResponse
24
+ from ..messages import (
25
+ ModelMessage,
26
+ ModelRequest,
27
+ ModelResponse,
28
+ SystemPromptPart,
29
+ )
23
30
  from ..settings import ModelSettings
24
31
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
25
32
  from .wrapper import WrapperModel
@@ -80,6 +87,8 @@ class InstrumentationSettings:
80
87
  event_logger: EventLogger = field(repr=False)
81
88
  event_mode: Literal['attributes', 'logs'] = 'attributes'
82
89
  include_binary_content: bool = True
90
+ include_content: bool = True
91
+ version: Literal[1, 2] = 1
83
92
 
84
93
  def __init__(
85
94
  self,
@@ -90,6 +99,7 @@ class InstrumentationSettings:
90
99
  event_logger_provider: EventLoggerProvider | None = None,
91
100
  include_binary_content: bool = True,
92
101
  include_content: bool = True,
102
+ version: Literal[1, 2] = 1,
93
103
  ):
94
104
  """Create instrumentation options.
95
105
 
@@ -109,6 +119,10 @@ class InstrumentationSettings:
109
119
  include_binary_content: Whether to include binary content in the instrumentation events.
110
120
  include_content: Whether to include prompts, completions, and tool call arguments and responses
111
121
  in the instrumentation events.
122
+ version: Version of the data format.
123
+ Version 1 is based on the legacy event-based OpenTelemetry GenAI spec.
124
+ Version 2 stores messages in the attributes `gen_ai.input.messages` and `gen_ai.output.messages`.
125
+ Version 2 is still WIP and experimental, but will become the default in Pydantic AI v1.
112
126
  """
113
127
  from pydantic_ai import __version__
114
128
 
@@ -122,6 +136,7 @@ class InstrumentationSettings:
122
136
  self.event_mode = event_mode
123
137
  self.include_binary_content = include_binary_content
124
138
  self.include_content = include_content
139
+ self.version = version
125
140
 
126
141
  # As specified in the OpenTelemetry GenAI metrics spec:
127
142
  # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
@@ -179,6 +194,90 @@ class InstrumentationSettings:
179
194
  event.body = InstrumentedModel.serialize_any(event.body)
180
195
  return events
181
196
 
197
+ def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_messages.ChatMessage]:
198
+ result: list[_otel_messages.ChatMessage] = []
199
+ for message in messages:
200
+ if isinstance(message, ModelRequest):
201
+ for is_system, group in itertools.groupby(message.parts, key=lambda p: isinstance(p, SystemPromptPart)):
202
+ message_parts: list[_otel_messages.MessagePart] = []
203
+ for part in group:
204
+ if hasattr(part, 'otel_message_parts'):
205
+ message_parts.extend(part.otel_message_parts(self))
206
+ result.append(
207
+ _otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
208
+ )
209
+ elif isinstance(message, ModelResponse): # pragma: no branch
210
+ result.append(_otel_messages.ChatMessage(role='assistant', parts=message.otel_message_parts(self)))
211
+ return result
212
+
213
+ def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
214
+ if self.version == 1:
215
+ events = self.messages_to_otel_events(input_messages)
216
+ for event in self.messages_to_otel_events([response]):
217
+ events.append(
218
+ Event(
219
+ 'gen_ai.choice',
220
+ body={
221
+ 'index': 0,
222
+ 'message': event.body,
223
+ },
224
+ )
225
+ )
226
+ for event in events:
227
+ event.attributes = {
228
+ GEN_AI_SYSTEM_ATTRIBUTE: system,
229
+ **(event.attributes or {}),
230
+ }
231
+ self._emit_events(span, events)
232
+ else:
233
+ output_messages = self.messages_to_otel_messages([response])
234
+ assert len(output_messages) == 1
235
+ output_message = cast(_otel_messages.OutputMessage, output_messages[0])
236
+ if response.provider_details and 'finish_reason' in response.provider_details:
237
+ output_message['finish_reason'] = response.provider_details['finish_reason']
238
+ instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
239
+ attributes = {
240
+ 'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
241
+ 'gen_ai.output.messages': json.dumps([output_message]),
242
+ 'logfire.json_schema': json.dumps(
243
+ {
244
+ 'type': 'object',
245
+ 'properties': {
246
+ 'gen_ai.input.messages': {'type': 'array'},
247
+ 'gen_ai.output.messages': {'type': 'array'},
248
+ **({'gen_ai.system_instructions': {'type': 'array'}} if instructions else {}),
249
+ 'model_request_parameters': {'type': 'object'},
250
+ },
251
+ }
252
+ ),
253
+ }
254
+ if instructions is not None:
255
+ attributes['gen_ai.system_instructions'] = json.dumps(
256
+ [_otel_messages.TextPart(type='text', content=instructions)]
257
+ )
258
+ span.set_attributes(attributes)
259
+
260
+ def _emit_events(self, span: Span, events: list[Event]) -> None:
261
+ if self.event_mode == 'logs':
262
+ for event in events:
263
+ self.event_logger.emit(event)
264
+ else:
265
+ attr_name = 'events'
266
+ span.set_attributes(
267
+ {
268
+ attr_name: json.dumps([InstrumentedModel.event_to_dict(event) for event in events]),
269
+ 'logfire.json_schema': json.dumps(
270
+ {
271
+ 'type': 'object',
272
+ 'properties': {
273
+ attr_name: {'type': 'array'},
274
+ 'model_request_parameters': {'type': 'object'},
275
+ },
276
+ }
277
+ ),
278
+ }
279
+ )
280
+
182
281
 
183
282
  GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system'
184
283
  GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
@@ -269,7 +368,7 @@ class InstrumentedModel(WrapperModel):
269
368
  # FallbackModel updates these span attributes.
270
369
  attributes.update(getattr(span, 'attributes', {}))
271
370
  request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
272
- system = attributes[GEN_AI_SYSTEM_ATTRIBUTE]
371
+ system = cast(str, attributes[GEN_AI_SYSTEM_ATTRIBUTE])
273
372
 
274
373
  response_model = response.model_name or request_model
275
374
 
@@ -297,18 +396,7 @@ class InstrumentedModel(WrapperModel):
297
396
  if not span.is_recording():
298
397
  return
299
398
 
300
- events = self.instrumentation_settings.messages_to_otel_events(messages)
301
- for event in self.instrumentation_settings.messages_to_otel_events([response]):
302
- events.append(
303
- Event(
304
- 'gen_ai.choice',
305
- body={
306
- # TODO finish_reason
307
- 'index': 0,
308
- 'message': event.body,
309
- },
310
- )
311
- )
399
+ self.instrumentation_settings.handle_messages(messages, response, system, span)
312
400
  span.set_attributes(
313
401
  {
314
402
  **response.usage.opentelemetry_attributes(),
@@ -316,12 +404,6 @@ class InstrumentedModel(WrapperModel):
316
404
  }
317
405
  )
318
406
  span.update_name(f'{operation} {request_model}')
319
- for event in events:
320
- event.attributes = {
321
- GEN_AI_SYSTEM_ATTRIBUTE: system,
322
- **(event.attributes or {}),
323
- }
324
- self._emit_events(span, events)
325
407
 
326
408
  yield finish
327
409
  finally:
@@ -330,27 +412,6 @@ class InstrumentedModel(WrapperModel):
330
412
  # to prevent them from being redundantly recorded in the span itself by logfire.
331
413
  record_metrics()
332
414
 
333
- def _emit_events(self, span: Span, events: list[Event]) -> None:
334
- if self.instrumentation_settings.event_mode == 'logs':
335
- for event in events:
336
- self.instrumentation_settings.event_logger.emit(event)
337
- else:
338
- attr_name = 'events'
339
- span.set_attributes(
340
- {
341
- attr_name: json.dumps([self.event_to_dict(event) for event in events]),
342
- 'logfire.json_schema': json.dumps(
343
- {
344
- 'type': 'object',
345
- 'properties': {
346
- attr_name: {'type': 'array'},
347
- 'model_request_parameters': {'type': 'object'},
348
- },
349
- }
350
- ),
351
- }
352
- )
353
-
354
415
  @staticmethod
355
416
  def model_attributes(model: Model):
356
417
  attributes: dict[str, AttributeValue] = {
@@ -353,6 +353,7 @@ class MistralModel(Model):
353
353
  model_name=response.model,
354
354
  timestamp=timestamp,
355
355
  provider_request_id=response.id,
356
+ provider_name=self._provider.name,
356
357
  )
357
358
 
358
359
  async def _process_streamed_response(
@@ -378,6 +379,7 @@ class MistralModel(Model):
378
379
  _response=peekable_response,
379
380
  _model_name=self._model_name,
380
381
  _timestamp=timestamp,
382
+ _provider_name=self._provider.name,
381
383
  )
382
384
 
383
385
  @staticmethod
@@ -584,6 +586,7 @@ class MistralStreamedResponse(StreamedResponse):
584
586
  _model_name: MistralModelName
585
587
  _response: AsyncIterable[MistralCompletionEvent]
586
588
  _timestamp: datetime
589
+ _provider_name: str
587
590
 
588
591
  _delta_content: str = field(default='', init=False)
589
592
 
@@ -631,6 +634,11 @@ class MistralStreamedResponse(StreamedResponse):
631
634
  """Get the model name of the response."""
632
635
  return self._model_name
633
636
 
637
+ @property
638
+ def provider_name(self) -> str:
639
+ """Get the provider name."""
640
+ return self._provider_name
641
+
634
642
  @property
635
643
  def timestamp(self) -> datetime:
636
644
  """Get the timestamp of the response."""