pydantic-ai-slim 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +31 -5
- pydantic_ai/ag_ui.py +68 -78
- pydantic_ai/mcp.py +79 -19
- pydantic_ai/messages.py +74 -16
- pydantic_ai/models/__init__.py +11 -0
- pydantic_ai/models/anthropic.py +11 -3
- pydantic_ai/models/bedrock.py +4 -2
- pydantic_ai/models/cohere.py +6 -6
- pydantic_ai/models/function.py +4 -2
- pydantic_ai/models/gemini.py +5 -1
- pydantic_ai/models/google.py +9 -2
- pydantic_ai/models/groq.py +6 -2
- pydantic_ai/models/huggingface.py +6 -2
- pydantic_ai/models/mistral.py +3 -1
- pydantic_ai/models/openai.py +34 -7
- pydantic_ai/models/test.py +6 -2
- pydantic_ai/profiles/openai.py +8 -0
- pydantic_ai/providers/__init__.py +8 -0
- pydantic_ai/providers/moonshotai.py +97 -0
- pydantic_ai/providers/vercel.py +107 -0
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.7.dist-info}/METADATA +6 -6
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.7.dist-info}/RECORD +25 -23
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.7.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.7.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.6.dist-info → pydantic_ai_slim-0.4.7.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/_parts_manager.py
CHANGED
|
@@ -17,6 +17,7 @@ from collections.abc import Hashable
|
|
|
17
17
|
from dataclasses import dataclass, field, replace
|
|
18
18
|
from typing import Any, Union
|
|
19
19
|
|
|
20
|
+
from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG
|
|
20
21
|
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
21
22
|
from pydantic_ai.messages import (
|
|
22
23
|
ModelResponsePart,
|
|
@@ -69,9 +70,10 @@ class ModelResponsePartsManager:
|
|
|
69
70
|
def handle_text_delta(
|
|
70
71
|
self,
|
|
71
72
|
*,
|
|
72
|
-
vendor_part_id:
|
|
73
|
+
vendor_part_id: VendorId | None,
|
|
73
74
|
content: str,
|
|
74
|
-
|
|
75
|
+
extract_think_tags: bool = False,
|
|
76
|
+
) -> ModelResponseStreamEvent | None:
|
|
75
77
|
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.
|
|
76
78
|
|
|
77
79
|
When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
|
|
@@ -83,9 +85,12 @@ class ModelResponsePartsManager:
|
|
|
83
85
|
of text. If None, a new part will be created unless the latest part is already
|
|
84
86
|
a TextPart.
|
|
85
87
|
content: The text content to append to the appropriate TextPart.
|
|
88
|
+
extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.
|
|
86
89
|
|
|
87
90
|
Returns:
|
|
88
|
-
A `PartStartEvent` if a new part was created
|
|
91
|
+
- A `PartStartEvent` if a new part was created.
|
|
92
|
+
- A `PartDeltaEvent` if an existing part was updated.
|
|
93
|
+
- `None` if no new event is emitted (e.g., the first text part was all whitespace).
|
|
89
94
|
|
|
90
95
|
Raises:
|
|
91
96
|
UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
|
|
@@ -104,11 +109,32 @@ class ModelResponsePartsManager:
|
|
|
104
109
|
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
|
|
105
110
|
if part_index is not None:
|
|
106
111
|
existing_part = self._parts[part_index]
|
|
107
|
-
|
|
112
|
+
|
|
113
|
+
if extract_think_tags and isinstance(existing_part, ThinkingPart):
|
|
114
|
+
# We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
|
|
115
|
+
if content == END_THINK_TAG:
|
|
116
|
+
# When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
|
|
117
|
+
self._vendor_id_to_part_index.pop(vendor_part_id)
|
|
118
|
+
return None
|
|
119
|
+
else:
|
|
120
|
+
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
|
|
121
|
+
elif isinstance(existing_part, TextPart):
|
|
122
|
+
existing_text_part_and_index = existing_part, part_index
|
|
123
|
+
else:
|
|
108
124
|
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
|
|
109
|
-
|
|
125
|
+
|
|
126
|
+
if extract_think_tags and content == START_THINK_TAG:
|
|
127
|
+
# When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
|
|
128
|
+
self._vendor_id_to_part_index.pop(vendor_part_id, None)
|
|
129
|
+
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
|
|
110
130
|
|
|
111
131
|
if existing_text_part_and_index is None:
|
|
132
|
+
# If the first text delta is all whitespace, don't emit a new part yet.
|
|
133
|
+
# This is a workaround for models that emit `<think>\n</think>\n\n` ahead of tool calls (e.g. Ollama + Qwen3),
|
|
134
|
+
# which we don't want to end up treating as a final result.
|
|
135
|
+
if content.isspace():
|
|
136
|
+
return None
|
|
137
|
+
|
|
112
138
|
# There is no existing text part that should be updated, so create a new one
|
|
113
139
|
new_part_index = len(self._parts)
|
|
114
140
|
part = TextPart(content=content)
|
pydantic_ai/ag_ui.py
CHANGED
|
@@ -291,12 +291,12 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
291
291
|
if isinstance(deps, StateHandler):
|
|
292
292
|
deps.state = run_input.state
|
|
293
293
|
|
|
294
|
-
|
|
294
|
+
messages = _messages_from_ag_ui(run_input.messages)
|
|
295
295
|
|
|
296
296
|
async with self.agent.iter(
|
|
297
297
|
user_prompt=None,
|
|
298
298
|
output_type=[output_type or self.agent.output_type, DeferredToolCalls],
|
|
299
|
-
message_history=
|
|
299
|
+
message_history=messages,
|
|
300
300
|
model=model,
|
|
301
301
|
deps=deps,
|
|
302
302
|
model_settings=model_settings,
|
|
@@ -305,7 +305,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
305
305
|
infer_name=infer_name,
|
|
306
306
|
toolsets=toolsets,
|
|
307
307
|
) as run:
|
|
308
|
-
async for event in self._agent_stream(run
|
|
308
|
+
async for event in self._agent_stream(run):
|
|
309
309
|
yield encoder.encode(event)
|
|
310
310
|
except _RunError as e:
|
|
311
311
|
yield encoder.encode(
|
|
@@ -327,20 +327,18 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
327
327
|
async def _agent_stream(
|
|
328
328
|
self,
|
|
329
329
|
run: AgentRun[AgentDepsT, Any],
|
|
330
|
-
history: _History,
|
|
331
330
|
) -> AsyncGenerator[BaseEvent, None]:
|
|
332
331
|
"""Run the agent streaming responses using AG-UI protocol events.
|
|
333
332
|
|
|
334
333
|
Args:
|
|
335
334
|
run: The agent run to process.
|
|
336
|
-
history: The history of messages and tool calls to use for the run.
|
|
337
335
|
|
|
338
336
|
Yields:
|
|
339
337
|
AG-UI Server-Sent Events (SSE).
|
|
340
338
|
"""
|
|
341
339
|
async for node in run:
|
|
340
|
+
stream_ctx = _RequestStreamContext()
|
|
342
341
|
if isinstance(node, ModelRequestNode):
|
|
343
|
-
stream_ctx = _RequestStreamContext()
|
|
344
342
|
async with node.stream(run.ctx) as request_stream:
|
|
345
343
|
async for agent_event in request_stream:
|
|
346
344
|
async for msg in self._handle_model_request_event(stream_ctx, agent_event):
|
|
@@ -352,8 +350,8 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
352
350
|
elif isinstance(node, CallToolsNode):
|
|
353
351
|
async with node.stream(run.ctx) as handle_stream:
|
|
354
352
|
async for event in handle_stream:
|
|
355
|
-
if isinstance(event, FunctionToolResultEvent)
|
|
356
|
-
async for msg in self._handle_tool_result_event(
|
|
353
|
+
if isinstance(event, FunctionToolResultEvent):
|
|
354
|
+
async for msg in self._handle_tool_result_event(stream_ctx, event):
|
|
357
355
|
yield msg
|
|
358
356
|
|
|
359
357
|
async def _handle_model_request_event(
|
|
@@ -382,19 +380,26 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
382
380
|
yield TextMessageStartEvent(
|
|
383
381
|
message_id=message_id,
|
|
384
382
|
)
|
|
385
|
-
stream_ctx.part_end = TextMessageEndEvent(
|
|
386
|
-
message_id=message_id,
|
|
387
|
-
)
|
|
388
383
|
if part.content: # pragma: no branch
|
|
389
384
|
yield TextMessageContentEvent(
|
|
390
385
|
message_id=message_id,
|
|
391
386
|
delta=part.content,
|
|
392
387
|
)
|
|
388
|
+
stream_ctx.part_end = TextMessageEndEvent(
|
|
389
|
+
message_id=message_id,
|
|
390
|
+
)
|
|
393
391
|
elif isinstance(part, ToolCallPart): # pragma: no branch
|
|
392
|
+
message_id = stream_ctx.message_id or stream_ctx.new_message_id()
|
|
394
393
|
yield ToolCallStartEvent(
|
|
395
394
|
tool_call_id=part.tool_call_id,
|
|
396
395
|
tool_call_name=part.tool_name,
|
|
396
|
+
parent_message_id=message_id,
|
|
397
397
|
)
|
|
398
|
+
if part.args:
|
|
399
|
+
yield ToolCallArgsEvent(
|
|
400
|
+
tool_call_id=part.tool_call_id,
|
|
401
|
+
delta=part.args if isinstance(part.args, str) else json.dumps(part.args),
|
|
402
|
+
)
|
|
398
403
|
stream_ctx.part_end = ToolCallEndEvent(
|
|
399
404
|
tool_call_id=part.tool_call_id,
|
|
400
405
|
)
|
|
@@ -407,7 +412,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
407
412
|
# used to indicate the start of thinking.
|
|
408
413
|
yield ThinkingTextMessageContentEvent(
|
|
409
414
|
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
|
|
410
|
-
delta=part.content
|
|
415
|
+
delta=part.content,
|
|
411
416
|
)
|
|
412
417
|
stream_ctx.part_end = ThinkingTextMessageEndEvent(
|
|
413
418
|
type=EventType.THINKING_TEXT_MESSAGE_END,
|
|
@@ -435,20 +440,25 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
435
440
|
|
|
436
441
|
async def _handle_tool_result_event(
|
|
437
442
|
self,
|
|
438
|
-
|
|
439
|
-
|
|
443
|
+
stream_ctx: _RequestStreamContext,
|
|
444
|
+
event: FunctionToolResultEvent,
|
|
440
445
|
) -> AsyncGenerator[BaseEvent, None]:
|
|
441
446
|
"""Convert a tool call result to AG-UI events.
|
|
442
447
|
|
|
443
448
|
Args:
|
|
444
|
-
|
|
445
|
-
|
|
449
|
+
stream_ctx: The request stream context to manage state.
|
|
450
|
+
event: The tool call result event to process.
|
|
446
451
|
|
|
447
452
|
Yields:
|
|
448
453
|
AG-UI Server-Sent Events (SSE).
|
|
449
454
|
"""
|
|
455
|
+
result = event.result
|
|
456
|
+
if not isinstance(result, ToolReturnPart):
|
|
457
|
+
return
|
|
458
|
+
|
|
459
|
+
message_id = stream_ctx.new_message_id()
|
|
450
460
|
yield ToolCallResultEvent(
|
|
451
|
-
message_id=
|
|
461
|
+
message_id=message_id,
|
|
452
462
|
type=EventType.TOOL_CALL_RESULT,
|
|
453
463
|
role='tool',
|
|
454
464
|
tool_call_id=result.tool_call_id,
|
|
@@ -468,75 +478,55 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
|
|
|
468
478
|
yield item
|
|
469
479
|
|
|
470
480
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
Args:
|
|
483
|
-
messages: List of AG-UI messages to convert.
|
|
484
|
-
|
|
485
|
-
Returns:
|
|
486
|
-
List of Pydantic AI model messages.
|
|
487
|
-
"""
|
|
488
|
-
prompt_message_id = ''
|
|
489
|
-
result: list[ModelMessage] = []
|
|
490
|
-
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
|
|
491
|
-
for msg in messages:
|
|
492
|
-
if isinstance(msg, UserMessage):
|
|
493
|
-
prompt_message_id = msg.id
|
|
494
|
-
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
|
|
495
|
-
elif isinstance(msg, AssistantMessage):
|
|
496
|
-
if msg.tool_calls:
|
|
497
|
-
for tool_call in msg.tool_calls:
|
|
498
|
-
tool_calls[tool_call.id] = tool_call.function.name
|
|
499
|
-
|
|
500
|
-
result.append(
|
|
501
|
-
ModelResponse(
|
|
502
|
-
parts=[
|
|
503
|
-
ToolCallPart(
|
|
504
|
-
tool_name=tool_call.function.name,
|
|
505
|
-
tool_call_id=tool_call.id,
|
|
506
|
-
args=tool_call.function.arguments,
|
|
507
|
-
)
|
|
508
|
-
for tool_call in msg.tool_calls
|
|
509
|
-
]
|
|
510
|
-
)
|
|
511
|
-
)
|
|
512
|
-
|
|
513
|
-
if msg.content:
|
|
514
|
-
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
|
|
515
|
-
elif isinstance(msg, SystemMessage):
|
|
516
|
-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
517
|
-
elif isinstance(msg, ToolMessage):
|
|
518
|
-
tool_name = tool_calls.get(msg.tool_call_id)
|
|
519
|
-
if tool_name is None: # pragma: no cover
|
|
520
|
-
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
|
|
481
|
+
def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
|
|
482
|
+
"""Convert a AG-UI history to a Pydantic AI one."""
|
|
483
|
+
result: list[ModelMessage] = []
|
|
484
|
+
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
|
|
485
|
+
for msg in messages:
|
|
486
|
+
if isinstance(msg, UserMessage):
|
|
487
|
+
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
|
|
488
|
+
elif isinstance(msg, AssistantMessage):
|
|
489
|
+
if msg.tool_calls:
|
|
490
|
+
for tool_call in msg.tool_calls:
|
|
491
|
+
tool_calls[tool_call.id] = tool_call.function.name
|
|
521
492
|
|
|
522
493
|
result.append(
|
|
523
|
-
|
|
494
|
+
ModelResponse(
|
|
524
495
|
parts=[
|
|
525
|
-
|
|
526
|
-
tool_name=
|
|
527
|
-
|
|
528
|
-
|
|
496
|
+
ToolCallPart(
|
|
497
|
+
tool_name=tool_call.function.name,
|
|
498
|
+
tool_call_id=tool_call.id,
|
|
499
|
+
args=tool_call.function.arguments,
|
|
529
500
|
)
|
|
501
|
+
for tool_call in msg.tool_calls
|
|
530
502
|
]
|
|
531
503
|
)
|
|
532
504
|
)
|
|
533
|
-
elif isinstance(msg, DeveloperMessage): # pragma: no branch
|
|
534
|
-
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
535
505
|
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
506
|
+
if msg.content:
|
|
507
|
+
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
|
|
508
|
+
elif isinstance(msg, SystemMessage):
|
|
509
|
+
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
510
|
+
elif isinstance(msg, ToolMessage):
|
|
511
|
+
tool_name = tool_calls.get(msg.tool_call_id)
|
|
512
|
+
if tool_name is None: # pragma: no cover
|
|
513
|
+
raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
|
|
514
|
+
|
|
515
|
+
result.append(
|
|
516
|
+
ModelRequest(
|
|
517
|
+
parts=[
|
|
518
|
+
ToolReturnPart(
|
|
519
|
+
tool_name=tool_name,
|
|
520
|
+
content=msg.content,
|
|
521
|
+
tool_call_id=msg.tool_call_id,
|
|
522
|
+
)
|
|
523
|
+
]
|
|
524
|
+
)
|
|
525
|
+
)
|
|
526
|
+
elif isinstance(msg, DeveloperMessage): # pragma: no branch
|
|
527
|
+
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
|
|
528
|
+
|
|
529
|
+
return result
|
|
540
530
|
|
|
541
531
|
|
|
542
532
|
@runtime_checkable
|
pydantic_ai/mcp.py
CHANGED
|
@@ -2,11 +2,13 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
4
|
import functools
|
|
5
|
+
import warnings
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from asyncio import Lock
|
|
7
8
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
8
9
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
9
10
|
from dataclasses import dataclass, field, replace
|
|
11
|
+
from datetime import timedelta
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Any, Callable
|
|
12
14
|
|
|
@@ -37,7 +39,7 @@ except ImportError as _import_error:
|
|
|
37
39
|
) from _import_error
|
|
38
40
|
|
|
39
41
|
# after mcp imports so any import error maps to this file, not _mcp.py
|
|
40
|
-
from . import _mcp, exceptions, messages, models
|
|
42
|
+
from . import _mcp, _utils, exceptions, messages, models
|
|
41
43
|
|
|
42
44
|
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
43
45
|
|
|
@@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
59
61
|
log_level: mcp_types.LoggingLevel | None = None
|
|
60
62
|
log_handler: LoggingFnT | None = None
|
|
61
63
|
timeout: float = 5
|
|
64
|
+
read_timeout: float = 5 * 60
|
|
62
65
|
process_tool_call: ProcessToolCallback | None = None
|
|
63
66
|
allow_sampling: bool = True
|
|
64
67
|
max_retries: int = 1
|
|
@@ -148,7 +151,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
148
151
|
except McpError as e:
|
|
149
152
|
raise exceptions.ModelRetry(e.error.message)
|
|
150
153
|
|
|
151
|
-
content = [self._map_tool_result_part(part) for part in result.content]
|
|
154
|
+
content = [await self._map_tool_result_part(part) for part in result.content]
|
|
152
155
|
|
|
153
156
|
if result.isError:
|
|
154
157
|
text = '\n'.join(str(part) for part in content)
|
|
@@ -208,6 +211,7 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
208
211
|
write_stream=self._write_stream,
|
|
209
212
|
sampling_callback=self._sampling_callback if self.allow_sampling else None,
|
|
210
213
|
logging_callback=self.log_handler,
|
|
214
|
+
read_timeout_seconds=timedelta(seconds=self.read_timeout),
|
|
211
215
|
)
|
|
212
216
|
self._client = await self._exit_stack.enter_async_context(client)
|
|
213
217
|
|
|
@@ -258,8 +262,8 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
258
262
|
model=self.sampling_model.model_name,
|
|
259
263
|
)
|
|
260
264
|
|
|
261
|
-
def _map_tool_result_part(
|
|
262
|
-
self, part: mcp_types.
|
|
265
|
+
async def _map_tool_result_part(
|
|
266
|
+
self, part: mcp_types.ContentBlock
|
|
263
267
|
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
|
|
264
268
|
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
|
|
265
269
|
|
|
@@ -281,18 +285,29 @@ class MCPServer(AbstractToolset[Any], ABC):
|
|
|
281
285
|
) # pragma: no cover
|
|
282
286
|
elif isinstance(part, mcp_types.EmbeddedResource):
|
|
283
287
|
resource = part.resource
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
assert_never(resource)
|
|
288
|
+
return self._get_content(resource)
|
|
289
|
+
elif isinstance(part, mcp_types.ResourceLink):
|
|
290
|
+
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
|
|
291
|
+
return (
|
|
292
|
+
self._get_content(resource_result.contents[0])
|
|
293
|
+
if len(resource_result.contents) == 1
|
|
294
|
+
else [self._get_content(resource) for resource in resource_result.contents]
|
|
295
|
+
)
|
|
293
296
|
else:
|
|
294
297
|
assert_never(part)
|
|
295
298
|
|
|
299
|
+
def _get_content(
|
|
300
|
+
self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
|
|
301
|
+
) -> str | messages.BinaryContent:
|
|
302
|
+
if isinstance(resource, mcp_types.TextResourceContents):
|
|
303
|
+
return resource.text
|
|
304
|
+
elif isinstance(resource, mcp_types.BlobResourceContents):
|
|
305
|
+
return messages.BinaryContent(
|
|
306
|
+
data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
assert_never(resource)
|
|
310
|
+
|
|
296
311
|
|
|
297
312
|
@dataclass
|
|
298
313
|
class MCPServerStdio(MCPServer):
|
|
@@ -401,7 +416,7 @@ class MCPServerStdio(MCPServer):
|
|
|
401
416
|
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
|
|
402
417
|
|
|
403
418
|
|
|
404
|
-
@dataclass
|
|
419
|
+
@dataclass(init=False)
|
|
405
420
|
class _MCPServerHTTP(MCPServer):
|
|
406
421
|
url: str
|
|
407
422
|
"""The URL of the endpoint on the MCP server."""
|
|
@@ -438,10 +453,10 @@ class _MCPServerHTTP(MCPServer):
|
|
|
438
453
|
```
|
|
439
454
|
"""
|
|
440
455
|
|
|
441
|
-
|
|
442
|
-
"""Maximum time in seconds to wait for new
|
|
456
|
+
read_timeout: float = 5 * 60
|
|
457
|
+
"""Maximum time in seconds to wait for new messages before timing out.
|
|
443
458
|
|
|
444
|
-
This timeout applies to the long-lived
|
|
459
|
+
This timeout applies to the long-lived connection after it's established.
|
|
445
460
|
If no new messages are received within this time, the connection will be considered stale
|
|
446
461
|
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
447
462
|
"""
|
|
@@ -485,6 +500,51 @@ class _MCPServerHTTP(MCPServer):
|
|
|
485
500
|
sampling_model: models.Model | None = None
|
|
486
501
|
"""The model to use for sampling."""
|
|
487
502
|
|
|
503
|
+
def __init__(
|
|
504
|
+
self,
|
|
505
|
+
*,
|
|
506
|
+
url: str,
|
|
507
|
+
headers: dict[str, str] | None = None,
|
|
508
|
+
http_client: httpx.AsyncClient | None = None,
|
|
509
|
+
read_timeout: float | None = None,
|
|
510
|
+
tool_prefix: str | None = None,
|
|
511
|
+
log_level: mcp_types.LoggingLevel | None = None,
|
|
512
|
+
log_handler: LoggingFnT | None = None,
|
|
513
|
+
timeout: float = 5,
|
|
514
|
+
process_tool_call: ProcessToolCallback | None = None,
|
|
515
|
+
allow_sampling: bool = True,
|
|
516
|
+
max_retries: int = 1,
|
|
517
|
+
sampling_model: models.Model | None = None,
|
|
518
|
+
**kwargs: Any,
|
|
519
|
+
):
|
|
520
|
+
# Handle deprecated sse_read_timeout parameter
|
|
521
|
+
if 'sse_read_timeout' in kwargs:
|
|
522
|
+
if read_timeout is not None:
|
|
523
|
+
raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.")
|
|
524
|
+
|
|
525
|
+
warnings.warn(
|
|
526
|
+
"'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2
|
|
527
|
+
)
|
|
528
|
+
read_timeout = kwargs.pop('sse_read_timeout')
|
|
529
|
+
|
|
530
|
+
_utils.validate_empty_kwargs(kwargs)
|
|
531
|
+
|
|
532
|
+
if read_timeout is None:
|
|
533
|
+
read_timeout = 5 * 60
|
|
534
|
+
|
|
535
|
+
self.url = url
|
|
536
|
+
self.headers = headers
|
|
537
|
+
self.http_client = http_client
|
|
538
|
+
self.tool_prefix = tool_prefix
|
|
539
|
+
self.log_level = log_level
|
|
540
|
+
self.log_handler = log_handler
|
|
541
|
+
self.timeout = timeout
|
|
542
|
+
self.process_tool_call = process_tool_call
|
|
543
|
+
self.allow_sampling = allow_sampling
|
|
544
|
+
self.max_retries = max_retries
|
|
545
|
+
self.sampling_model = sampling_model
|
|
546
|
+
self.read_timeout = read_timeout
|
|
547
|
+
|
|
488
548
|
@property
|
|
489
549
|
@abstractmethod
|
|
490
550
|
def _transport_client(
|
|
@@ -522,7 +582,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
522
582
|
self._transport_client,
|
|
523
583
|
url=self.url,
|
|
524
584
|
timeout=self.timeout,
|
|
525
|
-
sse_read_timeout=self.
|
|
585
|
+
sse_read_timeout=self.read_timeout,
|
|
526
586
|
)
|
|
527
587
|
|
|
528
588
|
if self.http_client is not None:
|
|
@@ -549,7 +609,7 @@ class _MCPServerHTTP(MCPServer):
|
|
|
549
609
|
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
|
|
550
610
|
|
|
551
611
|
|
|
552
|
-
@dataclass
|
|
612
|
+
@dataclass(init=False)
|
|
553
613
|
class MCPServerSSE(_MCPServerHTTP):
|
|
554
614
|
"""An MCP server that connects over streamable HTTP connections.
|
|
555
615
|
|
pydantic_ai/messages.py
CHANGED
|
@@ -85,7 +85,7 @@ class SystemPromptPart:
|
|
|
85
85
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
86
86
|
|
|
87
87
|
|
|
88
|
-
@dataclass(repr=False)
|
|
88
|
+
@dataclass(init=False, repr=False)
|
|
89
89
|
class FileUrl(ABC):
|
|
90
90
|
"""Abstract base class for any URL-based file."""
|
|
91
91
|
|
|
@@ -106,11 +106,29 @@ class FileUrl(ABC):
|
|
|
106
106
|
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
|
|
107
107
|
"""
|
|
108
108
|
|
|
109
|
-
|
|
109
|
+
_media_type: str | None = field(init=False, repr=False)
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
url: str,
|
|
114
|
+
force_download: bool = False,
|
|
115
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
116
|
+
media_type: str | None = None,
|
|
117
|
+
) -> None:
|
|
118
|
+
self.url = url
|
|
119
|
+
self.vendor_metadata = vendor_metadata
|
|
120
|
+
self.force_download = force_download
|
|
121
|
+
self._media_type = media_type
|
|
122
|
+
|
|
110
123
|
@abstractmethod
|
|
111
|
-
def
|
|
124
|
+
def _infer_media_type(self) -> str:
|
|
112
125
|
"""Return the media type of the file, based on the url."""
|
|
113
126
|
|
|
127
|
+
@property
|
|
128
|
+
def media_type(self) -> str:
|
|
129
|
+
"""Return the media type of the file, based on the url or the provided `_media_type`."""
|
|
130
|
+
return self._media_type or self._infer_media_type()
|
|
131
|
+
|
|
114
132
|
@property
|
|
115
133
|
@abstractmethod
|
|
116
134
|
def format(self) -> str:
|
|
@@ -119,7 +137,7 @@ class FileUrl(ABC):
|
|
|
119
137
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
120
138
|
|
|
121
139
|
|
|
122
|
-
@dataclass(repr=False)
|
|
140
|
+
@dataclass(init=False, repr=False)
|
|
123
141
|
class VideoUrl(FileUrl):
|
|
124
142
|
"""A URL to a video."""
|
|
125
143
|
|
|
@@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
|
|
|
129
147
|
kind: Literal['video-url'] = 'video-url'
|
|
130
148
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
131
149
|
|
|
132
|
-
|
|
133
|
-
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
url: str,
|
|
153
|
+
force_download: bool = False,
|
|
154
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
155
|
+
media_type: str | None = None,
|
|
156
|
+
kind: Literal['video-url'] = 'video-url',
|
|
157
|
+
) -> None:
|
|
158
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
159
|
+
self.kind = kind
|
|
160
|
+
|
|
161
|
+
def _infer_media_type(self) -> VideoMediaType:
|
|
134
162
|
"""Return the media type of the video, based on the url."""
|
|
135
163
|
if self.url.endswith('.mkv'):
|
|
136
164
|
return 'video/x-matroska'
|
|
@@ -170,7 +198,7 @@ class VideoUrl(FileUrl):
|
|
|
170
198
|
return _video_format_lookup[self.media_type]
|
|
171
199
|
|
|
172
200
|
|
|
173
|
-
@dataclass(repr=False)
|
|
201
|
+
@dataclass(init=False, repr=False)
|
|
174
202
|
class AudioUrl(FileUrl):
|
|
175
203
|
"""A URL to an audio file."""
|
|
176
204
|
|
|
@@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
|
|
|
180
208
|
kind: Literal['audio-url'] = 'audio-url'
|
|
181
209
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
182
210
|
|
|
183
|
-
|
|
184
|
-
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
url: str,
|
|
214
|
+
force_download: bool = False,
|
|
215
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
216
|
+
media_type: str | None = None,
|
|
217
|
+
kind: Literal['audio-url'] = 'audio-url',
|
|
218
|
+
) -> None:
|
|
219
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
220
|
+
self.kind = kind
|
|
221
|
+
|
|
222
|
+
def _infer_media_type(self) -> AudioMediaType:
|
|
185
223
|
"""Return the media type of the audio file, based on the url.
|
|
186
224
|
|
|
187
225
|
References:
|
|
@@ -208,7 +246,7 @@ class AudioUrl(FileUrl):
|
|
|
208
246
|
return _audio_format_lookup[self.media_type]
|
|
209
247
|
|
|
210
248
|
|
|
211
|
-
@dataclass(repr=False)
|
|
249
|
+
@dataclass(init=False, repr=False)
|
|
212
250
|
class ImageUrl(FileUrl):
|
|
213
251
|
"""A URL to an image."""
|
|
214
252
|
|
|
@@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
|
|
|
218
256
|
kind: Literal['image-url'] = 'image-url'
|
|
219
257
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
220
258
|
|
|
221
|
-
|
|
222
|
-
|
|
259
|
+
def __init__(
|
|
260
|
+
self,
|
|
261
|
+
url: str,
|
|
262
|
+
force_download: bool = False,
|
|
263
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
264
|
+
media_type: str | None = None,
|
|
265
|
+
kind: Literal['image-url'] = 'image-url',
|
|
266
|
+
) -> None:
|
|
267
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
268
|
+
self.kind = kind
|
|
269
|
+
|
|
270
|
+
def _infer_media_type(self) -> ImageMediaType:
|
|
223
271
|
"""Return the media type of the image, based on the url."""
|
|
224
272
|
if self.url.endswith(('.jpg', '.jpeg')):
|
|
225
273
|
return 'image/jpeg'
|
|
@@ -241,7 +289,7 @@ class ImageUrl(FileUrl):
|
|
|
241
289
|
return _image_format_lookup[self.media_type]
|
|
242
290
|
|
|
243
291
|
|
|
244
|
-
@dataclass(repr=False)
|
|
292
|
+
@dataclass(init=False, repr=False)
|
|
245
293
|
class DocumentUrl(FileUrl):
|
|
246
294
|
"""The URL of the document."""
|
|
247
295
|
|
|
@@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
|
|
|
251
299
|
kind: Literal['document-url'] = 'document-url'
|
|
252
300
|
"""Type identifier, this is available on all parts as a discriminator."""
|
|
253
301
|
|
|
254
|
-
|
|
255
|
-
|
|
302
|
+
def __init__(
|
|
303
|
+
self,
|
|
304
|
+
url: str,
|
|
305
|
+
force_download: bool = False,
|
|
306
|
+
vendor_metadata: dict[str, Any] | None = None,
|
|
307
|
+
media_type: str | None = None,
|
|
308
|
+
kind: Literal['document-url'] = 'document-url',
|
|
309
|
+
) -> None:
|
|
310
|
+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
|
|
311
|
+
self.kind = kind
|
|
312
|
+
|
|
313
|
+
def _infer_media_type(self) -> str:
|
|
256
314
|
"""Return the media type of the document, based on the url."""
|
|
257
315
|
type_, _ = guess_type(self.url)
|
|
258
316
|
if type_ is None:
|
|
@@ -632,7 +690,7 @@ class ThinkingPart:
|
|
|
632
690
|
|
|
633
691
|
def has_content(self) -> bool:
|
|
634
692
|
"""Return `True` if the thinking content is non-empty."""
|
|
635
|
-
return bool(self.content)
|
|
693
|
+
return bool(self.content)
|
|
636
694
|
|
|
637
695
|
__repr__ = _utils.dataclasses_no_defaults_repr
|
|
638
696
|
|