pydantic-ai-slim 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (57) hide show
  1. pydantic_ai/_a2a.py +6 -4
  2. pydantic_ai/_agent_graph.py +25 -32
  3. pydantic_ai/_cli.py +3 -3
  4. pydantic_ai/_output.py +8 -0
  5. pydantic_ai/_tool_manager.py +3 -0
  6. pydantic_ai/ag_ui.py +25 -14
  7. pydantic_ai/{agent.py → agent/__init__.py} +209 -1027
  8. pydantic_ai/agent/abstract.py +942 -0
  9. pydantic_ai/agent/wrapper.py +227 -0
  10. pydantic_ai/direct.py +9 -9
  11. pydantic_ai/durable_exec/__init__.py +0 -0
  12. pydantic_ai/durable_exec/temporal/__init__.py +83 -0
  13. pydantic_ai/durable_exec/temporal/_agent.py +699 -0
  14. pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
  15. pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
  16. pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
  17. pydantic_ai/durable_exec/temporal/_model.py +168 -0
  18. pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
  19. pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
  20. pydantic_ai/ext/aci.py +10 -9
  21. pydantic_ai/ext/langchain.py +4 -2
  22. pydantic_ai/mcp.py +203 -75
  23. pydantic_ai/messages.py +2 -2
  24. pydantic_ai/models/__init__.py +65 -9
  25. pydantic_ai/models/anthropic.py +16 -7
  26. pydantic_ai/models/bedrock.py +8 -5
  27. pydantic_ai/models/cohere.py +1 -4
  28. pydantic_ai/models/fallback.py +4 -2
  29. pydantic_ai/models/function.py +9 -4
  30. pydantic_ai/models/gemini.py +15 -9
  31. pydantic_ai/models/google.py +18 -14
  32. pydantic_ai/models/groq.py +17 -14
  33. pydantic_ai/models/huggingface.py +18 -12
  34. pydantic_ai/models/instrumented.py +3 -1
  35. pydantic_ai/models/mcp_sampling.py +3 -1
  36. pydantic_ai/models/mistral.py +12 -18
  37. pydantic_ai/models/openai.py +29 -26
  38. pydantic_ai/models/test.py +3 -0
  39. pydantic_ai/models/wrapper.py +6 -2
  40. pydantic_ai/profiles/openai.py +1 -1
  41. pydantic_ai/providers/google.py +7 -7
  42. pydantic_ai/result.py +21 -55
  43. pydantic_ai/run.py +357 -0
  44. pydantic_ai/tools.py +0 -1
  45. pydantic_ai/toolsets/__init__.py +2 -0
  46. pydantic_ai/toolsets/_dynamic.py +87 -0
  47. pydantic_ai/toolsets/abstract.py +23 -3
  48. pydantic_ai/toolsets/combined.py +19 -4
  49. pydantic_ai/toolsets/deferred.py +10 -2
  50. pydantic_ai/toolsets/function.py +23 -8
  51. pydantic_ai/toolsets/prefixed.py +4 -0
  52. pydantic_ai/toolsets/wrapper.py +14 -1
  53. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +6 -4
  54. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/RECORD +57 -44
  55. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
  56. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
  57. {pydantic_ai_slim-0.6.2.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -15,6 +15,7 @@ import anyio.to_thread
15
15
  from typing_extensions import ParamSpec, assert_never
16
16
 
17
17
  from pydantic_ai import _utils, usage
18
+ from pydantic_ai._run_context import RunContext
18
19
  from pydantic_ai.exceptions import UserError
19
20
  from pydantic_ai.messages import (
20
21
  AudioUrl,
@@ -230,10 +231,7 @@ class BedrockConverseModel(Model):
230
231
  super().__init__(settings=settings, profile=profile or provider.model_profile)
231
232
 
232
233
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
233
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
234
- if model_request_parameters.output_tools:
235
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
236
- return tools
234
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
237
235
 
238
236
  @staticmethod
239
237
  def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
@@ -269,10 +267,15 @@ class BedrockConverseModel(Model):
269
267
  messages: list[ModelMessage],
270
268
  model_settings: ModelSettings | None,
271
269
  model_request_parameters: ModelRequestParameters,
270
+ run_context: RunContext[Any] | None = None,
272
271
  ) -> AsyncIterator[StreamedResponse]:
273
272
  settings = cast(BedrockModelSettings, model_settings or {})
274
273
  response = await self._messages_create(messages, True, settings, model_request_parameters)
275
- yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response)
274
+ yield BedrockStreamedResponse(
275
+ model_request_parameters=model_request_parameters,
276
+ _model_name=self.model_name,
277
+ _event_stream=response,
278
+ )
276
279
 
277
280
  async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
278
281
  items: list[ModelResponsePart] = []
@@ -248,10 +248,7 @@ class CohereModel(Model):
248
248
  return cohere_messages
249
249
 
250
250
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
251
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
252
- if model_request_parameters.output_tools:
253
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
254
- return tools
251
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
255
252
 
256
253
  @staticmethod
257
254
  def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
@@ -3,10 +3,11 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import AsyncIterator
4
4
  from contextlib import AsyncExitStack, asynccontextmanager, suppress
5
5
  from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Callable
6
+ from typing import TYPE_CHECKING, Any, Callable
7
7
 
8
8
  from opentelemetry.trace import get_current_span
9
9
 
10
+ from pydantic_ai._run_context import RunContext
10
11
  from pydantic_ai.models.instrumented import InstrumentedModel
11
12
 
12
13
  from ..exceptions import FallbackExceptionGroup, ModelHTTPError
@@ -83,6 +84,7 @@ class FallbackModel(Model):
83
84
  messages: list[ModelMessage],
84
85
  model_settings: ModelSettings | None,
85
86
  model_request_parameters: ModelRequestParameters,
87
+ run_context: RunContext[Any] | None = None,
86
88
  ) -> AsyncIterator[StreamedResponse]:
87
89
  """Try each model in sequence until one succeeds."""
88
90
  exceptions: list[Exception] = []
@@ -92,7 +94,7 @@ class FallbackModel(Model):
92
94
  async with AsyncExitStack() as stack:
93
95
  try:
94
96
  response = await stack.enter_async_context(
95
- model.request_stream(messages, model_settings, customized_model_request_parameters)
97
+ model.request_stream(messages, model_settings, customized_model_request_parameters, run_context)
96
98
  )
97
99
  except Exception as exc:
98
100
  if self._fallback_on(exc):
@@ -7,13 +7,12 @@ from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime
9
9
  from itertools import chain
10
- from typing import Callable, Union
10
+ from typing import Any, Callable, Union
11
11
 
12
12
  from typing_extensions import TypeAlias, assert_never, overload
13
13
 
14
- from pydantic_ai.profiles import ModelProfileSpec
15
-
16
14
  from .. import _utils, usage
15
+ from .._run_context import RunContext
17
16
  from .._utils import PeekableAsyncStream
18
17
  from ..messages import (
19
18
  BinaryContent,
@@ -32,6 +31,7 @@ from ..messages import (
32
31
  UserContent,
33
32
  UserPromptPart,
34
33
  )
34
+ from ..profiles import ModelProfileSpec
35
35
  from ..settings import ModelSettings
36
36
  from ..tools import ToolDefinition
37
37
  from . import Model, ModelRequestParameters, StreamedResponse
@@ -147,6 +147,7 @@ class FunctionModel(Model):
147
147
  messages: list[ModelMessage],
148
148
  model_settings: ModelSettings | None,
149
149
  model_request_parameters: ModelRequestParameters,
150
+ run_context: RunContext[Any] | None = None,
150
151
  ) -> AsyncIterator[StreamedResponse]:
151
152
  agent_info = AgentInfo(
152
153
  model_request_parameters.function_tools,
@@ -165,7 +166,11 @@ class FunctionModel(Model):
165
166
  if isinstance(first, _utils.Unset):
166
167
  raise ValueError('Stream function must return at least one item')
167
168
 
168
- yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream)
169
+ yield FunctionStreamedResponse(
170
+ model_request_parameters=model_request_parameters,
171
+ _model_name=self._model_name,
172
+ _iter=response_stream,
173
+ )
169
174
 
170
175
  @property
171
176
  def model_name(self) -> str:
@@ -13,10 +13,9 @@ import pydantic
13
13
  from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
14
14
  from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
15
15
 
16
- from pydantic_ai.providers import Provider, infer_provider
17
-
18
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
17
  from .._output import OutputObjectDefinition
18
+ from .._run_context import RunContext
20
19
  from ..exceptions import UserError
21
20
  from ..messages import (
22
21
  BinaryContent,
@@ -38,6 +37,7 @@ from ..messages import (
38
37
  VideoUrl,
39
38
  )
40
39
  from ..profiles import ModelProfileSpec
40
+ from ..providers import Provider, infer_provider
41
41
  from ..settings import ModelSettings
42
42
  from ..tools import ToolDefinition
43
43
  from . import (
@@ -167,12 +167,13 @@ class GeminiModel(Model):
167
167
  messages: list[ModelMessage],
168
168
  model_settings: ModelSettings | None,
169
169
  model_request_parameters: ModelRequestParameters,
170
+ run_context: RunContext[Any] | None = None,
170
171
  ) -> AsyncIterator[StreamedResponse]:
171
172
  check_allow_model_requests()
172
173
  async with self._make_request(
173
174
  messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
174
175
  ) as http_response:
175
- yield await self._process_streamed_response(http_response)
176
+ yield await self._process_streamed_response(http_response, model_request_parameters)
176
177
 
177
178
  @property
178
179
  def model_name(self) -> GeminiModelName:
@@ -185,9 +186,7 @@ class GeminiModel(Model):
185
186
  return self._system
186
187
 
187
188
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
188
- tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
189
- if model_request_parameters.output_tools:
190
- tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
189
+ tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
191
190
  return _GeminiTools(function_declarations=tools) if tools else None
192
191
 
193
192
  def _get_tool_config(
@@ -288,7 +287,9 @@ class GeminiModel(Model):
288
287
  vendor_details=vendor_details,
289
288
  )
290
289
 
291
- async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
290
+ async def _process_streamed_response(
291
+ self, http_response: HTTPResponse, model_request_parameters: ModelRequestParameters
292
+ ) -> StreamedResponse:
292
293
  """Process a streamed response, and prepare a streaming response to return."""
293
294
  aiter_bytes = http_response.aiter_bytes()
294
295
  start_response: _GeminiResponse | None = None
@@ -309,7 +310,12 @@ class GeminiModel(Model):
309
310
  if start_response is None:
310
311
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
311
312
 
312
- return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
313
+ return GeminiStreamedResponse(
314
+ model_request_parameters=model_request_parameters,
315
+ _model_name=self._model_name,
316
+ _content=content,
317
+ _stream=aiter_bytes,
318
+ )
313
319
 
314
320
  async def _message_to_gemini_content(
315
321
  self, messages: list[ModelMessage]
@@ -872,7 +878,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
872
878
  metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
873
879
  suffix = key.removesuffix('_details')
874
880
  for detail in metadata_details:
875
- details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
881
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
876
882
 
877
883
  return usage.Usage(
878
884
  request_tokens=metadata.get('prompt_token_count', 0),
@@ -8,11 +8,11 @@ from datetime import datetime
8
8
  from typing import Any, Literal, Union, cast, overload
9
9
  from uuid import uuid4
10
10
 
11
- from google.genai.types import ExecutableCodeDict
12
11
  from typing_extensions import assert_never
13
12
 
14
13
  from .. import UnexpectedModelBehavior, _utils, usage
15
14
  from .._output import OutputObjectDefinition
15
+ from .._run_context import RunContext
16
16
  from ..builtin_tools import CodeExecutionTool, WebSearchTool
17
17
  from ..exceptions import UserError
18
18
  from ..messages import (
@@ -48,10 +48,11 @@ from . import (
48
48
  )
49
49
 
50
50
  try:
51
- from google import genai
51
+ from google.genai import Client
52
52
  from google.genai.types import (
53
53
  ContentDict,
54
54
  ContentUnionDict,
55
+ ExecutableCodeDict,
55
56
  FunctionCallDict,
56
57
  FunctionCallingConfigDict,
57
58
  FunctionCallingConfigMode,
@@ -136,10 +137,10 @@ class GoogleModel(Model):
136
137
  Apart from `__init__`, all methods are private or match those of the base class.
137
138
  """
138
139
 
139
- client: genai.Client = field(repr=False)
140
+ client: Client = field(repr=False)
140
141
 
141
142
  _model_name: GoogleModelName = field(repr=False)
142
- _provider: Provider[genai.Client] = field(repr=False)
143
+ _provider: Provider[Client] = field(repr=False)
143
144
  _url: str | None = field(repr=False)
144
145
  _system: str = field(default='google', repr=False)
145
146
 
@@ -147,7 +148,7 @@ class GoogleModel(Model):
147
148
  self,
148
149
  model_name: GoogleModelName,
149
150
  *,
150
- provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
151
+ provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
151
152
  profile: ModelProfileSpec | None = None,
152
153
  settings: ModelSettings | None = None,
153
154
  ):
@@ -193,11 +194,12 @@ class GoogleModel(Model):
193
194
  messages: list[ModelMessage],
194
195
  model_settings: ModelSettings | None,
195
196
  model_request_parameters: ModelRequestParameters,
197
+ run_context: RunContext[Any] | None = None,
196
198
  ) -> AsyncIterator[StreamedResponse]:
197
199
  check_allow_model_requests()
198
200
  model_settings = cast(GoogleModelSettings, model_settings or {})
199
201
  response = await self._generate_content(messages, True, model_settings, model_request_parameters)
200
- yield await self._process_streamed_response(response) # type: ignore
202
+ yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
201
203
 
202
204
  @property
203
205
  def model_name(self) -> GoogleModelName:
@@ -212,18 +214,17 @@ class GoogleModel(Model):
212
214
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
213
215
  tools: list[ToolDict] = [
214
216
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
215
- for t in model_request_parameters.function_tools
217
+ for t in model_request_parameters.tool_defs.values()
216
218
  ]
217
- if model_request_parameters.output_tools:
218
- tools += [
219
- ToolDict(function_declarations=[_function_declaration_from_tool(t)])
220
- for t in model_request_parameters.output_tools
221
- ]
222
219
  for tool in model_request_parameters.builtin_tools:
223
220
  if isinstance(tool, WebSearchTool):
224
221
  tools.append(ToolDict(google_search=GoogleSearchDict()))
225
222
  elif isinstance(tool, CodeExecutionTool): # pragma: no branch
226
223
  tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
224
+ else: # pragma: no cover
225
+ raise UserError(
226
+ f'`{tool.__class__.__name__}` is not supported by `GoogleModel`. If it should be, please file an issue.'
227
+ )
227
228
  return tools or None
228
229
 
229
230
  def _get_tool_config(
@@ -336,7 +337,9 @@ class GoogleModel(Model):
336
337
  parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
337
338
  )
338
339
 
339
- async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
340
+ async def _process_streamed_response(
341
+ self, response: AsyncIterator[GenerateContentResponse], model_request_parameters: ModelRequestParameters
342
+ ) -> StreamedResponse:
340
343
  """Process a streamed response, and prepare a streaming response to return."""
341
344
  peekable_response = _utils.PeekableAsyncStream(response)
342
345
  first_chunk = await peekable_response.peek()
@@ -344,6 +347,7 @@ class GoogleModel(Model):
344
347
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
345
348
 
346
349
  return GeminiStreamedResponse(
350
+ model_request_parameters=model_request_parameters,
347
351
  _model_name=self._model_name,
348
352
  _response=peekable_response,
349
353
  _timestamp=first_chunk.create_time or _utils.now_utc(),
@@ -603,7 +607,7 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
603
607
  if key.endswith('_details') and metadata_details:
604
608
  suffix = key.removesuffix('_details')
605
609
  for detail in metadata_details:
606
- details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
610
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
607
611
 
608
612
  return usage.Usage(
609
613
  request_tokens=metadata.get('prompt_token_count', 0),
@@ -5,17 +5,16 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Literal, Union, cast, overload
8
+ from typing import Any, Literal, Union, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
12
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
13
- from pydantic_ai.exceptions import UserError
14
- from pydantic_ai.profiles.groq import GroqModelProfile
15
-
16
12
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
+ from .._run_context import RunContext
14
+ from .._thinking_part import split_content_into_text_and_thinking
17
15
  from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
18
- from ..builtin_tools import CodeExecutionTool, WebSearchTool
16
+ from ..builtin_tools import WebSearchTool
17
+ from ..exceptions import UserError
19
18
  from ..messages import (
20
19
  BinaryContent,
21
20
  BuiltinToolCallPart,
@@ -36,6 +35,7 @@ from ..messages import (
36
35
  UserPromptPart,
37
36
  )
38
37
  from ..profiles import ModelProfile, ModelProfileSpec
38
+ from ..profiles.groq import GroqModelProfile
39
39
  from ..providers import Provider, infer_provider
40
40
  from ..settings import ModelSettings
41
41
  from ..tools import ToolDefinition
@@ -171,13 +171,14 @@ class GroqModel(Model):
171
171
  messages: list[ModelMessage],
172
172
  model_settings: ModelSettings | None,
173
173
  model_request_parameters: ModelRequestParameters,
174
+ run_context: RunContext[Any] | None = None,
174
175
  ) -> AsyncIterator[StreamedResponse]:
175
176
  check_allow_model_requests()
176
177
  response = await self._completions_create(
177
178
  messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
178
179
  )
179
180
  async with response:
180
- yield await self._process_streamed_response(response)
181
+ yield await self._process_streamed_response(response, model_request_parameters)
181
182
 
182
183
  @property
183
184
  def model_name(self) -> GroqModelName:
@@ -287,7 +288,9 @@ class GroqModel(Model):
287
288
  items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
288
289
  )
289
290
 
290
- async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
291
+ async def _process_streamed_response(
292
+ self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
293
+ ) -> GroqStreamedResponse:
291
294
  """Process a streamed response, and prepare a streaming response to return."""
292
295
  peekable_response = _utils.PeekableAsyncStream(response)
293
296
  first_chunk = await peekable_response.peek()
@@ -297,6 +300,7 @@ class GroqModel(Model):
297
300
  )
298
301
 
299
302
  return GroqStreamedResponse(
303
+ model_request_parameters=model_request_parameters,
300
304
  _response=peekable_response,
301
305
  _model_name=self._model_name,
302
306
  _model_profile=self.profile,
@@ -304,10 +308,7 @@ class GroqModel(Model):
304
308
  )
305
309
 
306
310
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
307
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
308
- if model_request_parameters.output_tools:
309
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
310
- return tools
311
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
311
312
 
312
313
  def _get_builtin_tools(
313
314
  self, model_request_parameters: ModelRequestParameters
@@ -317,8 +318,10 @@ class GroqModel(Model):
317
318
  if isinstance(tool, WebSearchTool):
318
319
  if not GroqModelProfile.from_profile(self.profile).groq_always_has_web_search_builtin_tool:
319
320
  raise UserError('`WebSearchTool` is not supported by Groq') # pragma: no cover
320
- elif isinstance(tool, CodeExecutionTool): # pragma: no branch
321
- raise UserError('`CodeExecutionTool` is not supported by Groq')
321
+ else:
322
+ raise UserError(
323
+ f'`{tool.__class__.__name__}` is not supported by `GroqModel`. If it should be, please file an issue.'
324
+ )
322
325
  return tools
323
326
 
324
327
  def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
@@ -5,16 +5,15 @@ from collections.abc import AsyncIterable, AsyncIterator
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from typing import Literal, Union, cast, overload
8
+ from typing import Any, Literal, Union, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
12
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
13
- from pydantic_ai.exceptions import UserError
14
- from pydantic_ai.providers import Provider, infer_provider
15
-
16
12
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
+ from .._run_context import RunContext
14
+ from .._thinking_part import split_content_into_text_and_thinking
17
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
16
+ from ..exceptions import UserError
18
17
  from ..messages import (
19
18
  AudioUrl,
20
19
  BinaryContent,
@@ -37,9 +36,15 @@ from ..messages import (
37
36
  VideoUrl,
38
37
  )
39
38
  from ..profiles import ModelProfile
39
+ from ..providers import Provider, infer_provider
40
40
  from ..settings import ModelSettings
41
41
  from ..tools import ToolDefinition
42
- from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
42
+ from . import (
43
+ Model,
44
+ ModelRequestParameters,
45
+ StreamedResponse,
46
+ check_allow_model_requests,
47
+ )
43
48
 
44
49
  try:
45
50
  import aiohttp
@@ -150,12 +155,13 @@ class HuggingFaceModel(Model):
150
155
  messages: list[ModelMessage],
151
156
  model_settings: ModelSettings | None,
152
157
  model_request_parameters: ModelRequestParameters,
158
+ run_context: RunContext[Any] | None = None,
153
159
  ) -> AsyncIterator[StreamedResponse]:
154
160
  check_allow_model_requests()
155
161
  response = await self._completions_create(
156
162
  messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
157
163
  )
158
- yield await self._process_streamed_response(response)
164
+ yield await self._process_streamed_response(response, model_request_parameters)
159
165
 
160
166
  @property
161
167
  def model_name(self) -> HuggingFaceModelName:
@@ -263,7 +269,9 @@ class HuggingFaceModel(Model):
263
269
  vendor_id=response.id,
264
270
  )
265
271
 
266
- async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
272
+ async def _process_streamed_response(
273
+ self, response: AsyncIterable[ChatCompletionStreamOutput], model_request_parameters: ModelRequestParameters
274
+ ) -> StreamedResponse:
267
275
  """Process a streamed response, and prepare a streaming response to return."""
268
276
  peekable_response = _utils.PeekableAsyncStream(response)
269
277
  first_chunk = await peekable_response.peek()
@@ -273,6 +281,7 @@ class HuggingFaceModel(Model):
273
281
  )
274
282
 
275
283
  return HuggingFaceStreamedResponse(
284
+ model_request_parameters=model_request_parameters,
276
285
  _model_name=self._model_name,
277
286
  _model_profile=self.profile,
278
287
  _response=peekable_response,
@@ -280,10 +289,7 @@ class HuggingFaceModel(Model):
280
289
  )
281
290
 
282
291
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
283
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
284
- if model_request_parameters.output_tools:
285
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
286
- return tools
292
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
287
293
 
288
294
  async def _map_messages(
289
295
  self, messages: list[ModelMessage]
@@ -18,6 +18,7 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
18
18
  from opentelemetry.util.types import AttributeValue
19
19
  from pydantic import TypeAdapter
20
20
 
21
+ from .._run_context import RunContext
21
22
  from ..messages import ModelMessage, ModelRequest, ModelResponse
22
23
  from ..settings import ModelSettings
23
24
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
@@ -218,12 +219,13 @@ class InstrumentedModel(WrapperModel):
218
219
  messages: list[ModelMessage],
219
220
  model_settings: ModelSettings | None,
220
221
  model_request_parameters: ModelRequestParameters,
222
+ run_context: RunContext[Any] | None = None,
221
223
  ) -> AsyncIterator[StreamedResponse]:
222
224
  with self._instrument(messages, model_settings, model_request_parameters) as finish:
223
225
  response_stream: StreamedResponse | None = None
224
226
  try:
225
227
  async with super().request_stream(
226
- messages, model_settings, model_request_parameters
228
+ messages, model_settings, model_request_parameters, run_context
227
229
  ) as response_stream:
228
230
  yield response_stream
229
231
  finally:
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, cast
6
+ from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  from .. import _mcp, exceptions, usage
9
+ from .._run_context import RunContext
9
10
  from ..messages import ModelMessage, ModelResponse
10
11
  from ..settings import ModelSettings
11
12
  from . import Model, ModelRequestParameters, StreamedResponse
@@ -76,6 +77,7 @@ class MCPSamplingModel(Model):
76
77
  messages: list[ModelMessage],
77
78
  model_settings: ModelSettings | None,
78
79
  model_request_parameters: ModelRequestParameters,
80
+ run_context: RunContext[Any] | None = None,
79
81
  ) -> AsyncIterator[StreamedResponse]:
80
82
  raise NotImplementedError('MCP Sampling does not support streaming')
81
83
  yield
@@ -11,11 +11,11 @@ import pydantic_core
11
11
  from httpx import Timeout
12
12
  from typing_extensions import assert_never
13
13
 
14
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
15
- from pydantic_ai.exceptions import UserError
16
-
17
14
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
15
+ from .._run_context import RunContext
16
+ from .._thinking_part import split_content_into_text_and_thinking
18
17
  from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
18
+ from ..exceptions import UserError
19
19
  from ..messages import (
20
20
  BinaryContent,
21
21
  BuiltinToolCallPart,
@@ -176,6 +176,7 @@ class MistralModel(Model):
176
176
  messages: list[ModelMessage],
177
177
  model_settings: ModelSettings | None,
178
178
  model_request_parameters: ModelRequestParameters,
179
+ run_context: RunContext[Any] | None = None,
179
180
  ) -> AsyncIterator[StreamedResponse]:
180
181
  """Make a streaming request to the model from Pydantic AI call."""
181
182
  check_allow_model_requests()
@@ -183,7 +184,7 @@ class MistralModel(Model):
183
184
  messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters
184
185
  )
185
186
  async with response:
186
- yield await self._process_streamed_response(model_request_parameters.output_tools, response)
187
+ yield await self._process_streamed_response(response, model_request_parameters)
187
188
 
188
189
  @property
189
190
  def model_name(self) -> MistralModelName:
@@ -246,11 +247,7 @@ class MistralModel(Model):
246
247
  if model_request_parameters.builtin_tools:
247
248
  raise UserError('Mistral does not support built-in tools')
248
249
 
249
- if (
250
- model_request_parameters.output_tools
251
- and model_request_parameters.function_tools
252
- or model_request_parameters.function_tools
253
- ):
250
+ if model_request_parameters.function_tools:
254
251
  # Function Calling
255
252
  response = await self.client.chat.stream_async(
256
253
  model=str(self._model_name),
@@ -318,16 +315,13 @@ class MistralModel(Model):
318
315
 
319
316
  Returns None if both function_tools and output_tools are empty.
320
317
  """
321
- all_tools: list[ToolDefinition] = (
322
- model_request_parameters.function_tools + model_request_parameters.output_tools
323
- )
324
318
  tools = [
325
319
  MistralTool(
326
320
  function=MistralFunction(
327
321
  name=r.name, parameters=r.parameters_json_schema, description=r.description or ''
328
322
  )
329
323
  )
330
- for r in all_tools
324
+ for r in model_request_parameters.tool_defs.values()
331
325
  ]
332
326
  return tools if tools else None
333
327
 
@@ -359,8 +353,8 @@ class MistralModel(Model):
359
353
 
360
354
  async def _process_streamed_response(
361
355
  self,
362
- output_tools: list[ToolDefinition],
363
356
  response: MistralEventStreamAsync[MistralCompletionEvent],
357
+ model_request_parameters: ModelRequestParameters,
364
358
  ) -> StreamedResponse:
365
359
  """Process a streamed response, and prepare a streaming response to return."""
366
360
  peekable_response = _utils.PeekableAsyncStream(response)
@@ -376,10 +370,10 @@ class MistralModel(Model):
376
370
  timestamp = _now_utc()
377
371
 
378
372
  return MistralStreamedResponse(
373
+ model_request_parameters=model_request_parameters,
379
374
  _response=peekable_response,
380
375
  _model_name=self._model_name,
381
376
  _timestamp=timestamp,
382
- _output_tools={c.name: c for c in output_tools},
383
377
  )
384
378
 
385
379
  @staticmethod
@@ -586,7 +580,6 @@ class MistralStreamedResponse(StreamedResponse):
586
580
  _model_name: MistralModelName
587
581
  _response: AsyncIterable[MistralCompletionEvent]
588
582
  _timestamp: datetime
589
- _output_tools: dict[str, ToolDefinition]
590
583
 
591
584
  _delta_content: str = field(default='', init=False)
592
585
 
@@ -605,10 +598,11 @@ class MistralStreamedResponse(StreamedResponse):
605
598
  text = _map_content(content)
606
599
  if text:
607
600
  # Attempt to produce an output tool call from the received text
608
- if self._output_tools:
601
+ output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
602
+ if output_tools:
609
603
  self._delta_content += text
610
604
  # TODO: Port to native "manual JSON" mode
611
- maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools)
605
+ maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, output_tools)
612
606
  if maybe_tool_call_part:
613
607
  yield self._parts_manager.handle_tool_call_part(
614
608
  vendor_part_id='output',