pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +219 -315
- pydantic_ai/_cli.py +9 -7
- pydantic_ai/_output.py +296 -226
- pydantic_ai/_parts_manager.py +2 -2
- pydantic_ai/_run_context.py +8 -14
- pydantic_ai/_tool_manager.py +190 -0
- pydantic_ai/_utils.py +18 -1
- pydantic_ai/ag_ui.py +675 -0
- pydantic_ai/agent.py +369 -155
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +14 -2
- pydantic_ai/ext/aci.py +12 -3
- pydantic_ai/ext/langchain.py +9 -1
- pydantic_ai/mcp.py +147 -84
- pydantic_ai/messages.py +19 -9
- pydantic_ai/models/__init__.py +43 -19
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/function.py +50 -24
- pydantic_ai/models/gemini.py +3 -11
- pydantic_ai/models/google.py +3 -12
- pydantic_ai/models/groq.py +2 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +3 -3
- pydantic_ai/models/openai.py +5 -5
- pydantic_ai/output.py +21 -7
- pydantic_ai/profiles/google.py +1 -1
- pydantic_ai/profiles/moonshotai.py +8 -0
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/grok.py +13 -1
- pydantic_ai/providers/groq.py +2 -0
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +57 -33
- pydantic_ai/tools.py +26 -119
- pydantic_ai/toolsets/__init__.py +22 -0
- pydantic_ai/toolsets/abstract.py +155 -0
- pydantic_ai/toolsets/combined.py +88 -0
- pydantic_ai/toolsets/deferred.py +38 -0
- pydantic_ai/toolsets/filtered.py +24 -0
- pydantic_ai/toolsets/function.py +238 -0
- pydantic_ai/toolsets/prefixed.py +37 -0
- pydantic_ai/toolsets/prepared.py +36 -0
- pydantic_ai/toolsets/renamed.py +42 -0
- pydantic_ai/toolsets/wrapper.py +37 -0
- pydantic_ai/usage.py +14 -8
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/METADATA +13 -8
- pydantic_ai_slim-0.4.4.dist-info/RECORD +98 -0
- pydantic_ai_slim-0.4.2.dist-info/RECORD +0 -83
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from typing import Literal, Union, cast, overload
|
|
9
|
+
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
14
|
+
|
|
15
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
16
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
|
|
17
|
+
from ..messages import (
|
|
18
|
+
AudioUrl,
|
|
19
|
+
BinaryContent,
|
|
20
|
+
DocumentUrl,
|
|
21
|
+
ImageUrl,
|
|
22
|
+
ModelMessage,
|
|
23
|
+
ModelRequest,
|
|
24
|
+
ModelResponse,
|
|
25
|
+
ModelResponsePart,
|
|
26
|
+
ModelResponseStreamEvent,
|
|
27
|
+
RetryPromptPart,
|
|
28
|
+
SystemPromptPart,
|
|
29
|
+
TextPart,
|
|
30
|
+
ThinkingPart,
|
|
31
|
+
ToolCallPart,
|
|
32
|
+
ToolReturnPart,
|
|
33
|
+
UserPromptPart,
|
|
34
|
+
VideoUrl,
|
|
35
|
+
)
|
|
36
|
+
from ..settings import ModelSettings
|
|
37
|
+
from ..tools import ToolDefinition
|
|
38
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
import aiohttp
|
|
42
|
+
from huggingface_hub import (
|
|
43
|
+
AsyncInferenceClient,
|
|
44
|
+
ChatCompletionInputMessage,
|
|
45
|
+
ChatCompletionInputMessageChunk,
|
|
46
|
+
ChatCompletionInputTool,
|
|
47
|
+
ChatCompletionInputToolCall,
|
|
48
|
+
ChatCompletionInputURL,
|
|
49
|
+
ChatCompletionOutput,
|
|
50
|
+
ChatCompletionOutputMessage,
|
|
51
|
+
ChatCompletionStreamOutput,
|
|
52
|
+
)
|
|
53
|
+
from huggingface_hub.errors import HfHubHTTPError
|
|
54
|
+
|
|
55
|
+
except ImportError as _import_error:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
|
|
58
|
+
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
|
|
59
|
+
) from _import_error
|
|
60
|
+
|
|
61
|
+
__all__ = (
|
|
62
|
+
'HuggingFaceModel',
|
|
63
|
+
'HuggingFaceModelSettings',
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
HFSystemPromptRole = Literal['system', 'user']
|
|
68
|
+
|
|
69
|
+
LatestHuggingFaceModelNames = Literal[
|
|
70
|
+
'deepseek-ai/DeepSeek-R1',
|
|
71
|
+
'meta-llama/Llama-3.3-70B-Instruct',
|
|
72
|
+
'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
|
|
73
|
+
'meta-llama/Llama-4-Scout-17B-16E-Instruct',
|
|
74
|
+
'Qwen/QwQ-32B',
|
|
75
|
+
'Qwen/Qwen2.5-72B-Instruct',
|
|
76
|
+
'Qwen/Qwen3-235B-A22B',
|
|
77
|
+
'Qwen/Qwen3-32B',
|
|
78
|
+
]
|
|
79
|
+
"""Latest Hugging Face models."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
|
|
83
|
+
"""Possible Hugging Face model names.
|
|
84
|
+
|
|
85
|
+
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class HuggingFaceModelSettings(ModelSettings, total=False):
|
|
90
|
+
"""Settings used for a Hugging Face model request."""
|
|
91
|
+
|
|
92
|
+
# ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
93
|
+
# This class is a placeholder for any future huggingface-specific settings
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(init=False)
|
|
97
|
+
class HuggingFaceModel(Model):
|
|
98
|
+
"""A model that uses Hugging Face Inference Providers.
|
|
99
|
+
|
|
100
|
+
Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API.
|
|
101
|
+
|
|
102
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
client: AsyncInferenceClient = field(repr=False)
|
|
106
|
+
|
|
107
|
+
_model_name: str = field(repr=False)
|
|
108
|
+
_system: str = field(default='huggingface', repr=False)
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
model_name: str,
|
|
113
|
+
*,
|
|
114
|
+
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
|
115
|
+
):
|
|
116
|
+
"""Initialize a Hugging Face model.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
120
|
+
provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
|
|
121
|
+
instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
|
|
122
|
+
"""
|
|
123
|
+
self._model_name = model_name
|
|
124
|
+
self._provider = provider
|
|
125
|
+
if isinstance(provider, str):
|
|
126
|
+
provider = infer_provider(provider)
|
|
127
|
+
self.client = provider.client
|
|
128
|
+
|
|
129
|
+
async def request(
|
|
130
|
+
self,
|
|
131
|
+
messages: list[ModelMessage],
|
|
132
|
+
model_settings: ModelSettings | None,
|
|
133
|
+
model_request_parameters: ModelRequestParameters,
|
|
134
|
+
) -> ModelResponse:
|
|
135
|
+
check_allow_model_requests()
|
|
136
|
+
response = await self._completions_create(
|
|
137
|
+
messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
138
|
+
)
|
|
139
|
+
model_response = self._process_response(response)
|
|
140
|
+
model_response.usage.requests = 1
|
|
141
|
+
return model_response
|
|
142
|
+
|
|
143
|
+
@asynccontextmanager
|
|
144
|
+
async def request_stream(
|
|
145
|
+
self,
|
|
146
|
+
messages: list[ModelMessage],
|
|
147
|
+
model_settings: ModelSettings | None,
|
|
148
|
+
model_request_parameters: ModelRequestParameters,
|
|
149
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
150
|
+
check_allow_model_requests()
|
|
151
|
+
response = await self._completions_create(
|
|
152
|
+
messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
153
|
+
)
|
|
154
|
+
yield await self._process_streamed_response(response)
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def model_name(self) -> HuggingFaceModelName:
|
|
158
|
+
"""The model name."""
|
|
159
|
+
return self._model_name
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def system(self) -> str:
|
|
163
|
+
"""The system / model provider."""
|
|
164
|
+
return self._system
|
|
165
|
+
|
|
166
|
+
@overload
|
|
167
|
+
async def _completions_create(
|
|
168
|
+
self,
|
|
169
|
+
messages: list[ModelMessage],
|
|
170
|
+
stream: Literal[True],
|
|
171
|
+
model_settings: HuggingFaceModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
|
+
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
174
|
+
|
|
175
|
+
@overload
|
|
176
|
+
async def _completions_create(
|
|
177
|
+
self,
|
|
178
|
+
messages: list[ModelMessage],
|
|
179
|
+
stream: Literal[False],
|
|
180
|
+
model_settings: HuggingFaceModelSettings,
|
|
181
|
+
model_request_parameters: ModelRequestParameters,
|
|
182
|
+
) -> ChatCompletionOutput: ...
|
|
183
|
+
|
|
184
|
+
async def _completions_create(
|
|
185
|
+
self,
|
|
186
|
+
messages: list[ModelMessage],
|
|
187
|
+
stream: bool,
|
|
188
|
+
model_settings: HuggingFaceModelSettings,
|
|
189
|
+
model_request_parameters: ModelRequestParameters,
|
|
190
|
+
) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]:
|
|
191
|
+
tools = self._get_tools(model_request_parameters)
|
|
192
|
+
|
|
193
|
+
if not tools:
|
|
194
|
+
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
195
|
+
elif not model_request_parameters.allow_text_output:
|
|
196
|
+
tool_choice = 'required'
|
|
197
|
+
else:
|
|
198
|
+
tool_choice = 'auto'
|
|
199
|
+
|
|
200
|
+
hf_messages = await self._map_messages(messages)
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
return await self.client.chat.completions.create( # type: ignore
|
|
204
|
+
model=self._model_name,
|
|
205
|
+
messages=hf_messages, # type: ignore
|
|
206
|
+
tools=tools,
|
|
207
|
+
tool_choice=tool_choice or None,
|
|
208
|
+
stream=stream,
|
|
209
|
+
stop=model_settings.get('stop_sequences', None),
|
|
210
|
+
temperature=model_settings.get('temperature', None),
|
|
211
|
+
top_p=model_settings.get('top_p', None),
|
|
212
|
+
seed=model_settings.get('seed', None),
|
|
213
|
+
presence_penalty=model_settings.get('presence_penalty', None),
|
|
214
|
+
frequency_penalty=model_settings.get('frequency_penalty', None),
|
|
215
|
+
logit_bias=model_settings.get('logit_bias', None), # type: ignore
|
|
216
|
+
logprobs=model_settings.get('logprobs', None),
|
|
217
|
+
top_logprobs=model_settings.get('top_logprobs', None),
|
|
218
|
+
extra_body=model_settings.get('extra_body'), # type: ignore
|
|
219
|
+
)
|
|
220
|
+
except aiohttp.ClientResponseError as e:
|
|
221
|
+
raise ModelHTTPError(
|
|
222
|
+
status_code=e.status,
|
|
223
|
+
model_name=self.model_name,
|
|
224
|
+
body=e.response_error_payload, # type: ignore
|
|
225
|
+
) from e
|
|
226
|
+
except HfHubHTTPError as e:
|
|
227
|
+
raise ModelHTTPError(
|
|
228
|
+
status_code=e.response.status_code,
|
|
229
|
+
model_name=self.model_name,
|
|
230
|
+
body=e.response.content,
|
|
231
|
+
) from e
|
|
232
|
+
|
|
233
|
+
def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
|
|
234
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
235
|
+
if response.created:
|
|
236
|
+
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
237
|
+
else:
|
|
238
|
+
timestamp = _now_utc()
|
|
239
|
+
|
|
240
|
+
choice = response.choices[0]
|
|
241
|
+
content = choice.message.content
|
|
242
|
+
tool_calls = choice.message.tool_calls
|
|
243
|
+
|
|
244
|
+
items: list[ModelResponsePart] = []
|
|
245
|
+
|
|
246
|
+
if content is not None:
|
|
247
|
+
items.extend(split_content_into_text_and_thinking(content))
|
|
248
|
+
if tool_calls is not None:
|
|
249
|
+
for c in tool_calls:
|
|
250
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
251
|
+
return ModelResponse(
|
|
252
|
+
items,
|
|
253
|
+
usage=_map_usage(response),
|
|
254
|
+
model_name=response.model,
|
|
255
|
+
timestamp=timestamp,
|
|
256
|
+
vendor_id=response.id,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
|
|
260
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
261
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
262
|
+
first_chunk = await peekable_response.peek()
|
|
263
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
264
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
265
|
+
'Streamed response ended without content or tool calls'
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return HuggingFaceStreamedResponse(
|
|
269
|
+
_model_name=self._model_name,
|
|
270
|
+
_response=peekable_response,
|
|
271
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
|
|
275
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
276
|
+
if model_request_parameters.output_tools:
|
|
277
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
278
|
+
return tools
|
|
279
|
+
|
|
280
|
+
async def _map_messages(
|
|
281
|
+
self, messages: list[ModelMessage]
|
|
282
|
+
) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
283
|
+
"""Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`."""
|
|
284
|
+
hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = []
|
|
285
|
+
for message in messages:
|
|
286
|
+
if isinstance(message, ModelRequest):
|
|
287
|
+
async for item in self._map_user_message(message):
|
|
288
|
+
hf_messages.append(item)
|
|
289
|
+
elif isinstance(message, ModelResponse):
|
|
290
|
+
texts: list[str] = []
|
|
291
|
+
tool_calls: list[ChatCompletionInputToolCall] = []
|
|
292
|
+
for item in message.parts:
|
|
293
|
+
if isinstance(item, TextPart):
|
|
294
|
+
texts.append(item.content)
|
|
295
|
+
elif isinstance(item, ToolCallPart):
|
|
296
|
+
tool_calls.append(self._map_tool_call(item))
|
|
297
|
+
elif isinstance(item, ThinkingPart):
|
|
298
|
+
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
|
|
299
|
+
# please open an issue. The below code is the code to send thinking to the provider.
|
|
300
|
+
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
301
|
+
pass
|
|
302
|
+
else:
|
|
303
|
+
assert_never(item)
|
|
304
|
+
message_param = ChatCompletionInputMessage(role='assistant') # type: ignore
|
|
305
|
+
if texts:
|
|
306
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
307
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
308
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
309
|
+
if tool_calls:
|
|
310
|
+
message_param['tool_calls'] = tool_calls
|
|
311
|
+
hf_messages.append(message_param)
|
|
312
|
+
else:
|
|
313
|
+
assert_never(message)
|
|
314
|
+
if instructions := self._get_instructions(messages):
|
|
315
|
+
hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
|
|
316
|
+
return hf_messages
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall:
|
|
320
|
+
return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore
|
|
321
|
+
{
|
|
322
|
+
'id': _guard_tool_call_id(t=t),
|
|
323
|
+
'type': 'function',
|
|
324
|
+
'function': {
|
|
325
|
+
'name': t.tool_name,
|
|
326
|
+
'arguments': t.args_as_json_str(),
|
|
327
|
+
},
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
|
|
333
|
+
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
|
|
334
|
+
{
|
|
335
|
+
'type': 'function',
|
|
336
|
+
'function': {
|
|
337
|
+
'name': f.name,
|
|
338
|
+
'description': f.description,
|
|
339
|
+
'parameters': f.parameters_json_schema,
|
|
340
|
+
},
|
|
341
|
+
}
|
|
342
|
+
)
|
|
343
|
+
if f.strict is not None:
|
|
344
|
+
tool_param['function']['strict'] = f.strict
|
|
345
|
+
return tool_param
|
|
346
|
+
|
|
347
|
+
async def _map_user_message(
|
|
348
|
+
self, message: ModelRequest
|
|
349
|
+
) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
350
|
+
for part in message.parts:
|
|
351
|
+
if isinstance(part, SystemPromptPart):
|
|
352
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore
|
|
353
|
+
elif isinstance(part, UserPromptPart):
|
|
354
|
+
yield await self._map_user_prompt(part)
|
|
355
|
+
elif isinstance(part, ToolReturnPart):
|
|
356
|
+
yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore
|
|
357
|
+
{
|
|
358
|
+
'role': 'tool',
|
|
359
|
+
'tool_call_id': _guard_tool_call_id(t=part),
|
|
360
|
+
'content': part.model_response_str(),
|
|
361
|
+
}
|
|
362
|
+
)
|
|
363
|
+
elif isinstance(part, RetryPromptPart):
|
|
364
|
+
if part.tool_name is None:
|
|
365
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
|
|
366
|
+
{'role': 'user', 'content': part.model_response()}
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
|
|
370
|
+
{
|
|
371
|
+
'role': 'tool',
|
|
372
|
+
'tool_call_id': _guard_tool_call_id(t=part),
|
|
373
|
+
'content': part.model_response(),
|
|
374
|
+
}
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
assert_never(part)
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
|
|
381
|
+
content: str | list[ChatCompletionInputMessage]
|
|
382
|
+
if isinstance(part.content, str):
|
|
383
|
+
content = part.content
|
|
384
|
+
else:
|
|
385
|
+
content = []
|
|
386
|
+
for item in part.content:
|
|
387
|
+
if isinstance(item, str):
|
|
388
|
+
content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore
|
|
389
|
+
elif isinstance(item, ImageUrl):
|
|
390
|
+
url = ChatCompletionInputURL(url=item.url) # type: ignore
|
|
391
|
+
content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
|
|
392
|
+
elif isinstance(item, BinaryContent):
|
|
393
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
394
|
+
if item.is_image:
|
|
395
|
+
url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore
|
|
396
|
+
content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
|
|
397
|
+
else: # pragma: no cover
|
|
398
|
+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
399
|
+
elif isinstance(item, AudioUrl):
|
|
400
|
+
raise NotImplementedError('AudioUrl is not supported for Hugging Face')
|
|
401
|
+
elif isinstance(item, DocumentUrl):
|
|
402
|
+
raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
|
|
403
|
+
elif isinstance(item, VideoUrl):
|
|
404
|
+
raise NotImplementedError('VideoUrl is not supported for Hugging Face')
|
|
405
|
+
else:
|
|
406
|
+
assert_never(item)
|
|
407
|
+
return ChatCompletionInputMessage(role='user', content=content) # type: ignore
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
@dataclass
|
|
411
|
+
class HuggingFaceStreamedResponse(StreamedResponse):
|
|
412
|
+
"""Implementation of `StreamedResponse` for Hugging Face models."""
|
|
413
|
+
|
|
414
|
+
_model_name: str
|
|
415
|
+
_response: AsyncIterable[ChatCompletionStreamOutput]
|
|
416
|
+
_timestamp: datetime
|
|
417
|
+
|
|
418
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
419
|
+
async for chunk in self._response:
|
|
420
|
+
self._usage += _map_usage(chunk)
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
choice = chunk.choices[0]
|
|
424
|
+
except IndexError:
|
|
425
|
+
continue
|
|
426
|
+
|
|
427
|
+
# Handle the text part of the response
|
|
428
|
+
content = choice.delta.content
|
|
429
|
+
if content is not None:
|
|
430
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
431
|
+
|
|
432
|
+
for dtc in choice.delta.tool_calls or []:
|
|
433
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
434
|
+
vendor_part_id=dtc.index,
|
|
435
|
+
tool_name=dtc.function and dtc.function.name, # type: ignore
|
|
436
|
+
args=dtc.function and dtc.function.arguments,
|
|
437
|
+
tool_call_id=dtc.id,
|
|
438
|
+
)
|
|
439
|
+
if maybe_event is not None:
|
|
440
|
+
yield maybe_event
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def model_name(self) -> str:
|
|
444
|
+
"""Get the model name of the response."""
|
|
445
|
+
return self._model_name
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def timestamp(self) -> datetime:
|
|
449
|
+
"""Get the timestamp of the response."""
|
|
450
|
+
return self._timestamp
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
|
|
454
|
+
response_usage = response.usage
|
|
455
|
+
if response_usage is None:
|
|
456
|
+
return usage.Usage()
|
|
457
|
+
|
|
458
|
+
return usage.Usage(
|
|
459
|
+
request_tokens=response_usage.prompt_tokens,
|
|
460
|
+
response_tokens=response_usage.completion_tokens,
|
|
461
|
+
total_tokens=response_usage.total_tokens,
|
|
462
|
+
details=None,
|
|
463
|
+
)
|
|
@@ -138,7 +138,7 @@ class InstrumentationSettings:
|
|
|
138
138
|
**tokens_histogram_kwargs,
|
|
139
139
|
explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
|
|
140
140
|
)
|
|
141
|
-
except TypeError:
|
|
141
|
+
except TypeError:
|
|
142
142
|
# Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
|
|
143
143
|
self.tokens_histogram = self.meter.create_histogram(
|
|
144
144
|
**tokens_histogram_kwargs, # pyright: ignore
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -75,7 +75,7 @@ try:
|
|
|
75
75
|
from mistralai.models.usermessage import UserMessage as MistralUserMessage
|
|
76
76
|
from mistralai.types.basemodel import Unset as MistralUnset
|
|
77
77
|
from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
|
|
78
|
-
except ImportError as e: # pragma:
|
|
78
|
+
except ImportError as e: # pragma: no cover
|
|
79
79
|
raise ImportError(
|
|
80
80
|
'Please install `mistral` to use the Mistral model, '
|
|
81
81
|
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
|
|
@@ -217,7 +217,7 @@ class MistralModel(Model):
|
|
|
217
217
|
except SDKError as e:
|
|
218
218
|
if (status_code := e.status_code) >= 400:
|
|
219
219
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
220
|
-
raise # pragma:
|
|
220
|
+
raise # pragma: no cover
|
|
221
221
|
|
|
222
222
|
assert response, 'A unexpected empty response from Mistral.'
|
|
223
223
|
return response
|
|
@@ -428,7 +428,7 @@ class MistralModel(Model):
|
|
|
428
428
|
if value_type == 'object':
|
|
429
429
|
additional_properties = value.get('additionalProperties', {})
|
|
430
430
|
if isinstance(additional_properties, bool):
|
|
431
|
-
return 'bool' # pragma: no cover
|
|
431
|
+
return 'bool' # pragma: lax no cover
|
|
432
432
|
additional_properties_type = additional_properties.get('type')
|
|
433
433
|
if (
|
|
434
434
|
additional_properties_type in SIMPLE_JSON_TYPE_MAPPING
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -50,7 +50,7 @@ from . import (
|
|
|
50
50
|
|
|
51
51
|
try:
|
|
52
52
|
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
|
|
53
|
-
from openai.types import
|
|
53
|
+
from openai.types import AllModels, chat, responses
|
|
54
54
|
from openai.types.chat import (
|
|
55
55
|
ChatCompletionChunk,
|
|
56
56
|
ChatCompletionContentPartImageParam,
|
|
@@ -80,7 +80,7 @@ __all__ = (
|
|
|
80
80
|
'OpenAIModelName',
|
|
81
81
|
)
|
|
82
82
|
|
|
83
|
-
OpenAIModelName = Union[str,
|
|
83
|
+
OpenAIModelName = Union[str, AllModels]
|
|
84
84
|
"""
|
|
85
85
|
Possible OpenAI model names.
|
|
86
86
|
|
|
@@ -345,7 +345,7 @@ class OpenAIModel(Model):
|
|
|
345
345
|
except APIStatusError as e:
|
|
346
346
|
if (status_code := e.status_code) >= 400:
|
|
347
347
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
348
|
-
raise # pragma:
|
|
348
|
+
raise # pragma: no cover
|
|
349
349
|
|
|
350
350
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
351
351
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -781,7 +781,7 @@ class OpenAIResponsesModel(Model):
|
|
|
781
781
|
except APIStatusError as e:
|
|
782
782
|
if (status_code := e.status_code) >= 400:
|
|
783
783
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
784
|
-
raise # pragma:
|
|
784
|
+
raise # pragma: no cover
|
|
785
785
|
|
|
786
786
|
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
|
|
787
787
|
reasoning_effort = model_settings.get('openai_reasoning_effort', None)
|
|
@@ -1051,7 +1051,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1051
1051
|
vendor_part_id=chunk.item_id,
|
|
1052
1052
|
tool_name=None,
|
|
1053
1053
|
args=chunk.delta,
|
|
1054
|
-
tool_call_id=
|
|
1054
|
+
tool_call_id=None,
|
|
1055
1055
|
)
|
|
1056
1056
|
if maybe_event is not None: # pragma: no branch
|
|
1057
1057
|
yield maybe_event
|
pydantic_ai/output.py
CHANGED
|
@@ -10,7 +10,8 @@ from pydantic_core import core_schema
|
|
|
10
10
|
from typing_extensions import TypeAliasType, TypeVar
|
|
11
11
|
|
|
12
12
|
from . import _utils
|
|
13
|
-
from .
|
|
13
|
+
from .messages import ToolCallPart
|
|
14
|
+
from .tools import RunContext, ToolDefinition
|
|
14
15
|
|
|
15
16
|
__all__ = (
|
|
16
17
|
# classes
|
|
@@ -330,15 +331,17 @@ def StructuredDict(
|
|
|
330
331
|
return _StructuredDict
|
|
331
332
|
|
|
332
333
|
|
|
334
|
+
_OutputSpecItem = TypeAliasType(
|
|
335
|
+
'_OutputSpecItem',
|
|
336
|
+
Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]],
|
|
337
|
+
type_params=(T_co,),
|
|
338
|
+
)
|
|
339
|
+
|
|
333
340
|
OutputSpec = TypeAliasType(
|
|
334
341
|
'OutputSpec',
|
|
335
342
|
Union[
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
NativeOutput[T_co],
|
|
339
|
-
PromptedOutput[T_co],
|
|
340
|
-
TextOutput[T_co],
|
|
341
|
-
Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]],
|
|
343
|
+
_OutputSpecItem[T_co],
|
|
344
|
+
Sequence['OutputSpec[T_co]'],
|
|
342
345
|
],
|
|
343
346
|
type_params=(T_co,),
|
|
344
347
|
)
|
|
@@ -354,3 +357,14 @@ You should not need to import or use this type directly.
|
|
|
354
357
|
|
|
355
358
|
See [output docs](../output.md) for more information.
|
|
356
359
|
"""
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@dataclass
|
|
363
|
+
class DeferredToolCalls:
|
|
364
|
+
"""Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.
|
|
365
|
+
|
|
366
|
+
See [deferred toolset docs](../toolsets.md#deferred-toolset) for more information.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
tool_calls: list[ToolCallPart]
|
|
370
|
+
tool_defs: dict[str, ToolDefinition]
|
pydantic_ai/profiles/google.py
CHANGED
|
@@ -43,7 +43,7 @@ class GoogleJsonSchemaTransformer(JsonSchemaTransformer):
|
|
|
43
43
|
f' Full schema: {self.schema}\n\n'
|
|
44
44
|
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
|
|
45
45
|
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
|
|
46
|
-
"If Google's APIs are updated to support this properly, please create an issue on the
|
|
46
|
+
"If Google's APIs are updated to support this properly, please create an issue on the Pydantic AI GitHub"
|
|
47
47
|
' and we will fix this behavior.',
|
|
48
48
|
UserWarning,
|
|
49
49
|
)
|
|
@@ -111,6 +111,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
111
111
|
from .heroku import HerokuProvider
|
|
112
112
|
|
|
113
113
|
return HerokuProvider
|
|
114
|
+
elif provider == 'huggingface':
|
|
115
|
+
from .huggingface import HuggingFaceProvider
|
|
116
|
+
|
|
117
|
+
return HuggingFaceProvider
|
|
114
118
|
elif provider == 'github':
|
|
115
119
|
from .github import GitHubProvider
|
|
116
120
|
|
pydantic_ai/providers/google.py
CHANGED
|
@@ -86,7 +86,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
86
86
|
# NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
|
|
87
87
|
api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
|
|
88
88
|
|
|
89
|
-
if vertexai is None:
|
|
89
|
+
if vertexai is None:
|
|
90
90
|
vertexai = bool(location or project or credentials)
|
|
91
91
|
|
|
92
92
|
if not vertexai:
|
|
@@ -114,7 +114,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
114
114
|
http_options={'headers': {'User-Agent': get_user_agent()}},
|
|
115
115
|
)
|
|
116
116
|
else:
|
|
117
|
-
self._client = client
|
|
117
|
+
self._client = client
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
VertexAILocation = Literal[
|
|
@@ -50,7 +50,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
50
50
|
return self._client
|
|
51
51
|
|
|
52
52
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
53
|
-
return google_model_profile(model_name)
|
|
53
|
+
return google_model_profile(model_name)
|
|
54
54
|
|
|
55
55
|
@overload
|
|
56
56
|
def __init__(
|
|
@@ -116,6 +116,8 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
116
116
|
class _VertexAIAuth(httpx.Auth):
|
|
117
117
|
"""Auth class for Vertex AI API."""
|
|
118
118
|
|
|
119
|
+
_refresh_lock: anyio.Lock = anyio.Lock()
|
|
120
|
+
|
|
119
121
|
credentials: BaseCredentials | ServiceAccountCredentials | None
|
|
120
122
|
|
|
121
123
|
def __init__(
|
|
@@ -169,10 +171,13 @@ class _VertexAIAuth(httpx.Auth):
|
|
|
169
171
|
return creds
|
|
170
172
|
|
|
171
173
|
async def _refresh_token(self) -> str: # pragma: no cover
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
174
|
+
async with self._refresh_lock:
|
|
175
|
+
assert self.credentials is not None
|
|
176
|
+
await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
|
|
177
|
+
assert isinstance(self.credentials.token, str), ( # type: ignore[reportUnknownMemberType]
|
|
178
|
+
f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
179
|
+
)
|
|
180
|
+
return self.credentials.token
|
|
176
181
|
|
|
177
182
|
|
|
178
183
|
async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
|