pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (75) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_a2a.py +1 -1
  3. pydantic_ai/_agent_graph.py +323 -156
  4. pydantic_ai/_function_schema.py +5 -5
  5. pydantic_ai/_griffe.py +2 -1
  6. pydantic_ai/_otel_messages.py +2 -2
  7. pydantic_ai/_output.py +31 -35
  8. pydantic_ai/_parts_manager.py +7 -5
  9. pydantic_ai/_run_context.py +3 -1
  10. pydantic_ai/_system_prompt.py +2 -2
  11. pydantic_ai/_tool_manager.py +32 -28
  12. pydantic_ai/_utils.py +14 -26
  13. pydantic_ai/ag_ui.py +82 -51
  14. pydantic_ai/agent/__init__.py +70 -9
  15. pydantic_ai/agent/abstract.py +35 -4
  16. pydantic_ai/agent/wrapper.py +6 -0
  17. pydantic_ai/builtin_tools.py +2 -2
  18. pydantic_ai/common_tools/duckduckgo.py +4 -2
  19. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  20. pydantic_ai/durable_exec/temporal/_agent.py +93 -11
  21. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  22. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  23. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  24. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  25. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  26. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  27. pydantic_ai/exceptions.py +45 -2
  28. pydantic_ai/format_prompt.py +2 -2
  29. pydantic_ai/mcp.py +15 -27
  30. pydantic_ai/messages.py +149 -42
  31. pydantic_ai/models/__init__.py +6 -4
  32. pydantic_ai/models/anthropic.py +9 -16
  33. pydantic_ai/models/bedrock.py +50 -56
  34. pydantic_ai/models/cohere.py +3 -3
  35. pydantic_ai/models/fallback.py +2 -2
  36. pydantic_ai/models/function.py +25 -23
  37. pydantic_ai/models/gemini.py +12 -13
  38. pydantic_ai/models/google.py +18 -4
  39. pydantic_ai/models/groq.py +126 -38
  40. pydantic_ai/models/huggingface.py +4 -4
  41. pydantic_ai/models/instrumented.py +35 -16
  42. pydantic_ai/models/mcp_sampling.py +3 -1
  43. pydantic_ai/models/mistral.py +6 -6
  44. pydantic_ai/models/openai.py +35 -40
  45. pydantic_ai/models/test.py +24 -4
  46. pydantic_ai/output.py +27 -32
  47. pydantic_ai/profiles/__init__.py +3 -3
  48. pydantic_ai/profiles/groq.py +1 -1
  49. pydantic_ai/profiles/openai.py +25 -4
  50. pydantic_ai/providers/__init__.py +4 -0
  51. pydantic_ai/providers/anthropic.py +2 -3
  52. pydantic_ai/providers/bedrock.py +3 -2
  53. pydantic_ai/providers/google_vertex.py +2 -1
  54. pydantic_ai/providers/groq.py +21 -2
  55. pydantic_ai/providers/litellm.py +134 -0
  56. pydantic_ai/result.py +144 -41
  57. pydantic_ai/retries.py +52 -31
  58. pydantic_ai/run.py +12 -5
  59. pydantic_ai/tools.py +127 -23
  60. pydantic_ai/toolsets/__init__.py +4 -1
  61. pydantic_ai/toolsets/_dynamic.py +4 -4
  62. pydantic_ai/toolsets/abstract.py +18 -2
  63. pydantic_ai/toolsets/approval_required.py +32 -0
  64. pydantic_ai/toolsets/combined.py +7 -12
  65. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  66. pydantic_ai/toolsets/filtered.py +1 -1
  67. pydantic_ai/toolsets/function.py +58 -21
  68. pydantic_ai/toolsets/wrapper.py +2 -1
  69. pydantic_ai/usage.py +44 -8
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/METADATA +8 -9
  71. pydantic_ai_slim-1.0.0.dist-info/RECORD +121 -0
  72. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  73. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/WHEEL +0 -0
  74. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -5,10 +5,13 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
 
10
+ from pydantic import BaseModel, Json, ValidationError
10
11
  from typing_extensions import assert_never
11
12
 
13
+ from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
14
+
12
15
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
16
  from .._run_context import RunContext
14
17
  from .._thinking_part import split_content_into_text_and_thinking
@@ -48,7 +51,7 @@ from . import (
48
51
  )
49
52
 
50
53
  try:
51
- from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
54
+ from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
52
55
  from groq.types import chat
53
56
  from groq.types.chat.chat_completion_content_part_image_param import ImageURL
54
57
  except ImportError as _import_error:
@@ -88,7 +91,7 @@ PreviewGroqModelNames = Literal[
88
91
  ]
89
92
  """Preview Groq models from <https://console.groq.com/docs/models#preview-models>."""
90
93
 
91
- GroqModelName = Union[str, ProductionGroqModelNames, PreviewGroqModelNames]
94
+ GroqModelName = str | ProductionGroqModelNames | PreviewGroqModelNames
92
95
  """Possible Groq model names.
93
96
 
94
97
  Since Groq supports a variety of models and the list changes frequencly, we explicitly list the named models as of 2025-03-31
@@ -169,9 +172,24 @@ class GroqModel(Model):
169
172
  model_request_parameters: ModelRequestParameters,
170
173
  ) -> ModelResponse:
171
174
  check_allow_model_requests()
172
- response = await self._completions_create(
173
- messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
174
- )
175
+ try:
176
+ response = await self._completions_create(
177
+ messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
178
+ )
179
+ except ModelHTTPError as e:
180
+ if isinstance(e.body, dict): # pragma: no branch
181
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
182
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
183
+ try:
184
+ error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
185
+ tool_call_part = ToolCallPart(
186
+ tool_name=error.error.failed_generation.name,
187
+ args=error.error.failed_generation.arguments,
188
+ )
189
+ return ModelResponse(parts=[tool_call_part])
190
+ except ValidationError:
191
+ pass
192
+ raise
175
193
  model_response = self._process_response(response)
176
194
  return model_response
177
195
 
@@ -228,6 +246,18 @@ class GroqModel(Model):
228
246
 
229
247
  groq_messages = self._map_messages(messages)
230
248
 
249
+ response_format: chat.completion_create_params.ResponseFormat | None = None
250
+ if model_request_parameters.output_mode == 'native':
251
+ output_object = model_request_parameters.output_object
252
+ assert output_object is not None
253
+ response_format = self._map_json_schema(output_object)
254
+ elif (
255
+ model_request_parameters.output_mode == 'prompted'
256
+ and not tools
257
+ and self.profile.supports_json_object_output
258
+ ): # pragma: no branch
259
+ response_format = {'type': 'json_object'}
260
+
231
261
  try:
232
262
  extra_headers = model_settings.get('extra_headers', {})
233
263
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -240,6 +270,7 @@ class GroqModel(Model):
240
270
  tool_choice=tool_choice or NOT_GIVEN,
241
271
  stop=model_settings.get('stop_sequences', NOT_GIVEN),
242
272
  stream=stream,
273
+ response_format=response_format or NOT_GIVEN,
243
274
  max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
244
275
  temperature=model_settings.get('temperature', NOT_GIVEN),
245
276
  top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -285,7 +316,7 @@ class GroqModel(Model):
285
316
  for c in choice.message.tool_calls:
286
317
  items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
287
318
  return ModelResponse(
288
- items,
319
+ parts=items,
289
320
  usage=_map_usage(response),
290
321
  model_name=response.model,
291
322
  timestamp=timestamp,
@@ -347,7 +378,7 @@ class GroqModel(Model):
347
378
  elif isinstance(item, ThinkingPart):
348
379
  # Skip thinking parts when mapping to Groq messages
349
380
  continue
350
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
381
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
351
382
  # This is currently never returned from groq
352
383
  pass
353
384
  else:
@@ -385,6 +416,19 @@ class GroqModel(Model):
385
416
  },
386
417
  }
387
418
 
419
+ def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
420
+ response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
421
+ 'type': 'json_schema',
422
+ 'json_schema': {
423
+ 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
424
+ 'schema': o.json_schema,
425
+ 'strict': o.strict,
426
+ },
427
+ }
428
+ if o.description: # pragma: no branch
429
+ response_format_param['json_schema']['description'] = o.description
430
+ return response_format_param
431
+
388
432
  @classmethod
389
433
  def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
390
434
  for part in message.parts:
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
449
493
  _provider_name: str
450
494
 
451
495
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
452
- async for chunk in self._response:
453
- self._usage += _map_usage(chunk)
454
-
455
- try:
456
- choice = chunk.choices[0]
457
- except IndexError:
458
- continue
459
-
460
- # Handle the text part of the response
461
- content = choice.delta.content
462
- if content is not None:
463
- maybe_event = self._parts_manager.handle_text_delta(
464
- vendor_part_id='content',
465
- content=content,
466
- thinking_tags=self._model_profile.thinking_tags,
467
- ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
468
- )
469
- if maybe_event is not None: # pragma: no branch
470
- yield maybe_event
471
-
472
- # Handle the tool calls
473
- for dtc in choice.delta.tool_calls or []:
474
- maybe_event = self._parts_manager.handle_tool_call_delta(
475
- vendor_part_id=dtc.index,
476
- tool_name=dtc.function and dtc.function.name,
477
- args=dtc.function and dtc.function.arguments,
478
- tool_call_id=dtc.id,
479
- )
480
- if maybe_event is not None:
481
- yield maybe_event
496
+ try:
497
+ async for chunk in self._response:
498
+ self._usage += _map_usage(chunk)
499
+
500
+ try:
501
+ choice = chunk.choices[0]
502
+ except IndexError:
503
+ continue
504
+
505
+ # Handle the text part of the response
506
+ content = choice.delta.content
507
+ if content is not None:
508
+ maybe_event = self._parts_manager.handle_text_delta(
509
+ vendor_part_id='content',
510
+ content=content,
511
+ thinking_tags=self._model_profile.thinking_tags,
512
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
513
+ )
514
+ if maybe_event is not None: # pragma: no branch
515
+ yield maybe_event
516
+
517
+ # Handle the tool calls
518
+ for dtc in choice.delta.tool_calls or []:
519
+ maybe_event = self._parts_manager.handle_tool_call_delta(
520
+ vendor_part_id=dtc.index,
521
+ tool_name=dtc.function and dtc.function.name,
522
+ args=dtc.function and dtc.function.arguments,
523
+ tool_call_id=dtc.id,
524
+ )
525
+ if maybe_event is not None:
526
+ yield maybe_event
527
+ except APIError as e:
528
+ if isinstance(e.body, dict): # pragma: no branch
529
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
530
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call
531
+ try:
532
+ error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
533
+ yield self._parts_manager.handle_tool_call_part(
534
+ vendor_part_id='tool_use_failed',
535
+ tool_name=error.failed_generation.name,
536
+ args=error.failed_generation.arguments,
537
+ )
538
+ return
539
+ except ValidationError as e: # pragma: no cover
540
+ pass
541
+ raise # pragma: no cover
482
542
 
483
543
  @property
484
544
  def model_name(self) -> GroqModelName:
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
510
570
  input_tokens=response_usage.prompt_tokens,
511
571
  output_tokens=response_usage.completion_tokens,
512
572
  )
573
+
574
+
575
+ class _GroqToolUseFailedGeneration(BaseModel):
576
+ name: str
577
+ arguments: dict[str, Any]
578
+
579
+
580
+ class _GroqToolUseFailedInnerError(BaseModel):
581
+ message: str
582
+ type: Literal['invalid_request_error']
583
+ code: Literal['tool_use_failed']
584
+ failed_generation: Json[_GroqToolUseFailedGeneration]
585
+
586
+
587
+ class _GroqToolUseFailedError(BaseModel):
588
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
589
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
590
+ # Example payload from `exception.body`:
591
+ # {
592
+ # 'error': {
593
+ # 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
594
+ # 'type': 'invalid_request_error',
595
+ # 'code': 'tool_use_failed',
596
+ # 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
597
+ # }
598
+ # }
599
+
600
+ error: _GroqToolUseFailedInnerError
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from typing import Any, Literal, Union, cast, overload
8
+ from typing import Any, Literal, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
@@ -88,7 +88,7 @@ LatestHuggingFaceModelNames = Literal[
88
88
  """Latest Hugging Face models."""
89
89
 
90
90
 
91
- HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
91
+ HuggingFaceModelName = str | LatestHuggingFaceModelNames
92
92
  """Possible Hugging Face model names.
93
93
 
94
94
  You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
@@ -267,7 +267,7 @@ class HuggingFaceModel(Model):
267
267
  for c in tool_calls:
268
268
  items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
269
269
  return ModelResponse(
270
- items,
270
+ parts=items,
271
271
  usage=_map_usage(response),
272
272
  model_name=response.model,
273
273
  timestamp=timestamp,
@@ -320,7 +320,7 @@ class HuggingFaceModel(Model):
320
320
  # please open an issue. The below code is the code to send thinking to the provider.
321
321
  # texts.append(f'<think>\n{item.content}\n</think>')
322
322
  pass
323
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
323
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
324
324
  # This is currently never returned from huggingface
325
325
  pass
326
326
  else:
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import itertools
4
4
  import json
5
- from collections.abc import AsyncIterator, Iterator, Mapping
5
+ import warnings
6
+ from collections.abc import AsyncIterator, Callable, Iterator, Mapping
6
7
  from contextlib import asynccontextmanager, contextmanager
7
8
  from dataclasses import dataclass, field
8
- from typing import Any, Callable, Literal, cast
9
+ from typing import Any, Literal, cast
9
10
  from urllib.parse import urlparse
10
11
 
11
12
  from opentelemetry._events import (
@@ -93,36 +94,41 @@ class InstrumentationSettings:
93
94
  def __init__(
94
95
  self,
95
96
  *,
96
- event_mode: Literal['attributes', 'logs'] = 'attributes',
97
97
  tracer_provider: TracerProvider | None = None,
98
98
  meter_provider: MeterProvider | None = None,
99
- event_logger_provider: EventLoggerProvider | None = None,
100
99
  include_binary_content: bool = True,
101
100
  include_content: bool = True,
102
- version: Literal[1, 2] = 1,
101
+ version: Literal[1, 2] = 2,
102
+ event_mode: Literal['attributes', 'logs'] = 'attributes',
103
+ event_logger_provider: EventLoggerProvider | None = None,
103
104
  ):
104
105
  """Create instrumentation options.
105
106
 
106
107
  Args:
107
- event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
108
- If `'logs'`, events are emitted as OpenTelemetry log-based events.
109
108
  tracer_provider: The OpenTelemetry tracer provider to use.
110
109
  If not provided, the global tracer provider is used.
111
110
  Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
112
111
  meter_provider: The OpenTelemetry meter provider to use.
113
112
  If not provided, the global meter provider is used.
114
113
  Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
115
- event_logger_provider: The OpenTelemetry event logger provider to use.
116
- If not provided, the global event logger provider is used.
117
- Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
118
- This is only used if `event_mode='logs'`.
119
114
  include_binary_content: Whether to include binary content in the instrumentation events.
120
115
  include_content: Whether to include prompts, completions, and tool call arguments and responses
121
116
  in the instrumentation events.
122
- version: Version of the data format.
123
- Version 1 is based on the legacy event-based OpenTelemetry GenAI spec.
124
- Version 2 stores messages in the attributes `gen_ai.input.messages` and `gen_ai.output.messages`.
125
- Version 2 is still WIP and experimental, but will become the default in Pydantic AI v1.
117
+ version: Version of the data format. This is unrelated to the Pydantic AI package version.
118
+ Version 1 is based on the legacy event-based OpenTelemetry GenAI spec
119
+ and will be removed in a future release.
120
+ The parameters `event_mode` and `event_logger_provider` are only relevant for version 1.
121
+ Version 2 uses the newer OpenTelemetry GenAI spec and stores messages in the following attributes:
122
+ - `gen_ai.system_instructions` for instructions passed to the agent.
123
+ - `gen_ai.input.messages` and `gen_ai.output.messages` on model request spans.
124
+ - `pydantic_ai.all_messages` on agent run spans.
125
+ event_mode: The mode for emitting events in version 1.
126
+ If `'attributes'`, events are attached to the span as attributes.
127
+ If `'logs'`, events are emitted as OpenTelemetry log-based events.
128
+ event_logger_provider: The OpenTelemetry event logger provider to use.
129
+ If not provided, the global event logger provider is used.
130
+ Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
131
+ This is only used if `event_mode='logs'` and `version=1`.
126
132
  """
127
133
  from pydantic_ai import __version__
128
134
 
@@ -136,6 +142,14 @@ class InstrumentationSettings:
136
142
  self.event_mode = event_mode
137
143
  self.include_binary_content = include_binary_content
138
144
  self.include_content = include_content
145
+
146
+ if event_mode == 'logs' and version != 1:
147
+ warnings.warn(
148
+ 'event_mode is only relevant for version=1 which is deprecated and will be removed in a future release.',
149
+ stacklevel=2,
150
+ )
151
+ version = 1
152
+
139
153
  self.version = version
140
154
 
141
155
  # As specified in the OpenTelemetry GenAI metrics spec:
@@ -366,7 +380,7 @@ class InstrumentedModel(WrapperModel):
366
380
 
367
381
  if model_settings:
368
382
  for key in MODEL_SETTING_ATTRIBUTES:
369
- if isinstance(value := model_settings.get(key), (float, int)):
383
+ if isinstance(value := model_settings.get(key), float | int):
370
384
  attributes[f'gen_ai.request.{key}'] = value
371
385
 
372
386
  record_metrics: Callable[[], None] | None = None
@@ -406,10 +420,15 @@ class InstrumentedModel(WrapperModel):
406
420
  return
407
421
 
408
422
  self.instrumentation_settings.handle_messages(messages, response, system, span)
423
+ try:
424
+ cost_attributes = {'operation.cost': float(response.cost().total_price)}
425
+ except LookupError:
426
+ cost_attributes = {}
409
427
  span.set_attributes(
410
428
  {
411
429
  **response.usage.opentelemetry_attributes(),
412
430
  'gen_ai.response.model': response_model,
431
+ **cost_attributes,
413
432
  }
414
433
  )
415
434
  span.update_name(f'{operation} {request_model}')
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from collections.abc import AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
- from dataclasses import dataclass
5
+ from dataclasses import KW_ONLY, dataclass
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  from .. import _mcp, exceptions
@@ -36,6 +36,8 @@ class MCPSamplingModel(Model):
36
36
  session: ServerSession
37
37
  """The MCP server session to use for sampling."""
38
38
 
39
+ _: KW_ONLY
40
+
39
41
  default_max_tokens: int = 16_384
40
42
  """Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
41
43
 
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Any, Literal, Union, cast
8
+ from typing import Any, Literal, cast
9
9
 
10
10
  import pydantic_core
11
11
  from httpx import Timeout
@@ -90,7 +90,7 @@ LatestMistralModelNames = Literal[
90
90
  ]
91
91
  """Latest Mistral models."""
92
92
 
93
- MistralModelName = Union[str, LatestMistralModelNames]
93
+ MistralModelName = str | LatestMistralModelNames
94
94
  """Possible Mistral model names.
95
95
 
96
96
  Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
@@ -117,7 +117,7 @@ class MistralModel(Model):
117
117
  """
118
118
 
119
119
  client: Mistral = field(repr=False)
120
- json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
120
+ json_mode_schema_prompt: str
121
121
 
122
122
  _model_name: MistralModelName = field(repr=False)
123
123
  _provider: Provider[Mistral] = field(repr=False)
@@ -348,7 +348,7 @@ class MistralModel(Model):
348
348
  parts.append(tool)
349
349
 
350
350
  return ModelResponse(
351
- parts,
351
+ parts=parts,
352
352
  usage=_map_usage(response),
353
353
  model_name=response.model,
354
354
  timestamp=timestamp,
@@ -515,7 +515,7 @@ class MistralModel(Model):
515
515
  pass
516
516
  elif isinstance(part, ToolCallPart):
517
517
  tool_calls.append(self._map_tool_call(part))
518
- elif isinstance(part, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
518
+ elif isinstance(part, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
519
519
  # This is currently never returned from mistral
520
520
  pass
521
521
  else:
@@ -576,7 +576,7 @@ class MistralModel(Model):
576
576
  return MistralUserMessage(content=content)
577
577
 
578
578
 
579
- MistralToolCallId = Union[str, None]
579
+ MistralToolCallId = str | None
580
580
 
581
581
 
582
582
  @dataclass
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
- from typing import Any, Literal, Union, cast, overload
9
+ from typing import Any, Literal, cast, overload
10
10
 
11
11
  from pydantic import ValidationError
12
12
  from typing_extensions import assert_never, deprecated
@@ -90,7 +90,7 @@ __all__ = (
90
90
  'OpenAIModelName',
91
91
  )
92
92
 
93
- OpenAIModelName = Union[str, AllModels]
93
+ OpenAIModelName = str | AllModels
94
94
  """
95
95
  Possible OpenAI model names.
96
96
 
@@ -225,6 +225,7 @@ class OpenAIChatModel(Model):
225
225
  'openrouter',
226
226
  'together',
227
227
  'vercel',
228
+ 'litellm',
228
229
  ]
229
230
  | Provider[AsyncOpenAI] = 'openai',
230
231
  profile: ModelProfileSpec | None = None,
@@ -252,6 +253,7 @@ class OpenAIChatModel(Model):
252
253
  'openrouter',
253
254
  'together',
254
255
  'vercel',
256
+ 'litellm',
255
257
  ]
256
258
  | Provider[AsyncOpenAI] = 'openai',
257
259
  profile: ModelProfileSpec | None = None,
@@ -278,6 +280,7 @@ class OpenAIChatModel(Model):
278
280
  'openrouter',
279
281
  'together',
280
282
  'vercel',
283
+ 'litellm',
281
284
  ]
282
285
  | Provider[AsyncOpenAI] = 'openai',
283
286
  profile: ModelProfileSpec | None = None,
@@ -409,13 +412,6 @@ class OpenAIChatModel(Model):
409
412
  for setting in unsupported_model_settings:
410
413
  model_settings.pop(setting, None)
411
414
 
412
- # TODO(Marcelo): Deprecate this in favor of `openai_unsupported_model_settings`.
413
- sampling_settings = (
414
- model_settings
415
- if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
416
- else OpenAIChatModelSettings()
417
- )
418
-
419
415
  try:
420
416
  extra_headers = model_settings.get('extra_headers', {})
421
417
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -437,13 +433,13 @@ class OpenAIChatModel(Model):
437
433
  web_search_options=web_search_options or NOT_GIVEN,
438
434
  service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
439
435
  prediction=model_settings.get('openai_prediction', NOT_GIVEN),
440
- temperature=sampling_settings.get('temperature', NOT_GIVEN),
441
- top_p=sampling_settings.get('top_p', NOT_GIVEN),
442
- presence_penalty=sampling_settings.get('presence_penalty', NOT_GIVEN),
443
- frequency_penalty=sampling_settings.get('frequency_penalty', NOT_GIVEN),
444
- logit_bias=sampling_settings.get('logit_bias', NOT_GIVEN),
445
- logprobs=sampling_settings.get('openai_logprobs', NOT_GIVEN),
446
- top_logprobs=sampling_settings.get('openai_top_logprobs', NOT_GIVEN),
436
+ temperature=model_settings.get('temperature', NOT_GIVEN),
437
+ top_p=model_settings.get('top_p', NOT_GIVEN),
438
+ presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
439
+ frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
440
+ logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
441
+ logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
442
+ top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
447
443
  extra_headers=extra_headers,
448
444
  extra_body=model_settings.get('extra_body'),
449
445
  )
@@ -512,7 +508,7 @@ class OpenAIChatModel(Model):
512
508
  part.tool_call_id = _guard_tool_call_id(part)
513
509
  items.append(part)
514
510
  return ModelResponse(
515
- items,
511
+ parts=items,
516
512
  usage=_map_usage(response),
517
513
  model_name=response.model,
518
514
  timestamp=timestamp,
@@ -582,7 +578,7 @@ class OpenAIChatModel(Model):
582
578
  elif isinstance(item, ToolCallPart):
583
579
  tool_calls.append(self._map_tool_call(item))
584
580
  # OpenAI doesn't return built-in tool calls
585
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
581
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
586
582
  pass
587
583
  else:
588
584
  assert_never(item)
@@ -613,7 +609,7 @@ class OpenAIChatModel(Model):
613
609
  def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
614
610
  response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
615
611
  'type': 'json_schema',
616
- 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True},
612
+ 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
617
613
  }
618
614
  if o.description:
619
615
  response_format_param['json_schema']['description'] = o.description
@@ -828,7 +824,7 @@ class OpenAIResponsesModel(Model):
828
824
  elif item.type == 'function_call':
829
825
  items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
830
826
  return ModelResponse(
831
- items,
827
+ parts=items,
832
828
  usage=_map_usage(response),
833
829
  model_name=response.model,
834
830
  provider_response_id=response.id,
@@ -918,11 +914,9 @@ class OpenAIResponsesModel(Model):
918
914
  text = text or {}
919
915
  text['verbosity'] = verbosity
920
916
 
921
- sampling_settings = (
922
- model_settings
923
- if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings
924
- else OpenAIResponsesModelSettings()
925
- )
917
+ unsupported_model_settings = OpenAIModelProfile.from_profile(self.profile).openai_unsupported_model_settings
918
+ for setting in unsupported_model_settings:
919
+ model_settings.pop(setting, None)
926
920
 
927
921
  try:
928
922
  extra_headers = model_settings.get('extra_headers', {})
@@ -936,8 +930,8 @@ class OpenAIResponsesModel(Model):
936
930
  tool_choice=tool_choice or NOT_GIVEN,
937
931
  max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
938
932
  stream=stream,
939
- temperature=sampling_settings.get('temperature', NOT_GIVEN),
940
- top_p=sampling_settings.get('top_p', NOT_GIVEN),
933
+ temperature=model_settings.get('temperature', NOT_GIVEN),
934
+ top_p=model_settings.get('top_p', NOT_GIVEN),
941
935
  truncation=model_settings.get('openai_truncation', NOT_GIVEN),
942
936
  timeout=model_settings.get('timeout', NOT_GIVEN),
943
937
  service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
@@ -1049,7 +1043,7 @@ class OpenAIResponsesModel(Model):
1049
1043
  elif isinstance(item, ToolCallPart):
1050
1044
  openai_messages.append(self._map_tool_call(item))
1051
1045
  # OpenAI doesn't return built-in tool calls
1052
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)):
1046
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
1053
1047
  pass
1054
1048
  elif isinstance(item, ThinkingPart):
1055
1049
  # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
@@ -1180,6 +1174,10 @@ class OpenAIStreamedResponse(StreamedResponse):
1180
1174
  except IndexError:
1181
1175
  continue
1182
1176
 
1177
+ # When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
1178
+ if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
1179
+ continue
1180
+
1183
1181
  # Handle the text part of the response
1184
1182
  content = choice.delta.content
1185
1183
  if content is not None:
@@ -1279,12 +1277,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1279
1277
  tool_call_id=chunk.item.call_id,
1280
1278
  )
1281
1279
  elif isinstance(chunk.item, responses.ResponseReasoningItem):
1282
- content = chunk.item.summary[0].text if chunk.item.summary else ''
1283
- yield self._parts_manager.handle_thinking_delta(
1284
- vendor_part_id=chunk.item.id,
1285
- content=content,
1286
- signature=chunk.item.id,
1287
- )
1280
+ pass
1288
1281
  elif isinstance(chunk.item, responses.ResponseOutputMessage):
1289
1282
  pass
1290
1283
  elif isinstance(chunk.item, responses.ResponseFunctionWebSearch):
@@ -1300,7 +1293,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1300
1293
  pass
1301
1294
 
1302
1295
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
1303
- pass # there's nothing we need to do here
1296
+ yield self._parts_manager.handle_thinking_delta(
1297
+ vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
1298
+ content=chunk.part.text,
1299
+ id=chunk.item_id,
1300
+ )
1304
1301
 
1305
1302
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
1306
1303
  pass # there's nothing we need to do here
@@ -1310,9 +1307,9 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1310
1307
 
1311
1308
  elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
1312
1309
  yield self._parts_manager.handle_thinking_delta(
1313
- vendor_part_id=chunk.item_id,
1310
+ vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
1314
1311
  content=chunk.delta,
1315
- signature=chunk.item_id,
1312
+ id=chunk.item_id,
1316
1313
  )
1317
1314
 
1318
1315
  # TODO(Marcelo): We should support annotations in the future.
@@ -1320,9 +1317,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1320
1317
  pass # there's nothing we need to do here
1321
1318
 
1322
1319
  elif isinstance(chunk, responses.ResponseTextDeltaEvent):
1323
- maybe_event = self._parts_manager.handle_text_delta(
1324
- vendor_part_id=chunk.content_index, content=chunk.delta
1325
- )
1320
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=chunk.item_id, content=chunk.delta)
1326
1321
  if maybe_event is not None: # pragma: no branch
1327
1322
  yield maybe_event
1328
1323