pydantic-ai-slim 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -126,6 +126,8 @@ _FINISH_REASON_MAP: dict[GoogleFinishReason, FinishReason | None] = {
126
126
  GoogleFinishReason.MALFORMED_FUNCTION_CALL: 'error',
127
127
  GoogleFinishReason.IMAGE_SAFETY: 'content_filter',
128
128
  GoogleFinishReason.UNEXPECTED_TOOL_CALL: 'error',
129
+ GoogleFinishReason.IMAGE_PROHIBITED_CONTENT: 'content_filter',
130
+ GoogleFinishReason.NO_IMAGE: 'error',
129
131
  }
130
132
 
131
133
 
@@ -453,23 +455,28 @@ class GoogleModel(Model):
453
455
  def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
454
456
  if not response.candidates:
455
457
  raise UnexpectedModelBehavior('Expected at least one candidate in Gemini response') # pragma: no cover
458
+
456
459
  candidate = response.candidates[0]
457
- if candidate.content is None or candidate.content.parts is None:
458
- if candidate.finish_reason == 'SAFETY':
459
- raise UnexpectedModelBehavior('Safety settings triggered', str(response))
460
- else:
461
- raise UnexpectedModelBehavior(
462
- 'Content field missing from Gemini response', str(response)
463
- ) # pragma: no cover
464
- parts = candidate.content.parts or []
465
460
 
466
461
  vendor_id = response.response_id
467
462
  vendor_details: dict[str, Any] | None = None
468
463
  finish_reason: FinishReason | None = None
469
- if raw_finish_reason := candidate.finish_reason: # pragma: no branch
464
+ raw_finish_reason = candidate.finish_reason
465
+ if raw_finish_reason: # pragma: no branch
470
466
  vendor_details = {'finish_reason': raw_finish_reason.value}
471
467
  finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
472
468
 
469
+ if candidate.content is None or candidate.content.parts is None:
470
+ if finish_reason == 'content_filter' and raw_finish_reason:
471
+ raise UnexpectedModelBehavior(
472
+ f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
473
+ )
474
+ else:
475
+ raise UnexpectedModelBehavior(
476
+ 'Content field missing from Gemini response', response.model_dump_json()
477
+ ) # pragma: no cover
478
+ parts = candidate.content.parts or []
479
+
473
480
  usage = _metadata_as_usage(response)
474
481
  return _process_response_from_parts(
475
482
  parts,
@@ -623,7 +630,8 @@ class GeminiStreamedResponse(StreamedResponse):
623
630
  if chunk.response_id: # pragma: no branch
624
631
  self.provider_response_id = chunk.response_id
625
632
 
626
- if raw_finish_reason := candidate.finish_reason:
633
+ raw_finish_reason = candidate.finish_reason
634
+ if raw_finish_reason:
627
635
  self.provider_details = {'finish_reason': raw_finish_reason.value}
628
636
  self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
629
637
 
@@ -641,13 +649,17 @@ class GeminiStreamedResponse(StreamedResponse):
641
649
  # )
642
650
 
643
651
  if candidate.content is None or candidate.content.parts is None:
644
- if candidate.finish_reason == 'STOP': # pragma: no cover
652
+ if self.finish_reason == 'stop': # pragma: no cover
645
653
  # Normal completion - skip this chunk
646
654
  continue
647
- elif candidate.finish_reason == 'SAFETY': # pragma: no cover
648
- raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))
655
+ elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
656
+ raise UnexpectedModelBehavior(
657
+ f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
658
+ )
649
659
  else: # pragma: no cover
650
- raise UnexpectedModelBehavior('Content field missing from streaming Gemini response', str(chunk))
660
+ raise UnexpectedModelBehavior(
661
+ 'Content field missing from streaming Gemini response', chunk.model_dump_json()
662
+ )
651
663
 
652
664
  parts = candidate.content.parts
653
665
  if not parts:
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
+ import json
4
5
  import warnings
5
6
  from collections.abc import AsyncIterable, AsyncIterator, Sequence
6
7
  from contextlib import asynccontextmanager
@@ -17,7 +18,7 @@ from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
17
18
  from .._run_context import RunContext
18
19
  from .._thinking_part import split_content_into_text_and_thinking
19
20
  from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
20
- from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, WebSearchTool
21
+ from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
21
22
  from ..exceptions import UserError
22
23
  from ..messages import (
23
24
  AudioUrl,
@@ -109,6 +110,11 @@ Using this more broad type for the model name instead of the ChatModel definitio
109
110
  allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
110
111
  """
111
112
 
113
+ MCP_SERVER_TOOL_CONNECTOR_URI_SCHEME: Literal['x-openai-connector'] = 'x-openai-connector'
114
+ """
115
+ Prefix for OpenAI connector IDs. OpenAI supports either a URL or a connector ID when passing MCP configuration to a model,
116
+ by using that prefix like `x-openai-connector:<connector-id>` in a URL, you can pass a connector ID to a model.
117
+ """
112
118
 
113
119
  _CHAT_FINISH_REASON_MAP: dict[
114
120
  Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason
@@ -1061,13 +1067,16 @@ class OpenAIResponsesModel(Model):
1061
1067
  elif isinstance(item, responses.ResponseFileSearchToolCall): # pragma: no cover
1062
1068
  # Pydantic AI doesn't yet support the FileSearch built-in tool
1063
1069
  pass
1064
- elif isinstance( # pragma: no cover
1065
- item,
1066
- responses.response_output_item.McpCall
1067
- | responses.response_output_item.McpListTools
1068
- | responses.response_output_item.McpApprovalRequest,
1069
- ):
1070
- # Pydantic AI supports MCP natively
1070
+ elif isinstance(item, responses.response_output_item.McpCall):
1071
+ call_part, return_part = _map_mcp_call(item, self.system)
1072
+ items.append(call_part)
1073
+ items.append(return_part)
1074
+ elif isinstance(item, responses.response_output_item.McpListTools):
1075
+ call_part, return_part = _map_mcp_list_tools(item, self.system)
1076
+ items.append(call_part)
1077
+ items.append(return_part)
1078
+ elif isinstance(item, responses.response_output_item.McpApprovalRequest): # pragma: no cover
1079
+ # Pydantic AI doesn't yet support McpApprovalRequest (explicit tool usage approval)
1071
1080
  pass
1072
1081
 
1073
1082
  finish_reason: FinishReason | None = None
@@ -1256,6 +1265,32 @@ class OpenAIResponsesModel(Model):
1256
1265
  elif isinstance(tool, CodeExecutionTool):
1257
1266
  has_image_generating_tool = True
1258
1267
  tools.append({'type': 'code_interpreter', 'container': {'type': 'auto'}})
1268
+ elif isinstance(tool, MCPServerTool):
1269
+ mcp_tool = responses.tool_param.Mcp(
1270
+ type='mcp',
1271
+ server_label=tool.id,
1272
+ require_approval='never',
1273
+ )
1274
+
1275
+ if tool.authorization_token: # pragma: no branch
1276
+ mcp_tool['authorization'] = tool.authorization_token
1277
+
1278
+ if tool.allowed_tools is not None: # pragma: no branch
1279
+ mcp_tool['allowed_tools'] = tool.allowed_tools
1280
+
1281
+ if tool.description: # pragma: no branch
1282
+ mcp_tool['server_description'] = tool.description
1283
+
1284
+ if tool.headers: # pragma: no branch
1285
+ mcp_tool['headers'] = tool.headers
1286
+
1287
+ if tool.url.startswith(MCP_SERVER_TOOL_CONNECTOR_URI_SCHEME + ':'):
1288
+ _, connector_id = tool.url.split(':', maxsplit=1)
1289
+ mcp_tool['connector_id'] = connector_id # pyright: ignore[reportGeneralTypeIssues]
1290
+ else:
1291
+ mcp_tool['server_url'] = tool.url
1292
+
1293
+ tools.append(mcp_tool)
1259
1294
  elif isinstance(tool, ImageGenerationTool): # pragma: no branch
1260
1295
  has_image_generating_tool = True
1261
1296
  tools.append(
@@ -1428,7 +1463,7 @@ class OpenAIResponsesModel(Model):
1428
1463
  type='web_search_call',
1429
1464
  )
1430
1465
  openai_messages.append(web_search_item)
1431
- elif item.tool_name == ImageGenerationTool.kind and item.tool_call_id: # pragma: no branch
1466
+ elif item.tool_name == ImageGenerationTool.kind and item.tool_call_id:
1432
1467
  # The cast is necessary because of https://github.com/openai/openai-python/issues/2648
1433
1468
  image_generation_item = cast(
1434
1469
  responses.response_input_item_param.ImageGenerationCall,
@@ -1438,6 +1473,37 @@ class OpenAIResponsesModel(Model):
1438
1473
  },
1439
1474
  )
1440
1475
  openai_messages.append(image_generation_item)
1476
+ elif ( # pragma: no branch
1477
+ item.tool_name.startswith(MCPServerTool.kind)
1478
+ and item.tool_call_id
1479
+ and (server_id := item.tool_name.split(':', 1)[1])
1480
+ and (args := item.args_as_dict())
1481
+ and (action := args.get('action'))
1482
+ ):
1483
+ if action == 'list_tools':
1484
+ mcp_list_tools_item = responses.response_input_item_param.McpListTools(
1485
+ id=item.tool_call_id,
1486
+ type='mcp_list_tools',
1487
+ server_label=server_id,
1488
+ tools=[], # These can be read server-side
1489
+ )
1490
+ openai_messages.append(mcp_list_tools_item)
1491
+ elif ( # pragma: no branch
1492
+ action == 'call_tool'
1493
+ and (tool_name := args.get('tool_name'))
1494
+ and (tool_args := args.get('tool_args'))
1495
+ ):
1496
+ mcp_call_item = responses.response_input_item_param.McpCall(
1497
+ id=item.tool_call_id,
1498
+ server_label=server_id,
1499
+ name=tool_name,
1500
+ arguments=to_json(tool_args).decode(),
1501
+ error=None, # These can be read server-side
1502
+ output=None, # These can be read server-side
1503
+ type='mcp_call',
1504
+ )
1505
+ openai_messages.append(mcp_call_item)
1506
+
1441
1507
  elif isinstance(item, BuiltinToolReturnPart):
1442
1508
  if item.provider_name == self.system and send_item_ids:
1443
1509
  if (
@@ -1456,9 +1522,12 @@ class OpenAIResponsesModel(Model):
1456
1522
  and (status := content.get('status'))
1457
1523
  ):
1458
1524
  web_search_item['status'] = status
1459
- elif item.tool_name == ImageGenerationTool.kind: # pragma: no branch
1525
+ elif item.tool_name == ImageGenerationTool.kind:
1460
1526
  # Image generation result does not need to be sent back, just the `id` off of `BuiltinToolCallPart`.
1461
1527
  pass
1528
+ elif item.tool_name.startswith(MCPServerTool.kind): # pragma: no branch
1529
+ # MCP call result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
1530
+ pass
1462
1531
  elif isinstance(item, FilePart):
1463
1532
  # This was generated by the `ImageGenerationTool` or `CodeExecutionTool`,
1464
1533
  # and does not need to be sent back separately from the corresponding `BuiltinToolReturnPart`.
@@ -1772,7 +1841,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1772
1841
  args_json = call_part.args_as_json_str()
1773
1842
  # Drop the final `"}` so that we can add code deltas
1774
1843
  args_json_delta = args_json[:-2]
1775
- assert args_json_delta.endswith('code":"')
1844
+ assert args_json_delta.endswith('"code":"'), f'Expected {args_json_delta!r} to end in `"code":"`'
1776
1845
 
1777
1846
  yield self._parts_manager.handle_part(
1778
1847
  vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None)
@@ -1786,7 +1855,28 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1786
1855
  elif isinstance(chunk.item, responses.response_output_item.ImageGenerationCall):
1787
1856
  call_part, _, _ = _map_image_generation_tool_call(chunk.item, self.provider_name)
1788
1857
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
1858
+ elif isinstance(chunk.item, responses.response_output_item.McpCall):
1859
+ call_part, _ = _map_mcp_call(chunk.item, self.provider_name)
1789
1860
 
1861
+ args_json = call_part.args_as_json_str()
1862
+ # Drop the final `{}}` so that we can add tool args deltas
1863
+ args_json_delta = args_json[:-3]
1864
+ assert args_json_delta.endswith('"tool_args":'), (
1865
+ f'Expected {args_json_delta!r} to end in `"tool_args":"`'
1866
+ )
1867
+
1868
+ yield self._parts_manager.handle_part(
1869
+ vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None)
1870
+ )
1871
+ maybe_event = self._parts_manager.handle_tool_call_delta(
1872
+ vendor_part_id=f'{chunk.item.id}-call',
1873
+ args=args_json_delta,
1874
+ )
1875
+ if maybe_event is not None: # pragma: no branch
1876
+ yield maybe_event
1877
+ elif isinstance(chunk.item, responses.response_output_item.McpListTools):
1878
+ call_part, _ = _map_mcp_list_tools(chunk.item, self.provider_name)
1879
+ yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
1790
1880
  else:
1791
1881
  warnings.warn( # pragma: no cover
1792
1882
  f'Handling of this item type is not yet implemented. Please report on our GitHub: {chunk}',
@@ -1827,6 +1917,13 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1827
1917
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-file', part=file_part)
1828
1918
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1829
1919
 
1920
+ elif isinstance(chunk.item, responses.response_output_item.McpCall):
1921
+ _, return_part = _map_mcp_call(chunk.item, self.provider_name)
1922
+ yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1923
+ elif isinstance(chunk.item, responses.response_output_item.McpListTools):
1924
+ _, return_part = _map_mcp_list_tools(chunk.item, self.provider_name)
1925
+ yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1926
+
1830
1927
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
1831
1928
  yield self._parts_manager.handle_thinking_delta(
1832
1929
  vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
@@ -1921,6 +2018,40 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1921
2018
  )
1922
2019
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item_id}-file', part=file_part)
1923
2020
 
2021
+ elif isinstance(chunk, responses.ResponseMcpCallArgumentsDoneEvent):
2022
+ maybe_event = self._parts_manager.handle_tool_call_delta(
2023
+ vendor_part_id=f'{chunk.item_id}-call',
2024
+ args='}',
2025
+ )
2026
+ if maybe_event is not None: # pragma: no branch
2027
+ yield maybe_event
2028
+
2029
+ elif isinstance(chunk, responses.ResponseMcpCallArgumentsDeltaEvent):
2030
+ maybe_event = self._parts_manager.handle_tool_call_delta(
2031
+ vendor_part_id=f'{chunk.item_id}-call',
2032
+ args=chunk.delta,
2033
+ )
2034
+ if maybe_event is not None: # pragma: no branch
2035
+ yield maybe_event
2036
+
2037
+ elif isinstance(chunk, responses.ResponseMcpListToolsInProgressEvent):
2038
+ pass # there's nothing we need to do here
2039
+
2040
+ elif isinstance(chunk, responses.ResponseMcpListToolsCompletedEvent):
2041
+ pass # there's nothing we need to do here
2042
+
2043
+ elif isinstance(chunk, responses.ResponseMcpListToolsFailedEvent): # pragma: no cover
2044
+ pass # there's nothing we need to do here
2045
+
2046
+ elif isinstance(chunk, responses.ResponseMcpCallInProgressEvent):
2047
+ pass # there's nothing we need to do here
2048
+
2049
+ elif isinstance(chunk, responses.ResponseMcpCallFailedEvent): # pragma: no cover
2050
+ pass # there's nothing we need to do here
2051
+
2052
+ elif isinstance(chunk, responses.ResponseMcpCallCompletedEvent):
2053
+ pass # there's nothing we need to do here
2054
+
1924
2055
  else: # pragma: no cover
1925
2056
  warnings.warn(
1926
2057
  f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
@@ -1990,7 +2121,6 @@ def _map_usage(
1990
2121
  def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
1991
2122
  # When reasoning, the Responses API requires the `ResponseFunctionToolCall` to be returned with both the `call_id` and `id` fields.
1992
2123
  # Before our `ToolCallPart` gained the `id` field alongside `tool_call_id` field, we combined the two fields into a single string stored on `tool_call_id`.
1993
-
1994
2124
  if '|' in combined_id:
1995
2125
  call_id, id = combined_id.split('|', 1)
1996
2126
  return call_id, id
@@ -2030,7 +2160,7 @@ def _map_code_interpreter_tool_call(
2030
2160
  tool_call_id=item.id,
2031
2161
  args={
2032
2162
  'container_id': item.container_id,
2033
- 'code': item.code,
2163
+ 'code': item.code or '',
2034
2164
  },
2035
2165
  provider_name=provider_name,
2036
2166
  ),
@@ -2122,3 +2252,50 @@ def _map_image_generation_tool_call(
2122
2252
  ),
2123
2253
  file_part,
2124
2254
  )
2255
+
2256
+
2257
+ def _map_mcp_list_tools(
2258
+ item: responses.response_output_item.McpListTools, provider_name: str
2259
+ ) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2260
+ tool_name = ':'.join([MCPServerTool.kind, item.server_label])
2261
+ return (
2262
+ BuiltinToolCallPart(
2263
+ tool_name=tool_name,
2264
+ tool_call_id=item.id,
2265
+ provider_name=provider_name,
2266
+ args={'action': 'list_tools'},
2267
+ ),
2268
+ BuiltinToolReturnPart(
2269
+ tool_name=tool_name,
2270
+ tool_call_id=item.id,
2271
+ content=item.model_dump(mode='json', include={'tools', 'error'}),
2272
+ provider_name=provider_name,
2273
+ ),
2274
+ )
2275
+
2276
+
2277
+ def _map_mcp_call(
2278
+ item: responses.response_output_item.McpCall, provider_name: str
2279
+ ) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2280
+ tool_name = ':'.join([MCPServerTool.kind, item.server_label])
2281
+ return (
2282
+ BuiltinToolCallPart(
2283
+ tool_name=tool_name,
2284
+ tool_call_id=item.id,
2285
+ args={
2286
+ 'action': 'call_tool',
2287
+ 'tool_name': item.name,
2288
+ 'tool_args': json.loads(item.arguments) if item.arguments else {},
2289
+ },
2290
+ provider_name=provider_name,
2291
+ ),
2292
+ BuiltinToolReturnPart(
2293
+ tool_name=tool_name,
2294
+ tool_call_id=item.id,
2295
+ content={
2296
+ 'output': item.output,
2297
+ 'error': item.error,
2298
+ },
2299
+ provider_name=provider_name,
2300
+ ),
2301
+ )
@@ -13,7 +13,8 @@ from pydantic_ai.providers import Provider
13
13
 
14
14
  try:
15
15
  from google.auth.credentials import Credentials
16
- from google.genai import Client
16
+ from google.genai._api_client import BaseApiClient
17
+ from google.genai.client import Client, DebugConfig
17
18
  from google.genai.types import HttpOptions
18
19
  except ImportError as _import_error:
19
20
  raise ImportError(
@@ -114,7 +115,7 @@ class GoogleProvider(Provider[Client]):
114
115
  base_url=base_url,
115
116
  headers={'User-Agent': get_user_agent()},
116
117
  httpx_async_client=http_client,
117
- # TODO: Remove once https://github.com/googleapis/python-genai/pull/1509#issuecomment-3430028790 is solved.
118
+ # TODO: Remove once https://github.com/googleapis/python-genai/issues/1565 is solved.
118
119
  async_client_args={'transport': httpx.AsyncHTTPTransport()},
119
120
  )
120
121
  if not vertexai:
@@ -186,9 +187,37 @@ More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/
186
187
 
187
188
 
188
189
  class _SafelyClosingClient(Client):
190
+ @staticmethod
191
+ def _get_api_client(
192
+ vertexai: bool | None = None,
193
+ api_key: str | None = None,
194
+ credentials: Credentials | None = None,
195
+ project: str | None = None,
196
+ location: str | None = None,
197
+ debug_config: DebugConfig | None = None,
198
+ http_options: HttpOptions | None = None,
199
+ ) -> BaseApiClient:
200
+ return _NonClosingApiClient(
201
+ vertexai=vertexai,
202
+ api_key=api_key,
203
+ credentials=credentials,
204
+ project=project,
205
+ location=location,
206
+ http_options=http_options,
207
+ )
208
+
189
209
  def close(self) -> None:
190
210
  # This is called from `Client.__del__`, even if `Client.__init__` raised an error before `self._api_client` is set, which would raise an `AttributeError` here.
211
+ # TODO: Remove once https://github.com/googleapis/python-genai/issues/1567 is solved.
191
212
  try:
192
213
  super().close()
193
214
  except AttributeError:
194
215
  pass
216
+
217
+
218
+ class _NonClosingApiClient(BaseApiClient):
219
+ async def aclose(self) -> None:
220
+ # The original implementation also calls `await self._async_httpx_client.aclose()`, but we don't want to close our `cached_async_http_client` or the one the user passed in.
221
+ # TODO: Remove once https://github.com/googleapis/python-genai/issues/1566 is solved.
222
+ if self._aiohttp_session:
223
+ await self._aiohttp_session.close() # pragma: no cover
pydantic_ai/run.py CHANGED
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import dataclasses
4
- from collections.abc import AsyncIterator
4
+ from collections.abc import AsyncIterator, Sequence
5
5
  from copy import deepcopy
6
6
  from datetime import datetime
7
7
  from typing import TYPE_CHECKING, Any, Generic, Literal, overload
8
8
 
9
- from pydantic_graph import End, GraphRun, GraphRunContext
9
+ from pydantic_graph import BaseNode, End, GraphRunContext
10
+ from pydantic_graph.beta.graph import EndMarker, GraphRun, GraphTask, JoinItem
11
+ from pydantic_graph.beta.step import NodeStep
10
12
 
11
13
  from . import (
12
14
  _agent_graph,
@@ -112,12 +114,8 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
112
114
 
113
115
  This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
114
116
  """
115
- next_node = self._graph_run.next_node
116
- if isinstance(next_node, End):
117
- return next_node
118
- if _agent_graph.is_agent_node(next_node):
119
- return next_node
120
- raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
117
+ task = self._graph_run.next_task
118
+ return self._task_to_node(task)
121
119
 
122
120
  @property
123
121
  def result(self) -> AgentRunResult[OutputDataT] | None:
@@ -126,13 +124,13 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
126
124
  Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated
127
125
  with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult].
128
126
  """
129
- graph_run_result = self._graph_run.result
130
- if graph_run_result is None:
127
+ graph_run_output = self._graph_run.output
128
+ if graph_run_output is None:
131
129
  return None
132
130
  return AgentRunResult(
133
- graph_run_result.output.output,
134
- graph_run_result.output.tool_name,
135
- graph_run_result.state,
131
+ graph_run_output.output,
132
+ graph_run_output.tool_name,
133
+ self._graph_run.state,
136
134
  self._graph_run.deps.new_message_index,
137
135
  self._traceparent(required=False),
138
136
  )
@@ -147,11 +145,28 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
147
145
  self,
148
146
  ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
149
147
  """Advance to the next node automatically based on the last returned node."""
150
- next_node = await self._graph_run.__anext__()
151
- if _agent_graph.is_agent_node(node=next_node):
152
- return next_node
153
- assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
154
- return next_node
148
+ task = await anext(self._graph_run)
149
+ return self._task_to_node(task)
150
+
151
+ def _task_to_node(
152
+ self, task: EndMarker[FinalResult[OutputDataT]] | JoinItem | Sequence[GraphTask]
153
+ ) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
154
+ if isinstance(task, Sequence) and len(task) == 1:
155
+ first_task = task[0]
156
+ if isinstance(first_task.inputs, BaseNode): # pragma: no branch
157
+ base_node: BaseNode[
158
+ _agent_graph.GraphAgentState,
159
+ _agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT],
160
+ FinalResult[OutputDataT],
161
+ ] = first_task.inputs # type: ignore[reportUnknownMemberType]
162
+ if _agent_graph.is_agent_node(node=base_node): # pragma: no branch
163
+ return base_node
164
+ if isinstance(task, EndMarker):
165
+ return End(task.value)
166
+ raise exceptions.AgentRunError(f'Unexpected node: {task}') # pragma: no cover
167
+
168
+ def _node_to_task(self, node: _agent_graph.AgentNode[AgentDepsT, OutputDataT]) -> GraphTask:
169
+ return GraphTask(NodeStep(type(node)).id, inputs=node, fork_stack=())
155
170
 
156
171
  async def next(
157
172
  self,
@@ -222,11 +237,12 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
222
237
  """
223
238
  # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
224
239
  # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
225
- next_node = await self._graph_run.next(node)
226
- if _agent_graph.is_agent_node(next_node):
227
- return next_node
228
- assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
229
- return next_node
240
+ task = [self._node_to_task(node)]
241
+ try:
242
+ task = await self._graph_run.next(task)
243
+ except StopAsyncIteration:
244
+ pass
245
+ return self._task_to_node(task)
230
246
 
231
247
  # TODO (v2): Make this a property
232
248
  def usage(self) -> _usage.RunUsage:
@@ -234,7 +250,7 @@ class AgentRun(Generic[AgentDepsT, OutputDataT]):
234
250
  return self._graph_run.state.usage
235
251
 
236
252
  def __repr__(self) -> str: # pragma: no cover
237
- result = self._graph_run.result
253
+ result = self._graph_run.output
238
254
  result_repr = '<run not finished>' if result is None else repr(result.output)
239
255
  return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>'
240
256
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 1.3.0
3
+ Version: 1.5.0
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Project-URL: Homepage, https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim
6
6
  Project-URL: Source, https://github.com/pydantic/pydantic-ai/tree/main/pydantic_ai_slim
@@ -33,7 +33,7 @@ Requires-Dist: genai-prices>=0.0.35
33
33
  Requires-Dist: griffe>=1.3.2
34
34
  Requires-Dist: httpx>=0.27
35
35
  Requires-Dist: opentelemetry-api>=1.28.0
36
- Requires-Dist: pydantic-graph==1.3.0
36
+ Requires-Dist: pydantic-graph==1.5.0
37
37
  Requires-Dist: pydantic>=2.10
38
38
  Requires-Dist: typing-inspection>=0.4.0
39
39
  Provides-Extra: a2a
@@ -57,7 +57,7 @@ Requires-Dist: dbos>=1.14.0; extra == 'dbos'
57
57
  Provides-Extra: duckduckgo
58
58
  Requires-Dist: ddgs>=9.0.0; extra == 'duckduckgo'
59
59
  Provides-Extra: evals
60
- Requires-Dist: pydantic-evals==1.3.0; extra == 'evals'
60
+ Requires-Dist: pydantic-evals==1.5.0; extra == 'evals'
61
61
  Provides-Extra: google
62
62
  Requires-Dist: google-genai>=1.46.0; extra == 'google'
63
63
  Provides-Extra: groq