pydantic-ai-slim 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from importlib.metadata import version
2
2
 
3
- from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
3
+ from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
4
4
  from .exceptions import (
5
5
  AgentRunError,
6
6
  FallbackExceptionGroup,
@@ -18,7 +18,7 @@ __all__ = (
18
18
  # agent
19
19
  'Agent',
20
20
  'EndStrategy',
21
- 'HandleResponseNode',
21
+ 'CallToolsNode',
22
22
  'ModelRequestNode',
23
23
  'UserPromptNode',
24
24
  'capture_run_messages',
@@ -2,7 +2,6 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
- from abc import ABC
6
5
  from collections.abc import AsyncIterator, Iterator, Sequence
7
6
  from contextlib import asynccontextmanager, contextmanager
8
7
  from contextvars import ContextVar
@@ -10,7 +9,7 @@ from dataclasses import field
10
9
  from typing import Any, Generic, Literal, Union, cast
11
10
 
12
11
  import logfire_api
13
- from typing_extensions import TypeVar, assert_never
12
+ from typing_extensions import TypeGuard, TypeVar, assert_never
14
13
 
15
14
  from pydantic_graph import BaseNode, Graph, GraphRunContext
16
15
  from pydantic_graph.nodes import End, NodeRunEndT
@@ -24,6 +23,7 @@ from . import (
24
23
  result,
25
24
  usage as _usage,
26
25
  )
26
+ from .models.instrumented import InstrumentedModel
27
27
  from .result import ResultDataT
28
28
  from .settings import ModelSettings, merge_model_settings
29
29
  from .tools import (
@@ -37,7 +37,7 @@ __all__ = (
37
37
  'GraphAgentDeps',
38
38
  'UserPromptNode',
39
39
  'ModelRequestNode',
40
- 'HandleResponseNode',
40
+ 'CallToolsNode',
41
41
  'build_run_context',
42
42
  'capture_run_messages',
43
43
  )
@@ -55,6 +55,7 @@ else:
55
55
  logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
56
56
 
57
57
  T = TypeVar('T')
58
+ S = TypeVar('S')
58
59
  NoneType = type(None)
59
60
  EndStrategy = Literal['early', 'exhaustive']
60
61
  """The strategy for handling multiple tool calls when a final result is found.
@@ -107,8 +108,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
107
108
  run_span: logfire_api.LogfireSpan
108
109
 
109
110
 
111
+ class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
112
+ """The base class for all agent nodes.
113
+
114
+ Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere
115
+ """
116
+
117
+
118
+ def is_agent_node(
119
+ node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]],
120
+ ) -> TypeGuard[AgentNode[T, S]]:
121
+ """Check if the provided node is an instance of `AgentNode`.
122
+
123
+ Usage:
124
+
125
+ if is_agent_node(node):
126
+ # `node` is an AgentNode
127
+ ...
128
+
129
+ This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`.
130
+ """
131
+ return isinstance(node, AgentNode)
132
+
133
+
110
134
  @dataclasses.dataclass
111
- class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
135
+ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
112
136
  user_prompt: str | Sequence[_messages.UserContent]
113
137
 
114
138
  system_prompts: tuple[str, ...]
@@ -215,17 +239,17 @@ async def _prepare_request_parameters(
215
239
 
216
240
 
217
241
  @dataclasses.dataclass
218
- class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
242
+ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
219
243
  """Make a request to the model using the last message in state.message_history."""
220
244
 
221
245
  request: _messages.ModelRequest
222
246
 
223
- _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
247
+ _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
224
248
  _did_stream: bool = field(default=False, repr=False)
225
249
 
226
250
  async def run(
227
251
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
228
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
252
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
229
253
  if self._result is not None:
230
254
  return self._result
231
255
 
@@ -236,48 +260,60 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
236
260
 
237
261
  return await self._make_request(ctx)
238
262
 
263
+ @asynccontextmanager
264
+ async def stream(
265
+ self,
266
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
267
+ ) -> AsyncIterator[result.AgentStream[DepsT, T]]:
268
+ async with self._stream(ctx) as streamed_response:
269
+ agent_stream = result.AgentStream[DepsT, T](
270
+ streamed_response,
271
+ ctx.deps.result_schema,
272
+ ctx.deps.result_validators,
273
+ build_run_context(ctx),
274
+ ctx.deps.usage_limits,
275
+ )
276
+ yield agent_stream
277
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
278
+ # otherwise usage won't be properly counted:
279
+ async for _ in agent_stream:
280
+ pass
281
+
239
282
  @asynccontextmanager
240
283
  async def _stream(
241
284
  self,
242
285
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
243
286
  ) -> AsyncIterator[models.StreamedResponse]:
244
- # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
245
287
  assert not self._did_stream, 'stream() should only be called once per node'
246
288
 
247
289
  model_settings, model_request_parameters = await self._prepare_request(ctx)
248
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
249
- async with ctx.deps.model.request_stream(
250
- ctx.state.message_history, model_settings, model_request_parameters
251
- ) as streamed_response:
252
- self._did_stream = True
253
- ctx.state.usage.incr(_usage.Usage(), requests=1)
254
- yield streamed_response
255
- # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
256
- # otherwise usage won't be properly counted:
257
- async for _ in streamed_response:
258
- pass
259
- model_response = streamed_response.get()
260
- request_usage = streamed_response.usage()
261
- span.set_attribute('response', model_response)
262
- span.set_attribute('usage', request_usage)
290
+ async with ctx.deps.model.request_stream(
291
+ ctx.state.message_history, model_settings, model_request_parameters
292
+ ) as streamed_response:
293
+ self._did_stream = True
294
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
295
+ yield streamed_response
296
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
297
+ # otherwise usage won't be properly counted:
298
+ async for _ in streamed_response:
299
+ pass
300
+ model_response = streamed_response.get()
301
+ request_usage = streamed_response.usage()
263
302
 
264
303
  self._finish_handling(ctx, model_response, request_usage)
265
304
  assert self._result is not None # this should be set by the previous line
266
305
 
267
306
  async def _make_request(
268
307
  self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
269
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
308
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
270
309
  if self._result is not None:
271
310
  return self._result
272
311
 
273
312
  model_settings, model_request_parameters = await self._prepare_request(ctx)
274
- with _logfire.span('model request', run_step=ctx.state.run_step) as span:
275
- model_response, request_usage = await ctx.deps.model.request(
276
- ctx.state.message_history, model_settings, model_request_parameters
277
- )
278
- ctx.state.usage.incr(_usage.Usage(), requests=1)
279
- span.set_attribute('response', model_response)
280
- span.set_attribute('usage', request_usage)
313
+ model_response, request_usage = await ctx.deps.model.request(
314
+ ctx.state.message_history, model_settings, model_request_parameters
315
+ )
316
+ ctx.state.usage.incr(_usage.Usage(), requests=1)
281
317
 
282
318
  return self._finish_handling(ctx, model_response, request_usage)
283
319
 
@@ -303,7 +339,7 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
303
339
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
304
340
  response: _messages.ModelResponse,
305
341
  usage: _usage.Usage,
306
- ) -> HandleResponseNode[DepsT, NodeRunEndT]:
342
+ ) -> CallToolsNode[DepsT, NodeRunEndT]:
307
343
  # Update usage
308
344
  ctx.state.usage.incr(usage, requests=0)
309
345
  if ctx.deps.usage_limits:
@@ -313,13 +349,13 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
313
349
  ctx.state.message_history.append(response)
314
350
 
315
351
  # Set the `_result` attribute since we can't use `return` in an async iterator
316
- self._result = HandleResponseNode(response)
352
+ self._result = CallToolsNode(response)
317
353
 
318
354
  return self._result
319
355
 
320
356
 
321
357
  @dataclasses.dataclass
322
- class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
358
+ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
323
359
  """Process a model response, and decide whether to end the run or make a new request."""
324
360
 
325
361
  model_response: _messages.ModelResponse
@@ -413,8 +449,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
413
449
  final_result: result.FinalResult[NodeRunEndT] | None = None
414
450
  parts: list[_messages.ModelRequestPart] = []
415
451
  if result_schema is not None:
416
- if match := result_schema.find_tool(tool_calls):
417
- call, result_tool = match
452
+ for call, result_tool in result_schema.find_tool(tool_calls):
418
453
  try:
419
454
  result_data = result_tool.validate(call)
420
455
  result_data = await _validate_result(result_data, ctx, call)
@@ -424,12 +459,17 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
424
459
  ctx.state.increment_retries(ctx.deps.max_result_retries)
425
460
  parts.append(e.tool_retry)
426
461
  else:
427
- final_result = result.FinalResult(result_data, call.tool_name)
462
+ final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
463
+ break
428
464
 
429
465
  # Then build the other request parts based on end strategy
430
466
  tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
431
467
  async for event in process_function_tools(
432
- tool_calls, final_result and final_result.tool_name, ctx, tool_responses
468
+ tool_calls,
469
+ final_result and final_result.tool_name,
470
+ final_result and final_result.tool_call_id,
471
+ ctx,
472
+ tool_responses,
433
473
  ):
434
474
  yield event
435
475
 
@@ -455,7 +495,10 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
455
495
  messages.append(_messages.ModelRequest(parts=tool_responses))
456
496
 
457
497
  run_span.set_attribute('usage', usage)
458
- run_span.set_attribute('all_messages', messages)
498
+ run_span.set_attribute(
499
+ 'all_messages_events',
500
+ [InstrumentedModel.event_to_dict(e) for e in InstrumentedModel.messages_to_otel_events(messages)],
501
+ )
459
502
 
460
503
  # End the run with self.data
461
504
  return End(final_result)
@@ -477,7 +520,7 @@ class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], r
477
520
  return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
478
521
  else:
479
522
  # The following cast is safe because we know `str` is an allowed result type
480
- return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
523
+ return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
481
524
  else:
482
525
  ctx.state.increment_retries(ctx.deps.max_result_retries)
483
526
  return ModelRequestNode[DepsT, NodeRunEndT](
@@ -506,6 +549,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
506
549
  async def process_function_tools(
507
550
  tool_calls: list[_messages.ToolCallPart],
508
551
  result_tool_name: str | None,
552
+ result_tool_call_id: str | None,
509
553
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
510
554
  output_parts: list[_messages.ModelRequestPart],
511
555
  ) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -525,7 +569,11 @@ async def process_function_tools(
525
569
  calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
526
570
  call_index_to_event_id: dict[int, str] = {}
527
571
  for call in tool_calls:
528
- if call.tool_name == result_tool_name and not found_used_result_tool:
572
+ if (
573
+ call.tool_name == result_tool_name
574
+ and call.tool_call_id == result_tool_call_id
575
+ and not found_used_result_tool
576
+ ):
529
577
  found_used_result_tool = True
530
578
  output_parts.append(
531
579
  _messages.ToolReturnPart(
@@ -552,9 +600,14 @@ async def process_function_tools(
552
600
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
553
601
  # validation, we don't add another part here
554
602
  if result_tool_name is not None:
603
+ if found_used_result_tool:
604
+ content = 'Result tool not used - a final result was already processed.'
605
+ else:
606
+ # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
607
+ content = 'Result tool not used - result failed validation.'
555
608
  part = _messages.ToolReturnPart(
556
609
  tool_name=call.tool_name,
557
- content='Result tool not used - a final result was already processed.',
610
+ content=content,
558
611
  tool_call_id=call.tool_call_id,
559
612
  )
560
613
  output_parts.append(part)
@@ -575,7 +628,7 @@ async def process_function_tools(
575
628
  for task in done:
576
629
  index = tasks.index(task)
577
630
  result = task.result()
578
- yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
631
+ yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
579
632
  if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
580
633
  results_by_index[index] = result
581
634
  else:
@@ -675,7 +728,7 @@ def build_agent_graph(
675
728
  nodes = (
676
729
  UserPromptNode[DepsT],
677
730
  ModelRequestNode[DepsT],
678
- HandleResponseNode[DepsT],
731
+ CallToolsNode[DepsT],
679
732
  )
680
733
  graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]](
681
734
  nodes=nodes,
pydantic_ai/_result.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
3
3
  import inspect
4
4
  import sys
5
5
  import types
6
- from collections.abc import Awaitable, Iterable
6
+ from collections.abc import Awaitable, Iterable, Iterator
7
7
  from dataclasses import dataclass, field
8
8
  from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
9
9
 
@@ -127,12 +127,12 @@ class ResultSchema(Generic[ResultDataT]):
127
127
  def find_tool(
128
128
  self,
129
129
  parts: Iterable[_messages.ModelResponsePart],
130
- ) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
130
+ ) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
131
131
  """Find a tool that matches one of the calls."""
132
132
  for part in parts:
133
133
  if isinstance(part, _messages.ToolCallPart):
134
134
  if result := self.tools.get(part.tool_name):
135
- return part, result
135
+ yield part, result
136
136
 
137
137
  def tool_names(self) -> list[str]:
138
138
  """Return the names of the tools."""
pydantic_ai/_utils.py CHANGED
@@ -48,6 +48,8 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
48
48
 
49
49
  if schema.get('type') == 'object':
50
50
  return schema
51
+ elif schema.get('$ref') is not None:
52
+ return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
51
53
  else:
52
54
  raise UserError('Schema must be an object')
53
55