pydantic-ai-slim 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -9,9 +9,9 @@ from types import FrameType
9
9
  from typing import Any, Callable, Generic, cast, final, overload
10
10
 
11
11
  import logfire_api
12
- from typing_extensions import TypeVar, deprecated
12
+ from typing_extensions import TypeGuard, TypeVar, deprecated
13
13
 
14
- from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext
14
+ from pydantic_graph import End, Graph, GraphRun, GraphRunContext
15
15
  from pydantic_graph._utils import get_event_loop
16
16
 
17
17
  from . import (
@@ -25,6 +25,7 @@ from . import (
25
25
  result,
26
26
  usage as _usage,
27
27
  )
28
+ from .models.instrumented import InstrumentedModel
28
29
  from .result import FinalResult, ResultDataT, StreamedRunResult
29
30
  from .settings import ModelSettings, merge_model_settings
30
31
  from .tools import (
@@ -42,18 +43,17 @@ from .tools import (
42
43
  # Re-exporting like this improves auto-import behavior in PyCharm
43
44
  capture_run_messages = _agent_graph.capture_run_messages
44
45
  EndStrategy = _agent_graph.EndStrategy
45
- HandleResponseNode = _agent_graph.HandleResponseNode
46
+ CallToolsNode = _agent_graph.CallToolsNode
46
47
  ModelRequestNode = _agent_graph.ModelRequestNode
47
48
  UserPromptNode = _agent_graph.UserPromptNode
48
49
 
49
-
50
50
  __all__ = (
51
51
  'Agent',
52
52
  'AgentRun',
53
53
  'AgentRunResult',
54
54
  'capture_run_messages',
55
55
  'EndStrategy',
56
- 'HandleResponseNode',
56
+ 'CallToolsNode',
57
57
  'ModelRequestNode',
58
58
  'UserPromptNode',
59
59
  )
@@ -71,6 +71,7 @@ else:
71
71
  logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
72
72
 
73
73
  T = TypeVar('T')
74
+ S = TypeVar('S')
74
75
  NoneType = type(None)
75
76
  RunResultDataT = TypeVar('RunResultDataT')
76
77
  """Type variable for the result data of a run where `result_type` was customized on the run call."""
@@ -294,7 +295,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
294
295
  """
295
296
  if infer_name and self.name is None:
296
297
  self._infer_name(inspect.currentframe())
297
- with self.iter(
298
+ async with self.iter(
298
299
  user_prompt=user_prompt,
299
300
  result_type=result_type,
300
301
  message_history=message_history,
@@ -310,8 +311,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
310
311
  assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
311
312
  return final_result
312
313
 
313
- @contextmanager
314
- def iter(
314
+ @asynccontextmanager
315
+ async def iter(
315
316
  self,
316
317
  user_prompt: str | Sequence[_messages.UserContent],
317
318
  *,
@@ -323,7 +324,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
323
324
  usage_limits: _usage.UsageLimits | None = None,
324
325
  usage: _usage.Usage | None = None,
325
326
  infer_name: bool = True,
326
- ) -> Iterator[AgentRun[AgentDepsT, Any]]:
327
+ ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
327
328
  """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
328
329
 
329
330
  This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an
@@ -344,7 +345,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
344
345
 
345
346
  async def main():
346
347
  nodes = []
347
- with agent.iter('What is the capital of France?') as agent_run:
348
+ async with agent.iter('What is the capital of France?') as agent_run:
348
349
  async for node in agent_run:
349
350
  nodes.append(node)
350
351
  print(nodes)
@@ -362,7 +363,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
362
363
  kind='request',
363
364
  )
364
365
  ),
365
- HandleResponseNode(
366
+ CallToolsNode(
366
367
  model_response=ModelResponse(
367
368
  parts=[TextPart(content='Paris', part_kind='text')],
368
369
  model_name='gpt-4o',
@@ -370,7 +371,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
370
371
  kind='response',
371
372
  )
372
373
  ),
373
- End(data=FinalResult(data='Paris', tool_name=None)),
374
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
374
375
  ]
375
376
  '''
376
377
  print(agent_run.result.data)
@@ -454,7 +455,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
454
455
  system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
455
456
  )
456
457
 
457
- with graph.iter(
458
+ async with graph.iter(
458
459
  start_node,
459
460
  state=state,
460
461
  deps=graph_deps,
@@ -633,7 +634,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
633
634
  self._infer_name(frame.f_back)
634
635
 
635
636
  yielded = False
636
- with self.iter(
637
+ async with self.iter(
637
638
  user_prompt,
638
639
  result_type=result_type,
639
640
  message_history=message_history,
@@ -646,10 +647,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
646
647
  ) as agent_run:
647
648
  first_node = agent_run.next_node # start with the first node
648
649
  assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
649
- node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node)
650
+ node = first_node
650
651
  while True:
651
- if isinstance(node, _agent_graph.ModelRequestNode):
652
- node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node)
652
+ if self.is_model_request_node(node):
653
653
  graph_ctx = agent_run.ctx
654
654
  async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]
655
655
 
@@ -662,11 +662,10 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
662
662
  new_part = maybe_part_event.part
663
663
  if isinstance(new_part, _messages.TextPart):
664
664
  if _agent_graph.allow_text_result(result_schema):
665
- return FinalResult(s, None)
666
- elif isinstance(new_part, _messages.ToolCallPart):
667
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
668
- call, _ = match
669
- return FinalResult(s, call.tool_name)
665
+ return FinalResult(s, None, None)
666
+ elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
667
+ for call, _ in result_schema.find_tool([new_part]):
668
+ return FinalResult(s, call.tool_name, call.tool_call_id)
670
669
  return None
671
670
 
672
671
  final_result_details = await stream_to_final(streamed_response)
@@ -693,6 +692,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
693
692
  async for _event in _agent_graph.process_function_tools(
694
693
  tool_calls,
695
694
  final_result_details.tool_name,
695
+ final_result_details.tool_call_id,
696
696
  graph_ctx,
697
697
  parts,
698
698
  ):
@@ -717,9 +717,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
717
717
  )
718
718
  break
719
719
  next_node = await agent_run.next(node)
720
- if not isinstance(next_node, BaseNode):
720
+ if not isinstance(next_node, _agent_graph.AgentNode):
721
721
  raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here')
722
- node = cast(BaseNode[Any, Any, Any], next_node)
722
+ node = cast(_agent_graph.AgentNode[Any, Any], next_node)
723
723
 
724
724
  if not yielded:
725
725
  raise exceptions.AgentRunError('Agent run finished without producing a final result')
@@ -1116,6 +1116,9 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1116
1116
  else:
1117
1117
  raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
1118
1118
 
1119
+ if not isinstance(model_, InstrumentedModel):
1120
+ model_ = InstrumentedModel(model_)
1121
+
1119
1122
  return model_
1120
1123
 
1121
1124
  def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
@@ -1173,12 +1176,52 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1173
1176
  else:
1174
1177
  return self._result_schema # pyright: ignore[reportReturnType]
1175
1178
 
1179
+ @staticmethod
1180
+ def is_model_request_node(
1181
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1182
+ ) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]:
1183
+ """Check if the node is a `ModelRequestNode`, narrowing the type if it is.
1184
+
1185
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1186
+ """
1187
+ return isinstance(node, _agent_graph.ModelRequestNode)
1188
+
1189
+ @staticmethod
1190
+ def is_call_tools_node(
1191
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1192
+ ) -> TypeGuard[_agent_graph.CallToolsNode[T, S]]:
1193
+ """Check if the node is a `CallToolsNode`, narrowing the type if it is.
1194
+
1195
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1196
+ """
1197
+ return isinstance(node, _agent_graph.CallToolsNode)
1198
+
1199
+ @staticmethod
1200
+ def is_user_prompt_node(
1201
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1202
+ ) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]:
1203
+ """Check if the node is a `UserPromptNode`, narrowing the type if it is.
1204
+
1205
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1206
+ """
1207
+ return isinstance(node, _agent_graph.UserPromptNode)
1208
+
1209
+ @staticmethod
1210
+ def is_end_node(
1211
+ node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],
1212
+ ) -> TypeGuard[End[result.FinalResult[S]]]:
1213
+ """Check if the node is a `End`, narrowing the type if it is.
1214
+
1215
+ This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.
1216
+ """
1217
+ return isinstance(node, End)
1218
+
1176
1219
 
1177
1220
  @dataclasses.dataclass(repr=False)
1178
1221
  class AgentRun(Generic[AgentDepsT, ResultDataT]):
1179
1222
  """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
1180
1223
 
1181
- You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`.
1224
+ You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
1182
1225
 
1183
1226
  Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
1184
1227
  [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
@@ -1193,7 +1236,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1193
1236
  async def main():
1194
1237
  nodes = []
1195
1238
  # Iterate through the run, recording each node along the way:
1196
- with agent.iter('What is the capital of France?') as agent_run:
1239
+ async with agent.iter('What is the capital of France?') as agent_run:
1197
1240
  async for node in agent_run:
1198
1241
  nodes.append(node)
1199
1242
  print(nodes)
@@ -1211,7 +1254,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1211
1254
  kind='request',
1212
1255
  )
1213
1256
  ),
1214
- HandleResponseNode(
1257
+ CallToolsNode(
1215
1258
  model_response=ModelResponse(
1216
1259
  parts=[TextPart(content='Paris', part_kind='text')],
1217
1260
  model_name='gpt-4o',
@@ -1219,7 +1262,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1219
1262
  kind='response',
1220
1263
  )
1221
1264
  ),
1222
- End(data=FinalResult(data='Paris', tool_name=None)),
1265
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1223
1266
  ]
1224
1267
  '''
1225
1268
  print(agent_run.result.data)
@@ -1244,15 +1287,17 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1244
1287
  @property
1245
1288
  def next_node(
1246
1289
  self,
1247
- ) -> (
1248
- BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]]
1249
- | End[FinalResult[ResultDataT]]
1250
- ):
1290
+ ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1251
1291
  """The next node that will be run in the agent graph.
1252
1292
 
1253
1293
  This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
1254
1294
  """
1255
- return self._graph_run.next_node
1295
+ next_node = self._graph_run.next_node
1296
+ if isinstance(next_node, End):
1297
+ return next_node
1298
+ if _agent_graph.is_agent_node(next_node):
1299
+ return next_node
1300
+ raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
1256
1301
 
1257
1302
  @property
1258
1303
  def result(self) -> AgentRunResult[ResultDataT] | None:
@@ -1273,45 +1318,24 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1273
1318
 
1274
1319
  def __aiter__(
1275
1320
  self,
1276
- ) -> AsyncIterator[
1277
- BaseNode[
1278
- _agent_graph.GraphAgentState,
1279
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1280
- FinalResult[ResultDataT],
1281
- ]
1282
- | End[FinalResult[ResultDataT]]
1283
- ]:
1321
+ ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]]:
1284
1322
  """Provide async-iteration over the nodes in the agent run."""
1285
1323
  return self
1286
1324
 
1287
1325
  async def __anext__(
1288
1326
  self,
1289
- ) -> (
1290
- BaseNode[
1291
- _agent_graph.GraphAgentState,
1292
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1293
- FinalResult[ResultDataT],
1294
- ]
1295
- | End[FinalResult[ResultDataT]]
1296
- ):
1327
+ ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1297
1328
  """Advance to the next node automatically based on the last returned node."""
1298
- return await self._graph_run.__anext__()
1329
+ next_node = await self._graph_run.__anext__()
1330
+ if _agent_graph.is_agent_node(next_node):
1331
+ return next_node
1332
+ assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
1333
+ return next_node
1299
1334
 
1300
1335
  async def next(
1301
1336
  self,
1302
- node: BaseNode[
1303
- _agent_graph.GraphAgentState,
1304
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1305
- FinalResult[ResultDataT],
1306
- ],
1307
- ) -> (
1308
- BaseNode[
1309
- _agent_graph.GraphAgentState,
1310
- _agent_graph.GraphAgentDeps[AgentDepsT, Any],
1311
- FinalResult[ResultDataT],
1312
- ]
1313
- | End[FinalResult[ResultDataT]]
1314
- ):
1337
+ node: _agent_graph.AgentNode[AgentDepsT, ResultDataT],
1338
+ ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]:
1315
1339
  """Manually drive the agent run by passing in the node you want to run next.
1316
1340
 
1317
1341
  This lets you inspect or mutate the node before continuing execution, or skip certain nodes
@@ -1326,7 +1350,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1326
1350
  agent = Agent('openai:gpt-4o')
1327
1351
 
1328
1352
  async def main():
1329
- with agent.iter('What is the capital of France?') as agent_run:
1353
+ async with agent.iter('What is the capital of France?') as agent_run:
1330
1354
  next_node = agent_run.next_node # start with the first node
1331
1355
  nodes = [next_node]
1332
1356
  while not isinstance(next_node, End):
@@ -1354,7 +1378,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1354
1378
  kind='request',
1355
1379
  )
1356
1380
  ),
1357
- HandleResponseNode(
1381
+ CallToolsNode(
1358
1382
  model_response=ModelResponse(
1359
1383
  parts=[TextPart(content='Paris', part_kind='text')],
1360
1384
  model_name='gpt-4o',
@@ -1362,7 +1386,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1362
1386
  kind='response',
1363
1387
  )
1364
1388
  ),
1365
- End(data=FinalResult(data='Paris', tool_name=None)),
1389
+ End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
1366
1390
  ]
1367
1391
  '''
1368
1392
  print('Final result:', agent_run.result.data)
@@ -1378,7 +1402,11 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1378
1402
  """
1379
1403
  # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it
1380
1404
  # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate.
1381
- return await self._graph_run.next(node)
1405
+ next_node = await self._graph_run.next(node)
1406
+ if _agent_graph.is_agent_node(next_node):
1407
+ return next_node
1408
+ assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
1409
+ return next_node
1382
1410
 
1383
1411
  def usage(self) -> _usage.Usage:
1384
1412
  """Get usage statistics for the run so far, including token usage, model requests, and so on."""
pydantic_ai/messages.py CHANGED
@@ -8,6 +8,7 @@ from typing import Annotated, Any, Literal, Union, cast, overload
8
8
 
9
9
  import pydantic
10
10
  import pydantic_core
11
+ from opentelemetry._events import Event
11
12
  from typing_extensions import TypeAlias
12
13
 
13
14
  from ._utils import now_utc as _now_utc
@@ -33,6 +34,9 @@ class SystemPromptPart:
33
34
  part_kind: Literal['system-prompt'] = 'system-prompt'
34
35
  """Part type identifier, this is available on all parts as a discriminator."""
35
36
 
37
+ def otel_event(self) -> Event:
38
+ return Event('gen_ai.system.message', body={'content': self.content, 'role': 'system'})
39
+
36
40
 
37
41
  @dataclass
38
42
  class AudioUrl:
@@ -138,6 +142,14 @@ class UserPromptPart:
138
142
  part_kind: Literal['user-prompt'] = 'user-prompt'
139
143
  """Part type identifier, this is available on all parts as a discriminator."""
140
144
 
145
+ def otel_event(self) -> Event:
146
+ if isinstance(self.content, str):
147
+ content = self.content
148
+ else:
149
+ # TODO figure out what to record for images and audio
150
+ content = [part if isinstance(part, str) else {'kind': part.kind} for part in self.content]
151
+ return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})
152
+
141
153
 
142
154
  tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
143
155
 
@@ -176,6 +188,9 @@ class ToolReturnPart:
176
188
  else:
177
189
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
178
190
 
191
+ def otel_event(self) -> Event:
192
+ return Event('gen_ai.tool.message', body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id})
193
+
179
194
 
180
195
  error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
181
196
 
@@ -224,6 +239,14 @@ class RetryPromptPart:
224
239
  description = f'{len(self.content)} validation errors: {json_errors.decode()}'
225
240
  return f'{description}\n\nFix the errors and try again.'
226
241
 
242
+ def otel_event(self) -> Event:
243
+ if self.tool_name is None:
244
+ return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
245
+ else:
246
+ return Event(
247
+ 'gen_ai.tool.message', body={'content': self.model_response(), 'role': 'tool', 'id': self.tool_call_id}
248
+ )
249
+
227
250
 
228
251
  ModelRequestPart = Annotated[
229
252
  Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
@@ -329,6 +352,36 @@ class ModelResponse:
329
352
  kind: Literal['response'] = 'response'
330
353
  """Message type identifier, this is available on all parts as a discriminator."""
331
354
 
355
+ def otel_events(self) -> list[Event]:
356
+ """Return OpenTelemetry events for the response."""
357
+ result: list[Event] = []
358
+
359
+ def new_event_body():
360
+ new_body: dict[str, Any] = {'role': 'assistant'}
361
+ ev = Event('gen_ai.assistant.message', body=new_body)
362
+ result.append(ev)
363
+ return new_body
364
+
365
+ body = new_event_body()
366
+ for part in self.parts:
367
+ if isinstance(part, ToolCallPart):
368
+ body.setdefault('tool_calls', []).append(
369
+ {
370
+ 'id': part.tool_call_id,
371
+ 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888
372
+ 'function': {
373
+ 'name': part.tool_name,
374
+ 'arguments': part.args,
375
+ },
376
+ }
377
+ )
378
+ elif isinstance(part, TextPart):
379
+ if body.get('content'):
380
+ body = new_event_body()
381
+ body['content'] = part.content
382
+
383
+ return result
384
+
332
385
 
333
386
  ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
334
387
  """Any message sent to or returned by a model."""
@@ -533,9 +586,26 @@ class PartDeltaEvent:
533
586
  """Event type identifier, used as a discriminator."""
534
587
 
535
588
 
589
+ @dataclass
590
+ class FinalResultEvent:
591
+ """An event indicating the response to the current model request matches the result schema."""
592
+
593
+ tool_name: str | None
594
+ """The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
595
+ tool_call_id: str | None
596
+ """The tool call ID, if any, that this result is associated with."""
597
+ event_kind: Literal['final_result'] = 'final_result'
598
+ """Event type identifier, used as a discriminator."""
599
+
600
+
536
601
  ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
537
602
  """An event in the model response stream, either starting a new part or applying a delta to an existing one."""
538
603
 
604
+ AgentStreamEvent = Annotated[
605
+ Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
606
+ ]
607
+ """An event in the agent stream."""
608
+
539
609
 
540
610
  @dataclass
541
611
  class FunctionToolCallEvent:
@@ -558,7 +628,7 @@ class FunctionToolResultEvent:
558
628
 
559
629
  result: ToolReturnPart | RetryPromptPart
560
630
  """The result of the call to the function tool."""
561
- call_id: str
631
+ tool_call_id: str
562
632
  """An ID used to match the result to its original call."""
563
633
  event_kind: Literal['function_tool_result'] = 'function_tool_result'
564
634
  """Event type identifier, used as a discriminator."""
@@ -84,6 +84,8 @@ KnownModelName = Literal[
84
84
  'gpt-4-turbo-2024-04-09',
85
85
  'gpt-4-turbo-preview',
86
86
  'gpt-4-vision-preview',
87
+ 'gpt-4.5-preview',
88
+ 'gpt-4.5-preview-2025-02-27',
87
89
  'gpt-4o',
88
90
  'gpt-4o-2024-05-13',
89
91
  'gpt-4o-2024-08-06',
@@ -138,6 +140,8 @@ KnownModelName = Literal[
138
140
  'openai:gpt-4-turbo-2024-04-09',
139
141
  'openai:gpt-4-turbo-preview',
140
142
  'openai:gpt-4-vision-preview',
143
+ 'openai:gpt-4.5-preview',
144
+ 'openai:gpt-4.5-preview-2025-02-27',
141
145
  'openai:gpt-4o',
142
146
  'openai:gpt-4o-2024-05-13',
143
147
  'openai:gpt-4o-2024-08-06',
@@ -177,6 +177,8 @@ class DeltaToolCall:
177
177
  """Incremental change to the name of the tool."""
178
178
  json_args: str | None = None
179
179
  """Incremental change to the arguments as JSON"""
180
+ tool_call_id: str | None = None
181
+ """Incremental change to the tool call ID."""
180
182
 
181
183
 
182
184
  DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
@@ -224,7 +226,7 @@ class FunctionStreamedResponse(StreamedResponse):
224
226
  vendor_part_id=dtc_index,
225
227
  tool_name=delta_tool_call.name,
226
228
  args=delta_tool_call.json_args,
227
- tool_call_id=None,
229
+ tool_call_id=delta_tool_call.tool_call_id,
228
230
  )
229
231
  if maybe_event is not None:
230
232
  yield maybe_event
@@ -280,7 +282,16 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
280
282
  return 0
281
283
  if isinstance(content, str):
282
284
  return len(re.split(r'[\s",.:]+', content.strip()))
283
- # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
284
285
  else: # pragma: no cover
285
- assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
286
- return 0
286
+ tokens = 0
287
+ for part in content:
288
+ if isinstance(part, str):
289
+ tokens += len(re.split(r'[\s",.:]+', part.strip()))
290
+ # TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
291
+ if isinstance(part, (AudioUrl, ImageUrl)):
292
+ tokens += 0
293
+ elif isinstance(part, BinaryContent):
294
+ tokens += len(part.data)
295
+ else:
296
+ tokens += 0
297
+ return tokens