pydantic-ai-slim 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -26,6 +26,7 @@ from .result import ResultData
26
26
  from .settings import ModelSettings, merge_model_settings
27
27
  from .tools import (
28
28
  AgentDeps,
29
+ DocstringFormat,
29
30
  RunContext,
30
31
  Tool,
31
32
  ToolDefinition,
@@ -242,9 +243,10 @@ class Agent(Generic[AgentDeps, ResultData]):
242
243
 
243
244
  agent = Agent('openai:gpt-4o')
244
245
 
245
- result_sync = agent.run_sync('What is the capital of Italy?')
246
- print(result_sync.data)
247
- #> Rome
246
+ async def main():
247
+ result = await agent.run('What is the capital of France?')
248
+ print(result.data)
249
+ #> Paris
248
250
  ```
249
251
 
250
252
  Args:
@@ -382,10 +384,9 @@ class Agent(Generic[AgentDeps, ResultData]):
382
384
 
383
385
  agent = Agent('openai:gpt-4o')
384
386
 
385
- async def main():
386
- result = await agent.run('What is the capital of France?')
387
- print(result.data)
388
- #> Paris
387
+ result_sync = agent.run_sync('What is the capital of Italy?')
388
+ print(result_sync.data)
389
+ #> Rome
389
390
  ```
390
391
 
391
392
  Args:
@@ -535,7 +536,7 @@ class Agent(Generic[AgentDeps, ResultData]):
535
536
  model_req_span.__exit__(None, None, None)
536
537
 
537
538
  with _logfire.span('handle model response') as handle_span:
538
- maybe_final_result = await self._handle_streamed_model_response(
539
+ maybe_final_result = await self._handle_streamed_response(
539
540
  model_response, run_context, result_schema
540
541
  )
541
542
 
@@ -774,6 +775,8 @@ class Agent(Generic[AgentDeps, ResultData]):
774
775
  *,
775
776
  retries: int | None = None,
776
777
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
778
+ docstring_format: DocstringFormat = 'auto',
779
+ require_parameter_descriptions: bool = False,
777
780
  ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
778
781
 
779
782
  def tool(
@@ -783,6 +786,8 @@ class Agent(Generic[AgentDeps, ResultData]):
783
786
  *,
784
787
  retries: int | None = None,
785
788
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
789
+ docstring_format: DocstringFormat = 'auto',
790
+ require_parameter_descriptions: bool = False,
786
791
  ) -> Any:
787
792
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
788
793
 
@@ -820,6 +825,9 @@ class Agent(Generic[AgentDeps, ResultData]):
820
825
  prepare: custom method to prepare the tool definition for each step, return `None` to omit this
821
826
  tool from a given step. This is useful if you want to customise a tool at call time,
822
827
  or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
828
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
829
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
830
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
823
831
  """
824
832
  if func is None:
825
833
 
@@ -827,13 +835,13 @@ class Agent(Generic[AgentDeps, ResultData]):
827
835
  func_: ToolFuncContext[AgentDeps, ToolParams],
828
836
  ) -> ToolFuncContext[AgentDeps, ToolParams]:
829
837
  # noinspection PyTypeChecker
830
- self._register_function(func_, True, retries, prepare)
838
+ self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
831
839
  return func_
832
840
 
833
841
  return tool_decorator
834
842
  else:
835
843
  # noinspection PyTypeChecker
836
- self._register_function(func, True, retries, prepare)
844
+ self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
837
845
  return func
838
846
 
839
847
  @overload
@@ -846,6 +854,8 @@ class Agent(Generic[AgentDeps, ResultData]):
846
854
  *,
847
855
  retries: int | None = None,
848
856
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
857
+ docstring_format: DocstringFormat = 'auto',
858
+ require_parameter_descriptions: bool = False,
849
859
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
850
860
 
851
861
  def tool_plain(
@@ -855,6 +865,8 @@ class Agent(Generic[AgentDeps, ResultData]):
855
865
  *,
856
866
  retries: int | None = None,
857
867
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
868
+ docstring_format: DocstringFormat = 'auto',
869
+ require_parameter_descriptions: bool = False,
858
870
  ) -> Any:
859
871
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
860
872
 
@@ -892,17 +904,22 @@ class Agent(Generic[AgentDeps, ResultData]):
892
904
  prepare: custom method to prepare the tool definition for each step, return `None` to omit this
893
905
  tool from a given step. This is useful if you want to customise a tool at call time,
894
906
  or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
907
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
908
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
909
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
895
910
  """
896
911
  if func is None:
897
912
 
898
913
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
899
914
  # noinspection PyTypeChecker
900
- self._register_function(func_, False, retries, prepare)
915
+ self._register_function(
916
+ func_, False, retries, prepare, docstring_format, require_parameter_descriptions
917
+ )
901
918
  return func_
902
919
 
903
920
  return tool_decorator
904
921
  else:
905
- self._register_function(func, False, retries, prepare)
922
+ self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
906
923
  return func
907
924
 
908
925
  def _register_function(
@@ -911,10 +928,19 @@ class Agent(Generic[AgentDeps, ResultData]):
911
928
  takes_ctx: bool,
912
929
  retries: int | None,
913
930
  prepare: ToolPrepareFunc[AgentDeps] | None,
931
+ docstring_format: DocstringFormat,
932
+ require_parameter_descriptions: bool,
914
933
  ) -> None:
915
934
  """Private utility to register a function as a tool."""
916
935
  retries_ = retries if retries is not None else self._default_retries
917
- tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
936
+ tool = Tool(
937
+ func,
938
+ takes_ctx=takes_ctx,
939
+ max_retries=retries_,
940
+ prepare=prepare,
941
+ docstring_format=docstring_format,
942
+ require_parameter_descriptions=require_parameter_descriptions,
943
+ )
918
944
  self._register_tool(tool)
919
945
 
920
946
  def _register_tool(self, tool: Tool[AgentDeps]) -> None:
@@ -1100,7 +1126,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1100
1126
  final_result: _MarkFinalResult[RunResultData] | None = None
1101
1127
 
1102
1128
  parts: list[_messages.ModelRequestPart] = []
1103
- if result_schema := result_schema:
1129
+ if result_schema is not None:
1104
1130
  if match := result_schema.find_tool(tool_calls):
1105
1131
  call, result_tool = match
1106
1132
  try:
@@ -1179,76 +1205,58 @@ class Agent(Generic[AgentDeps, ResultData]):
1179
1205
  parts.extend(task_results)
1180
1206
  return parts
1181
1207
 
1182
- async def _handle_streamed_model_response(
1208
+ async def _handle_streamed_response(
1183
1209
  self,
1184
- model_response: models.EitherStreamedResponse,
1210
+ streamed_response: models.StreamedResponse,
1185
1211
  run_context: RunContext[AgentDeps],
1186
1212
  result_schema: _result.ResultSchema[RunResultData] | None,
1187
- ) -> (
1188
- _MarkFinalResult[models.EitherStreamedResponse]
1189
- | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
1190
- ):
1213
+ ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
1191
1214
  """Process a streamed response from the model.
1192
1215
 
1193
1216
  Returns:
1194
1217
  Either a final result or a tuple of the model response and the tool responses for the next request.
1195
1218
  If a final result is returned, the conversation should end.
1196
1219
  """
1197
- if isinstance(model_response, models.StreamTextResponse):
1198
- # plain string response
1199
- if self._allow_text_result(result_schema):
1200
- return _MarkFinalResult(model_response, None)
1201
- else:
1202
- self._incr_result_retry(run_context)
1203
- response = _messages.RetryPromptPart(
1204
- content='Plain text responses are not permitted, please call one of the functions instead.',
1205
- )
1206
- # stream the response, so usage is correct
1207
- async for _ in model_response:
1208
- pass
1209
-
1210
- text = ''.join(model_response.get(final=True))
1211
- return _messages.ModelResponse([_messages.TextPart(text)]), [response]
1212
- elif isinstance(model_response, models.StreamStructuredResponse):
1213
- if result_schema is not None:
1214
- # if there's a result schema, iterate over the stream until we find at least one tool
1215
- # NOTE: this means we ignore any other tools called here
1216
- structured_msg = model_response.get()
1217
- while not structured_msg.parts:
1218
- try:
1219
- await model_response.__anext__()
1220
- except StopAsyncIteration:
1221
- break
1222
- structured_msg = model_response.get()
1223
-
1224
- if match := result_schema.find_tool(structured_msg.parts):
1225
- call, _ = match
1226
- return _MarkFinalResult(model_response, call.tool_name)
1227
-
1228
- # the model is calling a tool function, consume the response to get the next message
1229
- async for _ in model_response:
1230
- pass
1231
- model_response_msg = model_response.get()
1232
- if not model_response_msg.parts:
1233
- raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
1234
-
1235
- # we now run all tool functions in parallel
1236
- tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1237
- parts: list[_messages.ModelRequestPart] = []
1238
- for item in model_response_msg.parts:
1239
- if isinstance(item, _messages.ToolCallPart):
1240
- call = item
1241
- if tool := self._function_tools.get(call.tool_name):
1242
- tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1243
- else:
1244
- parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
1220
+ received_text = False
1221
+
1222
+ async for maybe_part_event in streamed_response:
1223
+ if isinstance(maybe_part_event, _messages.PartStartEvent):
1224
+ new_part = maybe_part_event.part
1225
+ if isinstance(new_part, _messages.TextPart):
1226
+ received_text = True
1227
+ if self._allow_text_result(result_schema):
1228
+ return _MarkFinalResult(streamed_response, None)
1229
+ elif isinstance(new_part, _messages.ToolCallPart):
1230
+ if result_schema is not None and (match := result_schema.find_tool([new_part])):
1231
+ call, _ = match
1232
+ return _MarkFinalResult(streamed_response, call.tool_name)
1233
+ else:
1234
+ assert_never(new_part)
1245
1235
 
1246
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1247
- task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1248
- parts.extend(task_results)
1249
- return model_response_msg, parts
1250
- else:
1251
- assert_never(model_response)
1236
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1237
+ parts: list[_messages.ModelRequestPart] = []
1238
+ model_response = streamed_response.get()
1239
+ if not model_response.parts:
1240
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
1241
+ for p in model_response.parts:
1242
+ if isinstance(p, _messages.ToolCallPart):
1243
+ if tool := self._function_tools.get(p.tool_name):
1244
+ tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
1245
+ else:
1246
+ parts.append(self._unknown_tool(p.tool_name, run_context, result_schema))
1247
+
1248
+ if received_text and not tasks and not parts:
1249
+ # Can only get here if self._allow_text_result returns `False` for the provided result_schema
1250
+ self._incr_result_retry(run_context)
1251
+ model_response = _messages.RetryPromptPart(
1252
+ content='Plain text responses are not permitted, please call one of the functions instead.',
1253
+ )
1254
+ return streamed_response.get(), [model_response]
1255
+
1256
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1257
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1258
+ parts.extend(task_results)
1259
+ return model_response, parts
1252
1260
 
1253
1261
  async def _validate_result(
1254
1262
  self,
@@ -37,7 +37,8 @@ def format_as_xml(
37
37
  none_str: String to use for `None` values.
38
38
  indent: Indentation string to use for pretty printing.
39
39
 
40
- Returns: XML representation of the object.
40
+ Returns:
41
+ XML representation of the object.
41
42
 
42
43
  Example:
43
44
  ```python {title="format_as_xml_example.py" lint="skip"}
pydantic_ai/messages.py CHANGED
@@ -1,14 +1,15 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass, field, replace
4
4
  from datetime import datetime
5
- from typing import Annotated, Any, Literal, Union, cast
5
+ from typing import Annotated, Any, Literal, Union, cast, overload
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
9
  from typing_extensions import Self, assert_never
10
10
 
11
11
  from ._utils import now_utc as _now_utc
12
+ from .exceptions import UnexpectedModelBehavior
12
13
 
13
14
 
14
15
  @dataclass
@@ -72,12 +73,14 @@ class ToolReturnPart:
72
73
  """Part type identifier, this is available on all parts as a discriminator."""
73
74
 
74
75
  def model_response_str(self) -> str:
76
+ """Return a string representation of the content for the model."""
75
77
  if isinstance(self.content, str):
76
78
  return self.content
77
79
  else:
78
80
  return tool_return_ta.dump_json(self.content).decode()
79
81
 
80
82
  def model_response_object(self) -> dict[str, Any]:
83
+ """Return a dictionary representation of the content, wrapping non-dict types appropriately."""
81
84
  # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
82
85
  if isinstance(self.content, dict):
83
86
  return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
@@ -124,6 +127,7 @@ class RetryPromptPart:
124
127
  """Part type identifier, this is available on all parts as a discriminator."""
125
128
 
126
129
  def model_response(self) -> str:
130
+ """Return a string message describing why the retry is requested."""
127
131
  if isinstance(self.content, str):
128
132
  description = self.content
129
133
  else:
@@ -159,6 +163,10 @@ class TextPart:
159
163
  part_kind: Literal['text'] = 'text'
160
164
  """Part type identifier, this is available on all parts as a discriminator."""
161
165
 
166
+ def has_content(self) -> bool:
167
+ """Return `True` if the text content is non-empty."""
168
+ return bool(self.content)
169
+
162
170
 
163
171
  @dataclass
164
172
  class ArgsJson:
@@ -197,7 +205,7 @@ class ToolCallPart:
197
205
 
198
206
  @classmethod
199
207
  def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
200
- """Create a `ToolCallPart` from raw arguments."""
208
+ """Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
201
209
  if isinstance(args, str):
202
210
  return cls(tool_name, ArgsJson(args), tool_call_id)
203
211
  elif isinstance(args, dict):
@@ -226,6 +234,7 @@ class ToolCallPart:
226
234
  return pydantic_core.to_json(self.args.args_dict).decode()
227
235
 
228
236
  def has_content(self) -> bool:
237
+ """Return `True` if the arguments contain any data."""
229
238
  if isinstance(self.args, ArgsDict):
230
239
  return any(self.args.args_dict.values())
231
240
  else:
@@ -254,17 +263,217 @@ class ModelResponse:
254
263
 
255
264
  @classmethod
256
265
  def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
257
- return cls([TextPart(content)], timestamp=timestamp or _now_utc())
266
+ """Create a `ModelResponse` containing a single `TextPart`."""
267
+ return cls([TextPart(content=content)], timestamp=timestamp or _now_utc())
258
268
 
259
269
  @classmethod
260
270
  def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
271
+ """Create a `ModelResponse` containing a single `ToolCallPart`."""
261
272
  return cls([tool_call])
262
273
 
263
274
 
264
- ModelMessage = Union[ModelRequest, ModelResponse]
265
- """Any message send to or returned by a model."""
275
+ ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
276
+ """Any message sent to or returned by a model."""
266
277
 
267
- ModelMessagesTypeAdapter = pydantic.TypeAdapter(
268
- list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
269
- )
278
+ ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True))
270
279
  """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
280
+
281
+
282
+ @dataclass
283
+ class TextPartDelta:
284
+ """A partial update (delta) for a `TextPart` to append new text content."""
285
+
286
+ content_delta: str
287
+ """The incremental text content to add to the existing `TextPart` content."""
288
+
289
+ part_delta_kind: Literal['text'] = 'text'
290
+ """Part delta type identifier, used as a discriminator."""
291
+
292
+ def apply(self, part: ModelResponsePart) -> TextPart:
293
+ """Apply this text delta to an existing `TextPart`.
294
+
295
+ Args:
296
+ part: The existing model response part, which must be a `TextPart`.
297
+
298
+ Returns:
299
+ A new `TextPart` with updated text content.
300
+
301
+ Raises:
302
+ ValueError: If `part` is not a `TextPart`.
303
+ """
304
+ if not isinstance(part, TextPart):
305
+ raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
306
+ return replace(part, content=part.content + self.content_delta)
307
+
308
+
309
+ @dataclass
310
+ class ToolCallPartDelta:
311
+ """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
312
+
313
+ tool_name_delta: str | None = None
314
+ """Incremental text to add to the existing tool name, if any."""
315
+
316
+ args_delta: str | dict[str, Any] | None = None
317
+ """Incremental data to add to the tool arguments.
318
+
319
+ If this is a string, it will be appended to existing JSON arguments.
320
+ If this is a dict, it will be merged with existing dict arguments.
321
+ """
322
+
323
+ tool_call_id: str | None = None
324
+ """Optional tool call identifier, this is used by some models including OpenAI.
325
+
326
+ Note this is never treated as a delta — it can replace None, but otherwise if a
327
+ non-matching value is provided an error will be raised."""
328
+
329
+ part_delta_kind: Literal['tool_call'] = 'tool_call'
330
+ """Part delta type identifier, used as a discriminator."""
331
+
332
+ def as_part(self) -> ToolCallPart | None:
333
+ """Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
334
+
335
+ Returns:
336
+ A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
337
+ """
338
+ if self.tool_name_delta is None or self.args_delta is None:
339
+ return None
340
+
341
+ return ToolCallPart.from_raw_args(
342
+ self.tool_name_delta,
343
+ self.args_delta,
344
+ self.tool_call_id,
345
+ )
346
+
347
+ @overload
348
+ def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
349
+
350
+ @overload
351
+ def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ...
352
+
353
+ def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
354
+ """Apply this delta to a part or delta, returning a new part or delta with the changes applied.
355
+
356
+ Args:
357
+ part: The existing model response part or delta to update.
358
+
359
+ Returns:
360
+ Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
361
+
362
+ Raises:
363
+ ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
364
+ UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
365
+ """
366
+ if isinstance(part, ToolCallPart):
367
+ return self._apply_to_part(part)
368
+
369
+ if isinstance(part, ToolCallPartDelta):
370
+ return self._apply_to_delta(part)
371
+
372
+ raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
373
+
374
+ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
375
+ """Internal helper to apply this delta to another delta."""
376
+ if self.tool_name_delta:
377
+ # Append incremental text to the existing tool_name_delta
378
+ updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta
379
+ delta = replace(delta, tool_name_delta=updated_tool_name_delta)
380
+
381
+ if isinstance(self.args_delta, str):
382
+ if isinstance(delta.args_delta, dict):
383
+ raise UnexpectedModelBehavior(
384
+ f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})'
385
+ )
386
+ updated_args_delta = (delta.args_delta or '') + self.args_delta
387
+ delta = replace(delta, args_delta=updated_args_delta)
388
+ elif isinstance(self.args_delta, dict):
389
+ if isinstance(delta.args_delta, str):
390
+ raise UnexpectedModelBehavior(
391
+ f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})'
392
+ )
393
+ updated_args_delta = {**(delta.args_delta or {}), **self.args_delta}
394
+ delta = replace(delta, args_delta=updated_args_delta)
395
+
396
+ if self.tool_call_id:
397
+ # Set the tool_call_id if it wasn't present, otherwise error if it has changed
398
+ if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
399
+ raise UnexpectedModelBehavior(
400
+ f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
401
+ )
402
+ delta = replace(delta, tool_call_id=self.tool_call_id)
403
+
404
+ # If we now have enough data to create a full ToolCallPart, do so
405
+ if delta.tool_name_delta is not None and delta.args_delta is not None:
406
+ return ToolCallPart.from_raw_args(
407
+ delta.tool_name_delta,
408
+ delta.args_delta,
409
+ delta.tool_call_id,
410
+ )
411
+
412
+ return delta
413
+
414
+ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
415
+ """Internal helper to apply this delta directly to a `ToolCallPart`."""
416
+ if self.tool_name_delta:
417
+ # Append incremental text to the existing tool_name
418
+ tool_name = part.tool_name + self.tool_name_delta
419
+ part = replace(part, tool_name=tool_name)
420
+
421
+ if isinstance(self.args_delta, str):
422
+ if not isinstance(part.args, ArgsJson):
423
+ raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
424
+ updated_json = part.args.args_json + self.args_delta
425
+ part = replace(part, args=ArgsJson(updated_json))
426
+ elif isinstance(self.args_delta, dict):
427
+ if not isinstance(part.args, ArgsDict):
428
+ raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
429
+ updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
430
+ part = replace(part, args=ArgsDict(updated_dict))
431
+
432
+ if self.tool_call_id:
433
+ # Replace the tool_call_id entirely if given
434
+ if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
435
+ raise UnexpectedModelBehavior(
436
+ f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
437
+ )
438
+ part = replace(part, tool_call_id=self.tool_call_id)
439
+ return part
440
+
441
+
442
+ ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
443
+ """A partial update (delta) for any model response part."""
444
+
445
+
446
+ @dataclass
447
+ class PartStartEvent:
448
+ """An event indicating that a new part has started.
449
+
450
+ If multiple `PartStartEvent`s are received with the same index,
451
+ the new one should fully replace the old one.
452
+ """
453
+
454
+ index: int
455
+ """The index of the part within the overall response parts list."""
456
+
457
+ part: ModelResponsePart
458
+ """The newly started `ModelResponsePart`."""
459
+
460
+ event_kind: Literal['part_start'] = 'part_start'
461
+ """Event type identifier, used as a discriminator."""
462
+
463
+
464
+ @dataclass
465
+ class PartDeltaEvent:
466
+ """An event indicating a delta update for an existing part."""
467
+
468
+ index: int
469
+ """The index of the part within the overall response parts list."""
470
+
471
+ delta: ModelResponsePartDelta
472
+ """The delta to apply to the specified part."""
473
+
474
+ event_kind: Literal['part_delta'] = 'part_delta'
475
+ """Event type identifier, used as a discriminator."""
476
+
477
+
478
+ ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
479
+ """An event in the model response stream, either starting a new part or applying a delta to an existing one."""