pydantic-ai-slim 0.0.29__py3-none-any.whl → 0.0.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -2,7 +2,6 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
4
  import dataclasses
5
- from abc import ABC
6
5
  from collections.abc import AsyncIterator, Iterator, Sequence
7
6
  from contextlib import asynccontextmanager, contextmanager
8
7
  from contextvars import ContextVar
@@ -10,7 +9,7 @@ from dataclasses import field
10
9
  from typing import Any, Generic, Literal, Union, cast
11
10
 
12
11
  import logfire_api
13
- from typing_extensions import TypeVar, assert_never
12
+ from typing_extensions import TypeGuard, TypeVar, assert_never
14
13
 
15
14
  from pydantic_graph import BaseNode, Graph, GraphRunContext
16
15
  from pydantic_graph.nodes import End, NodeRunEndT
@@ -55,6 +54,7 @@ else:
55
54
  logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
56
55
 
57
56
  T = TypeVar('T')
57
+ S = TypeVar('S')
58
58
  NoneType = type(None)
59
59
  EndStrategy = Literal['early', 'exhaustive']
60
60
  """The strategy for handling multiple tool calls when a final result is found.
@@ -107,8 +107,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
107
107
  run_span: logfire_api.LogfireSpan
108
108
 
109
109
 
110
+ class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
111
+ """The base class for all agent nodes.
112
+
113
+ Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere
114
+ """
115
+
116
+
117
+ def is_agent_node(
118
+ node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]],
119
+ ) -> TypeGuard[AgentNode[T, S]]:
120
+ """Check if the provided node is an instance of `AgentNode`.
121
+
122
+ Usage:
123
+
124
+ if is_agent_node(node):
125
+ # `node` is an AgentNode
126
+ ...
127
+
128
+ This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`.
129
+ """
130
+ return isinstance(node, AgentNode)
131
+
132
+
110
133
  @dataclasses.dataclass
111
- class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
134
+ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
112
135
  user_prompt: str | Sequence[_messages.UserContent]
113
136
 
114
137
  system_prompts: tuple[str, ...]
@@ -215,7 +238,7 @@ async def _prepare_request_parameters(
215
238
 
216
239
 
217
240
  @dataclasses.dataclass
218
- class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
241
+ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
219
242
  """Make a request to the model using the last message in state.message_history."""
220
243
 
221
244
  request: _messages.ModelRequest
@@ -236,12 +259,30 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
236
259
 
237
260
  return await self._make_request(ctx)
238
261
 
262
+ @asynccontextmanager
263
+ async def stream(
264
+ self,
265
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
266
+ ) -> AsyncIterator[result.AgentStream[DepsT, T]]:
267
+ async with self._stream(ctx) as streamed_response:
268
+ agent_stream = result.AgentStream[DepsT, T](
269
+ streamed_response,
270
+ ctx.deps.result_schema,
271
+ ctx.deps.result_validators,
272
+ build_run_context(ctx),
273
+ ctx.deps.usage_limits,
274
+ )
275
+ yield agent_stream
276
+ # In case the user didn't manually consume the full stream, ensure it is fully consumed here,
277
+ # otherwise usage won't be properly counted:
278
+ async for _ in agent_stream:
279
+ pass
280
+
239
281
  @asynccontextmanager
240
282
  async def _stream(
241
283
  self,
242
284
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
243
285
  ) -> AsyncIterator[models.StreamedResponse]:
244
- # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
245
286
  assert not self._did_stream, 'stream() should only be called once per node'
246
287
 
247
288
  model_settings, model_request_parameters = await self._prepare_request(ctx)
@@ -319,7 +360,7 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], res
319
360
 
320
361
 
321
362
  @dataclasses.dataclass
322
- class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
363
+ class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
323
364
  """Process a model response, and decide whether to end the run or make a new request."""
324
365
 
325
366
  model_response: _messages.ModelResponse
@@ -575,7 +616,7 @@ async def process_function_tools(
575
616
  for task in done:
576
617
  index = tasks.index(task)
577
618
  result = task.result()
578
- yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
619
+ yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
579
620
  if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
580
621
  results_by_index[index] = result
581
622
  else:
pydantic_ai/agent.py CHANGED
@@ -9,9 +9,9 @@ from types import FrameType
9
9
  from typing import Any, Callable, Generic, cast, final, overload
10
10
 
11
11
  import logfire_api
12
- from typing_extensions import TypeVar, deprecated
12
+ from typing_extensions import TypeGuard, TypeVar, deprecated
13
13
 
14
- from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext
14
+ from pydantic_graph import End, Graph, GraphRun, GraphRunContext
15
15
  from pydantic_graph._utils import get_event_loop
16
16
 
17
17
  from . import (
@@ -46,7 +46,6 @@ HandleResponseNode = _agent_graph.HandleResponseNode
46
46
  ModelRequestNode = _agent_graph.ModelRequestNode
47
47
  UserPromptNode = _agent_graph.UserPromptNode
48
48
 
49
-
50
49
  __all__ = (
51
50
  'Agent',
52
51
  'AgentRun',
@@ -71,6 +70,7 @@ else:
71
70
  logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
72
71
 
73
72
  T = TypeVar('T')
73
+ S = TypeVar('S')
74
74
  NoneType = type(None)
75
75
  RunResultDataT = TypeVar('RunResultDataT')
76
76
  """Type variable for the result data of a run where `result_type` was customized on the run call."""
@@ -646,10 +646,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
646
646
  ) as agent_run:
647
647
  first_node = agent_run.next_node # start with the first node
648
648
  assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
649
- node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node)
649
+ node = first_node
650
650
  while True:
651
- if isinstance(node, _agent_graph.ModelRequestNode):
652
- node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node)
651
+ if self.is_model_request_node(node):
653
652
  graph_ctx = agent_run.ctx
654
653
  async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
655
654
 
@@ -717,9 +716,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
717
716
  )
718
717
  break
719
718
  next_node = await agent_run.next(node)
720
- if not isinstance(next_node, BaseNode):
719
+ if not isinstance(next_node, _agent_graph.AgentNode):
721
720
  raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here')
722
- node = cast(BaseNode[Any, Any, Any], next_node)
721
+ node = cast(_agent_graph.AgentNode[Any, Any], next_node)
723
722
 
724
723
  if not yielded:
725
724
  raise exceptions.AgentRunError('Agent run finished without producing a final result')
@@ -1173,6 +1172,46 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1173
1172
  else:
1174
1173
  return self._result_schema # pyright: ignore[reportReturnType]
1175
1174
 
1175
+ @staticmethod
1176
+ def is_model_request_node(
1177
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1178
+ ) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]:
1179
+ """Check if the node is a `ModelRequestNode`, narrowing the type if it is.
1180
+
1181
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1182
+ """
1183
+ return isinstance(node, _agent_graph.ModelRequestNode)
1184
+
1185
+ @staticmethod
1186
+ def is_handle_response_node(
1187
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1188
+ ) -> TypeGuard[_agent_graph.HandleResponseNode[T, S]]:
1189
+ """Check if the node is a `HandleResponseNode`, narrowing the type if it is.
1190
+
1191
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1192
+ """
1193
+ return isinstance(node, _agent_graph.HandleResponseNode)
1194
+
1195
+ @staticmethod
1196
+ def is_user_prompt_node(
1197
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1198
+ ) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]:
1199
+ """Check if the node is a `UserPromptNode`, narrowing the type if it is.
1200
+
1201
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1202
+ """
1203
+ return isinstance(node, _agent_graph.UserPromptNode)
1204
+
1205
+ @staticmethod
1206
+ def is_end_node(
1207
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1208
+ ) -> TypeGuard[End[result.FinalResult[S]]]:
1209
+ """Check if the node is a `End`, narrowing the type if it is.
1210
+
1211
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1212
+ """
1213
+ return isinstance(node, End)
1214
+
1176
1215
 
1177
1216
  @dataclasses.dataclass(repr=False)
1178
1217
  class AgentRun(Generic[AgentDepsT, ResultDataT]):
@@ -1244,15 +1283,17 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1244
1283
  @property
1245
1284
  def next_node(
1246
1285
  self,
1247
- ) -> (
1248
- BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]]
1249
- | End[FinalResult[ResultDataT]]
1250
- ):
1286
+ ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1251
1287
  """The next node that will be run in the agent graph.
1252
1288
 
1253
1289
  This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
1254
1290
  """
1255
- return self._graph_run.next_node
1291
+ next_node = self._graph_run.next_node
1292
+ if isinstance(next_node, End):
1293
+ return next_node
1294
+ if _agent_graph.is_agent_node(next_node):
1295
+ return next_node
1296
+ raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
1256
1297
 
1257
1298
  @property
1258
1299
  def result(self) -> AgentRunResult[ResultDataT] | None:
@@ -1273,45 +1314,24 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1273
1314
 
1274
1315
  def __aiter__(
1275
1316
  self,
1276
- ) -> AsyncIterator[
1277
- BaseNode[
1278
- _agent_graph.GraphAgentState,
1279
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1280
- FinalResult[ResultDataT],
1281
- ]
1282
- | End[FinalResult[ResultDataT]]
1283
- ]:
1317
+ ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]]:
1284
1318
  """Provide async-iteration over the nodes in the agent run."""
1285
1319
  return self
1286
1320
 
1287
1321
  async def __anext__(
1288
1322
  self,
1289
- ) -> (
1290
- BaseNode[
1291
- _agent_graph.GraphAgentState,
1292
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1293
- FinalResult[ResultDataT],
1294
- ]
1295
- | End[FinalResult[ResultDataT]]
1296
- ):
1323
+ ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1297
1324
  """Advance to the next node automatically based on the last returned node."""
1298
- return await self._graph_run.__anext__()
1325
+ next_node = await self._graph_run.__anext__()
1326
+ if _agent_graph.is_agent_node(next_node):
1327
+ return next_node
1328
+ assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
1329
+ return next_node
1299
1330
 
1300
1331
  async def next(
1301
1332
  self,
1302
- node: BaseNode[
1303
- _agent_graph.GraphAgentState,
1304
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1305
- FinalResult[ResultDataT],
1306
- ],
1307
- ) -> (
1308
- BaseNode[
1309
- _agent_graph.GraphAgentState,
1310
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1311
- FinalResult[ResultDataT],
1312
- ]
1313
- | End[FinalResult[ResultDataT]]
1314
- ):
1333
+ node: _agent_graph.AgentNode[AgentDepsT, ResultDataT],
1334
+ ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1315
1335
  """Manually drive the agent run by passing in the node you want to run next.
1316
1336
 
1317
1337
  This lets you inspect or mutate the node before continuing execution, or skip certain nodes
@@ -1378,7 +1398,11 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1378
1398
  """
1379
1399
  # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
1380
1400
  # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
1381
- return await self._graph_run.next(node)
1401
+ next_node = await self._graph_run.next(node)
1402
+ if _agent_graph.is_agent_node(next_node):
1403
+ return next_node
1404
+ assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
1405
+ return next_node
1382
1406
 
1383
1407
  def usage(self) -> _usage.Usage:
1384
1408
  """Get usage statistics for the run so far, including token usage, model requests, and so on."""
pydantic_ai/messages.py CHANGED
@@ -533,9 +533,24 @@ class PartDeltaEvent:
533
533
  """Event type identifier, used as a discriminator."""
534
534
 
535
535
 
536
+ @dataclass
537
+ class FinalResultEvent:
538
+ """An event indicating the response to the current model request matches the result schema."""
539
+
540
+ tool_name: str | None
541
+ """The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
542
+ event_kind: Literal['final_result'] = 'final_result'
543
+ """Event type identifier, used as a discriminator."""
544
+
545
+
536
546
  ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
537
547
  """An event in the model response stream, either starting a new part or applying a delta to an existing one."""
538
548
 
549
+ AgentStreamEvent = Annotated[
550
+ Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
551
+ ]
552
+ """An event in the agent stream."""
553
+
539
554
 
540
555
  @dataclass
541
556
  class FunctionToolCallEvent:
@@ -558,7 +573,7 @@ class FunctionToolResultEvent:
558
573
 
559
574
  result: ToolReturnPart | RetryPromptPart
560
575
  """The result of the call to the function tool."""
561
- call_id: str
576
+ tool_call_id: str
562
577
  """An ID used to match the result to its original call."""
563
578
  event_kind: Literal['function_tool_result'] = 'function_tool_result'
564
579
  """Event type identifier, used as a discriminator."""
@@ -84,6 +84,8 @@ KnownModelName = Literal[
84
84
  'gpt-4-turbo-2024-04-09',
85
85
  'gpt-4-turbo-preview',
86
86
  'gpt-4-vision-preview',
87
+ 'gpt-4.5-preview',
88
+ 'gpt-4.5-preview-2025-02-27',
87
89
  'gpt-4o',
88
90
  'gpt-4o-2024-05-13',
89
91
  'gpt-4o-2024-08-06',
@@ -138,6 +140,8 @@ KnownModelName = Literal[
138
140
  'openai:gpt-4-turbo-2024-04-09',
139
141
  'openai:gpt-4-turbo-preview',
140
142
  'openai:gpt-4-vision-preview',
143
+ 'openai:gpt-4.5-preview',
144
+ 'openai:gpt-4.5-preview-2025-02-27',
141
145
  'openai:gpt-4o',
142
146
  'openai:gpt-4o-2024-05-13',
143
147
  'openai:gpt-4o-2024-08-06',
@@ -177,6 +177,8 @@ class DeltaToolCall:
177
177
  """Incremental change to the name of the tool."""
178
178
  json_args: str | None = None
179
179
  """Incremental change to the arguments as JSON"""
180
+ tool_call_id: str | None = None
181
+ """Incremental change to the tool call ID."""
180
182
 
181
183
 
182
184
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
@@ -224,7 +226,7 @@ class FunctionStreamedResponse(StreamedResponse):
224
226
  vendor_part_id=dtc_index,
225
227
  tool_name=delta_tool_call.name,
226
228
  args=delta_tool_call.json_args,
227
- tool_call_id=None,
229
+ tool_call_id=delta_tool_call.tool_call_id,
228
230
  )
229
231
  if maybe_event is not None:
230
232
  yield maybe_event
@@ -280,7 +282,16 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
280
282
  return 0
281
283
  if isinstance(content, str):
282
284
  return len(re.split(r'[\s",.:]+', content.strip()))
283
- # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
284
285
  else: # pragma: no cover
285
- assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
286
- return 0
286
+ tokens = 0
287
+ for part in content:
288
+ if isinstance(part, str):
289
+ tokens += len(re.split(r'[\s",.:]+', part.strip()))
290
+ # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
291
+ if isinstance(part, (AudioUrl, ImageUrl)):
292
+ tokens += 0
293
+ elif isinstance(part, BinaryContent):
294
+ tokens += len(part.data)
295
+ else:
296
+ tokens += 0
297
+ return tokens
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
3
4
  from collections.abc import AsyncIterator, Iterator
4
5
  from contextlib import asynccontextmanager, contextmanager
5
6
  from dataclasses import dataclass, field
@@ -9,6 +10,7 @@ from typing import Any, Callable, Literal
9
10
  import logfire_api
10
11
  from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
11
12
  from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider
13
+ from opentelemetry.util.types import AttributeValue
12
14
 
13
15
  from ..messages import (
14
16
  ModelMessage,
@@ -46,40 +48,42 @@ MODEL_SETTING_ATTRIBUTES: tuple[
46
48
  'frequency_penalty',
47
49
  )
48
50
 
49
- NOT_GIVEN = object()
50
-
51
51
 
52
52
  @dataclass
53
53
  class InstrumentedModel(WrapperModel):
54
- """Model which is instrumented with logfire."""
54
+ """Model which is instrumented with OpenTelemetry."""
55
55
 
56
56
  tracer: Tracer = field(repr=False)
57
57
  event_logger: EventLogger = field(repr=False)
58
+ event_mode: Literal['attributes', 'logs'] = 'attributes'
58
59
 
59
60
  def __init__(
60
61
  self,
61
62
  wrapped: Model | KnownModelName,
62
63
  tracer_provider: TracerProvider | None = None,
63
64
  event_logger_provider: EventLoggerProvider | None = None,
65
+ event_mode: Literal['attributes', 'logs'] = 'attributes',
64
66
  ):
65
67
  super().__init__(wrapped)
66
68
  tracer_provider = tracer_provider or get_tracer_provider()
67
69
  event_logger_provider = event_logger_provider or get_event_logger_provider()
68
70
  self.tracer = tracer_provider.get_tracer('pydantic-ai')
69
71
  self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
72
+ self.event_mode = event_mode
70
73
 
71
74
  @classmethod
72
75
  def from_logfire(
73
76
  cls,
74
77
  wrapped: Model | KnownModelName,
75
78
  logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
79
+ event_mode: Literal['attributes', 'logs'] = 'attributes',
76
80
  ) -> InstrumentedModel:
77
81
  if hasattr(logfire_instance.config, 'get_event_logger_provider'):
78
82
  event_provider = logfire_instance.config.get_event_logger_provider()
79
83
  else:
80
84
  event_provider = None
81
85
  tracer_provider = logfire_instance.config.get_tracer_provider()
82
- return cls(wrapped, tracer_provider, event_provider)
86
+ return cls(wrapped, tracer_provider, event_provider, event_mode)
83
87
 
84
88
  async def request(
85
89
  self,
@@ -111,7 +115,7 @@ class InstrumentedModel(WrapperModel):
111
115
  finish(response_stream.get(), response_stream.usage())
112
116
 
113
117
  @contextmanager
114
- def _instrument(
118
+ def _instrument( # noqa: C901
115
119
  self,
116
120
  messages: list[ModelMessage],
117
121
  model_settings: ModelSettings | None,
@@ -126,7 +130,7 @@ class InstrumentedModel(WrapperModel):
126
130
  # - server.port: to parse from the base_url
127
131
  # - error.type: unclear if we should do something here or just always rely on span exceptions
128
132
  # - gen_ai.request.stop_sequences/top_k: model_settings doesn't include these
129
- attributes: dict[str, Any] = {
133
+ attributes: dict[str, AttributeValue] = {
130
134
  'gen_ai.operation.name': operation,
131
135
  'gen_ai.system': system,
132
136
  'gen_ai.request.model': model_name,
@@ -134,10 +138,11 @@ class InstrumentedModel(WrapperModel):
134
138
 
135
139
  if model_settings:
136
140
  for key in MODEL_SETTING_ATTRIBUTES:
137
- if (value := model_settings.get(key, NOT_GIVEN)) is not NOT_GIVEN:
141
+ if isinstance(value := model_settings.get(key), (float, int)):
138
142
  attributes[f'gen_ai.request.{key}'] = value
139
143
 
140
- emit_event = partial(self._emit_event, system)
144
+ events_list = []
145
+ emit_event = partial(self._emit_event, system, events_list)
141
146
 
142
147
  with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
143
148
  if span.is_recording():
@@ -167,22 +172,36 @@ class InstrumentedModel(WrapperModel):
167
172
  )
168
173
  span.set_attributes(
169
174
  {
170
- k: v
171
- for k, v in {
172
- # TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
173
- # https://github.com/pydantic/pydantic-ai/issues/886
174
- 'gen_ai.response.model': response.model_name or model_name,
175
- 'gen_ai.usage.input_tokens': usage.request_tokens,
176
- 'gen_ai.usage.output_tokens': usage.response_tokens,
177
- }.items()
178
- if v is not None
175
+ # TODO finish_reason (https://github.com/open-telemetry/semantic-conventions/issues/1277), id
176
+ # https://github.com/pydantic/pydantic-ai/issues/886
177
+ 'gen_ai.response.model': response.model_name or model_name,
178
+ **usage.opentelemetry_attributes(),
179
179
  }
180
180
  )
181
+ if events_list:
182
+ attr_name = 'events'
183
+ span.set_attributes(
184
+ {
185
+ attr_name: json.dumps(events_list),
186
+ 'logfire.json_schema': json.dumps(
187
+ {
188
+ 'type': 'object',
189
+ 'properties': {attr_name: {'type': 'array'}},
190
+ }
191
+ ),
192
+ }
193
+ )
181
194
 
182
195
  yield finish
183
196
 
184
- def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
185
- self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))
197
+ def _emit_event(
198
+ self, system: str, events_list: list[dict[str, Any]], event_name: str, body: dict[str, Any]
199
+ ) -> None:
200
+ attributes = {'gen_ai.system': system}
201
+ if self.event_mode == 'logs':
202
+ self.event_logger.emit(Event(event_name, body=body, attributes=attributes))
203
+ else:
204
+ events_list.append({'event.name': event_name, **body, **attributes})
186
205
 
187
206
 
188
207
  def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
pydantic_ai/result.py CHANGED
@@ -7,9 +7,10 @@ from datetime import datetime
7
7
  from typing import Generic, Union, cast
8
8
 
9
9
  import logfire_api
10
- from typing_extensions import TypeVar
10
+ from typing_extensions import TypeVar, assert_type
11
11
 
12
12
  from . import _result, _utils, exceptions, messages as _messages, models
13
+ from .messages import AgentStreamEvent, FinalResultEvent
13
14
  from .tools import AgentDepsT, RunContext
14
15
  from .usage import Usage, UsageLimits
15
16
 
@@ -51,6 +52,125 @@ Usage `ResultValidatorFunc[AgentDepsT, T]`.
51
52
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
52
53
 
53
54
 
55
+ @dataclass
56
+ class AgentStream(Generic[AgentDepsT, ResultDataT]):
57
+ _raw_stream_response: models.StreamedResponse
58
+ _result_schema: _result.ResultSchema[ResultDataT] | None
59
+ _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
60
+ _run_ctx: RunContext[AgentDepsT]
61
+ _usage_limits: UsageLimits | None
62
+
63
+ _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
64
+ _final_result_event: FinalResultEvent | None = field(default=None, init=False)
65
+ _initial_run_ctx_usage: Usage = field(init=False)
66
+
67
+ def __post_init__(self):
68
+ self._initial_run_ctx_usage = copy(self._run_ctx.usage)
69
+
70
+ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
71
+ """Asynchronously stream the (validated) agent outputs."""
72
+ async for response in self.stream_responses(debounce_by=debounce_by):
73
+ if self._final_result_event is not None:
74
+ yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
75
+ if self._final_result_event is not None:
76
+ yield await self._validate_response(
77
+ self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
78
+ )
79
+
80
+ async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
81
+ """Asynchronously stream the (unvalidated) model responses for the agent."""
82
+ # if the message currently has any parts with content, yield before streaming
83
+ msg = self._raw_stream_response.get()
84
+ for part in msg.parts:
85
+ if part.has_content():
86
+ yield msg
87
+ break
88
+
89
+ async with _utils.group_by_temporal(self, debounce_by) as group_iter:
90
+ async for _items in group_iter:
91
+ yield self._raw_stream_response.get() # current state of the response
92
+
93
+ def usage(self) -> Usage:
94
+ """Return the usage of the whole run.
95
+
96
+ !!! note
97
+ This won't return the full usage until the stream is finished.
98
+ """
99
+ return self._initial_run_ctx_usage + self._raw_stream_response.usage()
100
+
101
+ async def _validate_response(
102
+ self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False
103
+ ) -> ResultDataT:
104
+ """Validate a structured result message."""
105
+ if self._result_schema is not None and result_tool_name is not None:
106
+ match = self._result_schema.find_named_tool(message.parts, result_tool_name)
107
+ if match is None:
108
+ raise exceptions.UnexpectedModelBehavior(
109
+ f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
110
+ )
111
+
112
+ call, result_tool = match
113
+ result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
114
+
115
+ for validator in self._result_validators:
116
+ result_data = await validator.validate(result_data, call, self._run_ctx)
117
+ return result_data
118
+ else:
119
+ text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
120
+ for validator in self._result_validators:
121
+ text = await validator.validate(
122
+ text,
123
+ None,
124
+ self._run_ctx,
125
+ )
126
+ # Since there is no result tool, we can assume that str is compatible with ResultDataT
127
+ return cast(ResultDataT, text)
128
+
129
+ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
130
+ """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
131
+
132
+ This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches
133
+ on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
134
+ first match is found.
135
+ """
136
+ if self._agent_stream_iterator is not None:
137
+ return self._agent_stream_iterator
138
+
139
+ async def aiter():
140
+ result_schema = self._result_schema
141
+ allow_text_result = result_schema is None or result_schema.allow_text_result
142
+
143
+ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
144
+ """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
145
+ if isinstance(e, _messages.PartStartEvent):
146
+ new_part = e.part
147
+ if isinstance(new_part, _messages.ToolCallPart):
148
+ if result_schema is not None and (match := result_schema.find_tool([new_part])):
149
+ call, _ = match
150
+ return _messages.FinalResultEvent(tool_name=call.tool_name)
151
+ elif allow_text_result:
152
+ assert_type(e, _messages.PartStartEvent)
153
+ return _messages.FinalResultEvent(tool_name=None)
154
+
155
+ usage_checking_stream = _get_usage_checking_stream_response(
156
+ self._raw_stream_response, self._usage_limits, self.usage
157
+ )
158
+ async for event in usage_checking_stream:
159
+ yield event
160
+ if (final_result_event := _get_final_result_event(event)) is not None:
161
+ self._final_result_event = final_result_event
162
+ yield final_result_event
163
+ break
164
+
165
+ # If we broke out of the above loop, we need to yield the rest of the events
166
+ # If we didn't, this will just be a no-op
167
+ async for event in usage_checking_stream:
168
+ yield event
169
+
170
+ self._agent_stream_iterator = aiter()
171
+ return self._agent_stream_iterator
172
+
173
+
54
174
  @dataclass
55
175
  class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
56
176
  """Result of a streamed run that returns structured data via a tool call."""
pydantic_ai/usage.py CHANGED
@@ -56,6 +56,16 @@ class Usage:
56
56
  new_usage.incr(other)
57
57
  return new_usage
58
58
 
59
+ def opentelemetry_attributes(self) -> dict[str, int]:
60
+ """Get the token limits as OpenTelemetry attributes."""
61
+ result = {
62
+ 'gen_ai.usage.input_tokens': self.request_tokens,
63
+ 'gen_ai.usage.output_tokens': self.response_tokens,
64
+ }
65
+ for key, value in (self.details or {}).items():
66
+ result[f'gen_ai.usage.details.{key}'] = value
67
+ return {k: v for k, v in result.items() if v is not None}
68
+
59
69
 
60
70
  @dataclass
61
71
  class UsageLimits:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.29
3
+ Version: 0.0.30
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: logfire-api>=1.2.0
32
- Requires-Dist: pydantic-graph==0.0.29
32
+ Requires-Dist: pydantic-graph==0.0.30
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Provides-Extra: anthropic
35
35
  Requires-Dist: anthropic>=0.40.0; extra == 'anthropic'
@@ -44,7 +44,7 @@ Requires-Dist: logfire>=2.3; extra == 'logfire'
44
44
  Provides-Extra: mistral
45
45
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
46
46
  Provides-Extra: openai
47
- Requires-Dist: openai>=1.61.0; extra == 'openai'
47
+ Requires-Dist: openai>=1.65.1; extra == 'openai'
48
48
  Provides-Extra: tavily
49
49
  Requires-Dist: tavily-python>=0.5.0; extra == 'tavily'
50
50
  Provides-Extra: vertexai
@@ -1,36 +1,36 @@
1
1
  pydantic_ai/__init__.py,sha256=Rmpjmorf8YY1PtlkXRRNN-J3ZoQDSh7chaibVGyqY0k,937
2
- pydantic_ai/_agent_graph.py,sha256=lNKTtUyVY14M0WODP5K1NUaE9zJA716-9rZutapSg8A,29042
2
+ pydantic_ai/_agent_graph.py,sha256=gvJQ17A2glk8p2w2TCSfHwvWNp0vla1sQb0EZXOZbxU,30284
3
3
  pydantic_ai/_griffe.py,sha256=RYRKiLbgG97QxnazbAwlnc74XxevGHLQet-FGfq9qls,3960
4
4
  pydantic_ai/_parts_manager.py,sha256=ARfDQY1_5AIY5rNl_M2fAYHEFCe03ZxdhgjHf9qeIKw,11872
5
5
  pydantic_ai/_pydantic.py,sha256=dROz3Hmfdi0C2exq88FhefDRVo_8S3rtkXnoUHzsz0c,8753
6
6
  pydantic_ai/_result.py,sha256=tN1pVulf_EM4bkBvpNUWPnUXezLY-sBrJEVCFdy2nLU,10264
7
7
  pydantic_ai/_system_prompt.py,sha256=602c2jyle2R_SesOrITBDETZqsLk4BZ8Cbo8yEhmx04,1120
8
8
  pydantic_ai/_utils.py,sha256=w9BYSfFZiaX757fRtMRclOL1uYzyQnxV_lxqbU2WTPs,9435
9
- pydantic_ai/agent.py,sha256=wCXvGwPykn-cYf1_4bR4XqYofqwI6X2lj4L-0-MbMp4,63685
9
+ pydantic_ai/agent.py,sha256=FeKELTSFKDkt6-UlmkezKnQTdnx1in6VckivqsfzfA4,65382
10
10
  pydantic_ai/exceptions.py,sha256=1ujJeB3jDDQ-pH5ydBYrgStvR35-GlEW0bYGTGEr4ME,3127
11
11
  pydantic_ai/format_as_xml.py,sha256=QE7eMlg5-YUMw1_2kcI3h0uKYPZZyGkgXFDtfZTMeeI,4480
12
- pydantic_ai/messages.py,sha256=U-RgeRsMR-Ew6IoeBDrnQVONX9AwxyVd0aTnAxEA7EM,20918
12
+ pydantic_ai/messages.py,sha256=k8sX-V1cTeqXh1u6oJbqExZPYt3E7F3UCIudxvjKRO8,21486
13
13
  pydantic_ai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- pydantic_ai/result.py,sha256=Rqbog6efO1l_bFJSuAd-_ZZLoQa_rz4motOGeR_5N3I,16803
14
+ pydantic_ai/result.py,sha256=Df_tPeqCQnLa0i0vVA-BGCJDx37ebD_3ojAmHnXE2yU,22767
15
15
  pydantic_ai/settings.py,sha256=ntuWnke9UA18aByDxk9OIhN0tAgOaPdqCEkRf-wlp8Y,3059
16
16
  pydantic_ai/tools.py,sha256=IPZuZJCSQUppz1uyLVwpfFLGoMirB8YtKWXIDQGR444,13414
17
- pydantic_ai/usage.py,sha256=60d9f6M7YEYuKMbqDGDogX4KsA73fhDtWyDXYXoIPaI,4948
17
+ pydantic_ai/usage.py,sha256=VmpU_o_RjFI65J81G1wfCwDIAYBclMjeWfLtslntFOw,5406
18
18
  pydantic_ai/common_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  pydantic_ai/common_tools/duckduckgo.py,sha256=-kSa1gGn5-NIYvtxFWrFcX2XdmfEmGxI3_wAqrb6jLI,2230
20
20
  pydantic_ai/common_tools/tavily.py,sha256=Lz35037ggkdWKa_Stj0yXBkiN_hygDefEevoRDUclF0,2560
21
- pydantic_ai/models/__init__.py,sha256=Qw_g58KzGUmuDKOBa2u3yFrNbgCXkdRSNtkhseLC1VM,13758
21
+ pydantic_ai/models/__init__.py,sha256=2A3CpdMnvllnVVX8PmlUcBs0HMGcG4RurOXsRKl0BPc,13886
22
22
  pydantic_ai/models/anthropic.py,sha256=bFtE6hku9L4l4pKJg8XER37T2ST2htArho5lPjEohAk,20637
23
23
  pydantic_ai/models/cohere.py,sha256=6F6eWPGVT7mpMXlRugbVbR-a8Q1zmb1SKS_fWOoBL80,11514
24
24
  pydantic_ai/models/fallback.py,sha256=smHwNIpxu19JsgYYjY0nmzl3yox7yQRJ0Ir08zdhnk0,4207
25
- pydantic_ai/models/function.py,sha256=EMlASu436RE-XzOTuHGkIqkS8J4WItUvwwaL08LLkX8,10948
25
+ pydantic_ai/models/function.py,sha256=THIwVJ8qI3efYLNlYXlYze_J8hc7MHB-NMb3kpknq0g,11373
26
26
  pydantic_ai/models/gemini.py,sha256=2hDTMIMf899dp-MS0tLT7m1GkXsL9KIRMBklGM0VLB4,34223
27
27
  pydantic_ai/models/groq.py,sha256=Z4sZJDu5Yxa2tZiAPp9EjSVMz4uwLhS3fW7kFSc09gI,16406
28
- pydantic_ai/models/instrumented.py,sha256=cvjHgQE_gJOH-YVQyvx9tBpGNB_Iuc8N8THn0TL0Rjk,8791
28
+ pydantic_ai/models/instrumented.py,sha256=xUZEn2VG8hP3hny0L5kZgXC5UnFdlUJ0DgXOxFmYhEo,9654
29
29
  pydantic_ai/models/mistral.py,sha256=ZJ4xPcL9wJIQ5io34yP2fPyXy8GZrSvsW4itZiKPYFw,27448
30
30
  pydantic_ai/models/openai.py,sha256=koIcK_pDHmV-JFq_-VIzU-edAqGKOOzkSk5QSYWvfoc,20156
31
31
  pydantic_ai/models/test.py,sha256=Ux20cmuJFkhvI9L1N7ItHNFcd-j284TBEsrM53eWRag,16873
32
32
  pydantic_ai/models/vertexai.py,sha256=9Kp_1KMBlbP8_HRJTuFnrkkFmlJ7yFhADQYjxOgIh9Y,9523
33
33
  pydantic_ai/models/wrapper.py,sha256=Zr3fgiUBpt2N9gXds6iSwaMEtEsFKr9WwhpHjSoHa7o,1410
34
- pydantic_ai_slim-0.0.29.dist-info/METADATA,sha256=f7eLYpKWzEGmzTC8U29ZO5T4aTYMTW0Shmqa1zuEtG0,3062
35
- pydantic_ai_slim-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- pydantic_ai_slim-0.0.29.dist-info/RECORD,,
34
+ pydantic_ai_slim-0.0.30.dist-info/METADATA,sha256=JDT77S9uw0w87WpAbXqK_c65849A7PeF1_dhJRGamiM,3062
35
+ pydantic_ai_slim-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
+ pydantic_ai_slim-0.0.30.dist-info/RECORD,,