pydantic-ai-slim 0.0.46__py3-none-any.whl → 0.0.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
+ import warnings
4
5
  from collections.abc import AsyncIterable, AsyncIterator
5
6
  from contextlib import asynccontextmanager
6
7
  from dataclasses import dataclass, field
7
8
  from datetime import datetime, timezone
8
9
  from typing import Literal, Union, cast, overload
9
10
 
11
+ from openai import NotGiven
12
+ from openai.types import Reasoning
10
13
  from typing_extensions import assert_never
11
14
 
12
15
  from pydantic_ai.providers import Provider, infer_provider
@@ -42,7 +45,7 @@ from . import (
42
45
 
43
46
  try:
44
47
  from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream
45
- from openai.types import ChatModel, chat
48
+ from openai.types import ChatModel, chat, responses
46
49
  from openai.types.chat import (
47
50
  ChatCompletionChunk,
48
51
  ChatCompletionContentPartImageParam,
@@ -52,6 +55,9 @@ try:
52
55
  )
53
56
  from openai.types.chat.chat_completion_content_part_image_param import ImageURL
54
57
  from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
58
+ from openai.types.responses.response_input_param import FunctionCallOutput, Message
59
+ from openai.types.shared import ReasoningEffort
60
+ from openai.types.shared_params import Reasoning
55
61
  except ImportError as _import_error:
56
62
  raise ImportError(
57
63
  'Please install `openai` to use the OpenAI model, '
@@ -74,16 +80,20 @@ OpenAISystemPromptRole = Literal['system', 'developer', 'user']
74
80
 
75
81
 
76
82
  class OpenAIModelSettings(ModelSettings, total=False):
77
- """Settings used for an OpenAI model request."""
83
+ """Settings used for an OpenAI model request.
78
84
 
79
- openai_reasoning_effort: chat.ChatCompletionReasoningEffort
85
+ ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
86
+ """
87
+
88
+ openai_reasoning_effort: ReasoningEffort
80
89
  """
81
90
  Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning).
91
+
82
92
  Currently supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
83
93
  result in faster responses and fewer tokens used on reasoning in a response.
84
94
  """
85
95
 
86
- user: str
96
+ openai_user: str
87
97
  """A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
88
98
 
89
99
  See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
@@ -175,8 +185,7 @@ class OpenAIModel(Model):
175
185
  stream: Literal[True],
176
186
  model_settings: OpenAIModelSettings,
177
187
  model_request_parameters: ModelRequestParameters,
178
- ) -> AsyncStream[ChatCompletionChunk]:
179
- pass
188
+ ) -> AsyncStream[ChatCompletionChunk]: ...
180
189
 
181
190
  @overload
182
191
  async def _completions_create(
@@ -185,8 +194,7 @@ class OpenAIModel(Model):
185
194
  stream: Literal[False],
186
195
  model_settings: OpenAIModelSettings,
187
196
  model_request_parameters: ModelRequestParameters,
188
- ) -> chat.ChatCompletion:
189
- pass
197
+ ) -> chat.ChatCompletion: ...
190
198
 
191
199
  async def _completions_create(
192
200
  self,
@@ -229,7 +237,7 @@ class OpenAIModel(Model):
229
237
  frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
230
238
  logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
231
239
  reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
232
- user=model_settings.get('user', NOT_GIVEN),
240
+ user=model_settings.get('openai_user', NOT_GIVEN),
233
241
  )
234
242
  except APIStatusError as e:
235
243
  if (status_code := e.status_code) >= 400:
@@ -245,7 +253,7 @@ class OpenAIModel(Model):
245
253
  items.append(TextPart(choice.message.content))
246
254
  if choice.message.tool_calls is not None:
247
255
  for c in choice.message.tool_calls:
248
- items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
256
+ items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
249
257
  return ModelResponse(items, model_name=response.model, timestamp=timestamp)
250
258
 
251
259
  async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
@@ -396,6 +404,311 @@ class OpenAIModel(Model):
396
404
  return chat.ChatCompletionUserMessageParam(role='user', content=content)
397
405
 
398
406
 
407
+ @dataclass(init=False)
408
+ class OpenAIResponsesModel(Model):
409
+ """A model that uses the OpenAI Responses API.
410
+
411
+ The [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) is the
412
+ new API for OpenAI models.
413
+
414
+ The Responses API has built-in tools, that you can use instead of building your own:
415
+
416
+ - [Web search](https://platform.openai.com/docs/guides/tools-web-search)
417
+ - [File search](https://platform.openai.com/docs/guides/tools-file-search)
418
+ - [Computer use](https://platform.openai.com/docs/guides/tools-computer-use)
419
+
420
+ If you are interested in the differences between the Responses API and the Chat Completions API,
421
+ see the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions).
422
+ """
423
+
424
+ client: AsyncOpenAI = field(repr=False)
425
+ system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
426
+
427
+ _model_name: OpenAIModelName = field(repr=False)
428
+ _system: str = field(default='openai', repr=False)
429
+
430
+ def __init__(
431
+ self,
432
+ model_name: OpenAIModelName,
433
+ *,
434
+ provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
435
+ ):
436
+ """Initialize an OpenAI Responses model.
437
+
438
+ Args:
439
+ model_name: The name of the OpenAI model to use.
440
+ provider: The provider to use. Defaults to `'openai'`.
441
+ """
442
+ self._model_name = model_name
443
+ if isinstance(provider, str):
444
+ provider = infer_provider(provider)
445
+ self.client = provider.client
446
+
447
+ @property
448
+ def model_name(self) -> OpenAIModelName:
449
+ """The model name."""
450
+ return self._model_name
451
+
452
+ @property
453
+ def system(self) -> str:
454
+ """The system / model provider."""
455
+ return self._system
456
+
457
+ async def request(
458
+ self,
459
+ messages: list[ModelRequest | ModelResponse],
460
+ model_settings: ModelSettings | None,
461
+ model_request_parameters: ModelRequestParameters,
462
+ ) -> tuple[ModelResponse, usage.Usage]:
463
+ check_allow_model_requests()
464
+ response = await self._responses_create(
465
+ messages, False, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
466
+ )
467
+ return self._process_response(response), _map_usage(response)
468
+
469
+ @asynccontextmanager
470
+ async def request_stream(
471
+ self,
472
+ messages: list[ModelMessage],
473
+ model_settings: ModelSettings | None,
474
+ model_request_parameters: ModelRequestParameters,
475
+ ) -> AsyncIterator[StreamedResponse]:
476
+ check_allow_model_requests()
477
+ response = await self._responses_create(
478
+ messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters
479
+ )
480
+ async with response:
481
+ yield await self._process_streamed_response(response)
482
+
483
+ def _process_response(self, response: responses.Response) -> ModelResponse:
484
+ """Process a non-streamed response, and prepare a message to return."""
485
+ timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
486
+ items: list[ModelResponsePart] = []
487
+ items.append(TextPart(response.output_text))
488
+ for item in response.output:
489
+ if item.type == 'function_call':
490
+ items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
491
+ return ModelResponse(items, model_name=response.model, timestamp=timestamp)
492
+
493
+ async def _process_streamed_response(
494
+ self, response: AsyncStream[responses.ResponseStreamEvent]
495
+ ) -> OpenAIResponsesStreamedResponse:
496
+ """Process a streamed response, and prepare a streaming response to return."""
497
+ peekable_response = _utils.PeekableAsyncStream(response)
498
+ first_chunk = await peekable_response.peek()
499
+ if isinstance(first_chunk, _utils.Unset): # pragma: no cover
500
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
501
+
502
+ assert isinstance(first_chunk, responses.ResponseCreatedEvent)
503
+ return OpenAIResponsesStreamedResponse(
504
+ _model_name=self._model_name,
505
+ _response=peekable_response,
506
+ _timestamp=datetime.fromtimestamp(first_chunk.response.created_at, tz=timezone.utc),
507
+ )
508
+
509
+ @overload
510
+ async def _responses_create(
511
+ self,
512
+ messages: list[ModelRequest | ModelResponse],
513
+ stream: Literal[False],
514
+ model_settings: OpenAIModelSettings,
515
+ model_request_parameters: ModelRequestParameters,
516
+ ) -> responses.Response: ...
517
+
518
+ @overload
519
+ async def _responses_create(
520
+ self,
521
+ messages: list[ModelRequest | ModelResponse],
522
+ stream: Literal[True],
523
+ model_settings: OpenAIModelSettings,
524
+ model_request_parameters: ModelRequestParameters,
525
+ ) -> AsyncStream[responses.ResponseStreamEvent]: ...
526
+
527
+ async def _responses_create(
528
+ self,
529
+ messages: list[ModelRequest | ModelResponse],
530
+ stream: bool,
531
+ model_settings: OpenAIModelSettings,
532
+ model_request_parameters: ModelRequestParameters,
533
+ ) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]:
534
+ tools = self._get_tools(model_request_parameters)
535
+
536
+ # standalone function to make it easier to override
537
+ if not tools:
538
+ tool_choice: Literal['none', 'required', 'auto'] | None = None
539
+ elif not model_request_parameters.allow_text_result:
540
+ tool_choice = 'required'
541
+ else:
542
+ tool_choice = 'auto'
543
+
544
+ system_prompt, openai_messages = await self._map_message(messages)
545
+
546
+ reasoning_effort = model_settings.get('openai_reasoning_effort', NOT_GIVEN)
547
+ if not isinstance(reasoning_effort, NotGiven):
548
+ reasoning = Reasoning(effort=reasoning_effort)
549
+ else:
550
+ reasoning = NOT_GIVEN
551
+
552
+ try:
553
+ return await self.client.responses.create(
554
+ input=openai_messages,
555
+ model=self._model_name,
556
+ instructions=system_prompt,
557
+ parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
558
+ tools=tools or NOT_GIVEN,
559
+ tool_choice=tool_choice or NOT_GIVEN,
560
+ max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN),
561
+ stream=stream,
562
+ temperature=model_settings.get('temperature', NOT_GIVEN),
563
+ top_p=model_settings.get('top_p', NOT_GIVEN),
564
+ timeout=model_settings.get('timeout', NOT_GIVEN),
565
+ reasoning=reasoning,
566
+ user=model_settings.get('user', NOT_GIVEN),
567
+ )
568
+ except APIStatusError as e:
569
+ if (status_code := e.status_code) >= 400:
570
+ raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
571
+ raise
572
+
573
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]:
574
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
575
+ if model_request_parameters.result_tools:
576
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
577
+ return tools
578
+
579
+ @staticmethod
580
+ def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
581
+ return {
582
+ 'name': f.name,
583
+ 'parameters': f.parameters_json_schema,
584
+ 'type': 'function',
585
+ 'description': f.description,
586
+ 'strict': True,
587
+ }
588
+
589
+ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[responses.ResponseInputItemParam]]:
590
+ """Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`."""
591
+ system_prompt: str = ''
592
+ openai_messages: list[responses.ResponseInputItemParam] = []
593
+ for message in messages:
594
+ if isinstance(message, ModelRequest):
595
+ for part in message.parts:
596
+ if isinstance(part, SystemPromptPart):
597
+ system_prompt += part.content
598
+ elif isinstance(part, UserPromptPart):
599
+ openai_messages.append(await self._map_user_prompt(part))
600
+ elif isinstance(part, ToolReturnPart):
601
+ openai_messages.append(
602
+ FunctionCallOutput(
603
+ type='function_call_output',
604
+ call_id=_guard_tool_call_id(t=part),
605
+ output=part.model_response_str(),
606
+ )
607
+ )
608
+ elif isinstance(part, RetryPromptPart):
609
+ # TODO(Marcelo): How do we test this conditional branch?
610
+ if part.tool_name is None: # pragma: no cover
611
+ openai_messages.append(
612
+ Message(role='user', content=[{'type': 'input_text', 'text': part.model_response()}])
613
+ )
614
+ else:
615
+ openai_messages.append(
616
+ FunctionCallOutput(
617
+ type='function_call_output',
618
+ call_id=_guard_tool_call_id(t=part),
619
+ output=part.model_response(),
620
+ )
621
+ )
622
+ else:
623
+ assert_never(part)
624
+ elif isinstance(message, ModelResponse):
625
+ for item in message.parts:
626
+ if isinstance(item, TextPart):
627
+ openai_messages.append(responses.EasyInputMessageParam(role='assistant', content=item.content))
628
+ elif isinstance(item, ToolCallPart):
629
+ openai_messages.append(self._map_tool_call(item))
630
+ else:
631
+ assert_never(item)
632
+ else:
633
+ assert_never(message)
634
+ return system_prompt, openai_messages
635
+
636
+ @staticmethod
637
+ def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
638
+ return responses.ResponseFunctionToolCallParam(
639
+ arguments=t.args_as_json_str(),
640
+ call_id=_guard_tool_call_id(t=t),
641
+ name=t.tool_name,
642
+ type='function_call',
643
+ )
644
+
645
+ @staticmethod
646
+ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam:
647
+ content: str | list[responses.ResponseInputContentParam]
648
+ if isinstance(part.content, str):
649
+ content = part.content
650
+ else:
651
+ content = []
652
+ for item in part.content:
653
+ if isinstance(item, str):
654
+ content.append(responses.ResponseInputTextParam(text=item, type='input_text'))
655
+ elif isinstance(item, BinaryContent):
656
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
657
+ if item.is_image:
658
+ content.append(
659
+ responses.ResponseInputImageParam(
660
+ image_url=f'data:{item.media_type};base64,{base64_encoded}',
661
+ type='input_image',
662
+ detail='auto',
663
+ )
664
+ )
665
+ elif item.is_document:
666
+ content.append(
667
+ responses.ResponseInputFileParam(
668
+ type='input_file',
669
+ file_data=f'data:{item.media_type};base64,{base64_encoded}',
670
+ # NOTE: Type wise it's not necessary to include the filename, but it's required by the
671
+ # API itself. If we add empty string, the server sends a 500 error - which OpenAI needs
672
+ # to fix. In any case, we add a placeholder name.
673
+ filename=f'filename.{item.format}',
674
+ )
675
+ )
676
+ elif item.is_audio:
677
+ raise NotImplementedError('Audio as binary content is not supported for OpenAI Responses API.')
678
+ else: # pragma: no cover
679
+ raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
680
+ elif isinstance(item, ImageUrl):
681
+ content.append(
682
+ responses.ResponseInputImageParam(image_url=item.url, type='input_image', detail='auto')
683
+ )
684
+ elif isinstance(item, AudioUrl): # pragma: no cover
685
+ client = cached_async_http_client()
686
+ response = await client.get(item.url)
687
+ response.raise_for_status()
688
+ base64_encoded = base64.b64encode(response.content).decode('utf-8')
689
+ content.append(
690
+ responses.ResponseInputFileParam(
691
+ type='input_file',
692
+ file_data=f'data:{item.media_type};base64,{base64_encoded}',
693
+ )
694
+ )
695
+ elif isinstance(item, DocumentUrl): # pragma: no cover
696
+ client = cached_async_http_client()
697
+ response = await client.get(item.url)
698
+ response.raise_for_status()
699
+ base64_encoded = base64.b64encode(response.content).decode('utf-8')
700
+ content.append(
701
+ responses.ResponseInputFileParam(
702
+ type='input_file',
703
+ file_data=f'data:{item.media_type};base64,{base64_encoded}',
704
+ filename=f'filename.{item.format}',
705
+ )
706
+ )
707
+ else:
708
+ assert_never(item)
709
+ return responses.EasyInputMessageParam(role='user', content=content)
710
+
711
+
399
712
  @dataclass
400
713
  class OpenAIStreamedResponse(StreamedResponse):
401
714
  """Implementation of `StreamedResponse` for OpenAI models."""
@@ -439,10 +752,101 @@ class OpenAIStreamedResponse(StreamedResponse):
439
752
  return self._timestamp
440
753
 
441
754
 
442
- def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage:
755
+ @dataclass
756
+ class OpenAIResponsesStreamedResponse(StreamedResponse):
757
+ """Implementation of `StreamedResponse` for OpenAI Responses API."""
758
+
759
+ _model_name: OpenAIModelName
760
+ _response: AsyncIterable[responses.ResponseStreamEvent]
761
+ _timestamp: datetime
762
+
763
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
764
+ async for chunk in self._response:
765
+ if isinstance(chunk, responses.ResponseCompletedEvent):
766
+ self._usage += _map_usage(chunk.response)
767
+
768
+ elif isinstance(chunk, responses.ResponseContentPartAddedEvent):
769
+ pass # there's nothing we need to do here
770
+
771
+ elif isinstance(chunk, responses.ResponseContentPartDoneEvent):
772
+ pass # there's nothing we need to do here
773
+
774
+ elif isinstance(chunk, responses.ResponseCreatedEvent):
775
+ pass # there's nothing we need to do here
776
+
777
+ elif isinstance(chunk, responses.ResponseFailedEvent): # pragma: no cover
778
+ self._usage += _map_usage(chunk.response)
779
+
780
+ elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDeltaEvent):
781
+ maybe_event = self._parts_manager.handle_tool_call_delta(
782
+ vendor_part_id=chunk.item_id,
783
+ tool_name=None,
784
+ args=chunk.delta,
785
+ tool_call_id=chunk.item_id,
786
+ )
787
+ if maybe_event is not None:
788
+ yield maybe_event
789
+
790
+ elif isinstance(chunk, responses.ResponseFunctionCallArgumentsDoneEvent):
791
+ pass # there's nothing we need to do here
792
+
793
+ elif isinstance(chunk, responses.ResponseIncompleteEvent): # pragma: no cover
794
+ self._usage += _map_usage(chunk.response)
795
+
796
+ elif isinstance(chunk, responses.ResponseInProgressEvent):
797
+ self._usage += _map_usage(chunk.response)
798
+
799
+ elif isinstance(chunk, responses.ResponseOutputItemAddedEvent):
800
+ if isinstance(chunk.item, responses.ResponseFunctionToolCall):
801
+ yield self._parts_manager.handle_tool_call_part(
802
+ vendor_part_id=chunk.item.id,
803
+ tool_name=chunk.item.name,
804
+ args=chunk.item.arguments,
805
+ tool_call_id=chunk.item.id,
806
+ )
807
+
808
+ elif isinstance(chunk, responses.ResponseOutputItemDoneEvent):
809
+ # NOTE: We only need this if the tool call deltas don't include the final info.
810
+ pass
811
+
812
+ elif isinstance(chunk, responses.ResponseTextDeltaEvent):
813
+ yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta)
814
+
815
+ elif isinstance(chunk, responses.ResponseTextDoneEvent):
816
+ pass # there's nothing we need to do here
817
+
818
+ else: # pragma: no cover
819
+ warnings.warn(
820
+ f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
821
+ UserWarning,
822
+ )
823
+
824
+ @property
825
+ def model_name(self) -> OpenAIModelName:
826
+ """Get the model name of the response."""
827
+ return self._model_name
828
+
829
+ @property
830
+ def timestamp(self) -> datetime:
831
+ """Get the timestamp of the response."""
832
+ return self._timestamp
833
+
834
+
835
+ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.Usage:
443
836
  response_usage = response.usage
444
837
  if response_usage is None:
445
838
  return usage.Usage()
839
+ elif isinstance(response_usage, responses.ResponseUsage):
840
+ details: dict[str, int] = {}
841
+ return usage.Usage(
842
+ request_tokens=response_usage.input_tokens,
843
+ response_tokens=response_usage.output_tokens,
844
+ total_tokens=response_usage.total_tokens,
845
+ details={
846
+ 'reasoning_tokens': response_usage.output_tokens_details.reasoning_tokens,
847
+ 'cached_tokens': response_usage.input_tokens_details.cached_tokens,
848
+ },
849
+ )
446
850
  else:
447
851
  details: dict[str, int] = {}
448
852
  if response_usage.completion_tokens_details is not None:
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import os
3
4
  from typing import overload
4
5
 
5
6
  from pydantic_ai.exceptions import UserError
@@ -8,6 +9,7 @@ from pydantic_ai.providers import Provider
8
9
  try:
9
10
  import boto3
10
11
  from botocore.client import BaseClient
12
+ from botocore.config import Config
11
13
  from botocore.exceptions import NoRegionError
12
14
  except ImportError as _import_error:
13
15
  raise ImportError(
@@ -42,6 +44,8 @@ class BedrockProvider(Provider[BaseClient]):
42
44
  aws_access_key_id: str | None = None,
43
45
  aws_secret_access_key: str | None = None,
44
46
  aws_session_token: str | None = None,
47
+ aws_read_timeout: float | None = None,
48
+ aws_connect_timeout: float | None = None,
45
49
  ) -> None: ...
46
50
 
47
51
  def __init__(
@@ -52,6 +56,8 @@ class BedrockProvider(Provider[BaseClient]):
52
56
  aws_access_key_id: str | None = None,
53
57
  aws_secret_access_key: str | None = None,
54
58
  aws_session_token: str | None = None,
59
+ aws_read_timeout: float | None = None,
60
+ aws_connect_timeout: float | None = None,
55
61
  ) -> None:
56
62
  """Initialize the Bedrock provider.
57
63
 
@@ -61,17 +67,22 @@ class BedrockProvider(Provider[BaseClient]):
61
67
  aws_access_key_id: The AWS access key ID.
62
68
  aws_secret_access_key: The AWS secret access key.
63
69
  aws_session_token: The AWS session token.
70
+ aws_read_timeout: The read timeout for Bedrock client.
71
+ aws_connect_timeout: The connect timeout for Bedrock client.
64
72
  """
65
73
  if bedrock_client is not None:
66
74
  self._client = bedrock_client
67
75
  else:
68
76
  try:
77
+ read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
78
+ connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
69
79
  self._client = boto3.client( # type: ignore[reportUnknownMemberType]
70
80
  'bedrock-runtime',
71
81
  aws_access_key_id=aws_access_key_id,
72
82
  aws_secret_access_key=aws_secret_access_key,
73
83
  aws_session_token=aws_session_token,
74
84
  region_name=region_name,
85
+ config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout),
75
86
  )
76
87
  except NoRegionError as exc: # pragma: no cover
77
88
  raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
pydantic_ai/tools.py CHANGED
@@ -2,10 +2,12 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import dataclasses
4
4
  import inspect
5
+ import json
5
6
  from collections.abc import Awaitable, Sequence
6
7
  from dataclasses import dataclass, field
7
8
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
8
9
 
10
+ from opentelemetry.trace import Tracer
9
11
  from pydantic import ValidationError
10
12
  from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
11
13
  from pydantic_core import SchemaValidator, core_schema
@@ -147,8 +149,8 @@ class GenerateToolJsonSchema(GenerateJsonSchema):
147
149
  def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
148
150
  s = super().typed_dict_schema(schema)
149
151
  total = schema.get('total')
150
- if total is not None:
151
- s['additionalProperties'] = not total
152
+ if 'additionalProperties' not in s and (total is True or total is None):
153
+ s['additionalProperties'] = False
152
154
  return s
153
155
 
154
156
  def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
@@ -286,9 +288,38 @@ class Tool(Generic[AgentDepsT]):
286
288
  return tool_def
287
289
 
288
290
  async def run(
291
+ self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer
292
+ ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
293
+ """Run the tool function asynchronously.
294
+
295
+ This method wraps `_run` in an OpenTelemetry span.
296
+
297
+ See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>.
298
+ """
299
+ span_attributes = {
300
+ 'gen_ai.tool.name': self.name,
301
+ # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
302
+ 'gen_ai.tool.call.id': message.tool_call_id,
303
+ 'tool_arguments': message.args_as_json_str(),
304
+ 'logfire.msg': f'running tool: {self.name}',
305
+ # add the JSON schema so these attributes are formatted nicely in Logfire
306
+ 'logfire.json_schema': json.dumps(
307
+ {
308
+ 'type': 'object',
309
+ 'properties': {
310
+ 'tool_arguments': {'type': 'object'},
311
+ 'gen_ai.tool.name': {},
312
+ 'gen_ai.tool.call.id': {},
313
+ },
314
+ }
315
+ ),
316
+ }
317
+ with tracer.start_as_current_span('running tool', attributes=span_attributes):
318
+ return await self._run(message, run_context)
319
+
320
+ async def _run(
289
321
  self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
290
322
  ) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
291
- """Run the tool function asynchronously."""
292
323
  try:
293
324
  if isinstance(message.args, str):
294
325
  args_dict = self._validator.validate_json(message.args)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.46
3
+ Version: 0.0.48
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.46
32
+ Requires-Dist: pydantic-graph==0.0.48
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -41,13 +41,15 @@ Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
41
41
  Requires-Dist: prompt-toolkit>=3; extra == 'cli'
42
42
  Requires-Dist: rich>=13; extra == 'cli'
43
43
  Provides-Extra: cohere
44
- Requires-Dist: cohere>=5.13.11; extra == 'cohere'
44
+ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == 'cohere'
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
+ Provides-Extra: evals
48
+ Requires-Dist: pydantic-evals==0.0.48; extra == 'evals'
47
49
  Provides-Extra: groq
48
50
  Requires-Dist: groq>=0.15.0; extra == 'groq'
49
51
  Provides-Extra: logfire
50
- Requires-Dist: logfire>=2.3; extra == 'logfire'
52
+ Requires-Dist: logfire>=3.11.0; extra == 'logfire'
51
53
  Provides-Extra: mcp
52
54
  Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
53
55
  Provides-Extra: mistral