pydantic-ai-slim 0.8.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +84 -17
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +70 -17
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +6 -3
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +156 -44
  31. pydantic_ai/models/__init__.py +20 -7
  32. pydantic_ai/models/anthropic.py +10 -17
  33. pydantic_ai/models/bedrock.py +55 -57
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +13 -14
  38. pydantic_ai/models/google.py +19 -5
  39. pydantic_ai/models/groq.py +127 -39
  40. pydantic_ai/models/huggingface.py +5 -5
  41. pydantic_ai/models/instrumented.py +49 -21
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +8 -8
  44. pydantic_ai/models/openai.py +37 -42
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +173 -52
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.0.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.0.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,10 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
 
10
+ from pydantic import BaseModel, Json, ValidationError
10
11
  from typing_extensions import assert_never
11
12
 
13
+ from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
14
+
12
15
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
16
  from .._run_context import RunContext
14
17
  from .._thinking_part import split_content_into_text_and_thinking
@@ -48,7 +51,7 @@ from . import (
48
51
  )
49
52
 
50
53
  try:
51
- from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
54
+ from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
52
55
  from groq.types import chat
53
56
  from groq.types.chat.chat_completion_content_part_image_param import ImageURL
54
57
  except ImportError as _import_error:
@@ -88,7 +91,7 @@ PreviewGroqModelNames = Literal[
88
91
  ]
89
92
  """Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
90
93
 
91
- GroqModelName = Union[str, ProductionGroqModelNames, PreviewGroqModelNames]
94
+ GroqModelName = str | ProductionGroqModelNames | PreviewGroqModelNames
92
95
  """Possible Groq model names.
93
96
 
94
97
  Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31
@@ -169,9 +172,24 @@ class GroqModel(Model):
169
172
  model_request_parameters: ModelRequestParameters,
170
173
  ) -> ModelResponse:
171
174
  check_allow_model_requests()
172
- response = await self._completions_create(
173
- messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
174
- )
175
+ try:
176
+ response = await self._completions_create(
177
+ messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
178
+ )
179
+ except ModelHTTPError as e:
180
+ if isinstance(e.body, dict): # pragma: no branch
181
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
182
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
183
+ try:
184
+ error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
185
+ tool_call_part = ToolCallPart(
186
+ tool_name=error.error.failed_generation.name,
187
+ args=error.error.failed_generation.arguments,
188
+ )
189
+ return ModelResponse(parts=[tool_call_part])
190
+ except ValidationError:
191
+ pass
192
+ raise
175
193
  model_response = self._process_response(response)
176
194
  return model_response
177
195
 
@@ -228,6 +246,18 @@ class GroqModel(Model):
228
246
 
229
247
  groq_messages = self._map_messages(messages)
230
248
 
249
+ response_format: chat.completion_create_params.ResponseFormat | None = None
250
+ if model_request_parameters.output_mode == 'native':
251
+ output_object = model_request_parameters.output_object
252
+ assert output_object is not None
253
+ response_format = self._map_json_schema(output_object)
254
+ elif (
255
+ model_request_parameters.output_mode == 'prompted'
256
+ and not tools
257
+ and self.profile.supports_json_object_output
258
+ ): # pragma: no branch
259
+ response_format = {'type': 'json_object'}
260
+
231
261
  try:
232
262
  extra_headers = model_settings.get('extra_headers', {})
233
263
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -240,6 +270,7 @@ class GroqModel(Model):
240
270
  tool_choice=tool_choice or NOT_GIVEN,
241
271
  stop=model_settings.get('stop_sequences', NOT_GIVEN),
242
272
  stream=stream,
273
+ response_format=response_format or NOT_GIVEN,
243
274
  max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
244
275
  temperature=model_settings.get('temperature', NOT_GIVEN),
245
276
  top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -285,11 +316,11 @@ class GroqModel(Model):
285
316
  for c in choice.message.tool_calls:
286
317
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
287
318
  return ModelResponse(
288
- items,
319
+ parts=items,
289
320
  usage=_map_usage(response),
290
321
  model_name=response.model,
291
322
  timestamp=timestamp,
292
- provider_request_id=response.id,
323
+ provider_response_id=response.id,
293
324
  provider_name=self._provider.name,
294
325
  )
295
326
 
@@ -347,7 +378,7 @@ class GroqModel(Model):
347
378
  elif isinstance(item, ThinkingPart):
348
379
  # Skip thinking parts when mapping to Groq messages
349
380
  continue
350
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
381
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
351
382
  # This is currently never returned from groq
352
383
  pass
353
384
  else:
@@ -385,6 +416,19 @@ class GroqModel(Model):
385
416
  },
386
417
  }
387
418
 
419
+ def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
420
+ response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
421
+ 'type': 'json_schema',
422
+ 'json_schema': {
423
+ 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
424
+ 'schema': o.json_schema,
425
+ 'strict': o.strict,
426
+ },
427
+ }
428
+ if o.description: # pragma: no branch
429
+ response_format_param['json_schema']['description'] = o.description
430
+ return response_format_param
431
+
388
432
  @classmethod
389
433
  def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
390
434
  for part in message.parts:
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
449
493
  _provider_name: str
450
494
 
451
495
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
452
- async for chunk in self._response:
453
- self._usage += _map_usage(chunk)
454
-
455
- try:
456
- choice = chunk.choices[0]
457
- except IndexError:
458
- continue
459
-
460
- # Handle the text part of the response
461
- content = choice.delta.content
462
- if content is not None:
463
- maybe_event = self._parts_manager.handle_text_delta(
464
- vendor_part_id='content',
465
- content=content,
466
- thinking_tags=self._model_profile.thinking_tags,
467
- ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
468
- )
469
- if maybe_event is not None: # pragma: no branch
470
- yield maybe_event
471
-
472
- # Handle the tool calls
473
- for dtc in choice.delta.tool_calls or []:
474
- maybe_event = self._parts_manager.handle_tool_call_delta(
475
- vendor_part_id=dtc.index,
476
- tool_name=dtc.function and dtc.function.name,
477
- args=dtc.function and dtc.function.arguments,
478
- tool_call_id=dtc.id,
479
- )
480
- if maybe_event is not None:
481
- yield maybe_event
496
+ try:
497
+ async for chunk in self._response:
498
+ self._usage += _map_usage(chunk)
499
+
500
+ try:
501
+ choice = chunk.choices[0]
502
+ except IndexError:
503
+ continue
504
+
505
+ # Handle the text part of the response
506
+ content = choice.delta.content
507
+ if content is not None:
508
+ maybe_event = self._parts_manager.handle_text_delta(
509
+ vendor_part_id='content',
510
+ content=content,
511
+ thinking_tags=self._model_profile.thinking_tags,
512
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
513
+ )
514
+ if maybe_event is not None: # pragma: no branch
515
+ yield maybe_event
516
+
517
+ # Handle the tool calls
518
+ for dtc in choice.delta.tool_calls or []:
519
+ maybe_event = self._parts_manager.handle_tool_call_delta(
520
+ vendor_part_id=dtc.index,
521
+ tool_name=dtc.function and dtc.function.name,
522
+ args=dtc.function and dtc.function.arguments,
523
+ tool_call_id=dtc.id,
524
+ )
525
+ if maybe_event is not None:
526
+ yield maybe_event
527
+ except APIError as e:
528
+ if isinstance(e.body, dict): # pragma: no branch
529
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
530
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call
531
+ try:
532
+ error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
533
+ yield self._parts_manager.handle_tool_call_part(
534
+ vendor_part_id='tool_use_failed',
535
+ tool_name=error.failed_generation.name,
536
+ args=error.failed_generation.arguments,
537
+ )
538
+ return
539
+ except ValidationError as e: # pragma: no cover
540
+ pass
541
+ raise # pragma: no cover
482
542
 
483
543
  @property
484
544
  def model_name(self) -> GroqModelName:
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
510
570
  input_tokens=response_usage.prompt_tokens,
511
571
  output_tokens=response_usage.completion_tokens,
512
572
  )
573
+
574
+
575
+ class _GroqToolUseFailedGeneration(BaseModel):
576
+ name: str
577
+ arguments: dict[str, Any]
578
+
579
+
580
+ class _GroqToolUseFailedInnerError(BaseModel):
581
+ message: str
582
+ type: Literal['invalid_request_error']
583
+ code: Literal['tool_use_failed']
584
+ failed_generation: Json[_GroqToolUseFailedGeneration]
585
+
586
+
587
+ class _GroqToolUseFailedError(BaseModel):
588
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
589
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
590
+ # Example payload from `exception.body`:
591
+ # {
592
+ # 'error': {
593
+ # 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
594
+ # 'type': 'invalid_request_error',
595
+ # 'code': 'tool_use_failed',
596
+ # 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
597
+ # }
598
+ # }
599
+
600
+ error: _GroqToolUseFailedInnerError
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
@@ -88,7 +88,7 @@ LatestHuggingFaceModelNames = Literal[
88
88
  """Latest Hugging Face models."""
89
89
 
90
90
 
91
- HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
91
+ HuggingFaceModelName = str | LatestHuggingFaceModelNames
92
92
  """Possible Hugging Face model names.
93
93
 
94
94
  You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
@@ -267,11 +267,11 @@ class HuggingFaceModel(Model):
267
267
  for c in tool_calls:
268
268
  items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
269
269
  return ModelResponse(
270
- items,
270
+ parts=items,
271
271
  usage=_map_usage(response),
272
272
  model_name=response.model,
273
273
  timestamp=timestamp,
274
- provider_request_id=response.id,
274
+ provider_response_id=response.id,
275
275
  provider_name=self._provider.name,
276
276
  )
277
277
 
@@ -320,7 +320,7 @@ class HuggingFaceModel(Model):
320
320
  # please open an issue. The below code is the code to send thinking to the provider.
321
321
  # texts.append(f'<think>\n{item.content}\n</think>')
322
322
  pass
323
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
323
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
324
324
  # This is currently never returned from huggingface
325
325
  pass
326
326
  else:
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import itertools
4
4
  import json
5
- from collections.abc import AsyncIterator, Iterator, Mapping
5
+ import warnings
6
+ from collections.abc import AsyncIterator, Callable, Iterator, Mapping
6
7
  from contextlib import asynccontextmanager, contextmanager
7
8
  from dataclasses import dataclass, field
8
- from typing import Any, Callable, Literal, cast
9
+ from typing import Any, Literal, cast
9
10
  from urllib.parse import urlparse
10
11
 
11
12
  from opentelemetry._events import (
@@ -93,36 +94,41 @@ class InstrumentationSettings:
93
94
  def __init__(
94
95
  self,
95
96
  *,
96
- event_mode: Literal['attributes', 'logs'] = 'attributes',
97
97
  tracer_provider: TracerProvider | None = None,
98
98
  meter_provider: MeterProvider | None = None,
99
- event_logger_provider: EventLoggerProvider | None = None,
100
99
  include_binary_content: bool = True,
101
100
  include_content: bool = True,
102
- version: Literal[1, 2] = 1,
101
+ version: Literal[1, 2] = 2,
102
+ event_mode: Literal['attributes', 'logs'] = 'attributes',
103
+ event_logger_provider: EventLoggerProvider | None = None,
103
104
  ):
104
105
  """Create instrumentation options.
105
106
 
106
107
  Args:
107
- event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
108
- If `'logs'`, events are emitted as OpenTelemetry log-based events.
109
108
  tracer_provider: The OpenTelemetry tracer provider to use.
110
109
  If not provided, the global tracer provider is used.
111
110
  Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
112
111
  meter_provider: The OpenTelemetry meter provider to use.
113
112
  If not provided, the global meter provider is used.
114
113
  Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
115
- event_logger_provider: The OpenTelemetry event logger provider to use.
116
- If not provided, the global event logger provider is used.
117
- Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
118
- This is only used if `event_mode='logs'`.
119
114
  include_binary_content: Whether to include binary content in the instrumentation events.
120
115
  include_content: Whether to include prompts, completions, and tool call arguments and responses
121
116
  in the instrumentation events.
122
- version: Version of the data format.
123
- Version 1 is based on the legacy event-based OpenTelemetry GenAI spec.
124
- Version 2 stores messages in the attributes `gen_ai.input.messages` and `gen_ai.output.messages`.
125
- Version 2 is still WIP and experimental, but will become the default in Pydantic AI v1.
117
+ version: Version of the data format. This is unrelated to the Pydantic AI package version.
118
+ Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
119
+ and will be removed in a future release.
120
+ The parameters `event_mode` and `event_logger_provider` are only relevant for version 1.
121
+ Version 2 uses the newer OpenTelemetry GenAI spec and stores messages in the following attributes:
122
+ - `gen_ai.system_instructions` for instructions passed to the agent.
123
+ - `gen_ai.input.messages` and `gen_ai.output.messages` on model request spans.
124
+ - `pydantic_ai.all_messages` on agent run spans.
125
+ event_mode: The mode for emitting events in version 1.
126
+ If `'attributes'`, events are attached to the span as attributes.
127
+ If `'logs'`, events are emitted as OpenTelemetry log-based events.
128
+ event_logger_provider: The OpenTelemetry event logger provider to use.
129
+ If not provided, the global event logger provider is used.
130
+ Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
131
+ This is only used if `event_mode='logs'` and `version=1`.
126
132
  """
127
133
  from pydantic_ai import __version__
128
134
 
@@ -136,6 +142,14 @@ class InstrumentationSettings:
136
142
  self.event_mode = event_mode
137
143
  self.include_binary_content = include_binary_content
138
144
  self.include_content = include_content
145
+
146
+ if event_mode == 'logs' and version != 1:
147
+ warnings.warn(
148
+ 'event_mode is only relevant for version=1 which is deprecated and will be removed in a future release.',
149
+ stacklevel=2,
150
+ )
151
+ version = 1
152
+
139
153
  self.version = version
140
154
 
141
155
  # As specified in the OpenTelemetry GenAI metrics spec:
@@ -236,27 +250,36 @@ class InstrumentationSettings:
236
250
  if response.provider_details and 'finish_reason' in response.provider_details:
237
251
  output_message['finish_reason'] = response.provider_details['finish_reason']
238
252
  instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
253
+ system_instructions_attributes = self.system_instructions_attributes(instructions)
239
254
  attributes = {
240
255
  'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
241
256
  'gen_ai.output.messages': json.dumps([output_message]),
257
+ **system_instructions_attributes,
242
258
  'logfire.json_schema': json.dumps(
243
259
  {
244
260
  'type': 'object',
245
261
  'properties': {
246
262
  'gen_ai.input.messages': {'type': 'array'},
247
263
  'gen_ai.output.messages': {'type': 'array'},
248
- **({'gen_ai.system_instructions': {'type': 'array'}} if instructions else {}),
264
+ **(
265
+ {'gen_ai.system_instructions': {'type': 'array'}}
266
+ if system_instructions_attributes
267
+ else {}
268
+ ),
249
269
  'model_request_parameters': {'type': 'object'},
250
270
  },
251
271
  }
252
272
  ),
253
273
  }
254
- if instructions is not None:
255
- attributes['gen_ai.system_instructions'] = json.dumps(
256
- [_otel_messages.TextPart(type='text', content=instructions)]
257
- )
258
274
  span.set_attributes(attributes)
259
275
 
276
+ def system_instructions_attributes(self, instructions: str | None) -> dict[str, str]:
277
+ if instructions and self.include_content:
278
+ return {
279
+ 'gen_ai.system_instructions': json.dumps([_otel_messages.TextPart(type='text', content=instructions)]),
280
+ }
281
+ return {}
282
+
260
283
  def _emit_events(self, span: Span, events: list[Event]) -> None:
261
284
  if self.event_mode == 'logs':
262
285
  for event in events:
@@ -357,7 +380,7 @@ class InstrumentedModel(WrapperModel):
357
380
 
358
381
  if model_settings:
359
382
  for key in MODEL_SETTING_ATTRIBUTES:
360
- if isinstance(value := model_settings.get(key), (float, int)):
383
+ if isinstance(value := model_settings.get(key), float | int):
361
384
  attributes[f'gen_ai.request.{key}'] = value
362
385
 
363
386
  record_metrics: Callable[[], None] | None = None
@@ -397,10 +420,15 @@ class InstrumentedModel(WrapperModel):
397
420
  return
398
421
 
399
422
  self.instrumentation_settings.handle_messages(messages, response, system, span)
423
+ try:
424
+ cost_attributes = {'operation.cost': float(response.cost().total_price)}
425
+ except LookupError:
426
+ cost_attributes = {}
400
427
  span.set_attributes(
401
428
  {
402
429
  **response.usage.opentelemetry_attributes(),
403
430
  'gen_ai.response.model': response_model,
431
+ **cost_attributes,
404
432
  }
405
433
  )
406
434
  span.update_name(f'{operation} {request_model}')
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from collections.abc import AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
- from dataclasses import dataclass
5
+ from dataclasses import KW_ONLY, dataclass
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  from .. import _mcp, exceptions
@@ -36,6 +36,8 @@ class MCPSamplingModel(Model):
36
36
  session: ServerSession
37
37
  """The MCP server session to use for sampling."""
38
38
 
39
+ _: KW_ONLY
40
+
39
41
  default_max_tokens: int = 16_384
40
42
  """Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
41
43
 
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast
8
+ from typing import Any, Literal, cast
9
9
 
10
10
  import pydantic_core
11
11
  from httpx import Timeout
@@ -79,7 +79,7 @@ try:
79
79
  from mistralai.models.usermessage import UserMessage as MistralUserMessage
80
80
  from mistralai.types.basemodel import Unset as MistralUnset
81
81
  from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
82
- except ImportError as e: # pragma: no cover
82
+ except ImportError as e:
83
83
  raise ImportError(
84
84
  'Please install `mistral` to use the Mistral model, '
85
85
  'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
@@ -90,7 +90,7 @@ LatestMistralModelNames = Literal[
90
90
  ]
91
91
  """Latest Mistral models."""
92
92
 
93
- MistralModelName = Union[str, LatestMistralModelNames]
93
+ MistralModelName = str | LatestMistralModelNames
94
94
  """Possible Mistral model names.
95
95
 
96
96
  Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
@@ -117,7 +117,7 @@ class MistralModel(Model):
117
117
  """
118
118
 
119
119
  client: Mistral = field(repr=False)
120
- json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
120
+ json_mode_schema_prompt: str
121
121
 
122
122
  _model_name: MistralModelName = field(repr=False)
123
123
  _provider: Provider[Mistral] = field(repr=False)
@@ -348,11 +348,11 @@ class MistralModel(Model):
348
348
  parts.append(tool)
349
349
 
350
350
  return ModelResponse(
351
- parts,
351
+ parts=parts,
352
352
  usage=_map_usage(response),
353
353
  model_name=response.model,
354
354
  timestamp=timestamp,
355
- provider_request_id=response.id,
355
+ provider_response_id=response.id,
356
356
  provider_name=self._provider.name,
357
357
  )
358
358
 
@@ -515,7 +515,7 @@ class MistralModel(Model):
515
515
  pass
516
516
  elif isinstance(part, ToolCallPart):
517
517
  tool_calls.append(self._map_tool_call(part))
518
- elif isinstance(part, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
518
+ elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
519
519
  # This is currently never returned from mistral
520
520
  pass
521
521
  else:
@@ -576,7 +576,7 @@ class MistralModel(Model):
576
576
  return MistralUserMessage(content=content)
577
577
 
578
578
 
579
- MistralToolCallId = Union[str, None]
579
+ MistralToolCallId = str | None
580
580
 
581
581
 
582
582
  @dataclass