pydantic-ai-slim 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/mcp.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import base64
4
4
  import functools
5
- import json
6
5
  from abc import ABC, abstractmethod
7
6
  from collections.abc import AsyncIterator, Awaitable, Sequence
8
7
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
@@ -13,41 +12,28 @@ from typing import Any, Callable
13
12
 
14
13
  import anyio
15
14
  import httpx
15
+ import pydantic_core
16
16
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
17
- from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
18
- from mcp.shared.exceptions import McpError
19
- from mcp.shared.message import SessionMessage
20
- from mcp.types import (
21
- AudioContent,
22
- BlobResourceContents,
23
- CallToolRequest,
24
- CallToolRequestParams,
25
- CallToolResult,
26
- ClientRequest,
27
- Content,
28
- EmbeddedResource,
29
- ImageContent,
30
- LoggingLevel,
31
- RequestParams,
32
- TextContent,
33
- TextResourceContents,
34
- )
35
17
  from typing_extensions import Self, assert_never, deprecated
36
18
 
37
- from pydantic_ai.exceptions import ModelRetry
38
- from pydantic_ai.messages import BinaryContent
39
- from pydantic_ai.tools import RunContext, ToolDefinition
40
-
41
19
  try:
42
- from mcp.client.session import ClientSession
20
+ from mcp import types as mcp_types
21
+ from mcp.client.session import ClientSession, LoggingFnT
43
22
  from mcp.client.sse import sse_client
44
23
  from mcp.client.stdio import StdioServerParameters, stdio_client
24
+ from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
25
+ from mcp.shared.context import RequestContext
26
+ from mcp.shared.exceptions import McpError
27
+ from mcp.shared.message import SessionMessage
45
28
  except ImportError as _import_error:
46
29
  raise ImportError(
47
30
  'Please install the `mcp` package to use the MCP server, '
48
31
  'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
49
32
  ) from _import_error
50
33
 
34
+ # after mcp imports so any import error maps to this file, not _mcp.py
35
+ from . import _mcp, exceptions, messages, models, tools
36
+
51
37
  __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
52
38
 
53
39
 
@@ -57,22 +43,22 @@ class MCPServer(ABC):
57
43
  See <https://modelcontextprotocol.io> for more information.
58
44
  """
59
45
 
60
- is_running: bool = False
46
+ # these fields should be re-defined by dataclass subclasses so they appear as fields {
61
47
  tool_prefix: str | None = None
62
- """A prefix to add to all tools that are registered with the server.
63
-
64
- If not empty, will include a trailing underscore(`_`).
65
-
66
- e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
67
- """
68
-
48
+ log_level: mcp_types.LoggingLevel | None = None
49
+ log_handler: LoggingFnT | None = None
50
+ timeout: float = 5
69
51
  process_tool_call: ProcessToolCallback | None = None
70
- """Hook to customize tool calling and optionally pass extra metadata."""
52
+ allow_sampling: bool = True
53
+ # } end of "abstract fields"
54
+
55
+ _running_count: int = 0
71
56
 
72
57
  _client: ClientSession
73
58
  _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
74
59
  _write_stream: MemoryObjectSendStream[SessionMessage]
75
60
  _exit_stack: AsyncExitStack
61
+ sampling_model: models.Model | None = None
76
62
 
77
63
  @abstractmethod
78
64
  @asynccontextmanager
@@ -88,14 +74,6 @@ class MCPServer(ABC):
88
74
  raise NotImplementedError('MCP Server subclasses must implement this method.')
89
75
  yield
90
76
 
91
- @abstractmethod
92
- def _get_log_level(self) -> LoggingLevel | None:
93
- """Get the log level for the MCP server."""
94
- raise NotImplementedError('MCP Server subclasses must implement this method.')
95
-
96
- def _get_client_initialize_timeout(self) -> float:
97
- return 5 # pragma: no cover
98
-
99
77
  def get_prefixed_tool_name(self, tool_name: str) -> str:
100
78
  """Get the tool name with prefix if `tool_prefix` is set."""
101
79
  return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name
@@ -104,21 +82,26 @@ class MCPServer(ABC):
104
82
  """Get original tool name without prefix for calling tools."""
105
83
  return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name
106
84
 
107
- async def list_tools(self) -> list[ToolDefinition]:
85
+ @property
86
+ def is_running(self) -> bool:
87
+ """Check if the MCP server is running."""
88
+ return bool(self._running_count)
89
+
90
+ async def list_tools(self) -> list[tools.ToolDefinition]:
108
91
  """Retrieve tools that are currently active on the server.
109
92
 
110
93
  Note:
111
94
  - We don't cache tools as they might change.
112
95
  - We also don't subscribe to the server to avoid complexity.
113
96
  """
114
- tools = await self._client.list_tools()
97
+ mcp_tools = await self._client.list_tools()
115
98
  return [
116
- ToolDefinition(
99
+ tools.ToolDefinition(
117
100
  name=self.get_prefixed_tool_name(tool.name),
118
101
  description=tool.description or '',
119
102
  parameters_json_schema=tool.inputSchema,
120
103
  )
121
- for tool in tools.tools
104
+ for tool in mcp_tools.tools
122
105
  ]
123
106
 
124
107
  async def call_tool(
@@ -143,44 +126,48 @@ class MCPServer(ABC):
143
126
  try:
144
127
  # meta param is not provided by session yet, so build and can send_request directly.
145
128
  result = await self._client.send_request(
146
- ClientRequest(
147
- CallToolRequest(
129
+ mcp_types.ClientRequest(
130
+ mcp_types.CallToolRequest(
148
131
  method='tools/call',
149
- params=CallToolRequestParams(
132
+ params=mcp_types.CallToolRequestParams(
150
133
  name=self.get_unprefixed_tool_name(tool_name),
151
134
  arguments=arguments,
152
- _meta=RequestParams.Meta(**metadata) if metadata else None,
135
+ _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
153
136
  ),
154
137
  )
155
138
  ),
156
- CallToolResult,
139
+ mcp_types.CallToolResult,
157
140
  )
158
141
  except McpError as e:
159
- raise ModelRetry(e.error.message)
142
+ raise exceptions.ModelRetry(e.error.message)
160
143
 
161
144
  content = [self._map_tool_result_part(part) for part in result.content]
162
145
 
163
146
  if result.isError:
164
147
  text = '\n'.join(str(part) for part in content)
165
- raise ModelRetry(text)
166
-
167
- if len(content) == 1:
168
- return content[0]
169
- return content
148
+ raise exceptions.ModelRetry(text)
149
+ else:
150
+ return content[0] if len(content) == 1 else content
170
151
 
171
152
  async def __aenter__(self) -> Self:
172
- self._exit_stack = AsyncExitStack()
173
-
174
- self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
175
- client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream)
176
- self._client = await self._exit_stack.enter_async_context(client)
153
+ if self._running_count == 0:
154
+ self._exit_stack = AsyncExitStack()
155
+
156
+ self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
157
+ client = ClientSession(
158
+ read_stream=self._read_stream,
159
+ write_stream=self._write_stream,
160
+ sampling_callback=self._sampling_callback if self.allow_sampling else None,
161
+ logging_callback=self.log_handler,
162
+ )
163
+ self._client = await self._exit_stack.enter_async_context(client)
177
164
 
178
- with anyio.fail_after(self._get_client_initialize_timeout()):
179
- await self._client.initialize()
165
+ with anyio.fail_after(self.timeout):
166
+ await self._client.initialize()
180
167
 
181
- if log_level := self._get_log_level():
182
- await self._client.set_logging_level(log_level)
183
- self.is_running = True
168
+ if log_level := self.log_level:
169
+ await self._client.set_logging_level(log_level)
170
+ self._running_count += 1
184
171
  return self
185
172
 
186
173
  async def __aexit__(
@@ -189,32 +176,64 @@ class MCPServer(ABC):
189
176
  exc_value: BaseException | None,
190
177
  traceback: TracebackType | None,
191
178
  ) -> bool | None:
192
- await self._exit_stack.aclose()
193
- self.is_running = False
179
+ self._running_count -= 1
180
+ if self._running_count <= 0:
181
+ await self._exit_stack.aclose()
182
+
183
+ async def _sampling_callback(
184
+ self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
185
+ ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
186
+ """MCP sampling callback."""
187
+ if self.sampling_model is None:
188
+ raise ValueError('Sampling model is not set') # pragma: no cover
189
+
190
+ pai_messages = _mcp.map_from_mcp_params(params)
191
+ model_settings = models.ModelSettings()
192
+ if max_tokens := params.maxTokens: # pragma: no branch
193
+ model_settings['max_tokens'] = max_tokens
194
+ if temperature := params.temperature: # pragma: no branch
195
+ model_settings['temperature'] = temperature
196
+ if stop_sequences := params.stopSequences: # pragma: no branch
197
+ model_settings['stop_sequences'] = stop_sequences
198
+
199
+ model_response = await self.sampling_model.request(
200
+ pai_messages,
201
+ model_settings,
202
+ models.ModelRequestParameters(),
203
+ )
204
+ return mcp_types.CreateMessageResult(
205
+ role='assistant',
206
+ content=_mcp.map_from_model_response(model_response),
207
+ model=self.sampling_model.model_name,
208
+ )
194
209
 
195
- def _map_tool_result_part(self, part: Content) -> str | BinaryContent | dict[str, Any] | list[Any]:
210
+ def _map_tool_result_part(
211
+ self, part: mcp_types.Content
212
+ ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
196
213
  # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
197
214
 
198
- if isinstance(part, TextContent):
215
+ if isinstance(part, mcp_types.TextContent):
199
216
  text = part.text
200
217
  if text.startswith(('[', '{')):
201
218
  try:
202
- return json.loads(text)
219
+ return pydantic_core.from_json(text)
203
220
  except ValueError:
204
221
  pass
205
222
  return text
206
- elif isinstance(part, ImageContent):
207
- return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
208
- elif isinstance(part, AudioContent):
223
+ elif isinstance(part, mcp_types.ImageContent):
224
+ return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
225
+ elif isinstance(part, mcp_types.AudioContent):
209
226
  # NOTE: The FastMCP server doesn't support audio content.
210
227
  # See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
211
- return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) # pragma: no cover
212
- elif isinstance(part, EmbeddedResource):
228
+ return messages.BinaryContent(
229
+ data=base64.b64decode(part.data), media_type=part.mimeType
230
+ ) # pragma: no cover
231
+ elif isinstance(part, mcp_types.EmbeddedResource):
213
232
  resource = part.resource
214
- if isinstance(resource, TextResourceContents):
233
+ if isinstance(resource, mcp_types.TextResourceContents):
215
234
  return resource.text
216
- elif isinstance(resource, BlobResourceContents):
217
- return BinaryContent(
235
+ elif isinstance(resource, mcp_types.BlobResourceContents):
236
+ return messages.BinaryContent(
218
237
  data=base64.b64decode(resource.blob),
219
238
  media_type=resource.mimeType or 'application/octet-stream',
220
239
  )
@@ -275,17 +294,11 @@ class MCPServerStdio(MCPServer):
275
294
  By default the subprocess will not inherit any environment variables from the parent process.
276
295
  If you want to inherit the environment variables from the parent process, use `env=os.environ`.
277
296
  """
278
- log_level: LoggingLevel | None = None
279
- """The log level to set when connecting to the server, if any.
280
-
281
- See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
282
-
283
- If `None`, no log level will be set.
284
- """
285
297
 
286
298
  cwd: str | Path | None = None
287
299
  """The working directory to use when spawning the process."""
288
300
 
301
+ # last fields are re-defined from the parent class so they appear as fields
289
302
  tool_prefix: str | None = None
290
303
  """A prefix to add to all tools that are registered with the server.
291
304
 
@@ -294,11 +307,25 @@ class MCPServerStdio(MCPServer):
294
307
  e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
295
308
  """
296
309
 
310
+ log_level: mcp_types.LoggingLevel | None = None
311
+ """The log level to set when connecting to the server, if any.
312
+
313
+ See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
314
+
315
+ If `None`, no log level will be set.
316
+ """
317
+
318
+ log_handler: LoggingFnT | None = None
319
+ """A handler for logging messages from the server."""
320
+
321
+ timeout: float = 5
322
+ """The timeout in seconds to wait for the client to initialize."""
323
+
297
324
  process_tool_call: ProcessToolCallback | None = None
298
325
  """Hook to customize tool calling and optionally pass extra metadata."""
299
326
 
300
- timeout: float = 5
301
- """ The timeout in seconds to wait for the client to initialize."""
327
+ allow_sampling: bool = True
328
+ """Whether to allow MCP sampling through this client."""
302
329
 
303
330
  @asynccontextmanager
304
331
  async def client_streams(
@@ -313,15 +340,9 @@ class MCPServerStdio(MCPServer):
313
340
  async with stdio_client(server=server) as (read_stream, write_stream):
314
341
  yield read_stream, write_stream
315
342
 
316
- def _get_log_level(self) -> LoggingLevel | None:
317
- return self.log_level
318
-
319
343
  def __repr__(self) -> str:
320
344
  return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
321
345
 
322
- def _get_client_initialize_timeout(self) -> float:
323
- return self.timeout
324
-
325
346
 
326
347
  @dataclass
327
348
  class _MCPServerHTTP(MCPServer):
@@ -360,13 +381,6 @@ class _MCPServerHTTP(MCPServer):
360
381
  ```
361
382
  """
362
383
 
363
- timeout: float = 5
364
- """Initial connection timeout in seconds for establishing the connection.
365
-
366
- This timeout applies to the initial connection setup and handshake.
367
- If the connection cannot be established within this time, the operation will fail.
368
- """
369
-
370
384
  sse_read_timeout: float = 5 * 60
371
385
  """Maximum time in seconds to wait for new SSE messages before timing out.
372
386
 
@@ -375,7 +389,16 @@ class _MCPServerHTTP(MCPServer):
375
389
  and may be closed. Defaults to 5 minutes (300 seconds).
376
390
  """
377
391
 
378
- log_level: LoggingLevel | None = None
392
+ # last fields are re-defined from the parent class so they appear as fields
393
+ tool_prefix: str | None = None
394
+ """A prefix to add to all tools that are registered with the server.
395
+
396
+ If not empty, will include a trailing underscore (`_`).
397
+
398
+ For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
399
+ """
400
+
401
+ log_level: mcp_types.LoggingLevel | None = None
379
402
  """The log level to set when connecting to the server, if any.
380
403
 
381
404
  See <https://modelcontextprotocol.io/introduction#logging> for more details.
@@ -383,17 +406,22 @@ class _MCPServerHTTP(MCPServer):
383
406
  If `None`, no log level will be set.
384
407
  """
385
408
 
386
- tool_prefix: str | None = None
387
- """A prefix to add to all tools that are registered with the server.
409
+ log_handler: LoggingFnT | None = None
410
+ """A handler for logging messages from the server."""
388
411
 
389
- If not empty, will include a trailing underscore (`_`).
412
+ timeout: float = 5
413
+ """Initial connection timeout in seconds for establishing the connection.
390
414
 
391
- For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
415
+ This timeout applies to the initial connection setup and handshake.
416
+ If the connection cannot be established within this time, the operation will fail.
392
417
  """
393
418
 
394
419
  process_tool_call: ProcessToolCallback | None = None
395
420
  """Hook to customize tool calling and optionally pass extra metadata."""
396
421
 
422
+ allow_sampling: bool = True
423
+ """Whether to allow MCP sampling through this client."""
424
+
397
425
  @property
398
426
  @abstractmethod
399
427
  def _transport_client(
@@ -419,7 +447,10 @@ class _MCPServerHTTP(MCPServer):
419
447
  async def client_streams(
420
448
  self,
421
449
  ) -> AsyncIterator[
422
- tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
450
+ tuple[
451
+ MemoryObjectReceiveStream[SessionMessage | Exception],
452
+ MemoryObjectSendStream[SessionMessage],
453
+ ]
423
454
  ]: # pragma: no cover
424
455
  if self.http_client and self.headers:
425
456
  raise ValueError('`http_client` is mutually exclusive with `headers`.')
@@ -451,15 +482,9 @@ class _MCPServerHTTP(MCPServer):
451
482
  async with transport_client_partial(headers=self.headers) as (read_stream, write_stream, *_):
452
483
  yield read_stream, write_stream
453
484
 
454
- def _get_log_level(self) -> LoggingLevel | None:
455
- return self.log_level
456
-
457
485
  def __repr__(self) -> str: # pragma: no cover
458
486
  return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
459
487
 
460
- def _get_client_initialize_timeout(self) -> float: # pragma: no cover
461
- return self.timeout
462
-
463
488
 
464
489
  @dataclass
465
490
  class MCPServerSSE(_MCPServerHTTP):
@@ -555,7 +580,11 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
555
580
 
556
581
 
557
582
  ToolResult = (
558
- str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]
583
+ str
584
+ | messages.BinaryContent
585
+ | dict[str, Any]
586
+ | list[Any]
587
+ | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
559
588
  )
560
589
  """The result type of a tool call."""
561
590
 
@@ -564,7 +593,7 @@ CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[
564
593
 
565
594
  ProcessToolCallback = Callable[
566
595
  [
567
- RunContext[Any],
596
+ tools.RunContext[Any],
568
597
  CallToolFunc,
569
598
  str,
570
599
  dict[str, Any],
@@ -20,9 +20,12 @@ from typing_extensions import Literal, TypeAliasType, TypedDict
20
20
 
21
21
  from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
22
22
 
23
+ from .. import _utils
24
+ from .._output import OutputObjectDefinition
23
25
  from .._parts_manager import ModelResponsePartsManager
24
26
  from ..exceptions import UserError
25
27
  from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
28
+ from ..output import OutputMode
26
29
  from ..profiles._json_schema import JsonSchemaTransformer
27
30
  from ..settings import ModelSettings
28
31
  from ..tools import ToolDefinition
@@ -300,13 +303,18 @@ KnownModelName = TypeAliasType(
300
303
  """
301
304
 
302
305
 
303
- @dataclass
306
+ @dataclass(repr=False)
304
307
  class ModelRequestParameters:
305
308
  """Configuration for an agent's request to a model, specifically related to tools and output handling."""
306
309
 
307
310
  function_tools: list[ToolDefinition] = field(default_factory=list)
308
- allow_text_output: bool = True
311
+
312
+ output_mode: OutputMode = 'text'
313
+ output_object: OutputObjectDefinition | None = None
309
314
  output_tools: list[ToolDefinition] = field(default_factory=list)
315
+ allow_text_output: bool = True
316
+
317
+ __repr__ = _utils.dataclasses_no_defaults_repr
310
318
 
311
319
 
312
320
  class Model(ABC):
@@ -351,6 +359,11 @@ class Model(ABC):
351
359
  function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools],
352
360
  output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools],
353
361
  )
362
+ if output_object := model_request_parameters.output_object:
363
+ model_request_parameters = replace(
364
+ model_request_parameters,
365
+ output_object=_customize_output_object(transformer, output_object),
366
+ )
354
367
 
355
368
  return model_request_parameters
356
369
 
@@ -718,3 +731,9 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit
718
731
  if t.strict is None:
719
732
  t = replace(t, strict=schema_transformer.is_strict_compatible)
720
733
  return replace(t, parameters_json_schema=parameters_json_schema)
734
+
735
+
736
+ def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition):
737
+ schema_transformer = transformer(o.json_schema, strict=True)
738
+ son_schema = schema_transformer.walk()
739
+ return replace(o, json_schema=son_schema)
@@ -11,6 +11,8 @@ from typing import Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
+ from pydantic_ai.profiles import ModelProfileSpec
15
+
14
16
  from .. import _utils, usage
15
17
  from .._utils import PeekableAsyncStream
16
18
  from ..messages import (
@@ -49,14 +51,27 @@ class FunctionModel(Model):
49
51
  _system: str = field(default='function', repr=False)
50
52
 
51
53
  @overload
52
- def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
54
+ def __init__(
55
+ self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None
56
+ ) -> None: ...
53
57
 
54
58
  @overload
55
- def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
59
+ def __init__(
60
+ self,
61
+ *,
62
+ stream_function: StreamFunctionDef,
63
+ model_name: str | None = None,
64
+ profile: ModelProfileSpec | None = None,
65
+ ) -> None: ...
56
66
 
57
67
  @overload
58
68
  def __init__(
59
- self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
69
+ self,
70
+ function: FunctionDef,
71
+ *,
72
+ stream_function: StreamFunctionDef,
73
+ model_name: str | None = None,
74
+ profile: ModelProfileSpec | None = None,
60
75
  ) -> None: ...
61
76
 
62
77
  def __init__(
@@ -65,6 +80,7 @@ class FunctionModel(Model):
65
80
  *,
66
81
  stream_function: StreamFunctionDef | None = None,
67
82
  model_name: str | None = None,
83
+ profile: ModelProfileSpec | None = None,
68
84
  ):
69
85
  """Initialize a `FunctionModel`.
70
86
 
@@ -74,6 +90,7 @@ class FunctionModel(Model):
74
90
  function: The function to call for non-streamed requests.
75
91
  stream_function: The function to call for streamed requests.
76
92
  model_name: The name of the model. If not provided, a name is generated from the function names.
93
+ profile: The model profile to use.
77
94
  """
78
95
  if function is None and stream_function is None:
79
96
  raise TypeError('Either `function` or `stream_function` must be provided')
@@ -83,6 +100,7 @@ class FunctionModel(Model):
83
100
  function_name = self.function.__name__ if self.function is not None else ''
84
101
  stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
85
102
  self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
103
+ self._profile = profile
86
104
 
87
105
  async def request(
88
106
  self,
@@ -16,6 +16,8 @@ from typing_extensions import NotRequired, TypedDict, assert_never
16
16
  from pydantic_ai.providers import Provider, infer_provider
17
17
 
18
18
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
+ from .._output import OutputObjectDefinition
20
+ from ..exceptions import UserError
19
21
  from ..messages import (
20
22
  BinaryContent,
21
23
  FileUrl,
@@ -203,12 +205,10 @@ class GeminiModel(Model):
203
205
  def _get_tool_config(
204
206
  self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
205
207
  ) -> _GeminiToolConfig | None:
206
- if model_request_parameters.allow_text_output:
207
- return None
208
- elif tools:
208
+ if not model_request_parameters.allow_text_output and tools:
209
209
  return _tool_config([t['name'] for t in tools['function_declarations']])
210
210
  else:
211
- return _tool_config([]) # pragma: no cover
211
+ return None
212
212
 
213
213
  @asynccontextmanager
214
214
  async def _make_request(
@@ -231,6 +231,18 @@ class GeminiModel(Model):
231
231
  request_data['toolConfig'] = tool_config
232
232
 
233
233
  generation_config = _settings_to_generation_config(model_settings)
234
+ if model_request_parameters.output_mode == 'native':
235
+ if tools:
236
+ raise UserError('Gemini does not support structured output and tools at the same time.')
237
+
238
+ generation_config['response_mime_type'] = 'application/json'
239
+
240
+ output_object = model_request_parameters.output_object
241
+ assert output_object is not None
242
+ generation_config['response_schema'] = self._map_response_schema(output_object)
243
+ elif model_request_parameters.output_mode == 'prompted' and not tools:
244
+ generation_config['response_mime_type'] = 'application/json'
245
+
234
246
  if generation_config:
235
247
  request_data['generationConfig'] = generation_config
236
248
 
@@ -376,6 +388,15 @@ class GeminiModel(Model):
376
388
  assert_never(item)
377
389
  return content
378
390
 
391
+ def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]:
392
+ response_schema = o.json_schema.copy()
393
+ if o.name:
394
+ response_schema['title'] = o.name
395
+ if o.description:
396
+ response_schema['description'] = o.description
397
+
398
+ return response_schema
399
+
379
400
 
380
401
  def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig:
381
402
  config: _GeminiGenerationConfig = {}
@@ -577,6 +598,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
577
598
  frequency_penalty: float
578
599
  stop_sequences: list[str]
579
600
  thinking_config: ThinkingConfig
601
+ response_mime_type: str
602
+ response_schema: dict[str, Any]
580
603
 
581
604
 
582
605
  class _GeminiContent(TypedDict):