pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (58) hide show
  1. pydantic_ai/_a2a.py +6 -4
  2. pydantic_ai/_agent_graph.py +37 -37
  3. pydantic_ai/_cli.py +3 -3
  4. pydantic_ai/_output.py +8 -0
  5. pydantic_ai/_tool_manager.py +3 -0
  6. pydantic_ai/ag_ui.py +25 -14
  7. pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
  8. pydantic_ai/agent/abstract.py +942 -0
  9. pydantic_ai/agent/wrapper.py +227 -0
  10. pydantic_ai/direct.py +9 -9
  11. pydantic_ai/durable_exec/__init__.py +0 -0
  12. pydantic_ai/durable_exec/temporal/__init__.py +83 -0
  13. pydantic_ai/durable_exec/temporal/_agent.py +699 -0
  14. pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
  15. pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
  16. pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
  17. pydantic_ai/durable_exec/temporal/_model.py +168 -0
  18. pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
  19. pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
  20. pydantic_ai/ext/aci.py +10 -9
  21. pydantic_ai/ext/langchain.py +4 -2
  22. pydantic_ai/mcp.py +203 -75
  23. pydantic_ai/messages.py +2 -2
  24. pydantic_ai/models/__init__.py +93 -9
  25. pydantic_ai/models/anthropic.py +16 -7
  26. pydantic_ai/models/bedrock.py +8 -5
  27. pydantic_ai/models/cohere.py +1 -4
  28. pydantic_ai/models/fallback.py +10 -3
  29. pydantic_ai/models/function.py +9 -4
  30. pydantic_ai/models/gemini.py +15 -9
  31. pydantic_ai/models/google.py +84 -20
  32. pydantic_ai/models/groq.py +17 -14
  33. pydantic_ai/models/huggingface.py +18 -12
  34. pydantic_ai/models/instrumented.py +3 -1
  35. pydantic_ai/models/mcp_sampling.py +3 -1
  36. pydantic_ai/models/mistral.py +12 -18
  37. pydantic_ai/models/openai.py +57 -30
  38. pydantic_ai/models/test.py +3 -0
  39. pydantic_ai/models/wrapper.py +6 -2
  40. pydantic_ai/profiles/openai.py +1 -1
  41. pydantic_ai/providers/google.py +7 -7
  42. pydantic_ai/result.py +21 -55
  43. pydantic_ai/run.py +357 -0
  44. pydantic_ai/tools.py +0 -1
  45. pydantic_ai/toolsets/__init__.py +2 -0
  46. pydantic_ai/toolsets/_dynamic.py +87 -0
  47. pydantic_ai/toolsets/abstract.py +23 -3
  48. pydantic_ai/toolsets/combined.py +19 -4
  49. pydantic_ai/toolsets/deferred.py +10 -2
  50. pydantic_ai/toolsets/function.py +23 -8
  51. pydantic_ai/toolsets/prefixed.py +4 -0
  52. pydantic_ai/toolsets/wrapper.py +14 -1
  53. pydantic_ai/usage.py +17 -1
  54. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/METADATA +7 -5
  55. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/RECORD +58 -45
  56. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/WHEEL +0 -0
  57. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/entry_points.txt +0 -0
  58. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -21,7 +21,9 @@ from typing_extensions import assert_never
21
21
  from pydantic_ai.builtin_tools import CodeExecutionTool, WebSearchTool
22
22
 
23
23
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
24
+ from .._run_context import RunContext
24
25
  from .._utils import guard_tool_call_id as _guard_tool_call_id
26
+ from ..exceptions import UserError
25
27
  from ..messages import (
26
28
  BinaryContent,
27
29
  BuiltinToolCallPart,
@@ -196,13 +198,14 @@ class AnthropicModel(Model):
196
198
  messages: list[ModelMessage],
197
199
  model_settings: ModelSettings | None,
198
200
  model_request_parameters: ModelRequestParameters,
201
+ run_context: RunContext[Any] | None = None,
199
202
  ) -> AsyncIterator[StreamedResponse]:
200
203
  check_allow_model_requests()
201
204
  response = await self._messages_create(
202
205
  messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
203
206
  )
204
207
  async with response:
205
- yield await self._process_streamed_response(response)
208
+ yield await self._process_streamed_response(response, model_request_parameters)
206
209
 
207
210
  @property
208
211
  def model_name(self) -> AnthropicModelName:
@@ -329,7 +332,9 @@ class AnthropicModel(Model):
329
332
 
330
333
  return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)
331
334
 
332
- async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageStreamEvent]) -> StreamedResponse:
335
+ async def _process_streamed_response(
336
+ self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters
337
+ ) -> StreamedResponse:
333
338
  peekable_response = _utils.PeekableAsyncStream(response)
334
339
  first_chunk = await peekable_response.peek()
335
340
  if isinstance(first_chunk, _utils.Unset):
@@ -338,14 +343,14 @@ class AnthropicModel(Model):
338
343
  # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
339
344
  timestamp = datetime.now(tz=timezone.utc)
340
345
  return AnthropicStreamedResponse(
341
- _model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
346
+ model_request_parameters=model_request_parameters,
347
+ _model_name=self._model_name,
348
+ _response=peekable_response,
349
+ _timestamp=timestamp,
342
350
  )
343
351
 
344
352
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
345
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
346
- if model_request_parameters.output_tools:
347
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
348
- return tools
353
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
349
354
 
350
355
  def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
351
356
  tools: list[BetaToolUnionParam] = []
@@ -363,6 +368,10 @@ class AnthropicModel(Model):
363
368
  )
364
369
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
365
370
  tools.append(BetaCodeExecutionTool20250522Param(name='code_execution', type='code_execution_20250522'))
371
+ else: # pragma: no cover
372
+ raise UserError(
373
+ f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
374
+ )
366
375
  return tools
367
376
 
368
377
  async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
@@ -15,6 +15,7 @@ import anyio.to_thread
15
15
  from typing_extensions import ParamSpec, assert_never
16
16
 
17
17
  from pydantic_ai import _utils, usage
18
+ from pydantic_ai._run_context import RunContext
18
19
  from pydantic_ai.exceptions import UserError
19
20
  from pydantic_ai.messages import (
20
21
  AudioUrl,
@@ -230,10 +231,7 @@ class BedrockConverseModel(Model):
230
231
  super().__init__(settings=settings, profile=profile or provider.model_profile)
231
232
 
232
233
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
233
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
234
- if model_request_parameters.output_tools:
235
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
236
- return tools
234
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
237
235
 
238
236
  @staticmethod
239
237
  def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
@@ -269,10 +267,15 @@ class BedrockConverseModel(Model):
269
267
  messages: list[ModelMessage],
270
268
  model_settings: ModelSettings | None,
271
269
  model_request_parameters: ModelRequestParameters,
270
+ run_context: RunContext[Any] | None = None,
272
271
  ) -> AsyncIterator[StreamedResponse]:
273
272
  settings = cast(BedrockModelSettings, model_settings or {})
274
273
  response = await self._messages_create(messages, True, settings, model_request_parameters)
275
- yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response)
274
+ yield BedrockStreamedResponse(
275
+ model_request_parameters=model_request_parameters,
276
+ _model_name=self.model_name,
277
+ _event_stream=response,
278
+ )
276
279
 
277
280
  async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
278
281
  items: list[ModelResponsePart] = []
@@ -248,10 +248,7 @@ class CohereModel(Model):
248
248
  return cohere_messages
249
249
 
250
250
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
251
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
252
- if model_request_parameters.output_tools:
253
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
254
- return tools
251
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
255
252
 
256
253
  @staticmethod
257
254
  def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
@@ -3,13 +3,15 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import AsyncIterator
4
4
  from contextlib import AsyncExitStack, asynccontextmanager, suppress
5
5
  from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Callable
6
+ from typing import TYPE_CHECKING, Any, Callable
7
7
 
8
8
  from opentelemetry.trace import get_current_span
9
9
 
10
+ from pydantic_ai._run_context import RunContext
10
11
  from pydantic_ai.models.instrumented import InstrumentedModel
11
12
 
12
13
  from ..exceptions import FallbackExceptionGroup, ModelHTTPError
14
+ from ..settings import merge_model_settings
13
15
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
14
16
 
15
17
  if TYPE_CHECKING:
@@ -64,8 +66,9 @@ class FallbackModel(Model):
64
66
 
65
67
  for model in self.models:
66
68
  customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
69
+ merged_settings = merge_model_settings(model.settings, model_settings)
67
70
  try:
68
- response = await model.request(messages, model_settings, customized_model_request_parameters)
71
+ response = await model.request(messages, merged_settings, customized_model_request_parameters)
69
72
  except Exception as exc:
70
73
  if self._fallback_on(exc):
71
74
  exceptions.append(exc)
@@ -83,16 +86,20 @@ class FallbackModel(Model):
83
86
  messages: list[ModelMessage],
84
87
  model_settings: ModelSettings | None,
85
88
  model_request_parameters: ModelRequestParameters,
89
+ run_context: RunContext[Any] | None = None,
86
90
  ) -> AsyncIterator[StreamedResponse]:
87
91
  """Try each model in sequence until one succeeds."""
88
92
  exceptions: list[Exception] = []
89
93
 
90
94
  for model in self.models:
91
95
  customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
96
+ merged_settings = merge_model_settings(model.settings, model_settings)
92
97
  async with AsyncExitStack() as stack:
93
98
  try:
94
99
  response = await stack.enter_async_context(
95
- model.request_stream(messages, model_settings, customized_model_request_parameters)
100
+ model.request_stream(
101
+ messages, merged_settings, customized_model_request_parameters, run_context
102
+ )
96
103
  )
97
104
  except Exception as exc:
98
105
  if self._fallback_on(exc):
@@ -7,13 +7,12 @@ from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Callable, Union
10
+ from typing import Any, Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
- from pydantic_ai.profiles import ModelProfileSpec
15
-
16
14
  from .. import _utils, usage
15
+ from .._run_context import RunContext
17
16
  from .._utils import PeekableAsyncStream
18
17
  from ..messages import (
19
18
  BinaryContent,
@@ -32,6 +31,7 @@ from ..messages import (
32
31
  UserContent,
33
32
  UserPromptPart,
34
33
  )
34
+ from ..profiles import ModelProfileSpec
35
35
  from ..settings import ModelSettings
36
36
  from ..tools import ToolDefinition
37
37
  from . import Model, ModelRequestParameters, StreamedResponse
@@ -147,6 +147,7 @@ class FunctionModel(Model):
147
147
  messages: list[ModelMessage],
148
148
  model_settings: ModelSettings | None,
149
149
  model_request_parameters: ModelRequestParameters,
150
+ run_context: RunContext[Any] | None = None,
150
151
  ) -> AsyncIterator[StreamedResponse]:
151
152
  agent_info = AgentInfo(
152
153
  model_request_parameters.function_tools,
@@ -165,7 +166,11 @@ class FunctionModel(Model):
165
166
  if isinstance(first, _utils.Unset):
166
167
  raise ValueError('Stream function must return at least one item')
167
168
 
168
- yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream)
169
+ yield FunctionStreamedResponse(
170
+ model_request_parameters=model_request_parameters,
171
+ _model_name=self._model_name,
172
+ _iter=response_stream,
173
+ )
169
174
 
170
175
  @property
171
176
  def model_name(self) -> str:
@@ -13,10 +13,9 @@ import pydantic
13
13
  from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
14
14
  from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
15
15
 
16
- from pydantic_ai.providers import Provider, infer_provider
17
-
18
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
17
  from .._output import OutputObjectDefinition
18
+ from .._run_context import RunContext
20
19
  from ..exceptions import UserError
21
20
  from ..messages import (
22
21
  BinaryContent,
@@ -38,6 +37,7 @@ from ..messages import (
38
37
  VideoUrl,
39
38
  )
40
39
  from ..profiles import ModelProfileSpec
40
+ from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
43
  from . import (
@@ -167,12 +167,13 @@ class GeminiModel(Model):
167
167
  messages: list[ModelMessage],
168
168
  model_settings: ModelSettings | None,
169
169
  model_request_parameters: ModelRequestParameters,
170
+ run_context: RunContext[Any] | None = None,
170
171
  ) -> AsyncIterator[StreamedResponse]:
171
172
  check_allow_model_requests()
172
173
  async with self._make_request(
173
174
  messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
174
175
  ) as http_response:
175
- yield await self._process_streamed_response(http_response)
176
+ yield await self._process_streamed_response(http_response, model_request_parameters)
176
177
 
177
178
  @property
178
179
  def model_name(self) -> GeminiModelName:
@@ -185,9 +186,7 @@ class GeminiModel(Model):
185
186
  return self._system
186
187
 
187
188
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
188
- tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
189
- if model_request_parameters.output_tools:
190
- tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
189
+ tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
191
190
  return _GeminiTools(function_declarations=tools) if tools else None
192
191
 
193
192
  def _get_tool_config(
@@ -288,7 +287,9 @@ class GeminiModel(Model):
288
287
  vendor_details=vendor_details,
289
288
  )
290
289
 
291
- async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
290
+ async def _process_streamed_response(
291
+ self, http_response: HTTPResponse, model_request_parameters: ModelRequestParameters
292
+ ) -> StreamedResponse:
292
293
  """Process a streamed response, and prepare a streaming response to return."""
293
294
  aiter_bytes = http_response.aiter_bytes()
294
295
  start_response: _GeminiResponse | None = None
@@ -309,7 +310,12 @@ class GeminiModel(Model):
309
310
  if start_response is None:
310
311
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
311
312
 
312
- return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
313
+ return GeminiStreamedResponse(
314
+ model_request_parameters=model_request_parameters,
315
+ _model_name=self._model_name,
316
+ _content=content,
317
+ _stream=aiter_bytes,
318
+ )
313
319
 
314
320
  async def _message_to_gemini_content(
315
321
  self, messages: list[ModelMessage]
@@ -872,7 +878,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
872
878
  metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
873
879
  suffix = key.removesuffix('_details')
874
880
  for detail in metadata_details:
875
- details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
881
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
876
882
 
877
883
  return usage.Usage(
878
884
  request_tokens=metadata.get('prompt_token_count', 0),
@@ -8,11 +8,11 @@ from datetime import datetime
8
8
  from typing import Any, Literal, Union, cast, overload
9
9
  from uuid import uuid4
10
10
 
11
- from google.genai.types import ExecutableCodeDict
12
11
  from typing_extensions import assert_never
13
12
 
14
13
  from .. import UnexpectedModelBehavior, _utils, usage
15
14
  from .._output import OutputObjectDefinition
15
+ from .._run_context import RunContext
16
16
  from ..builtin_tools import CodeExecutionTool, WebSearchTool
17
17
  from ..exceptions import UserError
18
18
  from ..messages import (
@@ -48,16 +48,19 @@ from . import (
48
48
  )
49
49
 
50
50
  try:
51
- from google import genai
51
+ from google.genai import Client
52
52
  from google.genai.types import (
53
53
  ContentDict,
54
54
  ContentUnionDict,
55
+ CountTokensConfigDict,
56
+ ExecutableCodeDict,
55
57
  FunctionCallDict,
56
58
  FunctionCallingConfigDict,
57
59
  FunctionCallingConfigMode,
58
60
  FunctionDeclarationDict,
59
61
  GenerateContentConfigDict,
60
62
  GenerateContentResponse,
63
+ GenerationConfigDict,
61
64
  GoogleSearchDict,
62
65
  HttpOptionsDict,
63
66
  MediaResolution,
@@ -136,10 +139,10 @@ class GoogleModel(Model):
136
139
  Apart from `__init__`, all methods are private or match those of the base class.
137
140
  """
138
141
 
139
- client: genai.Client = field(repr=False)
142
+ client: Client = field(repr=False)
140
143
 
141
144
  _model_name: GoogleModelName = field(repr=False)
142
- _provider: Provider[genai.Client] = field(repr=False)
145
+ _provider: Provider[Client] = field(repr=False)
143
146
  _url: str | None = field(repr=False)
144
147
  _system: str = field(default='google', repr=False)
145
148
 
@@ -147,7 +150,7 @@ class GoogleModel(Model):
147
150
  self,
148
151
  model_name: GoogleModelName,
149
152
  *,
150
- provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
153
+ provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
151
154
  profile: ModelProfileSpec | None = None,
152
155
  settings: ModelSettings | None = None,
153
156
  ):
@@ -187,17 +190,71 @@ class GoogleModel(Model):
187
190
  response = await self._generate_content(messages, False, model_settings, model_request_parameters)
188
191
  return self._process_response(response)
189
192
 
193
+ async def count_tokens(
194
+ self,
195
+ messages: list[ModelMessage],
196
+ model_settings: ModelSettings | None,
197
+ model_request_parameters: ModelRequestParameters,
198
+ ) -> usage.Usage:
199
+ check_allow_model_requests()
200
+ model_settings = cast(GoogleModelSettings, model_settings or {})
201
+ contents, generation_config = await self._build_content_and_config(
202
+ messages, model_settings, model_request_parameters
203
+ )
204
+
205
+ # Annoyingly, the type of `GenerateContentConfigDict.get` is "partially `Unknown`" because `response_schema` includes `typing._UnionGenericAlias`,
206
+ # so without this we'd need `pyright: ignore[reportUnknownMemberType]` on every line and wouldn't get type checking anyway.
207
+ generation_config = cast(dict[str, Any], generation_config)
208
+
209
+ config = CountTokensConfigDict(
210
+ http_options=generation_config.get('http_options'),
211
+ )
212
+ if self.system != 'google-gla':
213
+ # The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
214
+ config.update(
215
+ system_instruction=generation_config.get('system_instruction'),
216
+ tools=cast(list[ToolDict], generation_config.get('tools')),
217
+ # Annoyingly, GenerationConfigDict has fewer fields than GenerateContentConfigDict, and no extra fields are allowed.
218
+ generation_config=GenerationConfigDict(
219
+ temperature=generation_config.get('temperature'),
220
+ top_p=generation_config.get('top_p'),
221
+ max_output_tokens=generation_config.get('max_output_tokens'),
222
+ stop_sequences=generation_config.get('stop_sequences'),
223
+ presence_penalty=generation_config.get('presence_penalty'),
224
+ frequency_penalty=generation_config.get('frequency_penalty'),
225
+ thinking_config=generation_config.get('thinking_config'),
226
+ media_resolution=generation_config.get('media_resolution'),
227
+ response_mime_type=generation_config.get('response_mime_type'),
228
+ response_schema=generation_config.get('response_schema'),
229
+ ),
230
+ )
231
+
232
+ response = await self.client.aio.models.count_tokens(
233
+ model=self._model_name,
234
+ contents=contents,
235
+ config=config,
236
+ )
237
+ if response.total_tokens is None:
238
+ raise UnexpectedModelBehavior( # pragma: no cover
239
+ 'Total tokens missing from Gemini response', str(response)
240
+ )
241
+ return usage.Usage(
242
+ request_tokens=response.total_tokens,
243
+ total_tokens=response.total_tokens,
244
+ )
245
+
190
246
  @asynccontextmanager
191
247
  async def request_stream(
192
248
  self,
193
249
  messages: list[ModelMessage],
194
250
  model_settings: ModelSettings | None,
195
251
  model_request_parameters: ModelRequestParameters,
252
+ run_context: RunContext[Any] | None = None,
196
253
  ) -> AsyncIterator[StreamedResponse]:
197
254
  check_allow_model_requests()
198
255
  model_settings = cast(GoogleModelSettings, model_settings or {})
199
256
  response = await self._generate_content(messages, True, model_settings, model_request_parameters)
200
- yield await self._process_streamed_response(response) # type: ignore
257
+ yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
201
258
 
202
259
  @property
203
260
  def model_name(self) -> GoogleModelName:
@@ -212,18 +269,17 @@ class GoogleModel(Model):
212
269
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
213
270
  tools: list[ToolDict] = [
214
271
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
215
- for t in model_request_parameters.function_tools
272
+ for t in model_request_parameters.tool_defs.values()
216
273
  ]
217
- if model_request_parameters.output_tools:
218
- tools += [
219
- ToolDict(function_declarations=[_function_declaration_from_tool(t)])
220
- for t in model_request_parameters.output_tools
221
- ]
222
274
  for tool in model_request_parameters.builtin_tools:
223
275
  if isinstance(tool, WebSearchTool):
224
276
  tools.append(ToolDict(google_search=GoogleSearchDict()))
225
277
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
226
278
  tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
279
+ else: # pragma: no cover
280
+ raise UserError(
281
+ f'`{tool.__class__.__name__}` is not supported by `GoogleModel`. If it should be, please file an issue.'
282
+ )
227
283
  return tools or None
228
284
 
229
285
  def _get_tool_config(
@@ -264,16 +320,23 @@ class GoogleModel(Model):
264
320
  model_settings: GoogleModelSettings,
265
321
  model_request_parameters: ModelRequestParameters,
266
322
  ) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
267
- tools = self._get_tools(model_request_parameters)
323
+ contents, config = await self._build_content_and_config(messages, model_settings, model_request_parameters)
324
+ func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
325
+ return await func(model=self._model_name, contents=contents, config=config) # type: ignore
268
326
 
327
+ async def _build_content_and_config(
328
+ self,
329
+ messages: list[ModelMessage],
330
+ model_settings: GoogleModelSettings,
331
+ model_request_parameters: ModelRequestParameters,
332
+ ) -> tuple[list[ContentUnionDict], GenerateContentConfigDict]:
333
+ tools = self._get_tools(model_request_parameters)
269
334
  response_mime_type = None
270
335
  response_schema = None
271
336
  if model_request_parameters.output_mode == 'native':
272
337
  if tools:
273
338
  raise UserError('Gemini does not support structured output and tools at the same time.')
274
-
275
339
  response_mime_type = 'application/json'
276
-
277
340
  output_object = model_request_parameters.output_object
278
341
  assert output_object is not None
279
342
  response_schema = self._map_response_schema(output_object)
@@ -310,9 +373,7 @@ class GoogleModel(Model):
310
373
  response_mime_type=response_mime_type,
311
374
  response_schema=response_schema,
312
375
  )
313
-
314
- func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
315
- return await func(model=self._model_name, contents=contents, config=config) # type: ignore
376
+ return contents, config
316
377
 
317
378
  def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
318
379
  if not response.candidates or len(response.candidates) != 1:
@@ -336,7 +397,9 @@ class GoogleModel(Model):
336
397
  parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
337
398
  )
338
399
 
339
- async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
400
+ async def _process_streamed_response(
401
+ self, response: AsyncIterator[GenerateContentResponse], model_request_parameters: ModelRequestParameters
402
+ ) -> StreamedResponse:
340
403
  """Process a streamed response, and prepare a streaming response to return."""
341
404
  peekable_response = _utils.PeekableAsyncStream(response)
342
405
  first_chunk = await peekable_response.peek()
@@ -344,6 +407,7 @@ class GoogleModel(Model):
344
407
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
345
408
 
346
409
  return GeminiStreamedResponse(
410
+ model_request_parameters=model_request_parameters,
347
411
  _model_name=self._model_name,
348
412
  _response=peekable_response,
349
413
  _timestamp=first_chunk.create_time or _utils.now_utc(),
@@ -603,7 +667,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
603
667
  if key.endswith('_details') and metadata_details:
604
668
  suffix = key.removesuffix('_details')
605
669
  for detail in metadata_details:
606
- details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
670
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
607
671
 
608
672
  return usage.Usage(
609
673
  request_tokens=metadata.get('prompt_token_count', 0),
@@ -5,17 +5,16 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Literal, Union, cast, overload
8
+ from typing import Any, Literal, Union, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
12
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
13
- from pydantic_ai.exceptions import UserError
14
- from pydantic_ai.profiles.groq import GroqModelProfile
15
-
16
12
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
+ from .._run_context import RunContext
14
+ from .._thinking_part import split_content_into_text_and_thinking
17
15
  from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
18
- from ..builtin_tools import CodeExecutionTool, WebSearchTool
16
+ from ..builtin_tools import WebSearchTool
17
+ from ..exceptions import UserError
19
18
  from ..messages import (
20
19
  BinaryContent,
21
20
  BuiltinToolCallPart,
@@ -36,6 +35,7 @@ from ..messages import (
36
35
  UserPromptPart,
37
36
  )
38
37
  from ..profiles import ModelProfile, ModelProfileSpec
38
+ from ..profiles.groq import GroqModelProfile
39
39
  from ..providers import Provider, infer_provider
40
40
  from ..settings import ModelSettings
41
41
  from ..tools import ToolDefinition
@@ -171,13 +171,14 @@ class GroqModel(Model):
171
171
  messages: list[ModelMessage],
172
172
  model_settings: ModelSettings | None,
173
173
  model_request_parameters: ModelRequestParameters,
174
+ run_context: RunContext[Any] | None = None,
174
175
  ) -> AsyncIterator[StreamedResponse]:
175
176
  check_allow_model_requests()
176
177
  response = await self._completions_create(
177
178
  messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
178
179
  )
179
180
  async with response:
180
- yield await self._process_streamed_response(response)
181
+ yield await self._process_streamed_response(response, model_request_parameters)
181
182
 
182
183
  @property
183
184
  def model_name(self) -> GroqModelName:
@@ -287,7 +288,9 @@ class GroqModel(Model):
287
288
  items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
288
289
  )
289
290
 
290
- async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
291
+ async def _process_streamed_response(
292
+ self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
293
+ ) -> GroqStreamedResponse:
291
294
  """Process a streamed response, and prepare a streaming response to return."""
292
295
  peekable_response = _utils.PeekableAsyncStream(response)
293
296
  first_chunk = await peekable_response.peek()
@@ -297,6 +300,7 @@ class GroqModel(Model):
297
300
  )
298
301
 
299
302
  return GroqStreamedResponse(
303
+ model_request_parameters=model_request_parameters,
300
304
  _response=peekable_response,
301
305
  _model_name=self._model_name,
302
306
  _model_profile=self.profile,
@@ -304,10 +308,7 @@ class GroqModel(Model):
304
308
  )
305
309
 
306
310
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
307
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
308
- if model_request_parameters.output_tools:
309
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
310
- return tools
311
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
311
312
 
312
313
  def _get_builtin_tools(
313
314
  self, model_request_parameters: ModelRequestParameters
@@ -317,8 +318,10 @@ class GroqModel(Model):
317
318
  if isinstance(tool, WebSearchTool):
318
319
  if not GroqModelProfile.from_profile(self.profile).groq_always_has_web_search_builtin_tool:
319
320
  raise UserError('`WebSearchTool` is not supported by Groq') # pragma: no cover
320
- elif isinstance(tool, CodeExecutionTool): # pragma: no branch
321
- raise UserError('`CodeExecutionTool` is not supported by Groq')
321
+ else:
322
+ raise UserError(
323
+ f'`{tool.__class__.__name__}` is not supported by `GroqModel`. If it should be, please file an issue.'
324
+ )
322
325
  return tools
323
326
 
324
327
  def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: