pydantic-ai-slim 1.0.0b1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (38) hide show
  1. pydantic_ai/_a2a.py +1 -1
  2. pydantic_ai/_agent_graph.py +65 -49
  3. pydantic_ai/_parts_manager.py +3 -1
  4. pydantic_ai/_tool_manager.py +33 -6
  5. pydantic_ai/ag_ui.py +75 -43
  6. pydantic_ai/agent/__init__.py +10 -7
  7. pydantic_ai/durable_exec/dbos/__init__.py +6 -0
  8. pydantic_ai/durable_exec/dbos/_agent.py +718 -0
  9. pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
  10. pydantic_ai/durable_exec/dbos/_model.py +137 -0
  11. pydantic_ai/durable_exec/dbos/_utils.py +10 -0
  12. pydantic_ai/durable_exec/temporal/_agent.py +71 -10
  13. pydantic_ai/exceptions.py +2 -2
  14. pydantic_ai/mcp.py +14 -26
  15. pydantic_ai/messages.py +90 -19
  16. pydantic_ai/models/__init__.py +9 -0
  17. pydantic_ai/models/anthropic.py +28 -11
  18. pydantic_ai/models/bedrock.py +6 -14
  19. pydantic_ai/models/gemini.py +3 -1
  20. pydantic_ai/models/google.py +58 -5
  21. pydantic_ai/models/groq.py +122 -34
  22. pydantic_ai/models/instrumented.py +29 -11
  23. pydantic_ai/models/openai.py +84 -29
  24. pydantic_ai/providers/__init__.py +4 -0
  25. pydantic_ai/providers/bedrock.py +11 -3
  26. pydantic_ai/providers/google_vertex.py +2 -1
  27. pydantic_ai/providers/groq.py +21 -2
  28. pydantic_ai/providers/litellm.py +134 -0
  29. pydantic_ai/retries.py +42 -2
  30. pydantic_ai/tools.py +18 -7
  31. pydantic_ai/toolsets/combined.py +2 -2
  32. pydantic_ai/toolsets/function.py +54 -19
  33. pydantic_ai/usage.py +37 -3
  34. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/METADATA +9 -8
  35. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/RECORD +38 -32
  36. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/WHEEL +0 -0
  37. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/entry_points.txt +0 -0
  38. {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/licenses/LICENSE +0 -0
@@ -7,8 +7,11 @@ from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
8
  from typing import Any, Literal, cast, overload
9
9
 
10
+ from pydantic import BaseModel, Json, ValidationError
10
11
  from typing_extensions import assert_never
11
12
 
13
+ from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
14
+
12
15
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
16
  from .._run_context import RunContext
14
17
  from .._thinking_part import split_content_into_text_and_thinking
@@ -48,7 +51,7 @@ from . import (
48
51
  )
49
52
 
50
53
  try:
51
- from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
54
+ from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
52
55
  from groq.types import chat
53
56
  from groq.types.chat.chat_completion_content_part_image_param import ImageURL
54
57
  except ImportError as _import_error:
@@ -169,9 +172,24 @@ class GroqModel(Model):
169
172
  model_request_parameters: ModelRequestParameters,
170
173
  ) -> ModelResponse:
171
174
  check_allow_model_requests()
172
- response = await self._completions_create(
173
- messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
174
- )
175
+ try:
176
+ response = await self._completions_create(
177
+ messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
178
+ )
179
+ except ModelHTTPError as e:
180
+ if isinstance(e.body, dict): # pragma: no branch
181
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
182
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
183
+ try:
184
+ error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
185
+ tool_call_part = ToolCallPart(
186
+ tool_name=error.error.failed_generation.name,
187
+ args=error.error.failed_generation.arguments,
188
+ )
189
+ return ModelResponse(parts=[tool_call_part])
190
+ except ValidationError:
191
+ pass
192
+ raise
175
193
  model_response = self._process_response(response)
176
194
  return model_response
177
195
 
@@ -228,6 +246,18 @@ class GroqModel(Model):
228
246
 
229
247
  groq_messages = self._map_messages(messages)
230
248
 
249
+ response_format: chat.completion_create_params.ResponseFormat | None = None
250
+ if model_request_parameters.output_mode == 'native':
251
+ output_object = model_request_parameters.output_object
252
+ assert output_object is not None
253
+ response_format = self._map_json_schema(output_object)
254
+ elif (
255
+ model_request_parameters.output_mode == 'prompted'
256
+ and not tools
257
+ and self.profile.supports_json_object_output
258
+ ): # pragma: no branch
259
+ response_format = {'type': 'json_object'}
260
+
231
261
  try:
232
262
  extra_headers = model_settings.get('extra_headers', {})
233
263
  extra_headers.setdefault('User-Agent', get_user_agent())
@@ -240,6 +270,7 @@ class GroqModel(Model):
240
270
  tool_choice=tool_choice or NOT_GIVEN,
241
271
  stop=model_settings.get('stop_sequences', NOT_GIVEN),
242
272
  stream=stream,
273
+ response_format=response_format or NOT_GIVEN,
243
274
  max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
244
275
  temperature=model_settings.get('temperature', NOT_GIVEN),
245
276
  top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -385,6 +416,19 @@ class GroqModel(Model):
385
416
  },
386
417
  }
387
418
 
419
+ def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
420
+ response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
421
+ 'type': 'json_schema',
422
+ 'json_schema': {
423
+ 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
424
+ 'schema': o.json_schema,
425
+ 'strict': o.strict,
426
+ },
427
+ }
428
+ if o.description: # pragma: no branch
429
+ response_format_param['json_schema']['description'] = o.description
430
+ return response_format_param
431
+
388
432
  @classmethod
389
433
  def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
390
434
  for part in message.parts:
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
449
493
  _provider_name: str
450
494
 
451
495
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
452
- async for chunk in self._response:
453
- self._usage += _map_usage(chunk)
454
-
455
- try:
456
- choice = chunk.choices[0]
457
- except IndexError:
458
- continue
459
-
460
- # Handle the text part of the response
461
- content = choice.delta.content
462
- if content is not None:
463
- maybe_event = self._parts_manager.handle_text_delta(
464
- vendor_part_id='content',
465
- content=content,
466
- thinking_tags=self._model_profile.thinking_tags,
467
- ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
468
- )
469
- if maybe_event is not None: # pragma: no branch
470
- yield maybe_event
471
-
472
- # Handle the tool calls
473
- for dtc in choice.delta.tool_calls or []:
474
- maybe_event = self._parts_manager.handle_tool_call_delta(
475
- vendor_part_id=dtc.index,
476
- tool_name=dtc.function and dtc.function.name,
477
- args=dtc.function and dtc.function.arguments,
478
- tool_call_id=dtc.id,
479
- )
480
- if maybe_event is not None:
481
- yield maybe_event
496
+ try:
497
+ async for chunk in self._response:
498
+ self._usage += _map_usage(chunk)
499
+
500
+ try:
501
+ choice = chunk.choices[0]
502
+ except IndexError:
503
+ continue
504
+
505
+ # Handle the text part of the response
506
+ content = choice.delta.content
507
+ if content is not None:
508
+ maybe_event = self._parts_manager.handle_text_delta(
509
+ vendor_part_id='content',
510
+ content=content,
511
+ thinking_tags=self._model_profile.thinking_tags,
512
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
513
+ )
514
+ if maybe_event is not None: # pragma: no branch
515
+ yield maybe_event
516
+
517
+ # Handle the tool calls
518
+ for dtc in choice.delta.tool_calls or []:
519
+ maybe_event = self._parts_manager.handle_tool_call_delta(
520
+ vendor_part_id=dtc.index,
521
+ tool_name=dtc.function and dtc.function.name,
522
+ args=dtc.function and dtc.function.arguments,
523
+ tool_call_id=dtc.id,
524
+ )
525
+ if maybe_event is not None:
526
+ yield maybe_event
527
+ except APIError as e:
528
+ if isinstance(e.body, dict): # pragma: no branch
529
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
530
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call
531
+ try:
532
+ error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
533
+ yield self._parts_manager.handle_tool_call_part(
534
+ vendor_part_id='tool_use_failed',
535
+ tool_name=error.failed_generation.name,
536
+ args=error.failed_generation.arguments,
537
+ )
538
+ return
539
+ except ValidationError as e: # pragma: no cover
540
+ pass
541
+ raise # pragma: no cover
482
542
 
483
543
  @property
484
544
  def model_name(self) -> GroqModelName:
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
510
570
  input_tokens=response_usage.prompt_tokens,
511
571
  output_tokens=response_usage.completion_tokens,
512
572
  )
573
+
574
+
575
+ class _GroqToolUseFailedGeneration(BaseModel):
576
+ name: str
577
+ arguments: dict[str, Any]
578
+
579
+
580
+ class _GroqToolUseFailedInnerError(BaseModel):
581
+ message: str
582
+ type: Literal['invalid_request_error']
583
+ code: Literal['tool_use_failed']
584
+ failed_generation: Json[_GroqToolUseFailedGeneration]
585
+
586
+
587
+ class _GroqToolUseFailedError(BaseModel):
588
+ # The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
589
+ # but we'd rather handle it ourselves so we can tell the model to retry the tool call.
590
+ # Example payload from `exception.body`:
591
+ # {
592
+ # 'error': {
593
+ # 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
594
+ # 'type': 'invalid_request_error',
595
+ # 'code': 'tool_use_failed',
596
+ # 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
597
+ # }
598
+ # }
599
+
600
+ error: _GroqToolUseFailedInnerError
@@ -221,7 +221,10 @@ class InstrumentationSettings:
221
221
  _otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
222
222
  )
223
223
  elif isinstance(message, ModelResponse): # pragma: no branch
224
- result.append(_otel_messages.ChatMessage(role='assistant', parts=message.otel_message_parts(self)))
224
+ otel_message = _otel_messages.OutputMessage(role='assistant', parts=message.otel_message_parts(self))
225
+ if message.finish_reason is not None:
226
+ otel_message['finish_reason'] = message.finish_reason
227
+ result.append(otel_message)
225
228
  return result
226
229
 
227
230
  def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
@@ -246,12 +249,10 @@ class InstrumentationSettings:
246
249
  else:
247
250
  output_messages = self.messages_to_otel_messages([response])
248
251
  assert len(output_messages) == 1
249
- output_message = cast(_otel_messages.OutputMessage, output_messages[0])
250
- if response.provider_details and 'finish_reason' in response.provider_details:
251
- output_message['finish_reason'] = response.provider_details['finish_reason']
252
+ output_message = output_messages[0]
252
253
  instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
253
254
  system_instructions_attributes = self.system_instructions_attributes(instructions)
254
- attributes = {
255
+ attributes: dict[str, AttributeValue] = {
255
256
  'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
256
257
  'gen_ai.output.messages': json.dumps([output_message]),
257
258
  **system_instructions_attributes,
@@ -420,12 +421,25 @@ class InstrumentedModel(WrapperModel):
420
421
  return
421
422
 
422
423
  self.instrumentation_settings.handle_messages(messages, response, system, span)
423
- span.set_attributes(
424
- {
425
- **response.usage.opentelemetry_attributes(),
426
- 'gen_ai.response.model': response_model,
427
- }
428
- )
424
+
425
+ attributes_to_set = {
426
+ **response.usage.opentelemetry_attributes(),
427
+ 'gen_ai.response.model': response_model,
428
+ }
429
+ try:
430
+ attributes_to_set['operation.cost'] = float(response.cost().total_price)
431
+ except LookupError:
432
+ # The cost of this provider/model is unknown, which is common.
433
+ pass
434
+ except Exception as e:
435
+ warnings.warn(
436
+ f'Failed to get cost from response: {type(e).__name__}: {e}', CostCalculationFailedWarning
437
+ )
438
+ if response.provider_response_id is not None:
439
+ attributes_to_set['gen_ai.response.id'] = response.provider_response_id
440
+ if response.finish_reason is not None:
441
+ attributes_to_set['gen_ai.response.finish_reasons'] = [response.finish_reason]
442
+ span.set_attributes(attributes_to_set)
429
443
  span.update_name(f'{operation} {request_model}')
430
444
 
431
445
  yield finish
@@ -473,3 +487,7 @@ class InstrumentedModel(WrapperModel):
473
487
  return str(value)
474
488
  except Exception as e:
475
489
  return f'Unable to serialize: {e}'
490
+
491
+
492
+ class CostCalculationFailedWarning(Warning):
493
+ """Warning raised when cost calculation fails."""
@@ -24,6 +24,7 @@ from ..messages import (
24
24
  BuiltinToolCallPart,
25
25
  BuiltinToolReturnPart,
26
26
  DocumentUrl,
27
+ FinishReason,
27
28
  ImageUrl,
28
29
  ModelMessage,
29
30
  ModelRequest,
@@ -72,6 +73,7 @@ try:
72
73
  )
73
74
  from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
74
75
  from openai.types.responses.response_input_param import FunctionCallOutput, Message
76
+ from openai.types.responses.response_status import ResponseStatus
75
77
  from openai.types.shared import ReasoningEffort
76
78
  from openai.types.shared_params import Reasoning
77
79
  except ImportError as _import_error:
@@ -103,6 +105,25 @@ allows this model to be used more easily with other model types (ie, Ollama, Dee
103
105
  """
104
106
 
105
107
 
108
+ _CHAT_FINISH_REASON_MAP: dict[
109
+ Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason
110
+ ] = {
111
+ 'stop': 'stop',
112
+ 'length': 'length',
113
+ 'tool_calls': 'tool_call',
114
+ 'content_filter': 'content_filter',
115
+ 'function_call': 'tool_call',
116
+ }
117
+
118
+ _RESPONSES_FINISH_REASON_MAP: dict[Literal['max_output_tokens', 'content_filter'] | ResponseStatus, FinishReason] = {
119
+ 'max_output_tokens': 'length',
120
+ 'content_filter': 'content_filter',
121
+ 'completed': 'stop',
122
+ 'cancelled': 'error',
123
+ 'failed': 'error',
124
+ }
125
+
126
+
106
127
  class OpenAIChatModelSettings(ModelSettings, total=False):
107
128
  """Settings used for an OpenAI model request."""
108
129
 
@@ -225,6 +246,7 @@ class OpenAIChatModel(Model):
225
246
  'openrouter',
226
247
  'together',
227
248
  'vercel',
249
+ 'litellm',
228
250
  ]
229
251
  | Provider[AsyncOpenAI] = 'openai',
230
252
  profile: ModelProfileSpec | None = None,
@@ -252,6 +274,7 @@ class OpenAIChatModel(Model):
252
274
  'openrouter',
253
275
  'together',
254
276
  'vercel',
277
+ 'litellm',
255
278
  ]
256
279
  | Provider[AsyncOpenAI] = 'openai',
257
280
  profile: ModelProfileSpec | None = None,
@@ -278,6 +301,7 @@ class OpenAIChatModel(Model):
278
301
  'openrouter',
279
302
  'together',
280
303
  'vercel',
304
+ 'litellm',
281
305
  ]
282
306
  | Provider[AsyncOpenAI] = 'openai',
283
307
  profile: ModelProfileSpec | None = None,
@@ -471,24 +495,22 @@ class OpenAIChatModel(Model):
471
495
  if reasoning_content := getattr(choice.message, 'reasoning_content', None):
472
496
  items.append(ThinkingPart(content=reasoning_content))
473
497
 
474
- vendor_details: dict[str, Any] | None = None
498
+ vendor_details: dict[str, Any] = {}
475
499
 
476
500
  # Add logprobs to vendor_details if available
477
501
  if choice.logprobs is not None and choice.logprobs.content:
478
502
  # Convert logprobs to a serializable format
479
- vendor_details = {
480
- 'logprobs': [
481
- {
482
- 'token': lp.token,
483
- 'bytes': lp.bytes,
484
- 'logprob': lp.logprob,
485
- 'top_logprobs': [
486
- {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
487
- ],
488
- }
489
- for lp in choice.logprobs.content
490
- ],
491
- }
503
+ vendor_details['logprobs'] = [
504
+ {
505
+ 'token': lp.token,
506
+ 'bytes': lp.bytes,
507
+ 'logprob': lp.logprob,
508
+ 'top_logprobs': [
509
+ {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
510
+ ],
511
+ }
512
+ for lp in choice.logprobs.content
513
+ ]
492
514
 
493
515
  if choice.message.content is not None:
494
516
  items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
@@ -504,14 +526,21 @@ class OpenAIChatModel(Model):
504
526
  assert_never(c)
505
527
  part.tool_call_id = _guard_tool_call_id(part)
506
528
  items.append(part)
529
+
530
+ finish_reason: FinishReason | None = None
531
+ if raw_finish_reason := choice.finish_reason: # pragma: no branch
532
+ vendor_details['finish_reason'] = raw_finish_reason
533
+ finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
534
+
507
535
  return ModelResponse(
508
536
  parts=items,
509
537
  usage=_map_usage(response),
510
538
  model_name=response.model,
511
539
  timestamp=timestamp,
512
- provider_details=vendor_details,
540
+ provider_details=vendor_details or None,
513
541
  provider_response_id=response.id,
514
542
  provider_name=self._provider.name,
543
+ finish_reason=finish_reason,
515
544
  )
516
545
 
517
546
  async def _process_streamed_response(
@@ -606,7 +635,7 @@ class OpenAIChatModel(Model):
606
635
  def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
607
636
  response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
608
637
  'type': 'json_schema',
609
- 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True},
638
+ 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
610
639
  }
611
640
  if o.description:
612
641
  response_format_param['json_schema']['description'] = o.description
@@ -820,6 +849,14 @@ class OpenAIResponsesModel(Model):
820
849
  items.append(TextPart(content.text))
821
850
  elif item.type == 'function_call':
822
851
  items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
852
+
853
+ finish_reason: FinishReason | None = None
854
+ provider_details: dict[str, Any] | None = None
855
+ raw_finish_reason = details.reason if (details := response.incomplete_details) else response.status
856
+ if raw_finish_reason:
857
+ provider_details = {'finish_reason': raw_finish_reason}
858
+ finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
859
+
823
860
  return ModelResponse(
824
861
  parts=items,
825
862
  usage=_map_usage(response),
@@ -827,6 +864,8 @@ class OpenAIResponsesModel(Model):
827
864
  provider_response_id=response.id,
828
865
  timestamp=timestamp,
829
866
  provider_name=self._provider.name,
867
+ finish_reason=finish_reason,
868
+ provider_details=provider_details,
830
869
  )
831
870
 
832
871
  async def _process_streamed_response(
@@ -1166,11 +1205,22 @@ class OpenAIStreamedResponse(StreamedResponse):
1166
1205
  async for chunk in self._response:
1167
1206
  self._usage += _map_usage(chunk)
1168
1207
 
1208
+ if chunk.id and self.provider_response_id is None:
1209
+ self.provider_response_id = chunk.id
1210
+
1169
1211
  try:
1170
1212
  choice = chunk.choices[0]
1171
1213
  except IndexError:
1172
1214
  continue
1173
1215
 
1216
+ # When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
1217
+ if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
1218
+ continue
1219
+
1220
+ if raw_finish_reason := choice.finish_reason:
1221
+ self.provider_details = {'finish_reason': raw_finish_reason}
1222
+ self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
1223
+
1174
1224
  # Handle the text part of the response
1175
1225
  content = choice.delta.content
1176
1226
  if content is not None:
@@ -1230,6 +1280,13 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1230
1280
  if isinstance(chunk, responses.ResponseCompletedEvent):
1231
1281
  self._usage += _map_usage(chunk.response)
1232
1282
 
1283
+ raw_finish_reason = (
1284
+ details.reason if (details := chunk.response.incomplete_details) else chunk.response.status
1285
+ )
1286
+ if raw_finish_reason: # pragma: no branch
1287
+ self.provider_details = {'finish_reason': raw_finish_reason}
1288
+ self.finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
1289
+
1233
1290
  elif isinstance(chunk, responses.ResponseContentPartAddedEvent):
1234
1291
  pass # there's nothing we need to do here
1235
1292
 
@@ -1237,7 +1294,8 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1237
1294
  pass # there's nothing we need to do here
1238
1295
 
1239
1296
  elif isinstance(chunk, responses.ResponseCreatedEvent):
1240
- pass # there's nothing we need to do here
1297
+ if chunk.response.id: # pragma: no branch
1298
+ self.provider_response_id = chunk.response.id
1241
1299
 
1242
1300
  elif isinstance(chunk, responses.ResponseFailedEvent): # pragma: no cover
1243
1301
  self._usage += _map_usage(chunk.response)
@@ -1270,12 +1328,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1270
1328
  tool_call_id=chunk.item.call_id,
1271
1329
  )
1272
1330
  elif isinstance(chunk.item, responses.ResponseReasoningItem):
1273
- content = chunk.item.summary[0].text if chunk.item.summary else ''
1274
- yield self._parts_manager.handle_thinking_delta(
1275
- vendor_part_id=chunk.item.id,
1276
- content=content,
1277
- signature=chunk.item.id,
1278
- )
1331
+ pass
1279
1332
  elif isinstance(chunk.item, responses.ResponseOutputMessage):
1280
1333
  pass
1281
1334
  elif isinstance(chunk.item, responses.ResponseFunctionWebSearch):
@@ -1291,7 +1344,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1291
1344
  pass
1292
1345
 
1293
1346
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
1294
- pass # there's nothing we need to do here
1347
+ yield self._parts_manager.handle_thinking_delta(
1348
+ vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
1349
+ content=chunk.part.text,
1350
+ id=chunk.item_id,
1351
+ )
1295
1352
 
1296
1353
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
1297
1354
  pass # there's nothing we need to do here
@@ -1301,9 +1358,9 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1301
1358
 
1302
1359
  elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
1303
1360
  yield self._parts_manager.handle_thinking_delta(
1304
- vendor_part_id=chunk.item_id,
1361
+ vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
1305
1362
  content=chunk.delta,
1306
- signature=chunk.item_id,
1363
+ id=chunk.item_id,
1307
1364
  )
1308
1365
 
1309
1366
  # TODO(Marcelo): We should support annotations in the future.
@@ -1311,9 +1368,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1311
1368
  pass # there's nothing we need to do here
1312
1369
 
1313
1370
  elif isinstance(chunk, responses.ResponseTextDeltaEvent):
1314
- maybe_event = self._parts_manager.handle_text_delta(
1315
- vendor_part_id=chunk.content_index, content=chunk.delta
1316
- )
1371
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=chunk.item_id, content=chunk.delta)
1317
1372
  if maybe_event is not None: # pragma: no branch
1318
1373
  yield maybe_event
1319
1374
 
@@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
135
135
  from .github import GitHubProvider
136
136
 
137
137
  return GitHubProvider
138
+ elif provider == 'litellm':
139
+ from .litellm import LiteLLMProvider
140
+
141
+ return LiteLLMProvider
138
142
  else: # pragma: no cover
139
143
  raise ValueError(f'Unknown provider: {provider}')
140
144
 
@@ -35,11 +35,19 @@ class BedrockModelProfile(ModelProfile):
35
35
  ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
36
36
  """
37
37
 
38
- bedrock_supports_tool_choice: bool = True
38
+ bedrock_supports_tool_choice: bool = False
39
39
  bedrock_tool_result_format: Literal['text', 'json'] = 'text'
40
40
  bedrock_send_back_thinking_parts: bool = False
41
41
 
42
42
 
43
+ def bedrock_amazon_model_profile(model_name: str) -> ModelProfile | None:
44
+ """Get the model profile for an Amazon model used via Bedrock."""
45
+ profile = amazon_model_profile(model_name)
46
+ if 'nova' in model_name:
47
+ return BedrockModelProfile(bedrock_supports_tool_choice=True).update(profile)
48
+ return profile
49
+
50
+
43
51
  class BedrockProvider(Provider[BaseClient]):
44
52
  """Provider for AWS Bedrock."""
45
53
 
@@ -58,13 +66,13 @@ class BedrockProvider(Provider[BaseClient]):
58
66
  def model_profile(self, model_name: str) -> ModelProfile | None:
59
67
  provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
60
68
  'anthropic': lambda model_name: BedrockModelProfile(
61
- bedrock_supports_tool_choice=False, bedrock_send_back_thinking_parts=True
69
+ bedrock_supports_tool_choice=True, bedrock_send_back_thinking_parts=True
62
70
  ).update(anthropic_model_profile(model_name)),
63
71
  'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
64
72
  mistral_model_profile(model_name)
65
73
  ),
66
74
  'cohere': cohere_model_profile,
67
- 'amazon': amazon_model_profile,
75
+ 'amazon': bedrock_amazon_model_profile,
68
76
  'meta': meta_model_profile,
69
77
  'deepseek': deepseek_model_profile,
70
78
  }
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import functools
4
+ from asyncio import Lock
4
5
  from collections.abc import AsyncGenerator, Mapping
5
6
  from pathlib import Path
6
7
  from typing import Literal, overload
@@ -118,7 +119,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
118
119
  class _VertexAIAuth(httpx.Auth):
119
120
  """Auth class for Vertex AI API."""
120
121
 
121
- _refresh_lock: anyio.Lock = anyio.Lock()
122
+ _refresh_lock: Lock = Lock()
122
123
 
123
124
  credentials: BaseCredentials | ServiceAccountCredentials | None
124
125
 
@@ -14,6 +14,7 @@ from pydantic_ai.profiles.groq import groq_model_profile
14
14
  from pydantic_ai.profiles.meta import meta_model_profile
15
15
  from pydantic_ai.profiles.mistral import mistral_model_profile
16
16
  from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
17
+ from pydantic_ai.profiles.openai import openai_model_profile
17
18
  from pydantic_ai.profiles.qwen import qwen_model_profile
18
19
  from pydantic_ai.providers import Provider
19
20
 
@@ -26,6 +27,23 @@ except ImportError as _import_error: # pragma: no cover
26
27
  ) from _import_error
27
28
 
28
29
 
30
+ def groq_moonshotai_model_profile(model_name: str) -> ModelProfile | None:
31
+ """Get the model profile for an MoonshotAI model used with the Groq provider."""
32
+ return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
33
+ moonshotai_model_profile(model_name)
34
+ )
35
+
36
+
37
+ def meta_groq_model_profile(model_name: str) -> ModelProfile | None:
38
+ """Get the model profile for a Meta model used with the Groq provider."""
39
+ if model_name in {'llama-4-maverick-17b-128e-instruct', 'llama-4-scout-17b-16e-instruct'}:
40
+ return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
41
+ meta_model_profile(model_name)
42
+ )
43
+ else:
44
+ return meta_model_profile(model_name)
45
+
46
+
29
47
  class GroqProvider(Provider[AsyncGroq]):
30
48
  """Provider for Groq API."""
31
49
 
@@ -44,13 +62,14 @@ class GroqProvider(Provider[AsyncGroq]):
44
62
  def model_profile(self, model_name: str) -> ModelProfile | None:
45
63
  prefix_to_profile = {
46
64
  'llama': meta_model_profile,
47
- 'meta-llama/': meta_model_profile,
65
+ 'meta-llama/': meta_groq_model_profile,
48
66
  'gemma': google_model_profile,
49
67
  'qwen': qwen_model_profile,
50
68
  'deepseek': deepseek_model_profile,
51
69
  'mistral': mistral_model_profile,
52
- 'moonshotai/': moonshotai_model_profile,
70
+ 'moonshotai/': groq_moonshotai_model_profile,
53
71
  'compound-': groq_model_profile,
72
+ 'openai/': openai_model_profile,
54
73
  }
55
74
 
56
75
  for prefix, profile_func in prefix_to_profile.items():