pydantic-ai-slim 1.0.0b1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_a2a.py +1 -1
- pydantic_ai/_agent_graph.py +65 -49
- pydantic_ai/_parts_manager.py +3 -1
- pydantic_ai/_tool_manager.py +33 -6
- pydantic_ai/ag_ui.py +75 -43
- pydantic_ai/agent/__init__.py +10 -7
- pydantic_ai/durable_exec/dbos/__init__.py +6 -0
- pydantic_ai/durable_exec/dbos/_agent.py +718 -0
- pydantic_ai/durable_exec/dbos/_mcp_server.py +89 -0
- pydantic_ai/durable_exec/dbos/_model.py +137 -0
- pydantic_ai/durable_exec/dbos/_utils.py +10 -0
- pydantic_ai/durable_exec/temporal/_agent.py +71 -10
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/mcp.py +14 -26
- pydantic_ai/messages.py +90 -19
- pydantic_ai/models/__init__.py +9 -0
- pydantic_ai/models/anthropic.py +28 -11
- pydantic_ai/models/bedrock.py +6 -14
- pydantic_ai/models/gemini.py +3 -1
- pydantic_ai/models/google.py +58 -5
- pydantic_ai/models/groq.py +122 -34
- pydantic_ai/models/instrumented.py +29 -11
- pydantic_ai/models/openai.py +84 -29
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/bedrock.py +11 -3
- pydantic_ai/providers/google_vertex.py +2 -1
- pydantic_ai/providers/groq.py +21 -2
- pydantic_ai/providers/litellm.py +134 -0
- pydantic_ai/retries.py +42 -2
- pydantic_ai/tools.py +18 -7
- pydantic_ai/toolsets/combined.py +2 -2
- pydantic_ai/toolsets/function.py +54 -19
- pydantic_ai/usage.py +37 -3
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/METADATA +9 -8
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/RECORD +38 -32
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.0.0b1.dist-info → pydantic_ai_slim-1.0.2.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/groq.py
CHANGED
|
@@ -7,8 +7,11 @@ from dataclasses import dataclass, field
|
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from typing import Any, Literal, cast, overload
|
|
9
9
|
|
|
10
|
+
from pydantic import BaseModel, Json, ValidationError
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
13
|
+
from pydantic_ai._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
14
|
+
|
|
12
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
13
16
|
from .._run_context import RunContext
|
|
14
17
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
@@ -48,7 +51,7 @@ from . import (
|
|
|
48
51
|
)
|
|
49
52
|
|
|
50
53
|
try:
|
|
51
|
-
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
54
|
+
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
|
|
52
55
|
from groq.types import chat
|
|
53
56
|
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
54
57
|
except ImportError as _import_error:
|
|
@@ -169,9 +172,24 @@ class GroqModel(Model):
|
|
|
169
172
|
model_request_parameters: ModelRequestParameters,
|
|
170
173
|
) -> ModelResponse:
|
|
171
174
|
check_allow_model_requests()
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
+
try:
|
|
176
|
+
response = await self._completions_create(
|
|
177
|
+
messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters
|
|
178
|
+
)
|
|
179
|
+
except ModelHTTPError as e:
|
|
180
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
181
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
182
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
183
|
+
try:
|
|
184
|
+
error = _GroqToolUseFailedError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
185
|
+
tool_call_part = ToolCallPart(
|
|
186
|
+
tool_name=error.error.failed_generation.name,
|
|
187
|
+
args=error.error.failed_generation.arguments,
|
|
188
|
+
)
|
|
189
|
+
return ModelResponse(parts=[tool_call_part])
|
|
190
|
+
except ValidationError:
|
|
191
|
+
pass
|
|
192
|
+
raise
|
|
175
193
|
model_response = self._process_response(response)
|
|
176
194
|
return model_response
|
|
177
195
|
|
|
@@ -228,6 +246,18 @@ class GroqModel(Model):
|
|
|
228
246
|
|
|
229
247
|
groq_messages = self._map_messages(messages)
|
|
230
248
|
|
|
249
|
+
response_format: chat.completion_create_params.ResponseFormat | None = None
|
|
250
|
+
if model_request_parameters.output_mode == 'native':
|
|
251
|
+
output_object = model_request_parameters.output_object
|
|
252
|
+
assert output_object is not None
|
|
253
|
+
response_format = self._map_json_schema(output_object)
|
|
254
|
+
elif (
|
|
255
|
+
model_request_parameters.output_mode == 'prompted'
|
|
256
|
+
and not tools
|
|
257
|
+
and self.profile.supports_json_object_output
|
|
258
|
+
): # pragma: no branch
|
|
259
|
+
response_format = {'type': 'json_object'}
|
|
260
|
+
|
|
231
261
|
try:
|
|
232
262
|
extra_headers = model_settings.get('extra_headers', {})
|
|
233
263
|
extra_headers.setdefault('User-Agent', get_user_agent())
|
|
@@ -240,6 +270,7 @@ class GroqModel(Model):
|
|
|
240
270
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
241
271
|
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
242
272
|
stream=stream,
|
|
273
|
+
response_format=response_format or NOT_GIVEN,
|
|
243
274
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
244
275
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
245
276
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
@@ -385,6 +416,19 @@ class GroqModel(Model):
|
|
|
385
416
|
},
|
|
386
417
|
}
|
|
387
418
|
|
|
419
|
+
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
|
|
420
|
+
response_format_param: chat.completion_create_params.ResponseFormatResponseFormatJsonSchema = {
|
|
421
|
+
'type': 'json_schema',
|
|
422
|
+
'json_schema': {
|
|
423
|
+
'name': o.name or DEFAULT_OUTPUT_TOOL_NAME,
|
|
424
|
+
'schema': o.json_schema,
|
|
425
|
+
'strict': o.strict,
|
|
426
|
+
},
|
|
427
|
+
}
|
|
428
|
+
if o.description: # pragma: no branch
|
|
429
|
+
response_format_param['json_schema']['description'] = o.description
|
|
430
|
+
return response_format_param
|
|
431
|
+
|
|
388
432
|
@classmethod
|
|
389
433
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
|
|
390
434
|
for part in message.parts:
|
|
@@ -449,36 +493,52 @@ class GroqStreamedResponse(StreamedResponse):
|
|
|
449
493
|
_provider_name: str
|
|
450
494
|
|
|
451
495
|
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
496
|
+
try:
|
|
497
|
+
async for chunk in self._response:
|
|
498
|
+
self._usage += _map_usage(chunk)
|
|
499
|
+
|
|
500
|
+
try:
|
|
501
|
+
choice = chunk.choices[0]
|
|
502
|
+
except IndexError:
|
|
503
|
+
continue
|
|
504
|
+
|
|
505
|
+
# Handle the text part of the response
|
|
506
|
+
content = choice.delta.content
|
|
507
|
+
if content is not None:
|
|
508
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
509
|
+
vendor_part_id='content',
|
|
510
|
+
content=content,
|
|
511
|
+
thinking_tags=self._model_profile.thinking_tags,
|
|
512
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
513
|
+
)
|
|
514
|
+
if maybe_event is not None: # pragma: no branch
|
|
515
|
+
yield maybe_event
|
|
516
|
+
|
|
517
|
+
# Handle the tool calls
|
|
518
|
+
for dtc in choice.delta.tool_calls or []:
|
|
519
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
520
|
+
vendor_part_id=dtc.index,
|
|
521
|
+
tool_name=dtc.function and dtc.function.name,
|
|
522
|
+
args=dtc.function and dtc.function.arguments,
|
|
523
|
+
tool_call_id=dtc.id,
|
|
524
|
+
)
|
|
525
|
+
if maybe_event is not None:
|
|
526
|
+
yield maybe_event
|
|
527
|
+
except APIError as e:
|
|
528
|
+
if isinstance(e.body, dict): # pragma: no branch
|
|
529
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
530
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call
|
|
531
|
+
try:
|
|
532
|
+
error = _GroqToolUseFailedInnerError.model_validate(e.body) # pyright: ignore[reportUnknownMemberType]
|
|
533
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
534
|
+
vendor_part_id='tool_use_failed',
|
|
535
|
+
tool_name=error.failed_generation.name,
|
|
536
|
+
args=error.failed_generation.arguments,
|
|
537
|
+
)
|
|
538
|
+
return
|
|
539
|
+
except ValidationError as e: # pragma: no cover
|
|
540
|
+
pass
|
|
541
|
+
raise # pragma: no cover
|
|
482
542
|
|
|
483
543
|
@property
|
|
484
544
|
def model_name(self) -> GroqModelName:
|
|
@@ -510,3 +570,31 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
|
|
|
510
570
|
input_tokens=response_usage.prompt_tokens,
|
|
511
571
|
output_tokens=response_usage.completion_tokens,
|
|
512
572
|
)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class _GroqToolUseFailedGeneration(BaseModel):
|
|
576
|
+
name: str
|
|
577
|
+
arguments: dict[str, Any]
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class _GroqToolUseFailedInnerError(BaseModel):
|
|
581
|
+
message: str
|
|
582
|
+
type: Literal['invalid_request_error']
|
|
583
|
+
code: Literal['tool_use_failed']
|
|
584
|
+
failed_generation: Json[_GroqToolUseFailedGeneration]
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class _GroqToolUseFailedError(BaseModel):
|
|
588
|
+
# The Groq SDK tries to be helpful by raising an exception when generated tool arguments don't match the schema,
|
|
589
|
+
# but we'd rather handle it ourselves so we can tell the model to retry the tool call.
|
|
590
|
+
# Example payload from `exception.body`:
|
|
591
|
+
# {
|
|
592
|
+
# 'error': {
|
|
593
|
+
# 'message': "Tool call validation failed: tool call validation failed: parameters for tool get_something_by_name did not match schema: errors: [missing properties: 'name', additionalProperties 'foo' not allowed]",
|
|
594
|
+
# 'type': 'invalid_request_error',
|
|
595
|
+
# 'code': 'tool_use_failed',
|
|
596
|
+
# 'failed_generation': '{"name": "get_something_by_name", "arguments": {\n "foo": "bar"\n}}',
|
|
597
|
+
# }
|
|
598
|
+
# }
|
|
599
|
+
|
|
600
|
+
error: _GroqToolUseFailedInnerError
|
|
@@ -221,7 +221,10 @@ class InstrumentationSettings:
|
|
|
221
221
|
_otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts)
|
|
222
222
|
)
|
|
223
223
|
elif isinstance(message, ModelResponse): # pragma: no branch
|
|
224
|
-
|
|
224
|
+
otel_message = _otel_messages.OutputMessage(role='assistant', parts=message.otel_message_parts(self))
|
|
225
|
+
if message.finish_reason is not None:
|
|
226
|
+
otel_message['finish_reason'] = message.finish_reason
|
|
227
|
+
result.append(otel_message)
|
|
225
228
|
return result
|
|
226
229
|
|
|
227
230
|
def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
|
|
@@ -246,12 +249,10 @@ class InstrumentationSettings:
|
|
|
246
249
|
else:
|
|
247
250
|
output_messages = self.messages_to_otel_messages([response])
|
|
248
251
|
assert len(output_messages) == 1
|
|
249
|
-
output_message =
|
|
250
|
-
if response.provider_details and 'finish_reason' in response.provider_details:
|
|
251
|
-
output_message['finish_reason'] = response.provider_details['finish_reason']
|
|
252
|
+
output_message = output_messages[0]
|
|
252
253
|
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
|
|
253
254
|
system_instructions_attributes = self.system_instructions_attributes(instructions)
|
|
254
|
-
attributes = {
|
|
255
|
+
attributes: dict[str, AttributeValue] = {
|
|
255
256
|
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
|
|
256
257
|
'gen_ai.output.messages': json.dumps([output_message]),
|
|
257
258
|
**system_instructions_attributes,
|
|
@@ -420,12 +421,25 @@ class InstrumentedModel(WrapperModel):
|
|
|
420
421
|
return
|
|
421
422
|
|
|
422
423
|
self.instrumentation_settings.handle_messages(messages, response, system, span)
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
424
|
+
|
|
425
|
+
attributes_to_set = {
|
|
426
|
+
**response.usage.opentelemetry_attributes(),
|
|
427
|
+
'gen_ai.response.model': response_model,
|
|
428
|
+
}
|
|
429
|
+
try:
|
|
430
|
+
attributes_to_set['operation.cost'] = float(response.cost().total_price)
|
|
431
|
+
except LookupError:
|
|
432
|
+
# The cost of this provider/model is unknown, which is common.
|
|
433
|
+
pass
|
|
434
|
+
except Exception as e:
|
|
435
|
+
warnings.warn(
|
|
436
|
+
f'Failed to get cost from response: {type(e).__name__}: {e}', CostCalculationFailedWarning
|
|
437
|
+
)
|
|
438
|
+
if response.provider_response_id is not None:
|
|
439
|
+
attributes_to_set['gen_ai.response.id'] = response.provider_response_id
|
|
440
|
+
if response.finish_reason is not None:
|
|
441
|
+
attributes_to_set['gen_ai.response.finish_reasons'] = [response.finish_reason]
|
|
442
|
+
span.set_attributes(attributes_to_set)
|
|
429
443
|
span.update_name(f'{operation} {request_model}')
|
|
430
444
|
|
|
431
445
|
yield finish
|
|
@@ -473,3 +487,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
473
487
|
return str(value)
|
|
474
488
|
except Exception as e:
|
|
475
489
|
return f'Unable to serialize: {e}'
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
class CostCalculationFailedWarning(Warning):
|
|
493
|
+
"""Warning raised when cost calculation fails."""
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -24,6 +24,7 @@ from ..messages import (
|
|
|
24
24
|
BuiltinToolCallPart,
|
|
25
25
|
BuiltinToolReturnPart,
|
|
26
26
|
DocumentUrl,
|
|
27
|
+
FinishReason,
|
|
27
28
|
ImageUrl,
|
|
28
29
|
ModelMessage,
|
|
29
30
|
ModelRequest,
|
|
@@ -72,6 +73,7 @@ try:
|
|
|
72
73
|
)
|
|
73
74
|
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
|
|
74
75
|
from openai.types.responses.response_input_param import FunctionCallOutput, Message
|
|
76
|
+
from openai.types.responses.response_status import ResponseStatus
|
|
75
77
|
from openai.types.shared import ReasoningEffort
|
|
76
78
|
from openai.types.shared_params import Reasoning
|
|
77
79
|
except ImportError as _import_error:
|
|
@@ -103,6 +105,25 @@ allows this model to be used more easily with other model types (ie, Ollama, Dee
|
|
|
103
105
|
"""
|
|
104
106
|
|
|
105
107
|
|
|
108
|
+
_CHAT_FINISH_REASON_MAP: dict[
|
|
109
|
+
Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason
|
|
110
|
+
] = {
|
|
111
|
+
'stop': 'stop',
|
|
112
|
+
'length': 'length',
|
|
113
|
+
'tool_calls': 'tool_call',
|
|
114
|
+
'content_filter': 'content_filter',
|
|
115
|
+
'function_call': 'tool_call',
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
_RESPONSES_FINISH_REASON_MAP: dict[Literal['max_output_tokens', 'content_filter'] | ResponseStatus, FinishReason] = {
|
|
119
|
+
'max_output_tokens': 'length',
|
|
120
|
+
'content_filter': 'content_filter',
|
|
121
|
+
'completed': 'stop',
|
|
122
|
+
'cancelled': 'error',
|
|
123
|
+
'failed': 'error',
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
106
127
|
class OpenAIChatModelSettings(ModelSettings, total=False):
|
|
107
128
|
"""Settings used for an OpenAI model request."""
|
|
108
129
|
|
|
@@ -225,6 +246,7 @@ class OpenAIChatModel(Model):
|
|
|
225
246
|
'openrouter',
|
|
226
247
|
'together',
|
|
227
248
|
'vercel',
|
|
249
|
+
'litellm',
|
|
228
250
|
]
|
|
229
251
|
| Provider[AsyncOpenAI] = 'openai',
|
|
230
252
|
profile: ModelProfileSpec | None = None,
|
|
@@ -252,6 +274,7 @@ class OpenAIChatModel(Model):
|
|
|
252
274
|
'openrouter',
|
|
253
275
|
'together',
|
|
254
276
|
'vercel',
|
|
277
|
+
'litellm',
|
|
255
278
|
]
|
|
256
279
|
| Provider[AsyncOpenAI] = 'openai',
|
|
257
280
|
profile: ModelProfileSpec | None = None,
|
|
@@ -278,6 +301,7 @@ class OpenAIChatModel(Model):
|
|
|
278
301
|
'openrouter',
|
|
279
302
|
'together',
|
|
280
303
|
'vercel',
|
|
304
|
+
'litellm',
|
|
281
305
|
]
|
|
282
306
|
| Provider[AsyncOpenAI] = 'openai',
|
|
283
307
|
profile: ModelProfileSpec | None = None,
|
|
@@ -471,24 +495,22 @@ class OpenAIChatModel(Model):
|
|
|
471
495
|
if reasoning_content := getattr(choice.message, 'reasoning_content', None):
|
|
472
496
|
items.append(ThinkingPart(content=reasoning_content))
|
|
473
497
|
|
|
474
|
-
vendor_details: dict[str, Any]
|
|
498
|
+
vendor_details: dict[str, Any] = {}
|
|
475
499
|
|
|
476
500
|
# Add logprobs to vendor_details if available
|
|
477
501
|
if choice.logprobs is not None and choice.logprobs.content:
|
|
478
502
|
# Convert logprobs to a serializable format
|
|
479
|
-
vendor_details =
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
'
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
],
|
|
491
|
-
}
|
|
503
|
+
vendor_details['logprobs'] = [
|
|
504
|
+
{
|
|
505
|
+
'token': lp.token,
|
|
506
|
+
'bytes': lp.bytes,
|
|
507
|
+
'logprob': lp.logprob,
|
|
508
|
+
'top_logprobs': [
|
|
509
|
+
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
|
|
510
|
+
],
|
|
511
|
+
}
|
|
512
|
+
for lp in choice.logprobs.content
|
|
513
|
+
]
|
|
492
514
|
|
|
493
515
|
if choice.message.content is not None:
|
|
494
516
|
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
|
|
@@ -504,14 +526,21 @@ class OpenAIChatModel(Model):
|
|
|
504
526
|
assert_never(c)
|
|
505
527
|
part.tool_call_id = _guard_tool_call_id(part)
|
|
506
528
|
items.append(part)
|
|
529
|
+
|
|
530
|
+
finish_reason: FinishReason | None = None
|
|
531
|
+
if raw_finish_reason := choice.finish_reason: # pragma: no branch
|
|
532
|
+
vendor_details['finish_reason'] = raw_finish_reason
|
|
533
|
+
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
534
|
+
|
|
507
535
|
return ModelResponse(
|
|
508
536
|
parts=items,
|
|
509
537
|
usage=_map_usage(response),
|
|
510
538
|
model_name=response.model,
|
|
511
539
|
timestamp=timestamp,
|
|
512
|
-
provider_details=vendor_details,
|
|
540
|
+
provider_details=vendor_details or None,
|
|
513
541
|
provider_response_id=response.id,
|
|
514
542
|
provider_name=self._provider.name,
|
|
543
|
+
finish_reason=finish_reason,
|
|
515
544
|
)
|
|
516
545
|
|
|
517
546
|
async def _process_streamed_response(
|
|
@@ -606,7 +635,7 @@ class OpenAIChatModel(Model):
|
|
|
606
635
|
def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat:
|
|
607
636
|
response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage]
|
|
608
637
|
'type': 'json_schema',
|
|
609
|
-
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema
|
|
638
|
+
'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema},
|
|
610
639
|
}
|
|
611
640
|
if o.description:
|
|
612
641
|
response_format_param['json_schema']['description'] = o.description
|
|
@@ -820,6 +849,14 @@ class OpenAIResponsesModel(Model):
|
|
|
820
849
|
items.append(TextPart(content.text))
|
|
821
850
|
elif item.type == 'function_call':
|
|
822
851
|
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
|
|
852
|
+
|
|
853
|
+
finish_reason: FinishReason | None = None
|
|
854
|
+
provider_details: dict[str, Any] | None = None
|
|
855
|
+
raw_finish_reason = details.reason if (details := response.incomplete_details) else response.status
|
|
856
|
+
if raw_finish_reason:
|
|
857
|
+
provider_details = {'finish_reason': raw_finish_reason}
|
|
858
|
+
finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
859
|
+
|
|
823
860
|
return ModelResponse(
|
|
824
861
|
parts=items,
|
|
825
862
|
usage=_map_usage(response),
|
|
@@ -827,6 +864,8 @@ class OpenAIResponsesModel(Model):
|
|
|
827
864
|
provider_response_id=response.id,
|
|
828
865
|
timestamp=timestamp,
|
|
829
866
|
provider_name=self._provider.name,
|
|
867
|
+
finish_reason=finish_reason,
|
|
868
|
+
provider_details=provider_details,
|
|
830
869
|
)
|
|
831
870
|
|
|
832
871
|
async def _process_streamed_response(
|
|
@@ -1166,11 +1205,22 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1166
1205
|
async for chunk in self._response:
|
|
1167
1206
|
self._usage += _map_usage(chunk)
|
|
1168
1207
|
|
|
1208
|
+
if chunk.id and self.provider_response_id is None:
|
|
1209
|
+
self.provider_response_id = chunk.id
|
|
1210
|
+
|
|
1169
1211
|
try:
|
|
1170
1212
|
choice = chunk.choices[0]
|
|
1171
1213
|
except IndexError:
|
|
1172
1214
|
continue
|
|
1173
1215
|
|
|
1216
|
+
# When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
|
|
1217
|
+
if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
|
|
1218
|
+
continue
|
|
1219
|
+
|
|
1220
|
+
if raw_finish_reason := choice.finish_reason:
|
|
1221
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
1222
|
+
self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
1223
|
+
|
|
1174
1224
|
# Handle the text part of the response
|
|
1175
1225
|
content = choice.delta.content
|
|
1176
1226
|
if content is not None:
|
|
@@ -1230,6 +1280,13 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1230
1280
|
if isinstance(chunk, responses.ResponseCompletedEvent):
|
|
1231
1281
|
self._usage += _map_usage(chunk.response)
|
|
1232
1282
|
|
|
1283
|
+
raw_finish_reason = (
|
|
1284
|
+
details.reason if (details := chunk.response.incomplete_details) else chunk.response.status
|
|
1285
|
+
)
|
|
1286
|
+
if raw_finish_reason: # pragma: no branch
|
|
1287
|
+
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
1288
|
+
self.finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
1289
|
+
|
|
1233
1290
|
elif isinstance(chunk, responses.ResponseContentPartAddedEvent):
|
|
1234
1291
|
pass # there's nothing we need to do here
|
|
1235
1292
|
|
|
@@ -1237,7 +1294,8 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1237
1294
|
pass # there's nothing we need to do here
|
|
1238
1295
|
|
|
1239
1296
|
elif isinstance(chunk, responses.ResponseCreatedEvent):
|
|
1240
|
-
|
|
1297
|
+
if chunk.response.id: # pragma: no branch
|
|
1298
|
+
self.provider_response_id = chunk.response.id
|
|
1241
1299
|
|
|
1242
1300
|
elif isinstance(chunk, responses.ResponseFailedEvent): # pragma: no cover
|
|
1243
1301
|
self._usage += _map_usage(chunk.response)
|
|
@@ -1270,12 +1328,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1270
1328
|
tool_call_id=chunk.item.call_id,
|
|
1271
1329
|
)
|
|
1272
1330
|
elif isinstance(chunk.item, responses.ResponseReasoningItem):
|
|
1273
|
-
|
|
1274
|
-
yield self._parts_manager.handle_thinking_delta(
|
|
1275
|
-
vendor_part_id=chunk.item.id,
|
|
1276
|
-
content=content,
|
|
1277
|
-
signature=chunk.item.id,
|
|
1278
|
-
)
|
|
1331
|
+
pass
|
|
1279
1332
|
elif isinstance(chunk.item, responses.ResponseOutputMessage):
|
|
1280
1333
|
pass
|
|
1281
1334
|
elif isinstance(chunk.item, responses.ResponseFunctionWebSearch):
|
|
@@ -1291,7 +1344,11 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1291
1344
|
pass
|
|
1292
1345
|
|
|
1293
1346
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
|
|
1294
|
-
|
|
1347
|
+
yield self._parts_manager.handle_thinking_delta(
|
|
1348
|
+
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
|
|
1349
|
+
content=chunk.part.text,
|
|
1350
|
+
id=chunk.item_id,
|
|
1351
|
+
)
|
|
1295
1352
|
|
|
1296
1353
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
|
|
1297
1354
|
pass # there's nothing we need to do here
|
|
@@ -1301,9 +1358,9 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1301
1358
|
|
|
1302
1359
|
elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
|
|
1303
1360
|
yield self._parts_manager.handle_thinking_delta(
|
|
1304
|
-
vendor_part_id=chunk.item_id,
|
|
1361
|
+
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
|
|
1305
1362
|
content=chunk.delta,
|
|
1306
|
-
|
|
1363
|
+
id=chunk.item_id,
|
|
1307
1364
|
)
|
|
1308
1365
|
|
|
1309
1366
|
# TODO(Marcelo): We should support annotations in the future.
|
|
@@ -1311,9 +1368,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1311
1368
|
pass # there's nothing we need to do here
|
|
1312
1369
|
|
|
1313
1370
|
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
|
|
1314
|
-
maybe_event = self._parts_manager.handle_text_delta(
|
|
1315
|
-
vendor_part_id=chunk.content_index, content=chunk.delta
|
|
1316
|
-
)
|
|
1371
|
+
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=chunk.item_id, content=chunk.delta)
|
|
1317
1372
|
if maybe_event is not None: # pragma: no branch
|
|
1318
1373
|
yield maybe_event
|
|
1319
1374
|
|
|
@@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
135
135
|
from .github import GitHubProvider
|
|
136
136
|
|
|
137
137
|
return GitHubProvider
|
|
138
|
+
elif provider == 'litellm':
|
|
139
|
+
from .litellm import LiteLLMProvider
|
|
140
|
+
|
|
141
|
+
return LiteLLMProvider
|
|
138
142
|
else: # pragma: no cover
|
|
139
143
|
raise ValueError(f'Unknown provider: {provider}')
|
|
140
144
|
|
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -35,11 +35,19 @@ class BedrockModelProfile(ModelProfile):
|
|
|
35
35
|
ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
36
36
|
"""
|
|
37
37
|
|
|
38
|
-
bedrock_supports_tool_choice: bool =
|
|
38
|
+
bedrock_supports_tool_choice: bool = False
|
|
39
39
|
bedrock_tool_result_format: Literal['text', 'json'] = 'text'
|
|
40
40
|
bedrock_send_back_thinking_parts: bool = False
|
|
41
41
|
|
|
42
42
|
|
|
43
|
+
def bedrock_amazon_model_profile(model_name: str) -> ModelProfile | None:
|
|
44
|
+
"""Get the model profile for an Amazon model used via Bedrock."""
|
|
45
|
+
profile = amazon_model_profile(model_name)
|
|
46
|
+
if 'nova' in model_name:
|
|
47
|
+
return BedrockModelProfile(bedrock_supports_tool_choice=True).update(profile)
|
|
48
|
+
return profile
|
|
49
|
+
|
|
50
|
+
|
|
43
51
|
class BedrockProvider(Provider[BaseClient]):
|
|
44
52
|
"""Provider for AWS Bedrock."""
|
|
45
53
|
|
|
@@ -58,13 +66,13 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
58
66
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
59
67
|
provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
|
|
60
68
|
'anthropic': lambda model_name: BedrockModelProfile(
|
|
61
|
-
bedrock_supports_tool_choice=
|
|
69
|
+
bedrock_supports_tool_choice=True, bedrock_send_back_thinking_parts=True
|
|
62
70
|
).update(anthropic_model_profile(model_name)),
|
|
63
71
|
'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
|
|
64
72
|
mistral_model_profile(model_name)
|
|
65
73
|
),
|
|
66
74
|
'cohere': cohere_model_profile,
|
|
67
|
-
'amazon':
|
|
75
|
+
'amazon': bedrock_amazon_model_profile,
|
|
68
76
|
'meta': meta_model_profile,
|
|
69
77
|
'deepseek': deepseek_model_profile,
|
|
70
78
|
}
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
|
+
from asyncio import Lock
|
|
4
5
|
from collections.abc import AsyncGenerator, Mapping
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Literal, overload
|
|
@@ -118,7 +119,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
118
119
|
class _VertexAIAuth(httpx.Auth):
|
|
119
120
|
"""Auth class for Vertex AI API."""
|
|
120
121
|
|
|
121
|
-
_refresh_lock:
|
|
122
|
+
_refresh_lock: Lock = Lock()
|
|
122
123
|
|
|
123
124
|
credentials: BaseCredentials | ServiceAccountCredentials | None
|
|
124
125
|
|
pydantic_ai/providers/groq.py
CHANGED
|
@@ -14,6 +14,7 @@ from pydantic_ai.profiles.groq import groq_model_profile
|
|
|
14
14
|
from pydantic_ai.profiles.meta import meta_model_profile
|
|
15
15
|
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
16
16
|
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
|
|
17
|
+
from pydantic_ai.profiles.openai import openai_model_profile
|
|
17
18
|
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
18
19
|
from pydantic_ai.providers import Provider
|
|
19
20
|
|
|
@@ -26,6 +27,23 @@ except ImportError as _import_error: # pragma: no cover
|
|
|
26
27
|
) from _import_error
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
def groq_moonshotai_model_profile(model_name: str) -> ModelProfile | None:
|
|
31
|
+
"""Get the model profile for an MoonshotAI model used with the Groq provider."""
|
|
32
|
+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
|
|
33
|
+
moonshotai_model_profile(model_name)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def meta_groq_model_profile(model_name: str) -> ModelProfile | None:
|
|
38
|
+
"""Get the model profile for a Meta model used with the Groq provider."""
|
|
39
|
+
if model_name in {'llama-4-maverick-17b-128e-instruct', 'llama-4-scout-17b-16e-instruct'}:
|
|
40
|
+
return ModelProfile(supports_json_object_output=True, supports_json_schema_output=True).update(
|
|
41
|
+
meta_model_profile(model_name)
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
return meta_model_profile(model_name)
|
|
45
|
+
|
|
46
|
+
|
|
29
47
|
class GroqProvider(Provider[AsyncGroq]):
|
|
30
48
|
"""Provider for Groq API."""
|
|
31
49
|
|
|
@@ -44,13 +62,14 @@ class GroqProvider(Provider[AsyncGroq]):
|
|
|
44
62
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
45
63
|
prefix_to_profile = {
|
|
46
64
|
'llama': meta_model_profile,
|
|
47
|
-
'meta-llama/':
|
|
65
|
+
'meta-llama/': meta_groq_model_profile,
|
|
48
66
|
'gemma': google_model_profile,
|
|
49
67
|
'qwen': qwen_model_profile,
|
|
50
68
|
'deepseek': deepseek_model_profile,
|
|
51
69
|
'mistral': mistral_model_profile,
|
|
52
|
-
'moonshotai/':
|
|
70
|
+
'moonshotai/': groq_moonshotai_model_profile,
|
|
53
71
|
'compound-': groq_model_profile,
|
|
72
|
+
'openai/': openai_model_profile,
|
|
54
73
|
}
|
|
55
74
|
|
|
56
75
|
for prefix, profile_func in prefix_to_profile.items():
|