pydantic-ai-slim 0.0.46__py3-none-any.whl → 0.0.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__main__.py +6 -0
- pydantic_ai/_agent_graph.py +19 -13
- pydantic_ai/_cli.py +120 -77
- pydantic_ai/_result.py +11 -4
- pydantic_ai/_utils.py +1 -1
- pydantic_ai/agent.py +30 -30
- pydantic_ai/messages.py +1 -1
- pydantic_ai/models/__init__.py +206 -193
- pydantic_ai/models/anthropic.py +4 -1
- pydantic_ai/models/bedrock.py +7 -0
- pydantic_ai/models/cohere.py +4 -1
- pydantic_ai/models/gemini.py +4 -1
- pydantic_ai/models/groq.py +32 -15
- pydantic_ai/models/instrumented.py +6 -1
- pydantic_ai/models/mistral.py +6 -1
- pydantic_ai/models/openai.py +415 -11
- pydantic_ai/providers/bedrock.py +11 -0
- pydantic_ai/tools.py +34 -3
- {pydantic_ai_slim-0.0.46.dist-info → pydantic_ai_slim-0.0.48.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-0.0.46.dist-info → pydantic_ai_slim-0.0.48.dist-info}/RECORD +22 -21
- {pydantic_ai_slim-0.0.46.dist-info → pydantic_ai_slim-0.0.48.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.46.dist-info → pydantic_ai_slim-0.0.48.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
+
import warnings
|
|
4
5
|
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
6
|
from contextlib import asynccontextmanager
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
8
|
from datetime import datetime, timezone
|
|
8
9
|
from typing import Literal, Union, cast, overload
|
|
9
10
|
|
|
11
|
+
from openai import NotGiven
|
|
12
|
+
from openai.types import Reasoning
|
|
10
13
|
from typing_extensions import assert_never
|
|
11
14
|
|
|
12
15
|
from pydantic_ai.providers import Provider, infer_provider
|
|
@@ -42,7 +45,7 @@ from . import (
|
|
|
42
45
|
|
|
43
46
|
try:
|
|
44
47
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
|
|
45
|
-
from openai.types import ChatModel, chat
|
|
48
|
+
from openai.types import ChatModel, chat, responses
|
|
46
49
|
from openai.types.chat import (
|
|
47
50
|
ChatCompletionChunk,
|
|
48
51
|
ChatCompletionContentPartImageParam,
|
|
@@ -52,6 +55,9 @@ try:
|
|
|
52
55
|
)
|
|
53
56
|
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
|
54
57
|
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
|
|
58
|
+
from openai.types.responses.response_input_param import FunctionCallOutput, Message
|
|
59
|
+
from openai.types.shared import ReasoningEffort
|
|
60
|
+
from openai.types.shared_params import Reasoning
|
|
55
61
|
except ImportError as _import_error:
|
|
56
62
|
raise ImportError(
|
|
57
63
|
'Please install `openai` to use the OpenAI model, '
|
|
@@ -74,16 +80,20 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
|
74
80
|
|
|
75
81
|
|
|
76
82
|
class OpenAIModelSettings(ModelSettings, total=False):
|
|
77
|
-
"""Settings used for an OpenAI model request.
|
|
83
|
+
"""Settings used for an OpenAI model request.
|
|
78
84
|
|
|
79
|
-
|
|
85
|
+
ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
openai_reasoning_effort: ReasoningEffort
|
|
80
89
|
"""
|
|
81
90
|
Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
|
|
91
|
+
|
|
82
92
|
Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
|
|
83
93
|
result in faster responses and fewer tokens used on reasoning in a response.
|
|
84
94
|
"""
|
|
85
95
|
|
|
86
|
-
|
|
96
|
+
openai_user: str
|
|
87
97
|
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
|
|
88
98
|
|
|
89
99
|
See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
|
|
@@ -175,8 +185,7 @@ class OpenAIModel(Model):
|
|
|
175
185
|
stream: Literal[True],
|
|
176
186
|
model_settings: OpenAIModelSettings,
|
|
177
187
|
model_request_parameters: ModelRequestParameters,
|
|
178
|
-
) -> AsyncStream[ChatCompletionChunk]:
|
|
179
|
-
pass
|
|
188
|
+
) -> AsyncStream[ChatCompletionChunk]: ...
|
|
180
189
|
|
|
181
190
|
@overload
|
|
182
191
|
async def _completions_create(
|
|
@@ -185,8 +194,7 @@ class OpenAIModel(Model):
|
|
|
185
194
|
stream: Literal[False],
|
|
186
195
|
model_settings: OpenAIModelSettings,
|
|
187
196
|
model_request_parameters: ModelRequestParameters,
|
|
188
|
-
) -> chat.ChatCompletion:
|
|
189
|
-
pass
|
|
197
|
+
) -> chat.ChatCompletion: ...
|
|
190
198
|
|
|
191
199
|
async def _completions_create(
|
|
192
200
|
self,
|
|
@@ -229,7 +237,7 @@ class OpenAIModel(Model):
|
|
|
229
237
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
230
238
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
231
239
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
232
|
-
user=model_settings.get('
|
|
240
|
+
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
233
241
|
)
|
|
234
242
|
except APIStatusError as e:
|
|
235
243
|
if (status_code := e.status_code) >= 400:
|
|
@@ -245,7 +253,7 @@ class OpenAIModel(Model):
|
|
|
245
253
|
items.append(TextPart(choice.message.content))
|
|
246
254
|
if choice.message.tool_calls is not None:
|
|
247
255
|
for c in choice.message.tool_calls:
|
|
248
|
-
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
256
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
249
257
|
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
250
258
|
|
|
251
259
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|
|
@@ -396,6 +404,311 @@ class OpenAIModel(Model):
|
|
|
396
404
|
return chat.ChatCompletionUserMessageParam(role='user', content=content)
|
|
397
405
|
|
|
398
406
|
|
|
407
|
+
@dataclass(init=False)
|
|
408
|
+
class OpenAIResponsesModel(Model):
|
|
409
|
+
"""A model that uses the OpenAI Responses API.
|
|
410
|
+
|
|
411
|
+
The [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) is the
|
|
412
|
+
new API for OpenAI models.
|
|
413
|
+
|
|
414
|
+
The Responses API has built-in tools, that you can use instead of building your own:
|
|
415
|
+
|
|
416
|
+
- [Web search](https://platform.openai.com/docs/guides/tools-web-search)
|
|
417
|
+
- [File search](https://platform.openai.com/docs/guides/tools-file-search)
|
|
418
|
+
- [Computer use](https://platform.openai.com/docs/guides/tools-computer-use)
|
|
419
|
+
|
|
420
|
+
If you are interested in the differences between the Responses API and the Chat Completions API,
|
|
421
|
+
see the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions).
|
|
422
|
+
"""
|
|
423
|
+
|
|
424
|
+
client: AsyncOpenAI = field(repr=False)
|
|
425
|
+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
426
|
+
|
|
427
|
+
_model_name: OpenAIModelName = field(repr=False)
|
|
428
|
+
_system: str = field(default='openai', repr=False)
|
|
429
|
+
|
|
430
|
+
def __init__(
|
|
431
|
+
self,
|
|
432
|
+
model_name: OpenAIModelName,
|
|
433
|
+
*,
|
|
434
|
+
provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
|
|
435
|
+
):
|
|
436
|
+
"""Initialize an OpenAI Responses model.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
model_name: The name of the OpenAI model to use.
|
|
440
|
+
provider: The provider to use. Defaults to `'openai'`.
|
|
441
|
+
"""
|
|
442
|
+
self._model_name = model_name
|
|
443
|
+
if isinstance(provider, str):
|
|
444
|
+
provider = infer_provider(provider)
|
|
445
|
+
self.client = provider.client
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def model_name(self) -> OpenAIModelName:
|
|
449
|
+
"""The model name."""
|
|
450
|
+
return self._model_name
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def system(self) -> str:
|
|
454
|
+
"""The system / model provider."""
|
|
455
|
+
return self._system
|
|
456
|
+
|
|
457
|
+
async def request(
|
|
458
|
+
self,
|
|
459
|
+
messages: list[ModelRequest | ModelResponse],
|
|
460
|
+
model_settings: ModelSettings | None,
|
|
461
|
+
model_request_parameters: ModelRequestParameters,
|
|
462
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
463
|
+
check_allow_model_requests()
|
|
464
|
+
response = await self._responses_create(
|
|
465
|
+
messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
466
|
+
)
|
|
467
|
+
return self._process_response(response), _map_usage(response)
|
|
468
|
+
|
|
469
|
+
@asynccontextmanager
|
|
470
|
+
async def request_stream(
|
|
471
|
+
self,
|
|
472
|
+
messages: list[ModelMessage],
|
|
473
|
+
model_settings: ModelSettings | None,
|
|
474
|
+
model_request_parameters: ModelRequestParameters,
|
|
475
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
476
|
+
check_allow_model_requests()
|
|
477
|
+
response = await self._responses_create(
|
|
478
|
+
messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
|
|
479
|
+
)
|
|
480
|
+
async with response:
|
|
481
|
+
yield await self._process_streamed_response(response)
|
|
482
|
+
|
|
483
|
+
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
484
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
485
|
+
timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
|
|
486
|
+
items: list[ModelResponsePart] = []
|
|
487
|
+
items.append(TextPart(response.output_text))
|
|
488
|
+
for item in response.output:
|
|
489
|
+
if item.type == 'function_call':
|
|
490
|
+
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
|
|
491
|
+
return ModelResponse(items, model_name=response.model, timestamp=timestamp)
|
|
492
|
+
|
|
493
|
+
async def _process_streamed_response(
|
|
494
|
+
self, response: AsyncStream[responses.ResponseStreamEvent]
|
|
495
|
+
) -> OpenAIResponsesStreamedResponse:
|
|
496
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
497
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
498
|
+
first_chunk = await peekable_response.peek()
|
|
499
|
+
if isinstance(first_chunk, _utils.Unset): # pragma: no cover
|
|
500
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
501
|
+
|
|
502
|
+
assert isinstance(first_chunk, responses.ResponseCreatedEvent)
|
|
503
|
+
return OpenAIResponsesStreamedResponse(
|
|
504
|
+
_model_name=self._model_name,
|
|
505
|
+
_response=peekable_response,
|
|
506
|
+
_timestamp=datetime.fromtimestamp(first_chunk.response.created_at, tz=timezone.utc),
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
@overload
|
|
510
|
+
async def _responses_create(
|
|
511
|
+
self,
|
|
512
|
+
messages: list[ModelRequest | ModelResponse],
|
|
513
|
+
stream: Literal[False],
|
|
514
|
+
model_settings: OpenAIModelSettings,
|
|
515
|
+
model_request_parameters: ModelRequestParameters,
|
|
516
|
+
) -> responses.Response: ...
|
|
517
|
+
|
|
518
|
+
@overload
|
|
519
|
+
async def _responses_create(
|
|
520
|
+
self,
|
|
521
|
+
messages: list[ModelRequest | ModelResponse],
|
|
522
|
+
stream: Literal[True],
|
|
523
|
+
model_settings: OpenAIModelSettings,
|
|
524
|
+
model_request_parameters: ModelRequestParameters,
|
|
525
|
+
) -> AsyncStream[responses.ResponseStreamEvent]: ...
|
|
526
|
+
|
|
527
|
+
async def _responses_create(
|
|
528
|
+
self,
|
|
529
|
+
messages: list[ModelRequest | ModelResponse],
|
|
530
|
+
stream: bool,
|
|
531
|
+
model_settings: OpenAIModelSettings,
|
|
532
|
+
model_request_parameters: ModelRequestParameters,
|
|
533
|
+
) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]:
|
|
534
|
+
tools = self._get_tools(model_request_parameters)
|
|
535
|
+
|
|
536
|
+
# standalone function to make it easier to override
|
|
537
|
+
if not tools:
|
|
538
|
+
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
539
|
+
elif not model_request_parameters.allow_text_result:
|
|
540
|
+
tool_choice = 'required'
|
|
541
|
+
else:
|
|
542
|
+
tool_choice = 'auto'
|
|
543
|
+
|
|
544
|
+
system_prompt, openai_messages = await self._map_message(messages)
|
|
545
|
+
|
|
546
|
+
reasoning_effort = model_settings.get('openai_reasoning_effort', NOT_GIVEN)
|
|
547
|
+
if not isinstance(reasoning_effort, NotGiven):
|
|
548
|
+
reasoning = Reasoning(effort=reasoning_effort)
|
|
549
|
+
else:
|
|
550
|
+
reasoning = NOT_GIVEN
|
|
551
|
+
|
|
552
|
+
try:
|
|
553
|
+
return await self.client.responses.create(
|
|
554
|
+
input=openai_messages,
|
|
555
|
+
model=self._model_name,
|
|
556
|
+
instructions=system_prompt,
|
|
557
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
558
|
+
tools=tools or NOT_GIVEN,
|
|
559
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
560
|
+
max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
561
|
+
stream=stream,
|
|
562
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
563
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
564
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
565
|
+
reasoning=reasoning,
|
|
566
|
+
user=model_settings.get('user', NOT_GIVEN),
|
|
567
|
+
)
|
|
568
|
+
except APIStatusError as e:
|
|
569
|
+
if (status_code := e.status_code) >= 400:
|
|
570
|
+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
571
|
+
raise
|
|
572
|
+
|
|
573
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
|
|
574
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
575
|
+
if model_request_parameters.result_tools:
|
|
576
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
577
|
+
return tools
|
|
578
|
+
|
|
579
|
+
@staticmethod
|
|
580
|
+
def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
|
|
581
|
+
return {
|
|
582
|
+
'name': f.name,
|
|
583
|
+
'parameters': f.parameters_json_schema,
|
|
584
|
+
'type': 'function',
|
|
585
|
+
'description': f.description,
|
|
586
|
+
'strict': True,
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[responses.ResponseInputItemParam]]:
|
|
590
|
+
"""Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
|
|
591
|
+
system_prompt: str = ''
|
|
592
|
+
openai_messages: list[responses.ResponseInputItemParam] = []
|
|
593
|
+
for message in messages:
|
|
594
|
+
if isinstance(message, ModelRequest):
|
|
595
|
+
for part in message.parts:
|
|
596
|
+
if isinstance(part, SystemPromptPart):
|
|
597
|
+
system_prompt += part.content
|
|
598
|
+
elif isinstance(part, UserPromptPart):
|
|
599
|
+
openai_messages.append(await self._map_user_prompt(part))
|
|
600
|
+
elif isinstance(part, ToolReturnPart):
|
|
601
|
+
openai_messages.append(
|
|
602
|
+
FunctionCallOutput(
|
|
603
|
+
type='function_call_output',
|
|
604
|
+
call_id=_guard_tool_call_id(t=part),
|
|
605
|
+
output=part.model_response_str(),
|
|
606
|
+
)
|
|
607
|
+
)
|
|
608
|
+
elif isinstance(part, RetryPromptPart):
|
|
609
|
+
# TODO(Marcelo): How do we test this conditional branch?
|
|
610
|
+
if part.tool_name is None: # pragma: no cover
|
|
611
|
+
openai_messages.append(
|
|
612
|
+
Message(role='user', content=[{'type': 'input_text', 'text': part.model_response()}])
|
|
613
|
+
)
|
|
614
|
+
else:
|
|
615
|
+
openai_messages.append(
|
|
616
|
+
FunctionCallOutput(
|
|
617
|
+
type='function_call_output',
|
|
618
|
+
call_id=_guard_tool_call_id(t=part),
|
|
619
|
+
output=part.model_response(),
|
|
620
|
+
)
|
|
621
|
+
)
|
|
622
|
+
else:
|
|
623
|
+
assert_never(part)
|
|
624
|
+
elif isinstance(message, ModelResponse):
|
|
625
|
+
for item in message.parts:
|
|
626
|
+
if isinstance(item, TextPart):
|
|
627
|
+
openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content))
|
|
628
|
+
elif isinstance(item, ToolCallPart):
|
|
629
|
+
openai_messages.append(self._map_tool_call(item))
|
|
630
|
+
else:
|
|
631
|
+
assert_never(item)
|
|
632
|
+
else:
|
|
633
|
+
assert_never(message)
|
|
634
|
+
return system_prompt, openai_messages
|
|
635
|
+
|
|
636
|
+
@staticmethod
|
|
637
|
+
def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
|
|
638
|
+
return responses.ResponseFunctionToolCallParam(
|
|
639
|
+
arguments=t.args_as_json_str(),
|
|
640
|
+
call_id=_guard_tool_call_id(t=t),
|
|
641
|
+
name=t.tool_name,
|
|
642
|
+
type='function_call',
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
@staticmethod
|
|
646
|
+
async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam:
|
|
647
|
+
content: str | list[responses.ResponseInputContentParam]
|
|
648
|
+
if isinstance(part.content, str):
|
|
649
|
+
content = part.content
|
|
650
|
+
else:
|
|
651
|
+
content = []
|
|
652
|
+
for item in part.content:
|
|
653
|
+
if isinstance(item, str):
|
|
654
|
+
content.append(responses.ResponseInputTextParam(text=item, type='input_text'))
|
|
655
|
+
elif isinstance(item, BinaryContent):
|
|
656
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
657
|
+
if item.is_image:
|
|
658
|
+
content.append(
|
|
659
|
+
responses.ResponseInputImageParam(
|
|
660
|
+
image_url=f'data:{item.media_type};base64,{base64_encoded}',
|
|
661
|
+
type='input_image',
|
|
662
|
+
detail='auto',
|
|
663
|
+
)
|
|
664
|
+
)
|
|
665
|
+
elif item.is_document:
|
|
666
|
+
content.append(
|
|
667
|
+
responses.ResponseInputFileParam(
|
|
668
|
+
type='input_file',
|
|
669
|
+
file_data=f'data:{item.media_type};base64,{base64_encoded}',
|
|
670
|
+
# NOTE: Type wise it's not necessary to include the filename, but it's required by the
|
|
671
|
+
# API itself. If we add empty string, the server sends a 500 error - which OpenAI needs
|
|
672
|
+
# to fix. In any case, we add a placeholder name.
|
|
673
|
+
filename=f'filename.{item.format}',
|
|
674
|
+
)
|
|
675
|
+
)
|
|
676
|
+
elif item.is_audio:
|
|
677
|
+
raise NotImplementedError('Audio as binary content is not supported for OpenAI Responses API.')
|
|
678
|
+
else: # pragma: no cover
|
|
679
|
+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
680
|
+
elif isinstance(item, ImageUrl):
|
|
681
|
+
content.append(
|
|
682
|
+
responses.ResponseInputImageParam(image_url=item.url, type='input_image', detail='auto')
|
|
683
|
+
)
|
|
684
|
+
elif isinstance(item, AudioUrl): # pragma: no cover
|
|
685
|
+
client = cached_async_http_client()
|
|
686
|
+
response = await client.get(item.url)
|
|
687
|
+
response.raise_for_status()
|
|
688
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
689
|
+
content.append(
|
|
690
|
+
responses.ResponseInputFileParam(
|
|
691
|
+
type='input_file',
|
|
692
|
+
file_data=f'data:{item.media_type};base64,{base64_encoded}',
|
|
693
|
+
)
|
|
694
|
+
)
|
|
695
|
+
elif isinstance(item, DocumentUrl): # pragma: no cover
|
|
696
|
+
client = cached_async_http_client()
|
|
697
|
+
response = await client.get(item.url)
|
|
698
|
+
response.raise_for_status()
|
|
699
|
+
base64_encoded = base64.b64encode(response.content).decode('utf-8')
|
|
700
|
+
content.append(
|
|
701
|
+
responses.ResponseInputFileParam(
|
|
702
|
+
type='input_file',
|
|
703
|
+
file_data=f'data:{item.media_type};base64,{base64_encoded}',
|
|
704
|
+
filename=f'filename.{item.format}',
|
|
705
|
+
)
|
|
706
|
+
)
|
|
707
|
+
else:
|
|
708
|
+
assert_never(item)
|
|
709
|
+
return responses.EasyInputMessageParam(role='user', content=content)
|
|
710
|
+
|
|
711
|
+
|
|
399
712
|
@dataclass
|
|
400
713
|
class OpenAIStreamedResponse(StreamedResponse):
|
|
401
714
|
"""Implementation of `StreamedResponse` for OpenAI models."""
|
|
@@ -439,10 +752,101 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
439
752
|
return self._timestamp
|
|
440
753
|
|
|
441
754
|
|
|
442
|
-
|
|
755
|
+
@dataclass
|
|
756
|
+
class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
757
|
+
"""Implementation of `StreamedResponse` for OpenAI Responses API."""
|
|
758
|
+
|
|
759
|
+
_model_name: OpenAIModelName
|
|
760
|
+
_response: AsyncIterable[responses.ResponseStreamEvent]
|
|
761
|
+
_timestamp: datetime
|
|
762
|
+
|
|
763
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
|
|
764
|
+
async for chunk in self._response:
|
|
765
|
+
if isinstance(chunk, responses.ResponseCompletedEvent):
|
|
766
|
+
self._usage += _map_usage(chunk.response)
|
|
767
|
+
|
|
768
|
+
elif isinstance(chunk, responses.ResponseContentPartAddedEvent):
|
|
769
|
+
pass # there's nothing we need to do here
|
|
770
|
+
|
|
771
|
+
elif isinstance(chunk, responses.ResponseContentPartDoneEvent):
|
|
772
|
+
pass # there's nothing we need to do here
|
|
773
|
+
|
|
774
|
+
elif isinstance(chunk, responses.ResponseCreatedEvent):
|
|
775
|
+
pass # there's nothing we need to do here
|
|
776
|
+
|
|
777
|
+
elif isinstance(chunk, responses.ResponseFailedEvent): # pragma: no cover
|
|
778
|
+
self._usage += _map_usage(chunk.response)
|
|
779
|
+
|
|
780
|
+
elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDeltaEvent):
|
|
781
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
782
|
+
vendor_part_id=chunk.item_id,
|
|
783
|
+
tool_name=None,
|
|
784
|
+
args=chunk.delta,
|
|
785
|
+
tool_call_id=chunk.item_id,
|
|
786
|
+
)
|
|
787
|
+
if maybe_event is not None:
|
|
788
|
+
yield maybe_event
|
|
789
|
+
|
|
790
|
+
elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDoneEvent):
|
|
791
|
+
pass # there's nothing we need to do here
|
|
792
|
+
|
|
793
|
+
elif isinstance(chunk, responses.ResponseIncompleteEvent): # pragma: no cover
|
|
794
|
+
self._usage += _map_usage(chunk.response)
|
|
795
|
+
|
|
796
|
+
elif isinstance(chunk, responses.ResponseInProgressEvent):
|
|
797
|
+
self._usage += _map_usage(chunk.response)
|
|
798
|
+
|
|
799
|
+
elif isinstance(chunk, responses.ResponseOutputItemAddedEvent):
|
|
800
|
+
if isinstance(chunk.item, responses.ResponseFunctionToolCall):
|
|
801
|
+
yield self._parts_manager.handle_tool_call_part(
|
|
802
|
+
vendor_part_id=chunk.item.id,
|
|
803
|
+
tool_name=chunk.item.name,
|
|
804
|
+
args=chunk.item.arguments,
|
|
805
|
+
tool_call_id=chunk.item.id,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
elif isinstance(chunk, responses.ResponseOutputItemDoneEvent):
|
|
809
|
+
# NOTE: We only need this if the tool call deltas don't include the final info.
|
|
810
|
+
pass
|
|
811
|
+
|
|
812
|
+
elif isinstance(chunk, responses.ResponseTextDeltaEvent):
|
|
813
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta)
|
|
814
|
+
|
|
815
|
+
elif isinstance(chunk, responses.ResponseTextDoneEvent):
|
|
816
|
+
pass # there's nothing we need to do here
|
|
817
|
+
|
|
818
|
+
else: # pragma: no cover
|
|
819
|
+
warnings.warn(
|
|
820
|
+
f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
|
|
821
|
+
UserWarning,
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
@property
|
|
825
|
+
def model_name(self) -> OpenAIModelName:
|
|
826
|
+
"""Get the model name of the response."""
|
|
827
|
+
return self._model_name
|
|
828
|
+
|
|
829
|
+
@property
|
|
830
|
+
def timestamp(self) -> datetime:
|
|
831
|
+
"""Get the timestamp of the response."""
|
|
832
|
+
return self._timestamp
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.Usage:
|
|
443
836
|
response_usage = response.usage
|
|
444
837
|
if response_usage is None:
|
|
445
838
|
return usage.Usage()
|
|
839
|
+
elif isinstance(response_usage, responses.ResponseUsage):
|
|
840
|
+
details: dict[str, int] = {}
|
|
841
|
+
return usage.Usage(
|
|
842
|
+
request_tokens=response_usage.input_tokens,
|
|
843
|
+
response_tokens=response_usage.output_tokens,
|
|
844
|
+
total_tokens=response_usage.total_tokens,
|
|
845
|
+
details={
|
|
846
|
+
'reasoning_tokens': response_usage.output_tokens_details.reasoning_tokens,
|
|
847
|
+
'cached_tokens': response_usage.input_tokens_details.cached_tokens,
|
|
848
|
+
},
|
|
849
|
+
)
|
|
446
850
|
else:
|
|
447
851
|
details: dict[str, int] = {}
|
|
448
852
|
if response_usage.completion_tokens_details is not None:
|
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import overload
|
|
4
5
|
|
|
5
6
|
from pydantic_ai.exceptions import UserError
|
|
@@ -8,6 +9,7 @@ from pydantic_ai.providers import Provider
|
|
|
8
9
|
try:
|
|
9
10
|
import boto3
|
|
10
11
|
from botocore.client import BaseClient
|
|
12
|
+
from botocore.config import Config
|
|
11
13
|
from botocore.exceptions import NoRegionError
|
|
12
14
|
except ImportError as _import_error:
|
|
13
15
|
raise ImportError(
|
|
@@ -42,6 +44,8 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
42
44
|
aws_access_key_id: str | None = None,
|
|
43
45
|
aws_secret_access_key: str | None = None,
|
|
44
46
|
aws_session_token: str | None = None,
|
|
47
|
+
aws_read_timeout: float | None = None,
|
|
48
|
+
aws_connect_timeout: float | None = None,
|
|
45
49
|
) -> None: ...
|
|
46
50
|
|
|
47
51
|
def __init__(
|
|
@@ -52,6 +56,8 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
52
56
|
aws_access_key_id: str | None = None,
|
|
53
57
|
aws_secret_access_key: str | None = None,
|
|
54
58
|
aws_session_token: str | None = None,
|
|
59
|
+
aws_read_timeout: float | None = None,
|
|
60
|
+
aws_connect_timeout: float | None = None,
|
|
55
61
|
) -> None:
|
|
56
62
|
"""Initialize the Bedrock provider.
|
|
57
63
|
|
|
@@ -61,17 +67,22 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
61
67
|
aws_access_key_id: The AWS access key ID.
|
|
62
68
|
aws_secret_access_key: The AWS secret access key.
|
|
63
69
|
aws_session_token: The AWS session token.
|
|
70
|
+
aws_read_timeout: The read timeout for Bedrock client.
|
|
71
|
+
aws_connect_timeout: The connect timeout for Bedrock client.
|
|
64
72
|
"""
|
|
65
73
|
if bedrock_client is not None:
|
|
66
74
|
self._client = bedrock_client
|
|
67
75
|
else:
|
|
68
76
|
try:
|
|
77
|
+
read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
|
|
78
|
+
connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
|
|
69
79
|
self._client = boto3.client( # type: ignore[reportUnknownMemberType]
|
|
70
80
|
'bedrock-runtime',
|
|
71
81
|
aws_access_key_id=aws_access_key_id,
|
|
72
82
|
aws_secret_access_key=aws_secret_access_key,
|
|
73
83
|
aws_session_token=aws_session_token,
|
|
74
84
|
region_name=region_name,
|
|
85
|
+
config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout),
|
|
75
86
|
)
|
|
76
87
|
except NoRegionError as exc: # pragma: no cover
|
|
77
88
|
raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
|
pydantic_ai/tools.py
CHANGED
|
@@ -2,10 +2,12 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import inspect
|
|
5
|
+
import json
|
|
5
6
|
from collections.abc import Awaitable, Sequence
|
|
6
7
|
from dataclasses import dataclass, field
|
|
7
8
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
|
|
8
9
|
|
|
10
|
+
from opentelemetry.trace import Tracer
|
|
9
11
|
from pydantic import ValidationError
|
|
10
12
|
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
|
|
11
13
|
from pydantic_core import SchemaValidator, core_schema
|
|
@@ -147,8 +149,8 @@ class GenerateToolJsonSchema(GenerateJsonSchema):
|
|
|
147
149
|
def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
|
|
148
150
|
s = super().typed_dict_schema(schema)
|
|
149
151
|
total = schema.get('total')
|
|
150
|
-
if total is
|
|
151
|
-
s['additionalProperties'] =
|
|
152
|
+
if 'additionalProperties' not in s and (total is True or total is None):
|
|
153
|
+
s['additionalProperties'] = False
|
|
152
154
|
return s
|
|
153
155
|
|
|
154
156
|
def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
|
|
@@ -286,9 +288,38 @@ class Tool(Generic[AgentDepsT]):
|
|
|
286
288
|
return tool_def
|
|
287
289
|
|
|
288
290
|
async def run(
|
|
291
|
+
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer
|
|
292
|
+
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
|
|
293
|
+
"""Run the tool function asynchronously.
|
|
294
|
+
|
|
295
|
+
This method wraps `_run` in an OpenTelemetry span.
|
|
296
|
+
|
|
297
|
+
See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>.
|
|
298
|
+
"""
|
|
299
|
+
span_attributes = {
|
|
300
|
+
'gen_ai.tool.name': self.name,
|
|
301
|
+
# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
|
|
302
|
+
'gen_ai.tool.call.id': message.tool_call_id,
|
|
303
|
+
'tool_arguments': message.args_as_json_str(),
|
|
304
|
+
'logfire.msg': f'running tool: {self.name}',
|
|
305
|
+
# add the JSON schema so these attributes are formatted nicely in Logfire
|
|
306
|
+
'logfire.json_schema': json.dumps(
|
|
307
|
+
{
|
|
308
|
+
'type': 'object',
|
|
309
|
+
'properties': {
|
|
310
|
+
'tool_arguments': {'type': 'object'},
|
|
311
|
+
'gen_ai.tool.name': {},
|
|
312
|
+
'gen_ai.tool.call.id': {},
|
|
313
|
+
},
|
|
314
|
+
}
|
|
315
|
+
),
|
|
316
|
+
}
|
|
317
|
+
with tracer.start_as_current_span('running tool', attributes=span_attributes):
|
|
318
|
+
return await self._run(message, run_context)
|
|
319
|
+
|
|
320
|
+
async def _run(
|
|
289
321
|
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
|
|
290
322
|
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
|
|
291
|
-
"""Run the tool function asynchronously."""
|
|
292
323
|
try:
|
|
293
324
|
if isinstance(message.args, str):
|
|
294
325
|
args_dict = self._validator.validate_json(message.args)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.48
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.48
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -41,13 +41,15 @@ Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
|
|
|
41
41
|
Requires-Dist: prompt-toolkit>=3; extra == 'cli'
|
|
42
42
|
Requires-Dist: rich>=13; extra == 'cli'
|
|
43
43
|
Provides-Extra: cohere
|
|
44
|
-
Requires-Dist: cohere>=5.13.11; extra == 'cohere'
|
|
44
|
+
Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == 'cohere'
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
|
+
Provides-Extra: evals
|
|
48
|
+
Requires-Dist: pydantic-evals==0.0.48; extra == 'evals'
|
|
47
49
|
Provides-Extra: groq
|
|
48
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
49
51
|
Provides-Extra: logfire
|
|
50
|
-
Requires-Dist: logfire>=
|
|
52
|
+
Requires-Dist: logfire>=3.11.0; extra == 'logfire'
|
|
51
53
|
Provides-Extra: mcp
|
|
52
54
|
Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
|
|
53
55
|
Provides-Extra: mistral
|