pydantic-ai-slim 1.7.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (47) hide show
  1. pydantic_ai/__init__.py +2 -0
  2. pydantic_ai/_agent_graph.py +3 -0
  3. pydantic_ai/_cli.py +2 -2
  4. pydantic_ai/_run_context.py +8 -2
  5. pydantic_ai/_tool_manager.py +1 -0
  6. pydantic_ai/_utils.py +18 -0
  7. pydantic_ai/ag_ui.py +50 -696
  8. pydantic_ai/agent/__init__.py +13 -3
  9. pydantic_ai/agent/abstract.py +172 -9
  10. pydantic_ai/agent/wrapper.py +5 -0
  11. pydantic_ai/direct.py +16 -4
  12. pydantic_ai/durable_exec/dbos/_agent.py +31 -0
  13. pydantic_ai/durable_exec/prefect/_agent.py +28 -0
  14. pydantic_ai/durable_exec/temporal/_agent.py +28 -0
  15. pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -73
  16. pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
  17. pydantic_ai/durable_exec/temporal/_run_context.py +9 -3
  18. pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
  19. pydantic_ai/messages.py +49 -8
  20. pydantic_ai/models/__init__.py +42 -1
  21. pydantic_ai/models/google.py +5 -12
  22. pydantic_ai/models/groq.py +9 -1
  23. pydantic_ai/models/openai.py +6 -3
  24. pydantic_ai/profiles/openai.py +5 -2
  25. pydantic_ai/providers/anthropic.py +2 -2
  26. pydantic_ai/providers/openrouter.py +3 -0
  27. pydantic_ai/result.py +178 -11
  28. pydantic_ai/tools.py +10 -6
  29. pydantic_ai/ui/__init__.py +16 -0
  30. pydantic_ai/ui/_adapter.py +386 -0
  31. pydantic_ai/ui/_event_stream.py +591 -0
  32. pydantic_ai/ui/_messages_builder.py +28 -0
  33. pydantic_ai/ui/ag_ui/__init__.py +9 -0
  34. pydantic_ai/ui/ag_ui/_adapter.py +187 -0
  35. pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
  36. pydantic_ai/ui/ag_ui/app.py +148 -0
  37. pydantic_ai/ui/vercel_ai/__init__.py +16 -0
  38. pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
  39. pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
  40. pydantic_ai/ui/vercel_ai/_utils.py +16 -0
  41. pydantic_ai/ui/vercel_ai/request_types.py +275 -0
  42. pydantic_ai/ui/vercel_ai/response_types.py +230 -0
  43. {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/METADATA +10 -6
  44. {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/RECORD +47 -33
  45. {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/WHEEL +0 -0
  46. {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/entry_points.txt +0 -0
  47. {pydantic_ai_slim-1.7.0.dist-info → pydantic_ai_slim-1.11.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,57 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections.abc import Callable
4
- from dataclasses import dataclass
5
- from typing import Annotated, Any, Literal
4
+ from typing import Any, Literal
6
5
 
7
- from pydantic import ConfigDict, Discriminator, with_config
8
6
  from temporalio import activity, workflow
9
7
  from temporalio.workflow import ActivityConfig
10
- from typing_extensions import assert_never
11
8
 
12
9
  from pydantic_ai import FunctionToolset, ToolsetTool
13
- from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
10
+ from pydantic_ai.exceptions import UserError
14
11
  from pydantic_ai.tools import AgentDepsT, RunContext
15
12
  from pydantic_ai.toolsets.function import FunctionToolsetTool
16
13
 
17
14
  from ._run_context import TemporalRunContext
18
- from ._toolset import TemporalWrapperToolset
19
-
20
-
21
- @dataclass
22
- @with_config(ConfigDict(arbitrary_types_allowed=True))
23
- class _CallToolParams:
24
- name: str
25
- tool_args: dict[str, Any]
26
- serialized_run_context: Any
27
-
28
-
29
- @dataclass
30
- class _ApprovalRequired:
31
- kind: Literal['approval_required'] = 'approval_required'
32
-
33
-
34
- @dataclass
35
- class _CallDeferred:
36
- kind: Literal['call_deferred'] = 'call_deferred'
37
-
38
-
39
- @dataclass
40
- class _ModelRetry:
41
- message: str
42
- kind: Literal['model_retry'] = 'model_retry'
43
-
44
-
45
- @dataclass
46
- class _ToolReturn:
47
- result: Any
48
- kind: Literal['tool_return'] = 'tool_return'
49
-
50
-
51
- _CallToolResult = Annotated[
52
- _ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
53
- Discriminator('kind'),
54
- ]
15
+ from ._toolset import (
16
+ CallToolParams,
17
+ CallToolResult,
18
+ TemporalWrapperToolset,
19
+ )
55
20
 
56
21
 
57
22
  class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
@@ -70,7 +35,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
70
35
  self.tool_activity_config = tool_activity_config
71
36
  self.run_context_type = run_context_type
72
37
 
73
- async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
38
+ async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
74
39
  name = params.name
75
40
  ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
76
41
  try:
@@ -84,15 +49,7 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
84
49
  # The tool args will already have been validated into their proper types in the `ToolManager`,
85
50
  # but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
86
51
  args_dict = tool.args_validator.validate_python(params.tool_args)
87
- try:
88
- result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
89
- return _ToolReturn(result=result)
90
- except ApprovalRequired:
91
- return _ApprovalRequired()
92
- except CallDeferred:
93
- return _CallDeferred()
94
- except ModelRetry as e:
95
- return _ModelRetry(message=e.message)
52
+ return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool))
96
53
 
97
54
  # Set type hint explicitly so that Temporal can take care of serialization and deserialization
98
55
  call_tool_activity.__annotations__['deps'] = deps_type
@@ -123,25 +80,18 @@ class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
123
80
 
124
81
  tool_activity_config = self.activity_config | tool_activity_config
125
82
  serialized_run_context = self.run_context_type.serialize_run_context(ctx)
126
- result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
127
- activity=self.call_tool_activity,
128
- args=[
129
- _CallToolParams(
130
- name=name,
131
- tool_args=tool_args,
132
- serialized_run_context=serialized_run_context,
133
- ),
134
- ctx.deps,
135
- ],
136
- **tool_activity_config,
83
+ return self._unwrap_call_tool_result(
84
+ await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
85
+ activity=self.call_tool_activity,
86
+ args=[
87
+ CallToolParams(
88
+ name=name,
89
+ tool_args=tool_args,
90
+ serialized_run_context=serialized_run_context,
91
+ tool_def=None,
92
+ ),
93
+ ctx.deps,
94
+ ],
95
+ **tool_activity_config,
96
+ )
137
97
  )
138
- if isinstance(result, _ApprovalRequired):
139
- raise ApprovalRequired()
140
- elif isinstance(result, _CallDeferred):
141
- raise CallDeferred()
142
- elif isinstance(result, _ModelRetry):
143
- raise ModelRetry(result.message)
144
- elif isinstance(result, _ToolReturn):
145
- return result.result
146
- else:
147
- assert_never(result)
@@ -11,11 +11,15 @@ from typing_extensions import Self
11
11
 
12
12
  from pydantic_ai import ToolsetTool
13
13
  from pydantic_ai.exceptions import UserError
14
- from pydantic_ai.mcp import MCPServer, ToolResult
14
+ from pydantic_ai.mcp import MCPServer
15
15
  from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
16
16
 
17
17
  from ._run_context import TemporalRunContext
18
- from ._toolset import TemporalWrapperToolset
18
+ from ._toolset import (
19
+ CallToolParams,
20
+ CallToolResult,
21
+ TemporalWrapperToolset,
22
+ )
19
23
 
20
24
 
21
25
  @dataclass
@@ -24,15 +28,6 @@ class _GetToolsParams:
24
28
  serialized_run_context: Any
25
29
 
26
30
 
27
- @dataclass
28
- @with_config(ConfigDict(arbitrary_types_allowed=True))
29
- class _CallToolParams:
30
- name: str
31
- tool_args: dict[str, Any]
32
- serialized_run_context: Any
33
- tool_def: ToolDefinition
34
-
35
-
36
31
  class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
37
32
  def __init__(
38
33
  self,
@@ -72,13 +67,16 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
72
67
  get_tools_activity
73
68
  )
74
69
 
75
- async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult:
70
+ async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
76
71
  run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
77
- return await self.wrapped.call_tool(
78
- params.name,
79
- params.tool_args,
80
- run_context,
81
- self.tool_for_tool_def(params.tool_def),
72
+ assert isinstance(params.tool_def, ToolDefinition)
73
+ return await self._wrap_call_tool_result(
74
+ self.wrapped.call_tool(
75
+ params.name,
76
+ params.tool_args,
77
+ run_context,
78
+ self.tool_for_tool_def(params.tool_def),
79
+ )
82
80
  )
83
81
 
84
82
  # Set type hint explicitly so that Temporal can take care of serialization and deserialization
@@ -125,22 +123,24 @@ class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
125
123
  tool_args: dict[str, Any],
126
124
  ctx: RunContext[AgentDepsT],
127
125
  tool: ToolsetTool[AgentDepsT],
128
- ) -> ToolResult:
126
+ ) -> CallToolResult:
129
127
  if not workflow.in_workflow():
130
128
  return await super().call_tool(name, tool_args, ctx, tool)
131
129
 
132
130
  tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
133
131
  serialized_run_context = self.run_context_type.serialize_run_context(ctx)
134
- return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
135
- activity=self.call_tool_activity,
136
- args=[
137
- _CallToolParams(
138
- name=name,
139
- tool_args=tool_args,
140
- serialized_run_context=serialized_run_context,
141
- tool_def=tool.tool_def,
142
- ),
143
- ctx.deps,
144
- ],
145
- **tool_activity_config,
132
+ return self._unwrap_call_tool_result(
133
+ await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
134
+ activity=self.call_tool_activity,
135
+ args=[
136
+ CallToolParams(
137
+ name=name,
138
+ tool_args=tool_args,
139
+ serialized_run_context=serialized_run_context,
140
+ tool_def=tool.tool_def,
141
+ ),
142
+ ctx.deps,
143
+ ],
144
+ **tool_activity_config,
145
+ )
146
146
  )
@@ -2,14 +2,19 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any
4
4
 
5
+ from typing_extensions import TypeVar
6
+
5
7
  from pydantic_ai.exceptions import UserError
6
- from pydantic_ai.tools import AgentDepsT, RunContext
8
+ from pydantic_ai.tools import RunContext
9
+
10
+ AgentDepsT = TypeVar('AgentDepsT', default=None, covariant=True)
11
+ """Type variable for the agent dependencies in `RunContext`."""
7
12
 
8
13
 
9
14
  class TemporalRunContext(RunContext[AgentDepsT]):
10
15
  """The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
11
16
 
12
- By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` attributes will be available.
17
+ By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` attributes will be available.
13
18
  To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
14
19
  """
15
20
 
@@ -44,9 +49,10 @@ class TemporalRunContext(RunContext[AgentDepsT]):
44
49
  'retry': ctx.retry,
45
50
  'max_retries': ctx.max_retries,
46
51
  'run_step': ctx.run_step,
52
+ 'partial_output': ctx.partial_output,
47
53
  }
48
54
 
49
55
  @classmethod
50
- def deserialize_run_context(cls, ctx: dict[str, Any], deps: AgentDepsT) -> TemporalRunContext[AgentDepsT]:
56
+ def deserialize_run_context(cls, ctx: dict[str, Any], deps: Any) -> TemporalRunContext[Any]:
51
57
  """Deserialize the run context from a `dict[str, Any]`."""
52
58
  return cls(**ctx, deps=deps)
@@ -1,17 +1,58 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import Callable
5
- from typing import Any, Literal
4
+ from collections.abc import Awaitable, Callable
5
+ from dataclasses import dataclass
6
+ from typing import Annotated, Any, Literal
6
7
 
8
+ from pydantic import ConfigDict, Discriminator, with_config
7
9
  from temporalio.workflow import ActivityConfig
10
+ from typing_extensions import assert_never
8
11
 
9
12
  from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
10
- from pydantic_ai.tools import AgentDepsT
13
+ from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry
14
+ from pydantic_ai.tools import AgentDepsT, ToolDefinition
11
15
 
12
16
  from ._run_context import TemporalRunContext
13
17
 
14
18
 
19
+ @dataclass
20
+ @with_config(ConfigDict(arbitrary_types_allowed=True))
21
+ class CallToolParams:
22
+ name: str
23
+ tool_args: dict[str, Any]
24
+ serialized_run_context: Any
25
+ tool_def: ToolDefinition | None
26
+
27
+
28
+ @dataclass
29
+ class _ApprovalRequired:
30
+ kind: Literal['approval_required'] = 'approval_required'
31
+
32
+
33
+ @dataclass
34
+ class _CallDeferred:
35
+ kind: Literal['call_deferred'] = 'call_deferred'
36
+
37
+
38
+ @dataclass
39
+ class _ModelRetry:
40
+ message: str
41
+ kind: Literal['model_retry'] = 'model_retry'
42
+
43
+
44
+ @dataclass
45
+ class _ToolReturn:
46
+ result: Any
47
+ kind: Literal['tool_return'] = 'tool_return'
48
+
49
+
50
+ CallToolResult = Annotated[
51
+ _ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
52
+ Discriminator('kind'),
53
+ ]
54
+
55
+
15
56
  class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
16
57
  @property
17
58
  def id(self) -> str:
@@ -30,6 +71,29 @@ class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
30
71
  # Temporalized toolsets cannot be swapped out after the fact.
31
72
  return self
32
73
 
74
+ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
75
+ try:
76
+ result = await coro
77
+ return _ToolReturn(result=result)
78
+ except ApprovalRequired:
79
+ return _ApprovalRequired()
80
+ except CallDeferred:
81
+ return _CallDeferred()
82
+ except ModelRetry as e:
83
+ return _ModelRetry(message=e.message)
84
+
85
+ def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
86
+ if isinstance(result, _ToolReturn):
87
+ return result.result
88
+ elif isinstance(result, _ApprovalRequired):
89
+ raise ApprovalRequired()
90
+ elif isinstance(result, _CallDeferred):
91
+ raise CallDeferred()
92
+ elif isinstance(result, _ModelRetry):
93
+ raise ModelRetry(result.message)
94
+ else:
95
+ assert_never(result)
96
+
33
97
 
34
98
  def temporalize_toolset(
35
99
  toolset: AbstractToolset[AgentDepsT],
pydantic_ai/messages.py CHANGED
@@ -13,7 +13,7 @@ import pydantic
13
13
  import pydantic_core
14
14
  from genai_prices import calc_price, types as genai_types
15
15
  from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
16
- from typing_extensions import Self, deprecated
16
+ from typing_extensions import deprecated
17
17
 
18
18
  from . import _otel_messages, _utils
19
19
  from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
@@ -34,6 +34,7 @@ DocumentMediaType: TypeAlias = Literal[
34
34
  'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
35
35
  'text/html',
36
36
  'text/markdown',
37
+ 'application/msword',
37
38
  'application/vnd.ms-excel',
38
39
  ]
39
40
  VideoMediaType: TypeAlias = Literal[
@@ -434,8 +435,12 @@ class DocumentUrl(FileUrl):
434
435
  return 'application/pdf'
435
436
  elif self.url.endswith('.rtf'):
436
437
  return 'application/rtf'
438
+ elif self.url.endswith('.doc'):
439
+ return 'application/msword'
437
440
  elif self.url.endswith('.docx'):
438
441
  return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
442
+ elif self.url.endswith('.xls'):
443
+ return 'application/vnd.ms-excel'
439
444
  elif self.url.endswith('.xlsx'):
440
445
  return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
441
446
 
@@ -514,16 +519,16 @@ class BinaryContent:
514
519
  vendor_metadata=bc.vendor_metadata,
515
520
  )
516
521
  else:
517
- return bc # pragma: no cover
522
+ return bc
518
523
 
519
524
  @classmethod
520
- def from_data_uri(cls, data_uri: str) -> Self:
525
+ def from_data_uri(cls, data_uri: str) -> BinaryContent:
521
526
  """Create a `BinaryContent` from a data URI."""
522
527
  prefix = 'data:'
523
528
  if not data_uri.startswith(prefix):
524
- raise ValueError('Data URI must start with "data:"') # pragma: no cover
529
+ raise ValueError('Data URI must start with "data:"')
525
530
  media_type, data = data_uri[len(prefix) :].split(';base64,', 1)
526
- return cls(data=base64.b64decode(data), media_type=media_type)
531
+ return cls.narrow_type(cls(data=base64.b64decode(data), media_type=media_type))
527
532
 
528
533
  @pydantic.computed_field
529
534
  @property
@@ -645,6 +650,7 @@ _document_format_lookup: dict[str, DocumentFormat] = {
645
650
  'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'xlsx',
646
651
  'text/html': 'html',
647
652
  'text/markdown': 'md',
653
+ 'application/msword': 'doc',
648
654
  'application/vnd.ms-excel': 'xls',
649
655
  }
650
656
  _audio_format_lookup: dict[str, AudioFormat] = {
@@ -882,7 +888,10 @@ class RetryPromptPart:
882
888
  description = self.content
883
889
  else:
884
890
  json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
885
- description = f'{len(self.content)} validation errors: {json_errors.decode()}'
891
+ plural = isinstance(self.content, list) and len(self.content) != 1
892
+ description = (
893
+ f'{len(self.content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```'
894
+ )
886
895
  return f'{description}\n\nFix the errors and try again.'
887
896
 
888
897
  def otel_event(self, settings: InstrumentationSettings) -> Event:
@@ -1612,6 +1621,14 @@ class PartStartEvent:
1612
1621
  part: ModelResponsePart
1613
1622
  """The newly started `ModelResponsePart`."""
1614
1623
 
1624
+ previous_part_kind: (
1625
+ Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'file'] | None
1626
+ ) = None
1627
+ """The kind of the previous part, if any.
1628
+
1629
+ This is useful for UI event streams to know whether to group parts of the same kind together when emitting events.
1630
+ """
1631
+
1615
1632
  event_kind: Literal['part_start'] = 'part_start'
1616
1633
  """Event type identifier, used as a discriminator."""
1617
1634
 
@@ -1634,6 +1651,30 @@ class PartDeltaEvent:
1634
1651
  __repr__ = _utils.dataclasses_no_defaults_repr
1635
1652
 
1636
1653
 
1654
+ @dataclass(repr=False, kw_only=True)
1655
+ class PartEndEvent:
1656
+ """An event indicating that a part is complete."""
1657
+
1658
+ index: int
1659
+ """The index of the part within the overall response parts list."""
1660
+
1661
+ part: ModelResponsePart
1662
+ """The complete `ModelResponsePart`."""
1663
+
1664
+ next_part_kind: (
1665
+ Literal['text', 'thinking', 'tool-call', 'builtin-tool-call', 'builtin-tool-return', 'file'] | None
1666
+ ) = None
1667
+ """The kind of the next part, if any.
1668
+
1669
+ This is useful for UI event streams to know whether to group parts of the same kind together when emitting events.
1670
+ """
1671
+
1672
+ event_kind: Literal['part_end'] = 'part_end'
1673
+ """Event type identifier, used as a discriminator."""
1674
+
1675
+ __repr__ = _utils.dataclasses_no_defaults_repr
1676
+
1677
+
1637
1678
  @dataclass(repr=False, kw_only=True)
1638
1679
  class FinalResultEvent:
1639
1680
  """An event indicating the response to the current model request matches the output schema and will produce a result."""
@@ -1649,9 +1690,9 @@ class FinalResultEvent:
1649
1690
 
1650
1691
 
1651
1692
  ModelResponseStreamEvent = Annotated[
1652
- PartStartEvent | PartDeltaEvent | FinalResultEvent, pydantic.Discriminator('event_kind')
1693
+ PartStartEvent | PartDeltaEvent | PartEndEvent | FinalResultEvent, pydantic.Discriminator('event_kind')
1653
1694
  ]
1654
- """An event in the model response stream, starting a new part, applying a delta to an existing one, or indicating the final result."""
1695
+ """An event in the model response stream, starting a new part, applying a delta to an existing one, indicating a part is complete, or indicating the final result."""
1655
1696
 
1656
1697
 
1657
1698
  @dataclass(repr=False)
@@ -27,6 +27,7 @@ from .._run_context import RunContext
27
27
  from ..builtin_tools import AbstractBuiltinTool
28
28
  from ..exceptions import UserError
29
29
  from ..messages import (
30
+ BaseToolCallPart,
30
31
  BinaryImage,
31
32
  FilePart,
32
33
  FileUrl,
@@ -35,9 +36,12 @@ from ..messages import (
35
36
  ModelMessage,
36
37
  ModelRequest,
37
38
  ModelResponse,
39
+ ModelResponsePart,
38
40
  ModelResponseStreamEvent,
41
+ PartEndEvent,
39
42
  PartStartEvent,
40
43
  TextPart,
44
+ ThinkingPart,
41
45
  ToolCallPart,
42
46
  VideoUrl,
43
47
  )
@@ -543,7 +547,44 @@ class StreamedResponse(ABC):
543
547
  async for event in iterator:
544
548
  yield event
545
549
 
546
- self._event_iterator = iterator_with_final_event(self._get_event_iterator())
550
+ async def iterator_with_part_end(
551
+ iterator: AsyncIterator[ModelResponseStreamEvent],
552
+ ) -> AsyncIterator[ModelResponseStreamEvent]:
553
+ last_start_event: PartStartEvent | None = None
554
+
555
+ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent | None:
556
+ if not last_start_event:
557
+ return None
558
+
559
+ index = last_start_event.index
560
+ part = self._parts_manager.get_parts()[index]
561
+ if not isinstance(part, TextPart | ThinkingPart | BaseToolCallPart):
562
+ # Parts other than these 3 don't have deltas, so don't need an end part.
563
+ return None
564
+
565
+ return PartEndEvent(
566
+ index=index,
567
+ part=part,
568
+ next_part_kind=next_part.part_kind if next_part else None,
569
+ )
570
+
571
+ async for event in iterator:
572
+ if isinstance(event, PartStartEvent):
573
+ if last_start_event:
574
+ end_event = part_end_event(event.part)
575
+ if end_event:
576
+ yield end_event
577
+
578
+ event.previous_part_kind = last_start_event.part.part_kind
579
+ last_start_event = event
580
+
581
+ yield event
582
+
583
+ end_event = part_end_event()
584
+ if end_event:
585
+ yield end_event
586
+
587
+ self._event_iterator = iterator_with_part_end(iterator_with_final_event(self._get_event_iterator()))
547
588
  return self._event_iterator
548
589
 
549
590
  @abstractmethod
@@ -471,11 +471,9 @@ class GoogleModel(Model):
471
471
  raise UnexpectedModelBehavior(
472
472
  f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
473
473
  )
474
- else:
475
- raise UnexpectedModelBehavior(
476
- 'Content field missing from Gemini response', response.model_dump_json()
477
- ) # pragma: no cover
478
- parts = candidate.content.parts or []
474
+ parts = [] # pragma: no cover
475
+ else:
476
+ parts = candidate.content.parts or []
479
477
 
480
478
  usage = _metadata_as_usage(response)
481
479
  return _process_response_from_parts(
@@ -649,17 +647,12 @@ class GeminiStreamedResponse(StreamedResponse):
649
647
  # )
650
648
 
651
649
  if candidate.content is None or candidate.content.parts is None:
652
- if self.finish_reason == 'stop': # pragma: no cover
653
- # Normal completion - skip this chunk
654
- continue
655
- elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
650
+ if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
656
651
  raise UnexpectedModelBehavior(
657
652
  f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
658
653
  )
659
654
  else: # pragma: no cover
660
- raise UnexpectedModelBehavior(
661
- 'Content field missing from streaming Gemini response', chunk.model_dump_json()
662
- )
655
+ continue
663
656
 
664
657
  parts = candidate.content.parts
665
658
  if not parts:
@@ -524,6 +524,8 @@ class GroqStreamedResponse(StreamedResponse):
524
524
  async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
525
525
  try:
526
526
  executed_tool_call_id: str | None = None
527
+ reasoning_index = 0
528
+ reasoning = False
527
529
  async for chunk in self._response:
528
530
  self._usage += _map_usage(chunk)
529
531
 
@@ -540,10 +542,16 @@ class GroqStreamedResponse(StreamedResponse):
540
542
  self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
541
543
 
542
544
  if choice.delta.reasoning is not None:
545
+ if not reasoning:
546
+ reasoning_index += 1
547
+ reasoning = True
548
+
543
549
  # NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
544
550
  yield self._parts_manager.handle_thinking_delta(
545
- vendor_part_id='reasoning', content=choice.delta.reasoning
551
+ vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
546
552
  )
553
+ else:
554
+ reasoning = False
547
555
 
548
556
  if choice.delta.executed_tools:
549
557
  for tool in choice.delta.executed_tools:
@@ -948,6 +948,10 @@ class OpenAIResponsesModel(Model):
948
948
 
949
949
  super().__init__(settings=settings, profile=profile or provider.model_profile)
950
950
 
951
+ @property
952
+ def base_url(self) -> str:
953
+ return str(self.client.base_url)
954
+
951
955
  @property
952
956
  def model_name(self) -> OpenAIModelName:
953
957
  """The model name."""
@@ -1148,10 +1152,10 @@ class OpenAIResponsesModel(Model):
1148
1152
  + list(model_settings.get('openai_builtin_tools', []))
1149
1153
  + self._get_tools(model_request_parameters)
1150
1154
  )
1151
-
1155
+ profile = OpenAIModelProfile.from_profile(self.profile)
1152
1156
  if not tools:
1153
1157
  tool_choice: Literal['none', 'required', 'auto'] | None = None
1154
- elif not model_request_parameters.allow_text_output:
1158
+ elif not model_request_parameters.allow_text_output and profile.openai_supports_tool_choice_required:
1155
1159
  tool_choice = 'required'
1156
1160
  else:
1157
1161
  tool_choice = 'auto'
@@ -1184,7 +1188,6 @@ class OpenAIResponsesModel(Model):
1184
1188
  text = text or {}
1185
1189
  text['verbosity'] = verbosity
1186
1190
 
1187
- profile = OpenAIModelProfile.from_profile(self.profile)
1188
1191
  unsupported_model_settings = profile.openai_unsupported_model_settings
1189
1192
  for setting in unsupported_model_settings:
1190
1193
  model_settings.pop(setting, None)
@@ -62,7 +62,10 @@ class OpenAIModelProfile(ModelProfile):
62
62
 
63
63
  def openai_model_profile(model_name: str) -> ModelProfile:
64
64
  """Get the model profile for an OpenAI model."""
65
- is_reasoning_model = model_name.startswith('o') or model_name.startswith('gpt-5')
65
+ is_gpt_5 = model_name.startswith('gpt-5')
66
+ is_o_series = model_name.startswith('o')
67
+ is_reasoning_model = is_o_series or (is_gpt_5 and 'gpt-5-chat' not in model_name)
68
+
66
69
  # Check if the model supports web search (only specific search-preview models)
67
70
  supports_web_search = '-search-preview' in model_name
68
71
 
@@ -91,7 +94,7 @@ def openai_model_profile(model_name: str) -> ModelProfile:
91
94
  json_schema_transformer=OpenAIJsonSchemaTransformer,
92
95
  supports_json_schema_output=True,
93
96
  supports_json_object_output=True,
94
- supports_image_output=is_reasoning_model or '4.1' in model_name or '4o' in model_name,
97
+ supports_image_output=is_gpt_5 or 'o3' in model_name or '4.1' in model_name or '4o' in model_name,
95
98
  openai_unsupported_model_settings=openai_unsupported_model_settings,
96
99
  openai_system_prompt_role=openai_system_prompt_role,
97
100
  openai_chat_supports_web_search=supports_web_search,