pydantic-ai-slim 0.8.1__py3-none-any.whl → 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (70) hide show
  1. pydantic_ai/__init__.py +28 -2
  2. pydantic_ai/_agent_graph.py +310 -140
  3. pydantic_ai/_function_schema.py +5 -5
  4. pydantic_ai/_griffe.py +2 -1
  5. pydantic_ai/_otel_messages.py +2 -2
  6. pydantic_ai/_output.py +31 -35
  7. pydantic_ai/_parts_manager.py +4 -4
  8. pydantic_ai/_run_context.py +3 -1
  9. pydantic_ai/_system_prompt.py +2 -2
  10. pydantic_ai/_tool_manager.py +3 -22
  11. pydantic_ai/_utils.py +14 -26
  12. pydantic_ai/ag_ui.py +7 -8
  13. pydantic_ai/agent/__init__.py +70 -9
  14. pydantic_ai/agent/abstract.py +35 -4
  15. pydantic_ai/agent/wrapper.py +6 -0
  16. pydantic_ai/builtin_tools.py +2 -2
  17. pydantic_ai/common_tools/duckduckgo.py +4 -2
  18. pydantic_ai/durable_exec/temporal/__init__.py +4 -2
  19. pydantic_ai/durable_exec/temporal/_agent.py +23 -2
  20. pydantic_ai/durable_exec/temporal/_function_toolset.py +53 -6
  21. pydantic_ai/durable_exec/temporal/_logfire.py +1 -1
  22. pydantic_ai/durable_exec/temporal/_mcp_server.py +2 -1
  23. pydantic_ai/durable_exec/temporal/_model.py +2 -2
  24. pydantic_ai/durable_exec/temporal/_run_context.py +2 -1
  25. pydantic_ai/durable_exec/temporal/_toolset.py +2 -1
  26. pydantic_ai/exceptions.py +45 -2
  27. pydantic_ai/format_prompt.py +2 -2
  28. pydantic_ai/mcp.py +2 -2
  29. pydantic_ai/messages.py +73 -25
  30. pydantic_ai/models/__init__.py +5 -4
  31. pydantic_ai/models/anthropic.py +5 -5
  32. pydantic_ai/models/bedrock.py +58 -56
  33. pydantic_ai/models/cohere.py +3 -3
  34. pydantic_ai/models/fallback.py +2 -2
  35. pydantic_ai/models/function.py +25 -23
  36. pydantic_ai/models/gemini.py +9 -12
  37. pydantic_ai/models/google.py +3 -3
  38. pydantic_ai/models/groq.py +4 -4
  39. pydantic_ai/models/huggingface.py +4 -4
  40. pydantic_ai/models/instrumented.py +30 -16
  41. pydantic_ai/models/mcp_sampling.py +3 -1
  42. pydantic_ai/models/mistral.py +6 -6
  43. pydantic_ai/models/openai.py +18 -27
  44. pydantic_ai/models/test.py +24 -4
  45. pydantic_ai/output.py +27 -32
  46. pydantic_ai/profiles/__init__.py +3 -3
  47. pydantic_ai/profiles/groq.py +1 -1
  48. pydantic_ai/profiles/openai.py +25 -4
  49. pydantic_ai/providers/anthropic.py +2 -3
  50. pydantic_ai/providers/bedrock.py +3 -2
  51. pydantic_ai/result.py +144 -41
  52. pydantic_ai/retries.py +10 -29
  53. pydantic_ai/run.py +12 -5
  54. pydantic_ai/tools.py +126 -22
  55. pydantic_ai/toolsets/__init__.py +4 -1
  56. pydantic_ai/toolsets/_dynamic.py +4 -4
  57. pydantic_ai/toolsets/abstract.py +18 -2
  58. pydantic_ai/toolsets/approval_required.py +32 -0
  59. pydantic_ai/toolsets/combined.py +7 -12
  60. pydantic_ai/toolsets/{deferred.py → external.py} +11 -5
  61. pydantic_ai/toolsets/filtered.py +1 -1
  62. pydantic_ai/toolsets/function.py +13 -4
  63. pydantic_ai/toolsets/wrapper.py +2 -1
  64. pydantic_ai/usage.py +7 -5
  65. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/METADATA +5 -6
  66. pydantic_ai_slim-1.0.0b1.dist-info/RECORD +120 -0
  67. pydantic_ai_slim-0.8.1.dist-info/RECORD +0 -119
  68. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/WHEEL +0 -0
  69. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/entry_points.txt +0 -0
  70. {pydantic_ai_slim-0.8.1.dist-info → pydantic_ai_slim-1.0.0b1.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/messages.py CHANGED
@@ -3,22 +3,19 @@ from __future__ import annotations as _annotations
3
3
  import base64
4
4
  from abc import ABC, abstractmethod
5
5
  from collections.abc import Sequence
6
- from dataclasses import dataclass, field, replace
6
+ from dataclasses import KW_ONLY, dataclass, field, replace
7
7
  from datetime import datetime
8
8
  from mimetypes import guess_type
9
- from typing import TYPE_CHECKING, Annotated, Any, Literal, Union, cast, overload
9
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload
10
10
 
11
11
  import pydantic
12
12
  import pydantic_core
13
13
  from genai_prices import calc_price, types as genai_types
14
14
  from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
15
- from typing_extensions import TypeAlias, deprecated
15
+ from typing_extensions import deprecated
16
16
 
17
17
  from . import _otel_messages, _utils
18
- from ._utils import (
19
- generate_tool_call_id as _generate_tool_call_id,
20
- now_utc as _now_utc,
21
- )
18
+ from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
22
19
  from .exceptions import UnexpectedModelBehavior
23
20
  from .usage import RequestUsage
24
21
 
@@ -65,6 +62,8 @@ class SystemPromptPart:
65
62
  content: str
66
63
  """The content of the prompt."""
67
64
 
65
+ _: KW_ONLY
66
+
68
67
  timestamp: datetime = field(default_factory=_now_utc)
69
68
  """The timestamp of the prompt."""
70
69
 
@@ -96,6 +95,8 @@ class FileUrl(ABC):
96
95
  url: str
97
96
  """The URL of the file."""
98
97
 
98
+ _: KW_ONLY
99
+
99
100
  force_download: bool = False
100
101
  """If the model supports it:
101
102
 
@@ -153,6 +154,8 @@ class VideoUrl(FileUrl):
153
154
  url: str
154
155
  """The URL of the video."""
155
156
 
157
+ _: KW_ONLY
158
+
156
159
  kind: Literal['video-url'] = 'video-url'
157
160
  """Type identifier, this is available on all parts as a discriminator."""
158
161
 
@@ -224,6 +227,8 @@ class AudioUrl(FileUrl):
224
227
  url: str
225
228
  """The URL of the audio file."""
226
229
 
230
+ _: KW_ONLY
231
+
227
232
  kind: Literal['audio-url'] = 'audio-url'
228
233
  """Type identifier, this is available on all parts as a discriminator."""
229
234
 
@@ -282,6 +287,8 @@ class ImageUrl(FileUrl):
282
287
  url: str
283
288
  """The URL of the image."""
284
289
 
290
+ _: KW_ONLY
291
+
285
292
  kind: Literal['image-url'] = 'image-url'
286
293
  """Type identifier, this is available on all parts as a discriminator."""
287
294
 
@@ -335,6 +342,8 @@ class DocumentUrl(FileUrl):
335
342
  url: str
336
343
  """The URL of the document."""
337
344
 
345
+ _: KW_ONLY
346
+
338
347
  kind: Literal['document-url'] = 'document-url'
339
348
  """Type identifier, this is available on all parts as a discriminator."""
340
349
 
@@ -406,6 +415,8 @@ class BinaryContent:
406
415
  media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
407
416
  """The media type of the binary data."""
408
417
 
418
+ _: KW_ONLY
419
+
409
420
  identifier: str | None = None
410
421
  """Identifier for the binary content, such as a URL or unique ID.
411
422
 
@@ -462,7 +473,8 @@ class BinaryContent:
462
473
  __repr__ = _utils.dataclasses_no_defaults_repr
463
474
 
464
475
 
465
- UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent'
476
+ MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent
477
+ UserContent: TypeAlias = str | MultiModalContent
466
478
 
467
479
 
468
480
  @dataclass(repr=False)
@@ -478,17 +490,19 @@ class ToolReturn:
478
490
  return_value: Any
479
491
  """The return value to be used in the tool response."""
480
492
 
493
+ _: KW_ONLY
494
+
481
495
  content: str | Sequence[UserContent] | None = None
482
496
  """The content to be sent to the model as a UserPromptPart."""
483
497
 
484
498
  metadata: Any = None
485
499
  """Additional data that can be accessed programmatically by the application but is not sent to the LLM."""
486
500
 
501
+ kind: Literal['tool-return'] = 'tool-return'
502
+
487
503
  __repr__ = _utils.dataclasses_no_defaults_repr
488
504
 
489
505
 
490
- # Ideally this would be a Union of types, but Python 3.9 requires it to be a string, and strings don't work with `isinstance``.
491
- MultiModalContentTypes = (ImageUrl, AudioUrl, DocumentUrl, VideoUrl, BinaryContent)
492
506
  _document_format_lookup: dict[str, DocumentFormat] = {
493
507
  'application/pdf': 'pdf',
494
508
  'text/plain': 'txt',
@@ -536,6 +550,8 @@ class UserPromptPart:
536
550
  content: str | Sequence[UserContent]
537
551
  """The content of the prompt."""
538
552
 
553
+ _: KW_ONLY
554
+
539
555
  timestamp: datetime = field(default_factory=_now_utc)
540
556
  """The timestamp of the prompt."""
541
557
 
@@ -562,7 +578,7 @@ class UserPromptPart:
562
578
  parts.append(
563
579
  _otel_messages.TextPart(type='text', **({'content': part} if settings.include_content else {}))
564
580
  )
565
- elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)):
581
+ elif isinstance(part, ImageUrl | AudioUrl | DocumentUrl | VideoUrl):
566
582
  parts.append(
567
583
  _otel_messages.MediaUrlPart(
568
584
  type=part.kind,
@@ -599,6 +615,8 @@ class BaseToolReturnPart:
599
615
  tool_call_id: str
600
616
  """The tool call identifier, this is used by some models including OpenAI."""
601
617
 
618
+ _: KW_ONLY
619
+
602
620
  metadata: Any = None
603
621
  """Additional data that can be accessed programmatically by the application but is not sent to the LLM."""
604
622
 
@@ -654,6 +672,8 @@ class BaseToolReturnPart:
654
672
  class ToolReturnPart(BaseToolReturnPart):
655
673
  """A tool return message, this encodes the result of running a tool."""
656
674
 
675
+ _: KW_ONLY
676
+
657
677
  part_kind: Literal['tool-return'] = 'tool-return'
658
678
  """Part type identifier, this is available on all parts as a discriminator."""
659
679
 
@@ -662,6 +682,8 @@ class ToolReturnPart(BaseToolReturnPart):
662
682
  class BuiltinToolReturnPart(BaseToolReturnPart):
663
683
  """A tool return message from a built-in tool."""
664
684
 
685
+ _: KW_ONLY
686
+
665
687
  provider_name: str | None = None
666
688
  """The name of the provider that generated the response."""
667
689
 
@@ -695,6 +717,8 @@ class RetryPromptPart:
695
717
  error details.
696
718
  """
697
719
 
720
+ _: KW_ONLY
721
+
698
722
  tool_name: str | None = None
699
723
  """The name of the tool that was called, if any."""
700
724
 
@@ -753,7 +777,7 @@ class RetryPromptPart:
753
777
 
754
778
 
755
779
  ModelRequestPart = Annotated[
756
- Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
780
+ SystemPromptPart | UserPromptPart | ToolReturnPart | RetryPromptPart, pydantic.Discriminator('part_kind')
757
781
  ]
758
782
  """A message part sent by Pydantic AI to a model."""
759
783
 
@@ -765,6 +789,8 @@ class ModelRequest:
765
789
  parts: list[ModelRequestPart]
766
790
  """The parts of the user message."""
767
791
 
792
+ _: KW_ONLY
793
+
768
794
  instructions: str | None = None
769
795
  """The instructions for the model."""
770
796
 
@@ -786,6 +812,8 @@ class TextPart:
786
812
  content: str
787
813
  """The text content of the response."""
788
814
 
815
+ _: KW_ONLY
816
+
789
817
  part_kind: Literal['text'] = 'text'
790
818
  """Part type identifier, this is available on all parts as a discriminator."""
791
819
 
@@ -803,6 +831,8 @@ class ThinkingPart:
803
831
  content: str
804
832
  """The thinking content of the response."""
805
833
 
834
+ _: KW_ONLY
835
+
806
836
  id: str | None = None
807
837
  """The identifier of the thinking part."""
808
838
 
@@ -881,6 +911,8 @@ class BaseToolCallPart:
881
911
  class ToolCallPart(BaseToolCallPart):
882
912
  """A tool call from a model."""
883
913
 
914
+ _: KW_ONLY
915
+
884
916
  part_kind: Literal['tool-call'] = 'tool-call'
885
917
  """Part type identifier, this is available on all parts as a discriminator."""
886
918
 
@@ -889,6 +921,8 @@ class ToolCallPart(BaseToolCallPart):
889
921
  class BuiltinToolCallPart(BaseToolCallPart):
890
922
  """A tool call to a built-in tool."""
891
923
 
924
+ _: KW_ONLY
925
+
892
926
  provider_name: str | None = None
893
927
  """The name of the provider that generated the response."""
894
928
 
@@ -897,7 +931,7 @@ class BuiltinToolCallPart(BaseToolCallPart):
897
931
 
898
932
 
899
933
  ModelResponsePart = Annotated[
900
- Union[TextPart, ToolCallPart, BuiltinToolCallPart, BuiltinToolReturnPart, ThinkingPart],
934
+ TextPart | ToolCallPart | BuiltinToolCallPart | BuiltinToolReturnPart | ThinkingPart,
901
935
  pydantic.Discriminator('part_kind'),
902
936
  ]
903
937
  """A message part returned by a model."""
@@ -910,6 +944,8 @@ class ModelResponse:
910
944
  parts: list[ModelResponsePart]
911
945
  """The parts of the model message."""
912
946
 
947
+ _: KW_ONLY
948
+
913
949
  usage: RequestUsage = field(default_factory=RequestUsage)
914
950
  """Usage information for the request.
915
951
 
@@ -970,14 +1006,14 @@ class ModelResponse:
970
1006
  body.setdefault('tool_calls', []).append(
971
1007
  {
972
1008
  'id': part.tool_call_id,
973
- 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
1009
+ 'type': 'function',
974
1010
  'function': {
975
1011
  'name': part.tool_name,
976
1012
  **({'arguments': part.args} if settings.include_content else {}),
977
1013
  },
978
1014
  }
979
1015
  )
980
- elif isinstance(part, (TextPart, ThinkingPart)):
1016
+ elif isinstance(part, TextPart | ThinkingPart):
981
1017
  kind = part.part_kind
982
1018
  body.setdefault('content', []).append(
983
1019
  {'kind': kind, **({'text': part.content} if settings.include_content else {})}
@@ -1038,7 +1074,7 @@ class ModelResponse:
1038
1074
  __repr__ = _utils.dataclasses_no_defaults_repr
1039
1075
 
1040
1076
 
1041
- ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
1077
+ ModelMessage = Annotated[ModelRequest | ModelResponse, pydantic.Discriminator('kind')]
1042
1078
  """Any message sent to or returned by a model."""
1043
1079
 
1044
1080
  ModelMessagesTypeAdapter = pydantic.TypeAdapter(
@@ -1054,6 +1090,8 @@ class TextPartDelta:
1054
1090
  content_delta: str
1055
1091
  """The incremental text content to add to the existing `TextPart` content."""
1056
1092
 
1093
+ _: KW_ONLY
1094
+
1057
1095
  part_delta_kind: Literal['text'] = 'text'
1058
1096
  """Part delta type identifier, used as a discriminator."""
1059
1097
 
@@ -1076,7 +1114,7 @@ class TextPartDelta:
1076
1114
  __repr__ = _utils.dataclasses_no_defaults_repr
1077
1115
 
1078
1116
 
1079
- @dataclass(repr=False)
1117
+ @dataclass(repr=False, kw_only=True)
1080
1118
  class ThinkingPartDelta:
1081
1119
  """A partial update (delta) for a `ThinkingPart` to append new thinking content."""
1082
1120
 
@@ -1128,7 +1166,7 @@ class ThinkingPartDelta:
1128
1166
  __repr__ = _utils.dataclasses_no_defaults_repr
1129
1167
 
1130
1168
 
1131
- @dataclass(repr=False)
1169
+ @dataclass(repr=False, kw_only=True)
1132
1170
  class ToolCallPartDelta:
1133
1171
  """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
1134
1172
 
@@ -1248,12 +1286,12 @@ class ToolCallPartDelta:
1248
1286
 
1249
1287
 
1250
1288
  ModelResponsePartDelta = Annotated[
1251
- Union[TextPartDelta, ThinkingPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')
1289
+ TextPartDelta | ThinkingPartDelta | ToolCallPartDelta, pydantic.Discriminator('part_delta_kind')
1252
1290
  ]
1253
1291
  """A partial update (delta) for any model response part."""
1254
1292
 
1255
1293
 
1256
- @dataclass(repr=False)
1294
+ @dataclass(repr=False, kw_only=True)
1257
1295
  class PartStartEvent:
1258
1296
  """An event indicating that a new part has started.
1259
1297
 
@@ -1273,7 +1311,7 @@ class PartStartEvent:
1273
1311
  __repr__ = _utils.dataclasses_no_defaults_repr
1274
1312
 
1275
1313
 
1276
- @dataclass(repr=False)
1314
+ @dataclass(repr=False, kw_only=True)
1277
1315
  class PartDeltaEvent:
1278
1316
  """An event indicating a delta update for an existing part."""
1279
1317
 
@@ -1289,7 +1327,7 @@ class PartDeltaEvent:
1289
1327
  __repr__ = _utils.dataclasses_no_defaults_repr
1290
1328
 
1291
1329
 
1292
- @dataclass(repr=False)
1330
+ @dataclass(repr=False, kw_only=True)
1293
1331
  class FinalResultEvent:
1294
1332
  """An event indicating the response to the current model request matches the output schema and will produce a result."""
1295
1333
 
@@ -1304,7 +1342,7 @@ class FinalResultEvent:
1304
1342
 
1305
1343
 
1306
1344
  ModelResponseStreamEvent = Annotated[
1307
- Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
1345
+ PartStartEvent | PartDeltaEvent | FinalResultEvent, pydantic.Discriminator('event_kind')
1308
1346
  ]
1309
1347
  """An event in the model response stream, starting a new part, applying a delta to an existing one, or indicating the final result."""
1310
1348
 
@@ -1315,6 +1353,9 @@ class FunctionToolCallEvent:
1315
1353
 
1316
1354
  part: ToolCallPart
1317
1355
  """The (function) tool call to make."""
1356
+
1357
+ _: KW_ONLY
1358
+
1318
1359
  event_kind: Literal['function_tool_call'] = 'function_tool_call'
1319
1360
  """Event type identifier, used as a discriminator."""
1320
1361
 
@@ -1338,6 +1379,9 @@ class FunctionToolResultEvent:
1338
1379
 
1339
1380
  result: ToolReturnPart | RetryPromptPart
1340
1381
  """The result of the call to the function tool."""
1382
+
1383
+ _: KW_ONLY
1384
+
1341
1385
  event_kind: Literal['function_tool_result'] = 'function_tool_result'
1342
1386
  """Event type identifier, used as a discriminator."""
1343
1387
 
@@ -1356,6 +1400,8 @@ class BuiltinToolCallEvent:
1356
1400
  part: BuiltinToolCallPart
1357
1401
  """The built-in tool call to make."""
1358
1402
 
1403
+ _: KW_ONLY
1404
+
1359
1405
  event_kind: Literal['builtin_tool_call'] = 'builtin_tool_call'
1360
1406
  """Event type identifier, used as a discriminator."""
1361
1407
 
@@ -1367,15 +1413,17 @@ class BuiltinToolResultEvent:
1367
1413
  result: BuiltinToolReturnPart
1368
1414
  """The result of the call to the built-in tool."""
1369
1415
 
1416
+ _: KW_ONLY
1417
+
1370
1418
  event_kind: Literal['builtin_tool_result'] = 'builtin_tool_result'
1371
1419
  """Event type identifier, used as a discriminator."""
1372
1420
 
1373
1421
 
1374
1422
  HandleResponseEvent = Annotated[
1375
- Union[FunctionToolCallEvent, FunctionToolResultEvent, BuiltinToolCallEvent, BuiltinToolResultEvent],
1423
+ FunctionToolCallEvent | FunctionToolResultEvent | BuiltinToolCallEvent | BuiltinToolResultEvent,
1376
1424
  pydantic.Discriminator('event_kind'),
1377
1425
  ]
1378
1426
  """An event yielded when handling a model response, indicating tool calls and results."""
1379
1427
 
1380
- AgentStreamEvent = Annotated[Union[ModelResponseStreamEvent, HandleResponseEvent], pydantic.Discriminator('event_kind')]
1428
+ AgentStreamEvent = Annotated[ModelResponseStreamEvent | HandleResponseEvent, pydantic.Discriminator('event_kind')]
1381
1429
  """An event in the agent stream: model response stream events and response-handling events."""
@@ -14,10 +14,10 @@ from contextlib import asynccontextmanager, contextmanager
14
14
  from dataclasses import dataclass, field, replace
15
15
  from datetime import datetime
16
16
  from functools import cache, cached_property
17
- from typing import Any, Generic, TypeVar, overload
17
+ from typing import Any, Generic, Literal, TypeVar, overload
18
18
 
19
19
  import httpx
20
- from typing_extensions import Literal, TypeAliasType, TypedDict
20
+ from typing_extensions import TypeAliasType, TypedDict
21
21
 
22
22
  from .. import _utils
23
23
  from .._output import OutputObjectDefinition
@@ -367,7 +367,7 @@ KnownModelName = TypeAliasType(
367
367
  """
368
368
 
369
369
 
370
- @dataclass(repr=False)
370
+ @dataclass(repr=False, kw_only=True)
371
371
  class ModelRequestParameters:
372
372
  """Configuration for an agent's request to a model, specifically related to tools and output handling."""
373
373
 
@@ -552,6 +552,7 @@ class StreamedResponse(ABC):
552
552
  """Streamed response from an LLM when calling a tool."""
553
553
 
554
554
  model_request_parameters: ModelRequestParameters
555
+
555
556
  final_result_event: FinalResultEvent | None = field(default=None, init=False)
556
557
 
557
558
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
@@ -920,5 +921,5 @@ def _get_final_result_event(e: ModelResponseStreamEvent, params: ModelRequestPar
920
921
  elif isinstance(new_part, ToolCallPart) and (tool_def := params.tool_defs.get(new_part.tool_name)):
921
922
  if tool_def.kind == 'output':
922
923
  return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id)
923
- elif tool_def.kind == 'deferred':
924
+ elif tool_def.defer:
924
925
  return FinalResultEvent(tool_name=None, tool_call_id=None)
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
6
6
  from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime, timezone
9
- from typing import Any, Literal, Union, cast, overload
9
+ from typing import Any, Literal, cast, overload
10
10
 
11
11
  from typing_extensions import assert_never
12
12
 
@@ -99,7 +99,7 @@ except ImportError as _import_error:
99
99
  LatestAnthropicModelNames = ModelParam
100
100
  """Latest Anthropic models."""
101
101
 
102
- AnthropicModelName = Union[str, LatestAnthropicModelNames]
102
+ AnthropicModelName = str | LatestAnthropicModelNames
103
103
  """Possible Anthropic model names.
104
104
 
105
105
  Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
@@ -290,7 +290,7 @@ class AnthropicModel(Model):
290
290
  for item in response.content:
291
291
  if isinstance(item, BetaTextBlock):
292
292
  items.append(TextPart(content=item.text))
293
- elif isinstance(item, (BetaWebSearchToolResultBlock, BetaCodeExecutionToolResultBlock)):
293
+ elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock):
294
294
  items.append(
295
295
  BuiltinToolReturnPart(
296
296
  provider_name='anthropic',
@@ -327,7 +327,7 @@ class AnthropicModel(Model):
327
327
  )
328
328
 
329
329
  return ModelResponse(
330
- items,
330
+ parts=items,
331
331
  usage=_map_usage(response),
332
332
  model_name=response.model,
333
333
  provider_response_id=response.id,
@@ -654,7 +654,7 @@ class AnthropicStreamedResponse(StreamedResponse):
654
654
  elif isinstance(event, BetaRawMessageDeltaEvent):
655
655
  pass
656
656
 
657
- elif isinstance(event, (BetaRawContentBlockStopEvent, BetaRawMessageStopEvent)): # pragma: no branch
657
+ elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
658
658
  current_block = None
659
659
 
660
660
  @property
@@ -8,7 +8,7 @@ from contextlib import asynccontextmanager
8
8
  from dataclasses import dataclass, field
9
9
  from datetime import datetime
10
10
  from itertools import count
11
- from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast, overload
11
+ from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
12
12
 
13
13
  import anyio
14
14
  import anyio.to_thread
@@ -125,7 +125,7 @@ LatestBedrockModelNames = Literal[
125
125
  ]
126
126
  """Latest Bedrock models."""
127
127
 
128
- BedrockModelName = Union[str, LatestBedrockModelNames]
128
+ BedrockModelName = str | LatestBedrockModelNames
129
129
  """Possible Bedrock model names.
130
130
 
131
131
  Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
@@ -303,7 +303,7 @@ class BedrockConverseModel(Model):
303
303
  )
304
304
  response_id = response.get('ResponseMetadata', {}).get('RequestId', None)
305
305
  return ModelResponse(
306
- items,
306
+ parts=items,
307
307
  usage=u,
308
308
  model_name=self.model_name,
309
309
  provider_response_id=response_id,
@@ -490,7 +490,7 @@ class BedrockConverseModel(Model):
490
490
  else:
491
491
  # NOTE: We don't pass the thinking part to Bedrock for models other than Claude since it raises an error.
492
492
  pass
493
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)):
493
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart):
494
494
  pass
495
495
  else:
496
496
  assert isinstance(item, ToolCallPart)
@@ -546,7 +546,7 @@ class BedrockConverseModel(Model):
546
546
  content.append({'video': {'format': format, 'source': {'bytes': item.data}}})
547
547
  else:
548
548
  raise NotImplementedError('Binary content is not supported yet.')
549
- elif isinstance(item, (ImageUrl, DocumentUrl, VideoUrl)):
549
+ elif isinstance(item, ImageUrl | DocumentUrl | VideoUrl):
550
550
  downloaded_item = await download_item(item, data_format='bytes', type_format='extension')
551
551
  format = downloaded_item['data_type']
552
552
  if item.kind == 'image-url':
@@ -610,60 +610,62 @@ class BedrockStreamedResponse(StreamedResponse):
610
610
  chunk: ConverseStreamOutputTypeDef
611
611
  tool_id: str | None = None
612
612
  async for chunk in _AsyncIteratorWrapper(self._event_stream):
613
- # TODO(Marcelo): Switch this to `match` when we drop Python 3.9 support.
614
- if 'messageStart' in chunk:
615
- continue
616
- if 'messageStop' in chunk:
617
- continue
618
- if 'metadata' in chunk:
619
- if 'usage' in chunk['metadata']: # pragma: no branch
620
- self._usage += self._map_usage(chunk['metadata'])
621
- continue
622
- if 'contentBlockStart' in chunk:
623
- index = chunk['contentBlockStart']['contentBlockIndex']
624
- start = chunk['contentBlockStart']['start']
625
- if 'toolUse' in start: # pragma: no branch
626
- tool_use_start = start['toolUse']
627
- tool_id = tool_use_start['toolUseId']
628
- tool_name = tool_use_start['name']
629
- maybe_event = self._parts_manager.handle_tool_call_delta(
630
- vendor_part_id=index,
631
- tool_name=tool_name,
632
- args=None,
633
- tool_call_id=tool_id,
634
- )
635
- if maybe_event: # pragma: no branch
636
- yield maybe_event
637
- if 'contentBlockDelta' in chunk:
638
- index = chunk['contentBlockDelta']['contentBlockIndex']
639
- delta = chunk['contentBlockDelta']['delta']
640
- if 'reasoningContent' in delta:
641
- if text := delta['reasoningContent'].get('text'):
642
- yield self._parts_manager.handle_thinking_delta(
613
+ match chunk:
614
+ case {'messageStart': _}:
615
+ continue
616
+ case {'messageStop': _}:
617
+ continue
618
+ case {'metadata': metadata}:
619
+ if 'usage' in metadata: # pragma: no branch
620
+ self._usage += self._map_usage(metadata)
621
+ continue
622
+ case {'contentBlockStart': content_block_start}:
623
+ index = content_block_start['contentBlockIndex']
624
+ start = content_block_start['start']
625
+ if 'toolUse' in start: # pragma: no branch
626
+ tool_use_start = start['toolUse']
627
+ tool_id = tool_use_start['toolUseId']
628
+ tool_name = tool_use_start['name']
629
+ maybe_event = self._parts_manager.handle_tool_call_delta(
643
630
  vendor_part_id=index,
644
- content=text,
645
- signature=delta['reasoningContent'].get('signature'),
631
+ tool_name=tool_name,
632
+ args=None,
633
+ tool_call_id=tool_id,
646
634
  )
647
- else: # pragma: no cover
648
- warnings.warn(
649
- f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
650
- 'Please report this to the maintainers.',
651
- UserWarning,
635
+ if maybe_event: # pragma: no branch
636
+ yield maybe_event
637
+ case {'contentBlockDelta': content_block_delta}:
638
+ index = content_block_delta['contentBlockIndex']
639
+ delta = content_block_delta['delta']
640
+ if 'reasoningContent' in delta:
641
+ if text := delta['reasoningContent'].get('text'):
642
+ yield self._parts_manager.handle_thinking_delta(
643
+ vendor_part_id=index,
644
+ content=text,
645
+ signature=delta['reasoningContent'].get('signature'),
646
+ )
647
+ else: # pragma: no cover
648
+ warnings.warn(
649
+ f'Only text reasoning content is supported yet, but you got {delta["reasoningContent"]}. '
650
+ 'Please report this to the maintainers.',
651
+ UserWarning,
652
+ )
653
+ if 'text' in delta:
654
+ maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
655
+ if maybe_event is not None: # pragma: no branch
656
+ yield maybe_event
657
+ if 'toolUse' in delta:
658
+ tool_use = delta['toolUse']
659
+ maybe_event = self._parts_manager.handle_tool_call_delta(
660
+ vendor_part_id=index,
661
+ tool_name=tool_use.get('name'),
662
+ args=tool_use.get('input'),
663
+ tool_call_id=tool_id,
652
664
  )
653
- if 'text' in delta:
654
- maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
655
- if maybe_event is not None: # pragma: no branch
656
- yield maybe_event
657
- if 'toolUse' in delta:
658
- tool_use = delta['toolUse']
659
- maybe_event = self._parts_manager.handle_tool_call_delta(
660
- vendor_part_id=index,
661
- tool_name=tool_use.get('name'),
662
- args=tool_use.get('input'),
663
- tool_call_id=tool_id,
664
- )
665
- if maybe_event: # pragma: no branch
666
- yield maybe_event
665
+ if maybe_event: # pragma: no branch
666
+ yield maybe_event
667
+ case _:
668
+ pass # pyright wants match statements to be exhaustive
667
669
 
668
670
  @property
669
671
  def model_name(self) -> str:
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from collections.abc import Iterable
4
4
  from dataclasses import dataclass, field
5
- from typing import Literal, Union, cast
5
+ from typing import Literal, cast
6
6
 
7
7
  from typing_extensions import assert_never
8
8
 
@@ -72,7 +72,7 @@ LatestCohereModelNames = Literal[
72
72
  ]
73
73
  """Latest Cohere models."""
74
74
 
75
- CohereModelName = Union[str, LatestCohereModelNames]
75
+ CohereModelName = str | LatestCohereModelNames
76
76
  """Possible Cohere model names.
77
77
 
78
78
  Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
@@ -228,7 +228,7 @@ class CohereModel(Model):
228
228
  pass
229
229
  elif isinstance(item, ToolCallPart):
230
230
  tool_calls.append(self._map_tool_call(item))
231
- elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
231
+ elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
232
232
  # This is currently never returned from cohere
233
233
  pass
234
234
  else:
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator
3
+ from collections.abc import AsyncIterator, Callable
4
4
  from contextlib import AsyncExitStack, asynccontextmanager, suppress
5
5
  from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Any, Callable
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from opentelemetry.trace import get_current_span
9
9