pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +5 -2
- pydantic_ai/_output.py +120 -14
- pydantic_ai/agent.py +1 -0
- pydantic_ai/common_tools/duckduckgo.py +5 -2
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/messages.py +6 -4
- pydantic_ai/models/__init__.py +13 -1
- pydantic_ai/models/anthropic.py +1 -1
- pydantic_ai/models/bedrock.py +1 -1
- pydantic_ai/models/cohere.py +1 -1
- pydantic_ai/models/gemini.py +2 -2
- pydantic_ai/models/google.py +1 -1
- pydantic_ai/models/groq.py +1 -1
- pydantic_ai/models/huggingface.py +463 -0
- pydantic_ai/models/instrumented.py +1 -1
- pydantic_ai/models/mistral.py +2 -2
- pydantic_ai/models/openai.py +2 -2
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/google.py +2 -2
- pydantic_ai/providers/google_vertex.py +10 -5
- pydantic_ai/providers/huggingface.py +88 -0
- pydantic_ai/result.py +16 -5
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.3.dist-info}/METADATA +6 -4
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.3.dist-info}/RECORD +27 -25
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.3.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.3.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-0.4.2.dist-info → pydantic_ai_slim-0.4.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from typing import Literal, Union, cast, overload
|
|
9
|
+
|
|
10
|
+
from typing_extensions import assert_never
|
|
11
|
+
|
|
12
|
+
from pydantic_ai._thinking_part import split_content_into_text_and_thinking
|
|
13
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
14
|
+
|
|
15
|
+
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
16
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
|
|
17
|
+
from ..messages import (
|
|
18
|
+
AudioUrl,
|
|
19
|
+
BinaryContent,
|
|
20
|
+
DocumentUrl,
|
|
21
|
+
ImageUrl,
|
|
22
|
+
ModelMessage,
|
|
23
|
+
ModelRequest,
|
|
24
|
+
ModelResponse,
|
|
25
|
+
ModelResponsePart,
|
|
26
|
+
ModelResponseStreamEvent,
|
|
27
|
+
RetryPromptPart,
|
|
28
|
+
SystemPromptPart,
|
|
29
|
+
TextPart,
|
|
30
|
+
ThinkingPart,
|
|
31
|
+
ToolCallPart,
|
|
32
|
+
ToolReturnPart,
|
|
33
|
+
UserPromptPart,
|
|
34
|
+
VideoUrl,
|
|
35
|
+
)
|
|
36
|
+
from ..settings import ModelSettings
|
|
37
|
+
from ..tools import ToolDefinition
|
|
38
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
import aiohttp
|
|
42
|
+
from huggingface_hub import (
|
|
43
|
+
AsyncInferenceClient,
|
|
44
|
+
ChatCompletionInputMessage,
|
|
45
|
+
ChatCompletionInputMessageChunk,
|
|
46
|
+
ChatCompletionInputTool,
|
|
47
|
+
ChatCompletionInputToolCall,
|
|
48
|
+
ChatCompletionInputURL,
|
|
49
|
+
ChatCompletionOutput,
|
|
50
|
+
ChatCompletionOutputMessage,
|
|
51
|
+
ChatCompletionStreamOutput,
|
|
52
|
+
)
|
|
53
|
+
from huggingface_hub.errors import HfHubHTTPError
|
|
54
|
+
|
|
55
|
+
except ImportError as _import_error:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
|
|
58
|
+
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
|
|
59
|
+
) from _import_error
|
|
60
|
+
|
|
61
|
+
__all__ = (
|
|
62
|
+
'HuggingFaceModel',
|
|
63
|
+
'HuggingFaceModelSettings',
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
HFSystemPromptRole = Literal['system', 'user']
|
|
68
|
+
|
|
69
|
+
LatestHuggingFaceModelNames = Literal[
|
|
70
|
+
'deepseek-ai/DeepSeek-R1',
|
|
71
|
+
'meta-llama/Llama-3.3-70B-Instruct',
|
|
72
|
+
'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
|
|
73
|
+
'meta-llama/Llama-4-Scout-17B-16E-Instruct',
|
|
74
|
+
'Qwen/QwQ-32B',
|
|
75
|
+
'Qwen/Qwen2.5-72B-Instruct',
|
|
76
|
+
'Qwen/Qwen3-235B-A22B',
|
|
77
|
+
'Qwen/Qwen3-32B',
|
|
78
|
+
]
|
|
79
|
+
"""Latest Hugging Face models."""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
|
|
83
|
+
"""Possible Hugging Face model names.
|
|
84
|
+
|
|
85
|
+
You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class HuggingFaceModelSettings(ModelSettings, total=False):
|
|
90
|
+
"""Settings used for a Hugging Face model request."""
|
|
91
|
+
|
|
92
|
+
# ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
93
|
+
# This class is a placeholder for any future huggingface-specific settings
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(init=False)
|
|
97
|
+
class HuggingFaceModel(Model):
|
|
98
|
+
"""A model that uses Hugging Face Inference Providers.
|
|
99
|
+
|
|
100
|
+
Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API.
|
|
101
|
+
|
|
102
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
client: AsyncInferenceClient = field(repr=False)
|
|
106
|
+
|
|
107
|
+
_model_name: str = field(repr=False)
|
|
108
|
+
_system: str = field(default='huggingface', repr=False)
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
model_name: str,
|
|
113
|
+
*,
|
|
114
|
+
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
|
115
|
+
):
|
|
116
|
+
"""Initialize a Hugging Face model.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
|
|
120
|
+
provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
|
|
121
|
+
instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
|
|
122
|
+
"""
|
|
123
|
+
self._model_name = model_name
|
|
124
|
+
self._provider = provider
|
|
125
|
+
if isinstance(provider, str):
|
|
126
|
+
provider = infer_provider(provider)
|
|
127
|
+
self.client = provider.client
|
|
128
|
+
|
|
129
|
+
async def request(
|
|
130
|
+
self,
|
|
131
|
+
messages: list[ModelMessage],
|
|
132
|
+
model_settings: ModelSettings | None,
|
|
133
|
+
model_request_parameters: ModelRequestParameters,
|
|
134
|
+
) -> ModelResponse:
|
|
135
|
+
check_allow_model_requests()
|
|
136
|
+
response = await self._completions_create(
|
|
137
|
+
messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
138
|
+
)
|
|
139
|
+
model_response = self._process_response(response)
|
|
140
|
+
model_response.usage.requests = 1
|
|
141
|
+
return model_response
|
|
142
|
+
|
|
143
|
+
@asynccontextmanager
|
|
144
|
+
async def request_stream(
|
|
145
|
+
self,
|
|
146
|
+
messages: list[ModelMessage],
|
|
147
|
+
model_settings: ModelSettings | None,
|
|
148
|
+
model_request_parameters: ModelRequestParameters,
|
|
149
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
150
|
+
check_allow_model_requests()
|
|
151
|
+
response = await self._completions_create(
|
|
152
|
+
messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
|
|
153
|
+
)
|
|
154
|
+
yield await self._process_streamed_response(response)
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def model_name(self) -> HuggingFaceModelName:
|
|
158
|
+
"""The model name."""
|
|
159
|
+
return self._model_name
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def system(self) -> str:
|
|
163
|
+
"""The system / model provider."""
|
|
164
|
+
return self._system
|
|
165
|
+
|
|
166
|
+
@overload
|
|
167
|
+
async def _completions_create(
|
|
168
|
+
self,
|
|
169
|
+
messages: list[ModelMessage],
|
|
170
|
+
stream: Literal[True],
|
|
171
|
+
model_settings: HuggingFaceModelSettings,
|
|
172
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
|
+
) -> AsyncIterable[ChatCompletionStreamOutput]: ...
|
|
174
|
+
|
|
175
|
+
@overload
|
|
176
|
+
async def _completions_create(
|
|
177
|
+
self,
|
|
178
|
+
messages: list[ModelMessage],
|
|
179
|
+
stream: Literal[False],
|
|
180
|
+
model_settings: HuggingFaceModelSettings,
|
|
181
|
+
model_request_parameters: ModelRequestParameters,
|
|
182
|
+
) -> ChatCompletionOutput: ...
|
|
183
|
+
|
|
184
|
+
async def _completions_create(
|
|
185
|
+
self,
|
|
186
|
+
messages: list[ModelMessage],
|
|
187
|
+
stream: bool,
|
|
188
|
+
model_settings: HuggingFaceModelSettings,
|
|
189
|
+
model_request_parameters: ModelRequestParameters,
|
|
190
|
+
) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]:
|
|
191
|
+
tools = self._get_tools(model_request_parameters)
|
|
192
|
+
|
|
193
|
+
if not tools:
|
|
194
|
+
tool_choice: Literal['none', 'required', 'auto'] | None = None
|
|
195
|
+
elif not model_request_parameters.allow_text_output:
|
|
196
|
+
tool_choice = 'required'
|
|
197
|
+
else:
|
|
198
|
+
tool_choice = 'auto'
|
|
199
|
+
|
|
200
|
+
hf_messages = await self._map_messages(messages)
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
return await self.client.chat.completions.create( # type: ignore
|
|
204
|
+
model=self._model_name,
|
|
205
|
+
messages=hf_messages, # type: ignore
|
|
206
|
+
tools=tools,
|
|
207
|
+
tool_choice=tool_choice or None,
|
|
208
|
+
stream=stream,
|
|
209
|
+
stop=model_settings.get('stop_sequences', None),
|
|
210
|
+
temperature=model_settings.get('temperature', None),
|
|
211
|
+
top_p=model_settings.get('top_p', None),
|
|
212
|
+
seed=model_settings.get('seed', None),
|
|
213
|
+
presence_penalty=model_settings.get('presence_penalty', None),
|
|
214
|
+
frequency_penalty=model_settings.get('frequency_penalty', None),
|
|
215
|
+
logit_bias=model_settings.get('logit_bias', None), # type: ignore
|
|
216
|
+
logprobs=model_settings.get('logprobs', None),
|
|
217
|
+
top_logprobs=model_settings.get('top_logprobs', None),
|
|
218
|
+
extra_body=model_settings.get('extra_body'), # type: ignore
|
|
219
|
+
)
|
|
220
|
+
except aiohttp.ClientResponseError as e:
|
|
221
|
+
raise ModelHTTPError(
|
|
222
|
+
status_code=e.status,
|
|
223
|
+
model_name=self.model_name,
|
|
224
|
+
body=e.response_error_payload, # type: ignore
|
|
225
|
+
) from e
|
|
226
|
+
except HfHubHTTPError as e:
|
|
227
|
+
raise ModelHTTPError(
|
|
228
|
+
status_code=e.response.status_code,
|
|
229
|
+
model_name=self.model_name,
|
|
230
|
+
body=e.response.content,
|
|
231
|
+
) from e
|
|
232
|
+
|
|
233
|
+
def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
|
|
234
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
235
|
+
if response.created:
|
|
236
|
+
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
|
|
237
|
+
else:
|
|
238
|
+
timestamp = _now_utc()
|
|
239
|
+
|
|
240
|
+
choice = response.choices[0]
|
|
241
|
+
content = choice.message.content
|
|
242
|
+
tool_calls = choice.message.tool_calls
|
|
243
|
+
|
|
244
|
+
items: list[ModelResponsePart] = []
|
|
245
|
+
|
|
246
|
+
if content is not None:
|
|
247
|
+
items.extend(split_content_into_text_and_thinking(content))
|
|
248
|
+
if tool_calls is not None:
|
|
249
|
+
for c in tool_calls:
|
|
250
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
|
|
251
|
+
return ModelResponse(
|
|
252
|
+
items,
|
|
253
|
+
usage=_map_usage(response),
|
|
254
|
+
model_name=response.model,
|
|
255
|
+
timestamp=timestamp,
|
|
256
|
+
vendor_id=response.id,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
|
|
260
|
+
"""Process a streamed response, and prepare a streaming response to return."""
|
|
261
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
262
|
+
first_chunk = await peekable_response.peek()
|
|
263
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
264
|
+
raise UnexpectedModelBehavior( # pragma: no cover
|
|
265
|
+
'Streamed response ended without content or tool calls'
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return HuggingFaceStreamedResponse(
|
|
269
|
+
_model_name=self._model_name,
|
|
270
|
+
_response=peekable_response,
|
|
271
|
+
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
|
|
275
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
276
|
+
if model_request_parameters.output_tools:
|
|
277
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
|
|
278
|
+
return tools
|
|
279
|
+
|
|
280
|
+
async def _map_messages(
|
|
281
|
+
self, messages: list[ModelMessage]
|
|
282
|
+
) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
283
|
+
"""Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`."""
|
|
284
|
+
hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = []
|
|
285
|
+
for message in messages:
|
|
286
|
+
if isinstance(message, ModelRequest):
|
|
287
|
+
async for item in self._map_user_message(message):
|
|
288
|
+
hf_messages.append(item)
|
|
289
|
+
elif isinstance(message, ModelResponse):
|
|
290
|
+
texts: list[str] = []
|
|
291
|
+
tool_calls: list[ChatCompletionInputToolCall] = []
|
|
292
|
+
for item in message.parts:
|
|
293
|
+
if isinstance(item, TextPart):
|
|
294
|
+
texts.append(item.content)
|
|
295
|
+
elif isinstance(item, ToolCallPart):
|
|
296
|
+
tool_calls.append(self._map_tool_call(item))
|
|
297
|
+
elif isinstance(item, ThinkingPart):
|
|
298
|
+
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
|
|
299
|
+
# please open an issue. The below code is the code to send thinking to the provider.
|
|
300
|
+
# texts.append(f'<think>\n{item.content}\n</think>')
|
|
301
|
+
pass
|
|
302
|
+
else:
|
|
303
|
+
assert_never(item)
|
|
304
|
+
message_param = ChatCompletionInputMessage(role='assistant') # type: ignore
|
|
305
|
+
if texts:
|
|
306
|
+
# Note: model responses from this model should only have one text item, so the following
|
|
307
|
+
# shouldn't merge multiple texts into one unless you switch models between runs:
|
|
308
|
+
message_param['content'] = '\n\n'.join(texts)
|
|
309
|
+
if tool_calls:
|
|
310
|
+
message_param['tool_calls'] = tool_calls
|
|
311
|
+
hf_messages.append(message_param)
|
|
312
|
+
else:
|
|
313
|
+
assert_never(message)
|
|
314
|
+
if instructions := self._get_instructions(messages):
|
|
315
|
+
hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
|
|
316
|
+
return hf_messages
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall:
|
|
320
|
+
return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore
|
|
321
|
+
{
|
|
322
|
+
'id': _guard_tool_call_id(t=t),
|
|
323
|
+
'type': 'function',
|
|
324
|
+
'function': {
|
|
325
|
+
'name': t.tool_name,
|
|
326
|
+
'arguments': t.args_as_json_str(),
|
|
327
|
+
},
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
|
|
333
|
+
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
|
|
334
|
+
{
|
|
335
|
+
'type': 'function',
|
|
336
|
+
'function': {
|
|
337
|
+
'name': f.name,
|
|
338
|
+
'description': f.description,
|
|
339
|
+
'parameters': f.parameters_json_schema,
|
|
340
|
+
},
|
|
341
|
+
}
|
|
342
|
+
)
|
|
343
|
+
if f.strict is not None:
|
|
344
|
+
tool_param['function']['strict'] = f.strict
|
|
345
|
+
return tool_param
|
|
346
|
+
|
|
347
|
+
async def _map_user_message(
|
|
348
|
+
self, message: ModelRequest
|
|
349
|
+
) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
|
|
350
|
+
for part in message.parts:
|
|
351
|
+
if isinstance(part, SystemPromptPart):
|
|
352
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore
|
|
353
|
+
elif isinstance(part, UserPromptPart):
|
|
354
|
+
yield await self._map_user_prompt(part)
|
|
355
|
+
elif isinstance(part, ToolReturnPart):
|
|
356
|
+
yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore
|
|
357
|
+
{
|
|
358
|
+
'role': 'tool',
|
|
359
|
+
'tool_call_id': _guard_tool_call_id(t=part),
|
|
360
|
+
'content': part.model_response_str(),
|
|
361
|
+
}
|
|
362
|
+
)
|
|
363
|
+
elif isinstance(part, RetryPromptPart):
|
|
364
|
+
if part.tool_name is None:
|
|
365
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
|
|
366
|
+
{'role': 'user', 'content': part.model_response()}
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
|
|
370
|
+
{
|
|
371
|
+
'role': 'tool',
|
|
372
|
+
'tool_call_id': _guard_tool_call_id(t=part),
|
|
373
|
+
'content': part.model_response(),
|
|
374
|
+
}
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
assert_never(part)
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
|
|
381
|
+
content: str | list[ChatCompletionInputMessage]
|
|
382
|
+
if isinstance(part.content, str):
|
|
383
|
+
content = part.content
|
|
384
|
+
else:
|
|
385
|
+
content = []
|
|
386
|
+
for item in part.content:
|
|
387
|
+
if isinstance(item, str):
|
|
388
|
+
content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore
|
|
389
|
+
elif isinstance(item, ImageUrl):
|
|
390
|
+
url = ChatCompletionInputURL(url=item.url) # type: ignore
|
|
391
|
+
content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
|
|
392
|
+
elif isinstance(item, BinaryContent):
|
|
393
|
+
base64_encoded = base64.b64encode(item.data).decode('utf-8')
|
|
394
|
+
if item.is_image:
|
|
395
|
+
url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore
|
|
396
|
+
content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
|
|
397
|
+
else: # pragma: no cover
|
|
398
|
+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
|
|
399
|
+
elif isinstance(item, AudioUrl):
|
|
400
|
+
raise NotImplementedError('AudioUrl is not supported for Hugging Face')
|
|
401
|
+
elif isinstance(item, DocumentUrl):
|
|
402
|
+
raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
|
|
403
|
+
elif isinstance(item, VideoUrl):
|
|
404
|
+
raise NotImplementedError('VideoUrl is not supported for Hugging Face')
|
|
405
|
+
else:
|
|
406
|
+
assert_never(item)
|
|
407
|
+
return ChatCompletionInputMessage(role='user', content=content) # type: ignore
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
@dataclass
|
|
411
|
+
class HuggingFaceStreamedResponse(StreamedResponse):
|
|
412
|
+
"""Implementation of `StreamedResponse` for Hugging Face models."""
|
|
413
|
+
|
|
414
|
+
_model_name: str
|
|
415
|
+
_response: AsyncIterable[ChatCompletionStreamOutput]
|
|
416
|
+
_timestamp: datetime
|
|
417
|
+
|
|
418
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
419
|
+
async for chunk in self._response:
|
|
420
|
+
self._usage += _map_usage(chunk)
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
choice = chunk.choices[0]
|
|
424
|
+
except IndexError:
|
|
425
|
+
continue
|
|
426
|
+
|
|
427
|
+
# Handle the text part of the response
|
|
428
|
+
content = choice.delta.content
|
|
429
|
+
if content is not None:
|
|
430
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
|
|
431
|
+
|
|
432
|
+
for dtc in choice.delta.tool_calls or []:
|
|
433
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
434
|
+
vendor_part_id=dtc.index,
|
|
435
|
+
tool_name=dtc.function and dtc.function.name, # type: ignore
|
|
436
|
+
args=dtc.function and dtc.function.arguments,
|
|
437
|
+
tool_call_id=dtc.id,
|
|
438
|
+
)
|
|
439
|
+
if maybe_event is not None:
|
|
440
|
+
yield maybe_event
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def model_name(self) -> str:
|
|
444
|
+
"""Get the model name of the response."""
|
|
445
|
+
return self._model_name
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def timestamp(self) -> datetime:
|
|
449
|
+
"""Get the timestamp of the response."""
|
|
450
|
+
return self._timestamp
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
|
|
454
|
+
response_usage = response.usage
|
|
455
|
+
if response_usage is None:
|
|
456
|
+
return usage.Usage()
|
|
457
|
+
|
|
458
|
+
return usage.Usage(
|
|
459
|
+
request_tokens=response_usage.prompt_tokens,
|
|
460
|
+
response_tokens=response_usage.completion_tokens,
|
|
461
|
+
total_tokens=response_usage.total_tokens,
|
|
462
|
+
details=None,
|
|
463
|
+
)
|
|
@@ -138,7 +138,7 @@ class InstrumentationSettings:
|
|
|
138
138
|
**tokens_histogram_kwargs,
|
|
139
139
|
explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
|
|
140
140
|
)
|
|
141
|
-
except TypeError:
|
|
141
|
+
except TypeError:
|
|
142
142
|
# Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
|
|
143
143
|
self.tokens_histogram = self.meter.create_histogram(
|
|
144
144
|
**tokens_histogram_kwargs, # pyright: ignore
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -75,7 +75,7 @@ try:
|
|
|
75
75
|
from mistralai.models.usermessage import UserMessage as MistralUserMessage
|
|
76
76
|
from mistralai.types.basemodel import Unset as MistralUnset
|
|
77
77
|
from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
|
|
78
|
-
except ImportError as e: # pragma:
|
|
78
|
+
except ImportError as e: # pragma: no cover
|
|
79
79
|
raise ImportError(
|
|
80
80
|
'Please install `mistral` to use the Mistral model, '
|
|
81
81
|
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
|
|
@@ -217,7 +217,7 @@ class MistralModel(Model):
|
|
|
217
217
|
except SDKError as e:
|
|
218
218
|
if (status_code := e.status_code) >= 400:
|
|
219
219
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
220
|
-
raise # pragma:
|
|
220
|
+
raise # pragma: no cover
|
|
221
221
|
|
|
222
222
|
assert response, 'A unexpected empty response from Mistral.'
|
|
223
223
|
return response
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -345,7 +345,7 @@ class OpenAIModel(Model):
|
|
|
345
345
|
except APIStatusError as e:
|
|
346
346
|
if (status_code := e.status_code) >= 400:
|
|
347
347
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
348
|
-
raise # pragma:
|
|
348
|
+
raise # pragma: no cover
|
|
349
349
|
|
|
350
350
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
351
351
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
@@ -781,7 +781,7 @@ class OpenAIResponsesModel(Model):
|
|
|
781
781
|
except APIStatusError as e:
|
|
782
782
|
if (status_code := e.status_code) >= 400:
|
|
783
783
|
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
|
|
784
|
-
raise # pragma:
|
|
784
|
+
raise # pragma: no cover
|
|
785
785
|
|
|
786
786
|
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
|
|
787
787
|
reasoning_effort = model_settings.get('openai_reasoning_effort', None)
|
|
@@ -111,6 +111,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
111
111
|
from .heroku import HerokuProvider
|
|
112
112
|
|
|
113
113
|
return HerokuProvider
|
|
114
|
+
elif provider == 'huggingface':
|
|
115
|
+
from .huggingface import HuggingFaceProvider
|
|
116
|
+
|
|
117
|
+
return HuggingFaceProvider
|
|
114
118
|
elif provider == 'github':
|
|
115
119
|
from .github import GitHubProvider
|
|
116
120
|
|
pydantic_ai/providers/google.py
CHANGED
|
@@ -86,7 +86,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
86
86
|
# NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
|
|
87
87
|
api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
|
|
88
88
|
|
|
89
|
-
if vertexai is None:
|
|
89
|
+
if vertexai is None:
|
|
90
90
|
vertexai = bool(location or project or credentials)
|
|
91
91
|
|
|
92
92
|
if not vertexai:
|
|
@@ -114,7 +114,7 @@ class GoogleProvider(Provider[genai.Client]):
|
|
|
114
114
|
http_options={'headers': {'User-Agent': get_user_agent()}},
|
|
115
115
|
)
|
|
116
116
|
else:
|
|
117
|
-
self._client = client
|
|
117
|
+
self._client = client
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
VertexAILocation = Literal[
|
|
@@ -50,7 +50,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
50
50
|
return self._client
|
|
51
51
|
|
|
52
52
|
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
53
|
-
return google_model_profile(model_name)
|
|
53
|
+
return google_model_profile(model_name)
|
|
54
54
|
|
|
55
55
|
@overload
|
|
56
56
|
def __init__(
|
|
@@ -116,6 +116,8 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
|
116
116
|
class _VertexAIAuth(httpx.Auth):
|
|
117
117
|
"""Auth class for Vertex AI API."""
|
|
118
118
|
|
|
119
|
+
_refresh_lock: anyio.Lock = anyio.Lock()
|
|
120
|
+
|
|
119
121
|
credentials: BaseCredentials | ServiceAccountCredentials | None
|
|
120
122
|
|
|
121
123
|
def __init__(
|
|
@@ -169,10 +171,13 @@ class _VertexAIAuth(httpx.Auth):
|
|
|
169
171
|
return creds
|
|
170
172
|
|
|
171
173
|
async def _refresh_token(self) -> str: # pragma: no cover
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
174
|
+
async with self._refresh_lock:
|
|
175
|
+
assert self.credentials is not None
|
|
176
|
+
await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
|
|
177
|
+
assert isinstance(self.credentials.token, str), ( # type: ignore[reportUnknownMemberType]
|
|
178
|
+
f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
179
|
+
)
|
|
180
|
+
return self.credentials.token
|
|
176
181
|
|
|
177
182
|
|
|
178
183
|
async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.exceptions import UserError
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from huggingface_hub import AsyncInferenceClient
|
|
12
|
+
except ImportError as _import_error: # pragma: no cover
|
|
13
|
+
raise ImportError(
|
|
14
|
+
'Please install the `huggingface_hub` package to use the HuggingFace provider, '
|
|
15
|
+
"you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`"
|
|
16
|
+
) from _import_error
|
|
17
|
+
|
|
18
|
+
from . import Provider
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HuggingFaceProvider(Provider[AsyncInferenceClient]):
|
|
22
|
+
"""Provider for Hugging Face."""
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def name(self) -> str:
|
|
26
|
+
return 'huggingface'
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def base_url(self) -> str:
|
|
30
|
+
return self.client.model # type: ignore
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def client(self) -> AsyncInferenceClient:
|
|
34
|
+
return self._client
|
|
35
|
+
|
|
36
|
+
@overload
|
|
37
|
+
def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
|
|
38
|
+
@overload
|
|
39
|
+
def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ...
|
|
40
|
+
@overload
|
|
41
|
+
def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ...
|
|
42
|
+
@overload
|
|
43
|
+
def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ...
|
|
44
|
+
@overload
|
|
45
|
+
def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ...
|
|
46
|
+
@overload
|
|
47
|
+
def __init__(self, *, api_key: str | None = None) -> None: ...
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
base_url: str | None = None,
|
|
52
|
+
api_key: str | None = None,
|
|
53
|
+
hf_client: AsyncInferenceClient | None = None,
|
|
54
|
+
http_client: AsyncClient | None = None,
|
|
55
|
+
provider_name: str | None = None,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Create a new Hugging Face provider.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
base_url: The base url for the Hugging Face requests.
|
|
61
|
+
api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable
|
|
62
|
+
will be used if available.
|
|
63
|
+
hf_client: An existing
|
|
64
|
+
[`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
|
|
65
|
+
client to use. If not provided, a new instance will be created.
|
|
66
|
+
http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
67
|
+
provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
|
|
68
|
+
defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
|
69
|
+
If `base_url` is passed, then `provider_name` is not used.
|
|
70
|
+
"""
|
|
71
|
+
api_key = api_key or os.environ.get('HF_TOKEN')
|
|
72
|
+
|
|
73
|
+
if api_key is None:
|
|
74
|
+
raise UserError(
|
|
75
|
+
'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`'
|
|
76
|
+
'to use the HuggingFace provider.'
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if http_client is not None:
|
|
80
|
+
raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.')
|
|
81
|
+
|
|
82
|
+
if base_url is not None and provider_name is not None:
|
|
83
|
+
raise ValueError('Cannot provide both `base_url` and `provider_name`.')
|
|
84
|
+
|
|
85
|
+
if hf_client is None:
|
|
86
|
+
self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore
|
|
87
|
+
else:
|
|
88
|
+
self._client = hf_client
|