pydantic-ai-slim 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (63) hide show
  1. pydantic_ai/__init__.py +5 -0
  2. pydantic_ai/_a2a.py +6 -4
  3. pydantic_ai/_agent_graph.py +32 -32
  4. pydantic_ai/_cli.py +3 -3
  5. pydantic_ai/_output.py +8 -0
  6. pydantic_ai/_tool_manager.py +3 -0
  7. pydantic_ai/_utils.py +7 -1
  8. pydantic_ai/ag_ui.py +25 -14
  9. pydantic_ai/{agent.py → agent/__init__.py} +217 -1026
  10. pydantic_ai/agent/abstract.py +942 -0
  11. pydantic_ai/agent/wrapper.py +227 -0
  12. pydantic_ai/builtin_tools.py +105 -0
  13. pydantic_ai/direct.py +9 -9
  14. pydantic_ai/durable_exec/__init__.py +0 -0
  15. pydantic_ai/durable_exec/temporal/__init__.py +83 -0
  16. pydantic_ai/durable_exec/temporal/_agent.py +699 -0
  17. pydantic_ai/durable_exec/temporal/_function_toolset.py +92 -0
  18. pydantic_ai/durable_exec/temporal/_logfire.py +48 -0
  19. pydantic_ai/durable_exec/temporal/_mcp_server.py +145 -0
  20. pydantic_ai/durable_exec/temporal/_model.py +168 -0
  21. pydantic_ai/durable_exec/temporal/_run_context.py +50 -0
  22. pydantic_ai/durable_exec/temporal/_toolset.py +77 -0
  23. pydantic_ai/ext/aci.py +10 -9
  24. pydantic_ai/ext/langchain.py +4 -2
  25. pydantic_ai/mcp.py +203 -75
  26. pydantic_ai/messages.py +75 -13
  27. pydantic_ai/models/__init__.py +66 -8
  28. pydantic_ai/models/anthropic.py +135 -18
  29. pydantic_ai/models/bedrock.py +16 -5
  30. pydantic_ai/models/cohere.py +11 -4
  31. pydantic_ai/models/fallback.py +4 -2
  32. pydantic_ai/models/function.py +18 -4
  33. pydantic_ai/models/gemini.py +20 -9
  34. pydantic_ai/models/google.py +53 -15
  35. pydantic_ai/models/groq.py +47 -11
  36. pydantic_ai/models/huggingface.py +26 -11
  37. pydantic_ai/models/instrumented.py +3 -1
  38. pydantic_ai/models/mcp_sampling.py +3 -1
  39. pydantic_ai/models/mistral.py +27 -17
  40. pydantic_ai/models/openai.py +97 -33
  41. pydantic_ai/models/test.py +12 -0
  42. pydantic_ai/models/wrapper.py +6 -2
  43. pydantic_ai/profiles/groq.py +23 -0
  44. pydantic_ai/profiles/openai.py +1 -1
  45. pydantic_ai/providers/google.py +7 -7
  46. pydantic_ai/providers/groq.py +2 -0
  47. pydantic_ai/result.py +21 -55
  48. pydantic_ai/run.py +357 -0
  49. pydantic_ai/tools.py +0 -1
  50. pydantic_ai/toolsets/__init__.py +2 -0
  51. pydantic_ai/toolsets/_dynamic.py +87 -0
  52. pydantic_ai/toolsets/abstract.py +23 -3
  53. pydantic_ai/toolsets/combined.py +19 -4
  54. pydantic_ai/toolsets/deferred.py +10 -2
  55. pydantic_ai/toolsets/function.py +23 -8
  56. pydantic_ai/toolsets/prefixed.py +4 -0
  57. pydantic_ai/toolsets/wrapper.py +14 -1
  58. {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/METADATA +7 -5
  59. pydantic_ai_slim-0.7.0.dist-info/RECORD +115 -0
  60. pydantic_ai_slim-0.6.1.dist-info/RECORD +0 -100
  61. {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/WHEEL +0 -0
  62. {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/entry_points.txt +0 -0
  63. {pydantic_ai_slim-0.6.1.dist-info → pydantic_ai_slim-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,13 +13,14 @@ import pydantic
13
13
  from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
14
14
  from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
15
15
 
16
- from pydantic_ai.providers import Provider, infer_provider
17
-
18
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
19
17
  from .._output import OutputObjectDefinition
18
+ from .._run_context import RunContext
20
19
  from ..exceptions import UserError
21
20
  from ..messages import (
22
21
  BinaryContent,
22
+ BuiltinToolCallPart,
23
+ BuiltinToolReturnPart,
23
24
  FileUrl,
24
25
  ModelMessage,
25
26
  ModelRequest,
@@ -36,6 +37,7 @@ from ..messages import (
36
37
  VideoUrl,
37
38
  )
38
39
  from ..profiles import ModelProfileSpec
40
+ from ..providers import Provider, infer_provider
39
41
  from ..settings import ModelSettings
40
42
  from ..tools import ToolDefinition
41
43
  from . import (
@@ -165,12 +167,13 @@ class GeminiModel(Model):
165
167
  messages: list[ModelMessage],
166
168
  model_settings: ModelSettings | None,
167
169
  model_request_parameters: ModelRequestParameters,
170
+ run_context: RunContext[Any] | None = None,
168
171
  ) -> AsyncIterator[StreamedResponse]:
169
172
  check_allow_model_requests()
170
173
  async with self._make_request(
171
174
  messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters
172
175
  ) as http_response:
173
- yield await self._process_streamed_response(http_response)
176
+ yield await self._process_streamed_response(http_response, model_request_parameters)
174
177
 
175
178
  @property
176
179
  def model_name(self) -> GeminiModelName:
@@ -183,9 +186,7 @@ class GeminiModel(Model):
183
186
  return self._system
184
187
 
185
188
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
186
- tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools]
187
- if model_request_parameters.output_tools:
188
- tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
189
+ tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
189
190
  return _GeminiTools(function_declarations=tools) if tools else None
190
191
 
191
192
  def _get_tool_config(
@@ -286,7 +287,9 @@ class GeminiModel(Model):
286
287
  vendor_details=vendor_details,
287
288
  )
288
289
 
289
- async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
290
+ async def _process_streamed_response(
291
+ self, http_response: HTTPResponse, model_request_parameters: ModelRequestParameters
292
+ ) -> StreamedResponse:
290
293
  """Process a streamed response, and prepare a streaming response to return."""
291
294
  aiter_bytes = http_response.aiter_bytes()
292
295
  start_response: _GeminiResponse | None = None
@@ -307,7 +310,12 @@ class GeminiModel(Model):
307
310
  if start_response is None:
308
311
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
309
312
 
310
- return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
313
+ return GeminiStreamedResponse(
314
+ model_request_parameters=model_request_parameters,
315
+ _model_name=self._model_name,
316
+ _content=content,
317
+ _stream=aiter_bytes,
318
+ )
311
319
 
312
320
  async def _message_to_gemini_content(
313
321
  self, messages: list[ModelMessage]
@@ -610,6 +618,9 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
610
618
  elif isinstance(item, TextPart):
611
619
  if item.content:
612
620
  parts.append(_GeminiTextPart(text=item.content))
621
+ elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
622
+ # This is currently never returned from gemini
623
+ pass
613
624
  else:
614
625
  assert_never(item)
615
626
  return _GeminiContent(role='model', parts=parts)
@@ -867,7 +878,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
867
878
  metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
868
879
  suffix = key.removesuffix('_details')
869
880
  for detail in metadata_details:
870
- details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
881
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
871
882
 
872
883
  return usage.Usage(
873
884
  request_tokens=metadata.get('prompt_token_count', 0),
@@ -12,9 +12,13 @@ from typing_extensions import assert_never
12
12
 
13
13
  from .. import UnexpectedModelBehavior, _utils, usage
14
14
  from .._output import OutputObjectDefinition
15
+ from .._run_context import RunContext
16
+ from ..builtin_tools import CodeExecutionTool, WebSearchTool
15
17
  from ..exceptions import UserError
16
18
  from ..messages import (
17
19
  BinaryContent,
20
+ BuiltinToolCallPart,
21
+ BuiltinToolReturnPart,
18
22
  FileUrl,
19
23
  ModelMessage,
20
24
  ModelRequest,
@@ -44,22 +48,25 @@ from . import (
44
48
  )
45
49
 
46
50
  try:
47
- from google import genai
51
+ from google.genai import Client
48
52
  from google.genai.types import (
49
53
  ContentDict,
50
54
  ContentUnionDict,
55
+ ExecutableCodeDict,
51
56
  FunctionCallDict,
52
57
  FunctionCallingConfigDict,
53
58
  FunctionCallingConfigMode,
54
59
  FunctionDeclarationDict,
55
60
  GenerateContentConfigDict,
56
61
  GenerateContentResponse,
62
+ GoogleSearchDict,
57
63
  HttpOptionsDict,
58
64
  MediaResolution,
59
65
  Part,
60
66
  PartDict,
61
67
  SafetySettingDict,
62
68
  ThinkingConfigDict,
69
+ ToolCodeExecutionDict,
63
70
  ToolConfigDict,
64
71
  ToolDict,
65
72
  ToolListUnionDict,
@@ -130,10 +137,10 @@ class GoogleModel(Model):
130
137
  Apart from `__init__`, all methods are private or match those of the base class.
131
138
  """
132
139
 
133
- client: genai.Client = field(repr=False)
140
+ client: Client = field(repr=False)
134
141
 
135
142
  _model_name: GoogleModelName = field(repr=False)
136
- _provider: Provider[genai.Client] = field(repr=False)
143
+ _provider: Provider[Client] = field(repr=False)
137
144
  _url: str | None = field(repr=False)
138
145
  _system: str = field(default='google', repr=False)
139
146
 
@@ -141,7 +148,7 @@ class GoogleModel(Model):
141
148
  self,
142
149
  model_name: GoogleModelName,
143
150
  *,
144
- provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla',
151
+ provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
145
152
  profile: ModelProfileSpec | None = None,
146
153
  settings: ModelSettings | None = None,
147
154
  ):
@@ -187,11 +194,12 @@ class GoogleModel(Model):
187
194
  messages: list[ModelMessage],
188
195
  model_settings: ModelSettings | None,
189
196
  model_request_parameters: ModelRequestParameters,
197
+ run_context: RunContext[Any] | None = None,
190
198
  ) -> AsyncIterator[StreamedResponse]:
191
199
  check_allow_model_requests()
192
200
  model_settings = cast(GoogleModelSettings, model_settings or {})
193
201
  response = await self._generate_content(messages, True, model_settings, model_request_parameters)
194
- yield await self._process_streamed_response(response) # type: ignore
202
+ yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
195
203
 
196
204
  @property
197
205
  def model_name(self) -> GoogleModelName:
@@ -206,13 +214,17 @@ class GoogleModel(Model):
206
214
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
207
215
  tools: list[ToolDict] = [
208
216
  ToolDict(function_declarations=[_function_declaration_from_tool(t)])
209
- for t in model_request_parameters.function_tools
217
+ for t in model_request_parameters.tool_defs.values()
210
218
  ]
211
- if model_request_parameters.output_tools:
212
- tools += [
213
- ToolDict(function_declarations=[_function_declaration_from_tool(t)])
214
- for t in model_request_parameters.output_tools
215
- ]
219
+ for tool in model_request_parameters.builtin_tools:
220
+ if isinstance(tool, WebSearchTool):
221
+ tools.append(ToolDict(google_search=GoogleSearchDict()))
222
+ elif isinstance(tool, CodeExecutionTool): # pragma: no branch
223
+ tools.append(ToolDict(code_execution=ToolCodeExecutionDict()))
224
+ else: # pragma: no cover
225
+ raise UserError(
226
+ f'`{tool.__class__.__name__}` is not supported by `GoogleModel`. If it should be, please file an issue.'
227
+ )
216
228
  return tools or None
217
229
 
218
230
  def _get_tool_config(
@@ -325,7 +337,9 @@ class GoogleModel(Model):
325
337
  parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
326
338
  )
327
339
 
328
- async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
340
+ async def _process_streamed_response(
341
+ self, response: AsyncIterator[GenerateContentResponse], model_request_parameters: ModelRequestParameters
342
+ ) -> StreamedResponse:
329
343
  """Process a streamed response, and prepare a streaming response to return."""
330
344
  peekable_response = _utils.PeekableAsyncStream(response)
331
345
  first_chunk = await peekable_response.peek()
@@ -333,6 +347,7 @@ class GoogleModel(Model):
333
347
  raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover
334
348
 
335
349
  return GeminiStreamedResponse(
350
+ model_request_parameters=model_request_parameters,
336
351
  _model_name=self._model_name,
337
352
  _response=peekable_response,
338
353
  _timestamp=first_chunk.create_time or _utils.now_utc(),
@@ -499,6 +514,14 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
499
514
  # please open an issue. The below code is the code to send thinking to the provider.
500
515
  # parts.append({'text': item.content, 'thought': True})
501
516
  pass
517
+ elif isinstance(item, BuiltinToolCallPart):
518
+ if item.provider_name == 'google':
519
+ if item.tool_name == 'code_execution': # pragma: no branch
520
+ parts.append({'executable_code': cast(ExecutableCodeDict, item.args)})
521
+ elif isinstance(item, BuiltinToolReturnPart):
522
+ if item.provider_name == 'google':
523
+ if item.tool_name == 'code_execution': # pragma: no branch
524
+ parts.append({'code_execution_result': item.content})
502
525
  else:
503
526
  assert_never(item)
504
527
  return ContentDict(role='model', parts=parts)
@@ -513,7 +536,22 @@ def _process_response_from_parts(
513
536
  ) -> ModelResponse:
514
537
  items: list[ModelResponsePart] = []
515
538
  for part in parts:
516
- if part.text is not None:
539
+ if part.executable_code is not None:
540
+ items.append(
541
+ BuiltinToolCallPart(
542
+ provider_name='google', args=part.executable_code.model_dump(), tool_name='code_execution'
543
+ )
544
+ )
545
+ elif part.code_execution_result is not None:
546
+ items.append(
547
+ BuiltinToolReturnPart(
548
+ provider_name='google',
549
+ tool_name='code_execution',
550
+ content=part.code_execution_result,
551
+ tool_call_id='not_provided',
552
+ )
553
+ )
554
+ elif part.text is not None:
517
555
  if part.thought:
518
556
  items.append(ThinkingPart(content=part.text))
519
557
  else:
@@ -563,13 +601,13 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
563
601
  details['thoughts_tokens'] = thoughts_token_count
564
602
 
565
603
  if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
566
- details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
604
+ details['tool_use_prompt_tokens'] = tool_use_prompt_token_count
567
605
 
568
606
  for key, metadata_details in metadata.items():
569
607
  if key.endswith('_details') and metadata_details:
570
608
  suffix = key.removesuffix('_details')
571
609
  for detail in metadata_details:
572
- details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
610
+ details[f'{detail["modality"].lower()}_{suffix}'] = detail.get('token_count', 0)
573
611
 
574
612
  return usage.Usage(
575
613
  request_tokens=metadata.get('prompt_token_count', 0),
@@ -5,16 +5,20 @@ from collections.abc import AsyncIterable, AsyncIterator, Iterable
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime
8
- from typing import Literal, Union, cast, overload
8
+ from typing import Any, Literal, Union, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
12
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
13
-
14
12
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15
- from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
13
+ from .._run_context import RunContext
14
+ from .._thinking_part import split_content_into_text_and_thinking
15
+ from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
16
+ from ..builtin_tools import WebSearchTool
17
+ from ..exceptions import UserError
16
18
  from ..messages import (
17
19
  BinaryContent,
20
+ BuiltinToolCallPart,
21
+ BuiltinToolReturnPart,
18
22
  DocumentUrl,
19
23
  ImageUrl,
20
24
  ModelMessage,
@@ -31,6 +35,7 @@ from ..messages import (
31
35
  UserPromptPart,
32
36
  )
33
37
  from ..profiles import ModelProfile, ModelProfileSpec
38
+ from ..profiles.groq import GroqModelProfile
34
39
  from ..providers import Provider, infer_provider
35
40
  from ..settings import ModelSettings
36
41
  from ..tools import ToolDefinition
@@ -166,13 +171,14 @@ class GroqModel(Model):
166
171
  messages: list[ModelMessage],
167
172
  model_settings: ModelSettings | None,
168
173
  model_request_parameters: ModelRequestParameters,
174
+ run_context: RunContext[Any] | None = None,
169
175
  ) -> AsyncIterator[StreamedResponse]:
170
176
  check_allow_model_requests()
171
177
  response = await self._completions_create(
172
178
  messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters
173
179
  )
174
180
  async with response:
175
- yield await self._process_streamed_response(response)
181
+ yield await self._process_streamed_response(response, model_request_parameters)
176
182
 
177
183
  @property
178
184
  def model_name(self) -> GroqModelName:
@@ -212,7 +218,7 @@ class GroqModel(Model):
212
218
  model_request_parameters: ModelRequestParameters,
213
219
  ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
214
220
  tools = self._get_tools(model_request_parameters)
215
- # standalone function to make it easier to override
221
+ tools += self._get_builtin_tools(model_request_parameters)
216
222
  if not tools:
217
223
  tool_choice: Literal['none', 'required', 'auto'] | None = None
218
224
  elif not model_request_parameters.allow_text_output:
@@ -226,7 +232,7 @@ class GroqModel(Model):
226
232
  extra_headers = model_settings.get('extra_headers', {})
227
233
  extra_headers.setdefault('User-Agent', get_user_agent())
228
234
  return await self.client.chat.completions.create(
229
- model=str(self._model_name),
235
+ model=self._model_name,
230
236
  messages=groq_messages,
231
237
  n=1,
232
238
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
@@ -256,6 +262,19 @@ class GroqModel(Model):
256
262
  timestamp = number_to_datetime(response.created)
257
263
  choice = response.choices[0]
258
264
  items: list[ModelResponsePart] = []
265
+ if choice.message.executed_tools:
266
+ for tool in choice.message.executed_tools:
267
+ tool_call_id = generate_tool_call_id()
268
+ items.append(
269
+ BuiltinToolCallPart(
270
+ tool_name=tool.type, args=tool.arguments, provider_name='groq', tool_call_id=tool_call_id
271
+ )
272
+ )
273
+ items.append(
274
+ BuiltinToolReturnPart(
275
+ provider_name='groq', tool_name=tool.type, content=tool.output, tool_call_id=tool_call_id
276
+ )
277
+ )
259
278
  # NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
260
279
  if choice.message.reasoning is not None:
261
280
  items.append(ThinkingPart(content=choice.message.reasoning))
@@ -269,7 +288,9 @@ class GroqModel(Model):
269
288
  items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
270
289
  )
271
290
 
272
- async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
291
+ async def _process_streamed_response(
292
+ self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters
293
+ ) -> GroqStreamedResponse:
273
294
  """Process a streamed response, and prepare a streaming response to return."""
274
295
  peekable_response = _utils.PeekableAsyncStream(response)
275
296
  first_chunk = await peekable_response.peek()
@@ -279,6 +300,7 @@ class GroqModel(Model):
279
300
  )
280
301
 
281
302
  return GroqStreamedResponse(
303
+ model_request_parameters=model_request_parameters,
282
304
  _response=peekable_response,
283
305
  _model_name=self._model_name,
284
306
  _model_profile=self.profile,
@@ -286,9 +308,20 @@ class GroqModel(Model):
286
308
  )
287
309
 
288
310
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
289
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
290
- if model_request_parameters.output_tools:
291
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
311
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
312
+
313
+ def _get_builtin_tools(
314
+ self, model_request_parameters: ModelRequestParameters
315
+ ) -> list[chat.ChatCompletionToolParam]:
316
+ tools: list[chat.ChatCompletionToolParam] = []
317
+ for tool in model_request_parameters.builtin_tools:
318
+ if isinstance(tool, WebSearchTool):
319
+ if not GroqModelProfile.from_profile(self.profile).groq_always_has_web_search_builtin_tool:
320
+ raise UserError('`WebSearchTool` is not supported by Groq') # pragma: no cover
321
+ else:
322
+ raise UserError(
323
+ f'`{tool.__class__.__name__}` is not supported by `GroqModel`. If it should be, please file an issue.'
324
+ )
292
325
  return tools
293
326
 
294
327
  def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
@@ -308,6 +341,9 @@ class GroqModel(Model):
308
341
  elif isinstance(item, ThinkingPart):
309
342
  # Skip thinking parts when mapping to Groq messages
310
343
  continue
344
+ elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
345
+ # This is currently never returned from groq
346
+ pass
311
347
  else:
312
348
  assert_never(item)
313
349
  message_param = chat.ChatCompletionAssistantMessageParam(role='assistant')
@@ -5,18 +5,20 @@ from collections.abc import AsyncIterable, AsyncIterator
5
5
  from contextlib import asynccontextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from datetime import datetime, timezone
8
- from typing import Literal, Union, cast, overload
8
+ from typing import Any, Literal, Union, cast, overload
9
9
 
10
10
  from typing_extensions import assert_never
11
11
 
12
- from pydantic_ai._thinking_part import split_content_into_text_and_thinking
13
- from pydantic_ai.providers import Provider, infer_provider
14
-
15
12
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
13
+ from .._run_context import RunContext
14
+ from .._thinking_part import split_content_into_text_and_thinking
16
15
  from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
16
+ from ..exceptions import UserError
17
17
  from ..messages import (
18
18
  AudioUrl,
19
19
  BinaryContent,
20
+ BuiltinToolCallPart,
21
+ BuiltinToolReturnPart,
20
22
  DocumentUrl,
21
23
  ImageUrl,
22
24
  ModelMessage,
@@ -34,9 +36,15 @@ from ..messages import (
34
36
  VideoUrl,
35
37
  )
36
38
  from ..profiles import ModelProfile
39
+ from ..providers import Provider, infer_provider
37
40
  from ..settings import ModelSettings
38
41
  from ..tools import ToolDefinition
39
- from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
42
+ from . import (
43
+ Model,
44
+ ModelRequestParameters,
45
+ StreamedResponse,
46
+ check_allow_model_requests,
47
+ )
40
48
 
41
49
  try:
42
50
  import aiohttp
@@ -147,12 +155,13 @@ class HuggingFaceModel(Model):
147
155
  messages: list[ModelMessage],
148
156
  model_settings: ModelSettings | None,
149
157
  model_request_parameters: ModelRequestParameters,
158
+ run_context: RunContext[Any] | None = None,
150
159
  ) -> AsyncIterator[StreamedResponse]:
151
160
  check_allow_model_requests()
152
161
  response = await self._completions_create(
153
162
  messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
154
163
  )
155
- yield await self._process_streamed_response(response)
164
+ yield await self._process_streamed_response(response, model_request_parameters)
156
165
 
157
166
  @property
158
167
  def model_name(self) -> HuggingFaceModelName:
@@ -198,6 +207,9 @@ class HuggingFaceModel(Model):
198
207
  else:
199
208
  tool_choice = 'auto'
200
209
 
210
+ if model_request_parameters.builtin_tools:
211
+ raise UserError('HuggingFace does not support built-in tools')
212
+
201
213
  hf_messages = await self._map_messages(messages)
202
214
 
203
215
  try:
@@ -257,7 +269,9 @@ class HuggingFaceModel(Model):
257
269
  vendor_id=response.id,
258
270
  )
259
271
 
260
- async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
272
+ async def _process_streamed_response(
273
+ self, response: AsyncIterable[ChatCompletionStreamOutput], model_request_parameters: ModelRequestParameters
274
+ ) -> StreamedResponse:
261
275
  """Process a streamed response, and prepare a streaming response to return."""
262
276
  peekable_response = _utils.PeekableAsyncStream(response)
263
277
  first_chunk = await peekable_response.peek()
@@ -267,6 +281,7 @@ class HuggingFaceModel(Model):
267
281
  )
268
282
 
269
283
  return HuggingFaceStreamedResponse(
284
+ model_request_parameters=model_request_parameters,
270
285
  _model_name=self._model_name,
271
286
  _model_profile=self.profile,
272
287
  _response=peekable_response,
@@ -274,10 +289,7 @@ class HuggingFaceModel(Model):
274
289
  )
275
290
 
276
291
  def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
277
- tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
278
- if model_request_parameters.output_tools:
279
- tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
280
- return tools
292
+ return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
281
293
 
282
294
  async def _map_messages(
283
295
  self, messages: list[ModelMessage]
@@ -301,6 +313,9 @@ class HuggingFaceModel(Model):
301
313
  # please open an issue. The below code is the code to send thinking to the provider.
302
314
  # texts.append(f'<think>\n{item.content}\n</think>')
303
315
  pass
316
+ elif isinstance(item, (BuiltinToolCallPart, BuiltinToolReturnPart)): # pragma: no cover
317
+ # This is currently never returned from huggingface
318
+ pass
304
319
  else:
305
320
  assert_never(item)
306
321
  message_param = ChatCompletionInputMessage(role='assistant') # type: ignore
@@ -18,6 +18,7 @@ from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provide
18
18
  from opentelemetry.util.types import AttributeValue
19
19
  from pydantic import TypeAdapter
20
20
 
21
+ from .._run_context import RunContext
21
22
  from ..messages import ModelMessage, ModelRequest, ModelResponse
22
23
  from ..settings import ModelSettings
23
24
  from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
@@ -218,12 +219,13 @@ class InstrumentedModel(WrapperModel):
218
219
  messages: list[ModelMessage],
219
220
  model_settings: ModelSettings | None,
220
221
  model_request_parameters: ModelRequestParameters,
222
+ run_context: RunContext[Any] | None = None,
221
223
  ) -> AsyncIterator[StreamedResponse]:
222
224
  with self._instrument(messages, model_settings, model_request_parameters) as finish:
223
225
  response_stream: StreamedResponse | None = None
224
226
  try:
225
227
  async with super().request_stream(
226
- messages, model_settings, model_request_parameters
228
+ messages, model_settings, model_request_parameters, run_context
227
229
  ) as response_stream:
228
230
  yield response_stream
229
231
  finally:
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
3
3
  from collections.abc import AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, cast
6
+ from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  from .. import _mcp, exceptions, usage
9
+ from .._run_context import RunContext
9
10
  from ..messages import ModelMessage, ModelResponse
10
11
  from ..settings import ModelSettings
11
12
  from . import Model, ModelRequestParameters, StreamedResponse
@@ -76,6 +77,7 @@ class MCPSamplingModel(Model):
76
77
  messages: list[ModelMessage],
77
78
  model_settings: ModelSettings | None,
78
79
  model_request_parameters: ModelRequestParameters,
80
+ run_context: RunContext[Any] | None = None,
79
81
  ) -> AsyncIterator[StreamedResponse]:
80
82
  raise NotImplementedError('MCP Sampling does not support streaming')
81
83
  yield