pydantic-ai-slim 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -17,6 +17,7 @@ from collections.abc import Hashable
17
17
  from dataclasses import dataclass, field, replace
18
18
  from typing import Any, Union
19
19
 
20
+ from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG
20
21
  from pydantic_ai.exceptions import UnexpectedModelBehavior
21
22
  from pydantic_ai.messages import (
22
23
  ModelResponsePart,
@@ -69,9 +70,10 @@ class ModelResponsePartsManager:
69
70
  def handle_text_delta(
70
71
  self,
71
72
  *,
72
- vendor_part_id: Hashable | None,
73
+ vendor_part_id: VendorId | None,
73
74
  content: str,
74
- ) -> ModelResponseStreamEvent:
75
+ extract_think_tags: bool = False,
76
+ ) -> ModelResponseStreamEvent | None:
75
77
  """Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
76
78
 
77
79
  When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
@@ -83,9 +85,12 @@ class ModelResponsePartsManager:
83
85
  of text. If None, a new part will be created unless the latest part is already
84
86
  a TextPart.
85
87
  content: The text content to append to the appropriate TextPart.
88
+ extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.
86
89
 
87
90
  Returns:
88
- A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
91
+ - A `PartStartEvent` if a new part was created.
92
+ - A `PartDeltaEvent` if an existing part was updated.
93
+ - `None` if no new event is emitted (e.g., the first text part was all whitespace).
89
94
 
90
95
  Raises:
91
96
  UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
@@ -104,11 +109,32 @@ class ModelResponsePartsManager:
104
109
  part_index = self._vendor_id_to_part_index.get(vendor_part_id)
105
110
  if part_index is not None:
106
111
  existing_part = self._parts[part_index]
107
- if not isinstance(existing_part, TextPart):
112
+
113
+ if extract_think_tags and isinstance(existing_part, ThinkingPart):
114
+ # We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
115
+ if content == END_THINK_TAG:
116
+ # When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
117
+ self._vendor_id_to_part_index.pop(vendor_part_id)
118
+ return None
119
+ else:
120
+ return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
121
+ elif isinstance(existing_part, TextPart):
122
+ existing_text_part_and_index = existing_part, part_index
123
+ else:
108
124
  raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
109
- existing_text_part_and_index = existing_part, part_index
125
+
126
+ if extract_think_tags and content == START_THINK_TAG:
127
+ # When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
128
+ self._vendor_id_to_part_index.pop(vendor_part_id, None)
129
+ return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
110
130
 
111
131
  if existing_text_part_and_index is None:
132
+ # If the first text delta is all whitespace, don't emit a new part yet.
133
+ # This is a workaround for models that emit `<think>\n</think>\n\n` ahead of tool calls (e.g. Ollama + Qwen3),
134
+ # which we don't want to end up treating as a final result.
135
+ if content.isspace():
136
+ return None
137
+
112
138
  # There is no existing text part that should be updated, so create a new one
113
139
  new_part_index = len(self._parts)
114
140
  part = TextPart(content=content)
pydantic_ai/ag_ui.py CHANGED
@@ -291,12 +291,12 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
291
291
  if isinstance(deps, StateHandler):
292
292
  deps.state = run_input.state
293
293
 
294
- history = _History.from_ag_ui(run_input.messages)
294
+ messages = _messages_from_ag_ui(run_input.messages)
295
295
 
296
296
  async with self.agent.iter(
297
297
  user_prompt=None,
298
298
  output_type=[output_type or self.agent.output_type, DeferredToolCalls],
299
- message_history=history.messages,
299
+ message_history=messages,
300
300
  model=model,
301
301
  deps=deps,
302
302
  model_settings=model_settings,
@@ -305,7 +305,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
305
305
  infer_name=infer_name,
306
306
  toolsets=toolsets,
307
307
  ) as run:
308
- async for event in self._agent_stream(run, history):
308
+ async for event in self._agent_stream(run):
309
309
  yield encoder.encode(event)
310
310
  except _RunError as e:
311
311
  yield encoder.encode(
@@ -327,20 +327,18 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
327
327
  async def _agent_stream(
328
328
  self,
329
329
  run: AgentRun[AgentDepsT, Any],
330
- history: _History,
331
330
  ) -> AsyncGenerator[BaseEvent, None]:
332
331
  """Run the agent streaming responses using AG-UI protocol events.
333
332
 
334
333
  Args:
335
334
  run: The agent run to process.
336
- history: The history of messages and tool calls to use for the run.
337
335
 
338
336
  Yields:
339
337
  AG-UI Server-Sent Events (SSE).
340
338
  """
341
339
  async for node in run:
340
+ stream_ctx = _RequestStreamContext()
342
341
  if isinstance(node, ModelRequestNode):
343
- stream_ctx = _RequestStreamContext()
344
342
  async with node.stream(run.ctx) as request_stream:
345
343
  async for agent_event in request_stream:
346
344
  async for msg in self._handle_model_request_event(stream_ctx, agent_event):
@@ -352,8 +350,8 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
352
350
  elif isinstance(node, CallToolsNode):
353
351
  async with node.stream(run.ctx) as handle_stream:
354
352
  async for event in handle_stream:
355
- if isinstance(event, FunctionToolResultEvent) and isinstance(event.result, ToolReturnPart):
356
- async for msg in self._handle_tool_result_event(event.result, history.prompt_message_id):
353
+ if isinstance(event, FunctionToolResultEvent):
354
+ async for msg in self._handle_tool_result_event(stream_ctx, event):
357
355
  yield msg
358
356
 
359
357
  async def _handle_model_request_event(
@@ -382,19 +380,26 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
382
380
  yield TextMessageStartEvent(
383
381
  message_id=message_id,
384
382
  )
385
- stream_ctx.part_end = TextMessageEndEvent(
386
- message_id=message_id,
387
- )
388
383
  if part.content: # pragma: no branch
389
384
  yield TextMessageContentEvent(
390
385
  message_id=message_id,
391
386
  delta=part.content,
392
387
  )
388
+ stream_ctx.part_end = TextMessageEndEvent(
389
+ message_id=message_id,
390
+ )
393
391
  elif isinstance(part, ToolCallPart): # pragma: no branch
392
+ message_id = stream_ctx.message_id or stream_ctx.new_message_id()
394
393
  yield ToolCallStartEvent(
395
394
  tool_call_id=part.tool_call_id,
396
395
  tool_call_name=part.tool_name,
396
+ parent_message_id=message_id,
397
397
  )
398
+ if part.args:
399
+ yield ToolCallArgsEvent(
400
+ tool_call_id=part.tool_call_id,
401
+ delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
402
+ )
398
403
  stream_ctx.part_end = ToolCallEndEvent(
399
404
  tool_call_id=part.tool_call_id,
400
405
  )
@@ -407,7 +412,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
407
412
  # used to indicate the start of thinking.
408
413
  yield ThinkingTextMessageContentEvent(
409
414
  type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
410
- delta=part.content or '',
415
+ delta=part.content,
411
416
  )
412
417
  stream_ctx.part_end = ThinkingTextMessageEndEvent(
413
418
  type=EventType.THINKING_TEXT_MESSAGE_END,
@@ -435,20 +440,25 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
435
440
 
436
441
  async def _handle_tool_result_event(
437
442
  self,
438
- result: ToolReturnPart,
439
- prompt_message_id: str,
443
+ stream_ctx: _RequestStreamContext,
444
+ event: FunctionToolResultEvent,
440
445
  ) -> AsyncGenerator[BaseEvent, None]:
441
446
  """Convert a tool call result to AG-UI events.
442
447
 
443
448
  Args:
444
- result: The tool call result to process.
445
- prompt_message_id: The message ID of the prompt that initiated the tool call.
449
+ stream_ctx: The request stream context to manage state.
450
+ event: The tool call result event to process.
446
451
 
447
452
  Yields:
448
453
  AG-UI Server-Sent Events (SSE).
449
454
  """
455
+ result = event.result
456
+ if not isinstance(result, ToolReturnPart):
457
+ return
458
+
459
+ message_id = stream_ctx.new_message_id()
450
460
  yield ToolCallResultEvent(
451
- message_id=prompt_message_id,
461
+ message_id=message_id,
452
462
  type=EventType.TOOL_CALL_RESULT,
453
463
  role='tool',
454
464
  tool_call_id=result.tool_call_id,
@@ -468,75 +478,55 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
468
478
  yield item
469
479
 
470
480
 
471
- @dataclass
472
- class _History:
473
- """A simple history representation for AG-UI protocol."""
474
-
475
- prompt_message_id: str # The ID of the last user message.
476
- messages: list[ModelMessage]
477
-
478
- @classmethod
479
- def from_ag_ui(cls, messages: list[Message]) -> _History:
480
- """Convert a AG-UI history to a Pydantic AI one.
481
-
482
- Args:
483
- messages: List of AG-UI messages to convert.
484
-
485
- Returns:
486
- List of Pydantic AI model messages.
487
- """
488
- prompt_message_id = ''
489
- result: list[ModelMessage] = []
490
- tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
491
- for msg in messages:
492
- if isinstance(msg, UserMessage):
493
- prompt_message_id = msg.id
494
- result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
495
- elif isinstance(msg, AssistantMessage):
496
- if msg.tool_calls:
497
- for tool_call in msg.tool_calls:
498
- tool_calls[tool_call.id] = tool_call.function.name
499
-
500
- result.append(
501
- ModelResponse(
502
- parts=[
503
- ToolCallPart(
504
- tool_name=tool_call.function.name,
505
- tool_call_id=tool_call.id,
506
- args=tool_call.function.arguments,
507
- )
508
- for tool_call in msg.tool_calls
509
- ]
510
- )
511
- )
512
-
513
- if msg.content:
514
- result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
515
- elif isinstance(msg, SystemMessage):
516
- result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
517
- elif isinstance(msg, ToolMessage):
518
- tool_name = tool_calls.get(msg.tool_call_id)
519
- if tool_name is None: # pragma: no cover
520
- raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
481
+ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
482
+ """Convert a AG-UI history to a Pydantic AI one."""
483
+ result: list[ModelMessage] = []
484
+ tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
485
+ for msg in messages:
486
+ if isinstance(msg, UserMessage):
487
+ result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
488
+ elif isinstance(msg, AssistantMessage):
489
+ if msg.tool_calls:
490
+ for tool_call in msg.tool_calls:
491
+ tool_calls[tool_call.id] = tool_call.function.name
521
492
 
522
493
  result.append(
523
- ModelRequest(
494
+ ModelResponse(
524
495
  parts=[
525
- ToolReturnPart(
526
- tool_name=tool_name,
527
- content=msg.content,
528
- tool_call_id=msg.tool_call_id,
496
+ ToolCallPart(
497
+ tool_name=tool_call.function.name,
498
+ tool_call_id=tool_call.id,
499
+ args=tool_call.function.arguments,
529
500
  )
501
+ for tool_call in msg.tool_calls
530
502
  ]
531
503
  )
532
504
  )
533
- elif isinstance(msg, DeveloperMessage): # pragma: no branch
534
- result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
535
505
 
536
- return cls(
537
- prompt_message_id=prompt_message_id,
538
- messages=result,
539
- )
506
+ if msg.content:
507
+ result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
508
+ elif isinstance(msg, SystemMessage):
509
+ result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
510
+ elif isinstance(msg, ToolMessage):
511
+ tool_name = tool_calls.get(msg.tool_call_id)
512
+ if tool_name is None: # pragma: no cover
513
+ raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
514
+
515
+ result.append(
516
+ ModelRequest(
517
+ parts=[
518
+ ToolReturnPart(
519
+ tool_name=tool_name,
520
+ content=msg.content,
521
+ tool_call_id=msg.tool_call_id,
522
+ )
523
+ ]
524
+ )
525
+ )
526
+ elif isinstance(msg, DeveloperMessage): # pragma: no branch
527
+ result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
528
+
529
+ return result
540
530
 
541
531
 
542
532
  @runtime_checkable
pydantic_ai/mcp.py CHANGED
@@ -2,11 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  import base64
4
4
  import functools
5
+ import warnings
5
6
  from abc import ABC, abstractmethod
6
7
  from asyncio import Lock
7
8
  from collections.abc import AsyncIterator, Awaitable, Sequence
8
9
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
9
10
  from dataclasses import dataclass, field, replace
11
+ from datetime import timedelta
10
12
  from pathlib import Path
11
13
  from typing import Any, Callable
12
14
 
@@ -37,7 +39,7 @@ except ImportError as _import_error:
37
39
  ) from _import_error
38
40
 
39
41
  # after mcp imports so any import error maps to this file, not _mcp.py
40
- from . import _mcp, exceptions, messages, models
42
+ from . import _mcp, _utils, exceptions, messages, models
41
43
 
42
44
  __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
43
45
 
@@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
59
61
  log_level: mcp_types.LoggingLevel | None = None
60
62
  log_handler: LoggingFnT | None = None
61
63
  timeout: float = 5
64
+ read_timeout: float = 5 * 60
62
65
  process_tool_call: ProcessToolCallback | None = None
63
66
  allow_sampling: bool = True
64
67
  max_retries: int = 1
@@ -148,7 +151,7 @@ class MCPServer(AbstractToolset[Any], ABC):
148
151
  except McpError as e:
149
152
  raise exceptions.ModelRetry(e.error.message)
150
153
 
151
- content = [self._map_tool_result_part(part) for part in result.content]
154
+ content = [await self._map_tool_result_part(part) for part in result.content]
152
155
 
153
156
  if result.isError:
154
157
  text = '\n'.join(str(part) for part in content)
@@ -208,6 +211,7 @@ class MCPServer(AbstractToolset[Any], ABC):
208
211
  write_stream=self._write_stream,
209
212
  sampling_callback=self._sampling_callback if self.allow_sampling else None,
210
213
  logging_callback=self.log_handler,
214
+ read_timeout_seconds=timedelta(seconds=self.read_timeout),
211
215
  )
212
216
  self._client = await self._exit_stack.enter_async_context(client)
213
217
 
@@ -258,8 +262,8 @@ class MCPServer(AbstractToolset[Any], ABC):
258
262
  model=self.sampling_model.model_name,
259
263
  )
260
264
 
261
- def _map_tool_result_part(
262
- self, part: mcp_types.Content
265
+ async def _map_tool_result_part(
266
+ self, part: mcp_types.ContentBlock
263
267
  ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
264
268
  # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
265
269
 
@@ -281,18 +285,29 @@ class MCPServer(AbstractToolset[Any], ABC):
281
285
  ) # pragma: no cover
282
286
  elif isinstance(part, mcp_types.EmbeddedResource):
283
287
  resource = part.resource
284
- if isinstance(resource, mcp_types.TextResourceContents):
285
- return resource.text
286
- elif isinstance(resource, mcp_types.BlobResourceContents):
287
- return messages.BinaryContent(
288
- data=base64.b64decode(resource.blob),
289
- media_type=resource.mimeType or 'application/octet-stream',
290
- )
291
- else:
292
- assert_never(resource)
288
+ return self._get_content(resource)
289
+ elif isinstance(part, mcp_types.ResourceLink):
290
+ resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
291
+ return (
292
+ self._get_content(resource_result.contents[0])
293
+ if len(resource_result.contents) == 1
294
+ else [self._get_content(resource) for resource in resource_result.contents]
295
+ )
293
296
  else:
294
297
  assert_never(part)
295
298
 
299
+ def _get_content(
300
+ self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
301
+ ) -> str | messages.BinaryContent:
302
+ if isinstance(resource, mcp_types.TextResourceContents):
303
+ return resource.text
304
+ elif isinstance(resource, mcp_types.BlobResourceContents):
305
+ return messages.BinaryContent(
306
+ data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
307
+ )
308
+ else:
309
+ assert_never(resource)
310
+
296
311
 
297
312
  @dataclass
298
313
  class MCPServerStdio(MCPServer):
@@ -401,7 +416,7 @@ class MCPServerStdio(MCPServer):
401
416
  return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
402
417
 
403
418
 
404
- @dataclass
419
+ @dataclass(init=False)
405
420
  class _MCPServerHTTP(MCPServer):
406
421
  url: str
407
422
  """The URL of the endpoint on the MCP server."""
@@ -438,10 +453,10 @@ class _MCPServerHTTP(MCPServer):
438
453
  ```
439
454
  """
440
455
 
441
- sse_read_timeout: float = 5 * 60
442
- """Maximum time in seconds to wait for new SSE messages before timing out.
456
+ read_timeout: float = 5 * 60
457
+ """Maximum time in seconds to wait for new messages before timing out.
443
458
 
444
- This timeout applies to the long-lived SSE connection after it's established.
459
+ This timeout applies to the long-lived connection after it's established.
445
460
  If no new messages are received within this time, the connection will be considered stale
446
461
  and may be closed. Defaults to 5 minutes (300 seconds).
447
462
  """
@@ -485,6 +500,51 @@ class _MCPServerHTTP(MCPServer):
485
500
  sampling_model: models.Model | None = None
486
501
  """The model to use for sampling."""
487
502
 
503
+ def __init__(
504
+ self,
505
+ *,
506
+ url: str,
507
+ headers: dict[str, str] | None = None,
508
+ http_client: httpx.AsyncClient | None = None,
509
+ read_timeout: float | None = None,
510
+ tool_prefix: str | None = None,
511
+ log_level: mcp_types.LoggingLevel | None = None,
512
+ log_handler: LoggingFnT | None = None,
513
+ timeout: float = 5,
514
+ process_tool_call: ProcessToolCallback | None = None,
515
+ allow_sampling: bool = True,
516
+ max_retries: int = 1,
517
+ sampling_model: models.Model | None = None,
518
+ **kwargs: Any,
519
+ ):
520
+ # Handle deprecated sse_read_timeout parameter
521
+ if 'sse_read_timeout' in kwargs:
522
+ if read_timeout is not None:
523
+ raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
524
+
525
+ warnings.warn(
526
+ "'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
527
+ )
528
+ read_timeout = kwargs.pop('sse_read_timeout')
529
+
530
+ _utils.validate_empty_kwargs(kwargs)
531
+
532
+ if read_timeout is None:
533
+ read_timeout = 5 * 60
534
+
535
+ self.url = url
536
+ self.headers = headers
537
+ self.http_client = http_client
538
+ self.tool_prefix = tool_prefix
539
+ self.log_level = log_level
540
+ self.log_handler = log_handler
541
+ self.timeout = timeout
542
+ self.process_tool_call = process_tool_call
543
+ self.allow_sampling = allow_sampling
544
+ self.max_retries = max_retries
545
+ self.sampling_model = sampling_model
546
+ self.read_timeout = read_timeout
547
+
488
548
  @property
489
549
  @abstractmethod
490
550
  def _transport_client(
@@ -522,7 +582,7 @@ class _MCPServerHTTP(MCPServer):
522
582
  self._transport_client,
523
583
  url=self.url,
524
584
  timeout=self.timeout,
525
- sse_read_timeout=self.sse_read_timeout,
585
+ sse_read_timeout=self.read_timeout,
526
586
  )
527
587
 
528
588
  if self.http_client is not None:
@@ -549,7 +609,7 @@ class _MCPServerHTTP(MCPServer):
549
609
  return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
550
610
 
551
611
 
552
- @dataclass
612
+ @dataclass(init=False)
553
613
  class MCPServerSSE(_MCPServerHTTP):
554
614
  """An MCP server that connects over streamable HTTP connections.
555
615
 
pydantic_ai/messages.py CHANGED
@@ -85,7 +85,7 @@ class SystemPromptPart:
85
85
  __repr__ = _utils.dataclasses_no_defaults_repr
86
86
 
87
87
 
88
- @dataclass(repr=False)
88
+ @dataclass(init=False, repr=False)
89
89
  class FileUrl(ABC):
90
90
  """Abstract base class for any URL-based file."""
91
91
 
@@ -106,11 +106,29 @@ class FileUrl(ABC):
106
106
  - `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
107
107
  """
108
108
 
109
- @property
109
+ _media_type: str | None = field(init=False, repr=False)
110
+
111
+ def __init__(
112
+ self,
113
+ url: str,
114
+ force_download: bool = False,
115
+ vendor_metadata: dict[str, Any] | None = None,
116
+ media_type: str | None = None,
117
+ ) -> None:
118
+ self.url = url
119
+ self.vendor_metadata = vendor_metadata
120
+ self.force_download = force_download
121
+ self._media_type = media_type
122
+
110
123
  @abstractmethod
111
- def media_type(self) -> str:
124
+ def _infer_media_type(self) -> str:
112
125
  """Return the media type of the file, based on the url."""
113
126
 
127
+ @property
128
+ def media_type(self) -> str:
129
+ """Return the media type of the file, based on the url or the provided `_media_type`."""
130
+ return self._media_type or self._infer_media_type()
131
+
114
132
  @property
115
133
  @abstractmethod
116
134
  def format(self) -> str:
@@ -119,7 +137,7 @@ class FileUrl(ABC):
119
137
  __repr__ = _utils.dataclasses_no_defaults_repr
120
138
 
121
139
 
122
- @dataclass(repr=False)
140
+ @dataclass(init=False, repr=False)
123
141
  class VideoUrl(FileUrl):
124
142
  """A URL to a video."""
125
143
 
@@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
129
147
  kind: Literal['video-url'] = 'video-url'
130
148
  """Type identifier, this is available on all parts as a discriminator."""
131
149
 
132
- @property
133
- def media_type(self) -> VideoMediaType:
150
+ def __init__(
151
+ self,
152
+ url: str,
153
+ force_download: bool = False,
154
+ vendor_metadata: dict[str, Any] | None = None,
155
+ media_type: str | None = None,
156
+ kind: Literal['video-url'] = 'video-url',
157
+ ) -> None:
158
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
159
+ self.kind = kind
160
+
161
+ def _infer_media_type(self) -> VideoMediaType:
134
162
  """Return the media type of the video, based on the url."""
135
163
  if self.url.endswith('.mkv'):
136
164
  return 'video/x-matroska'
@@ -170,7 +198,7 @@ class VideoUrl(FileUrl):
170
198
  return _video_format_lookup[self.media_type]
171
199
 
172
200
 
173
- @dataclass(repr=False)
201
+ @dataclass(init=False, repr=False)
174
202
  class AudioUrl(FileUrl):
175
203
  """A URL to an audio file."""
176
204
 
@@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
180
208
  kind: Literal['audio-url'] = 'audio-url'
181
209
  """Type identifier, this is available on all parts as a discriminator."""
182
210
 
183
- @property
184
- def media_type(self) -> AudioMediaType:
211
+ def __init__(
212
+ self,
213
+ url: str,
214
+ force_download: bool = False,
215
+ vendor_metadata: dict[str, Any] | None = None,
216
+ media_type: str | None = None,
217
+ kind: Literal['audio-url'] = 'audio-url',
218
+ ) -> None:
219
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
220
+ self.kind = kind
221
+
222
+ def _infer_media_type(self) -> AudioMediaType:
185
223
  """Return the media type of the audio file, based on the url.
186
224
 
187
225
  References:
@@ -208,7 +246,7 @@ class AudioUrl(FileUrl):
208
246
  return _audio_format_lookup[self.media_type]
209
247
 
210
248
 
211
- @dataclass(repr=False)
249
+ @dataclass(init=False, repr=False)
212
250
  class ImageUrl(FileUrl):
213
251
  """A URL to an image."""
214
252
 
@@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
218
256
  kind: Literal['image-url'] = 'image-url'
219
257
  """Type identifier, this is available on all parts as a discriminator."""
220
258
 
221
- @property
222
- def media_type(self) -> ImageMediaType:
259
+ def __init__(
260
+ self,
261
+ url: str,
262
+ force_download: bool = False,
263
+ vendor_metadata: dict[str, Any] | None = None,
264
+ media_type: str | None = None,
265
+ kind: Literal['image-url'] = 'image-url',
266
+ ) -> None:
267
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
268
+ self.kind = kind
269
+
270
+ def _infer_media_type(self) -> ImageMediaType:
223
271
  """Return the media type of the image, based on the url."""
224
272
  if self.url.endswith(('.jpg', '.jpeg')):
225
273
  return 'image/jpeg'
@@ -241,7 +289,7 @@ class ImageUrl(FileUrl):
241
289
  return _image_format_lookup[self.media_type]
242
290
 
243
291
 
244
- @dataclass(repr=False)
292
+ @dataclass(init=False, repr=False)
245
293
  class DocumentUrl(FileUrl):
246
294
  """The URL of the document."""
247
295
 
@@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
251
299
  kind: Literal['document-url'] = 'document-url'
252
300
  """Type identifier, this is available on all parts as a discriminator."""
253
301
 
254
- @property
255
- def media_type(self) -> str:
302
+ def __init__(
303
+ self,
304
+ url: str,
305
+ force_download: bool = False,
306
+ vendor_metadata: dict[str, Any] | None = None,
307
+ media_type: str | None = None,
308
+ kind: Literal['document-url'] = 'document-url',
309
+ ) -> None:
310
+ super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
311
+ self.kind = kind
312
+
313
+ def _infer_media_type(self) -> str:
256
314
  """Return the media type of the document, based on the url."""
257
315
  type_, _ = guess_type(self.url)
258
316
  if type_ is None:
@@ -632,7 +690,7 @@ class ThinkingPart:
632
690
 
633
691
  def has_content(self) -> bool:
634
692
  """Return `True` if the thinking content is non-empty."""
635
- return bool(self.content) # pragma: no cover
693
+ return bool(self.content)
636
694
 
637
695
  __repr__ = _utils.dataclasses_no_defaults_repr
638
696