pydantic-ai-slim 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -0,0 +1,463 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import base64
4
+ from collections.abc import AsyncIterable, AsyncIterator
5
+ from contextlib import asynccontextmanager
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime, timezone
8
+ from typing import Literal, Union, cast, overload
9
+
10
+ from typing_extensions import assert_never
11
+
12
+ from pydantic_ai._thinking_part import split_content_into_text_and_thinking
13
+ from pydantic_ai.providers import Provider, infer_provider
14
+
15
+ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
16
+ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
17
+ from ..messages import (
18
+ AudioUrl,
19
+ BinaryContent,
20
+ DocumentUrl,
21
+ ImageUrl,
22
+ ModelMessage,
23
+ ModelRequest,
24
+ ModelResponse,
25
+ ModelResponsePart,
26
+ ModelResponseStreamEvent,
27
+ RetryPromptPart,
28
+ SystemPromptPart,
29
+ TextPart,
30
+ ThinkingPart,
31
+ ToolCallPart,
32
+ ToolReturnPart,
33
+ UserPromptPart,
34
+ VideoUrl,
35
+ )
36
+ from ..settings import ModelSettings
37
+ from ..tools import ToolDefinition
38
+ from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
39
+
40
+ try:
41
+ import aiohttp
42
+ from huggingface_hub import (
43
+ AsyncInferenceClient,
44
+ ChatCompletionInputMessage,
45
+ ChatCompletionInputMessageChunk,
46
+ ChatCompletionInputTool,
47
+ ChatCompletionInputToolCall,
48
+ ChatCompletionInputURL,
49
+ ChatCompletionOutput,
50
+ ChatCompletionOutputMessage,
51
+ ChatCompletionStreamOutput,
52
+ )
53
+ from huggingface_hub.errors import HfHubHTTPError
54
+
55
+ except ImportError as _import_error:
56
+ raise ImportError(
57
+ 'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
58
+ 'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
59
+ ) from _import_error
60
+
61
+ __all__ = (
62
+ 'HuggingFaceModel',
63
+ 'HuggingFaceModelSettings',
64
+ )
65
+
66
+
67
+ HFSystemPromptRole = Literal['system', 'user']
68
+
69
+ LatestHuggingFaceModelNames = Literal[
70
+ 'deepseek-ai/DeepSeek-R1',
71
+ 'meta-llama/Llama-3.3-70B-Instruct',
72
+ 'meta-llama/Llama-4-Maverick-17B-128E-Instruct',
73
+ 'meta-llama/Llama-4-Scout-17B-16E-Instruct',
74
+ 'Qwen/QwQ-32B',
75
+ 'Qwen/Qwen2.5-72B-Instruct',
76
+ 'Qwen/Qwen3-235B-A22B',
77
+ 'Qwen/Qwen3-32B',
78
+ ]
79
+ """Latest Hugging Face models."""
80
+
81
+
82
+ HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames]
83
+ """Possible Hugging Face model names.
84
+
85
+ You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
86
+ """
87
+
88
+
89
+ class HuggingFaceModelSettings(ModelSettings, total=False):
90
+ """Settings used for a Hugging Face model request."""
91
+
92
+ # ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
93
+ # This class is a placeholder for any future huggingface-specific settings
94
+
95
+
96
+ @dataclass(init=False)
97
+ class HuggingFaceModel(Model):
98
+ """A model that uses Hugging Face Inference Providers.
99
+
100
+ Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API.
101
+
102
+ Apart from `__init__`, all methods are private or match those of the base class.
103
+ """
104
+
105
+ client: AsyncInferenceClient = field(repr=False)
106
+
107
+ _model_name: str = field(repr=False)
108
+ _system: str = field(default='huggingface', repr=False)
109
+
110
+ def __init__(
111
+ self,
112
+ model_name: str,
113
+ *,
114
+ provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
115
+ ):
116
+ """Initialize a Hugging Face model.
117
+
118
+ Args:
119
+ model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending).
120
+ provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an
121
+ instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used.
122
+ """
123
+ self._model_name = model_name
124
+ self._provider = provider
125
+ if isinstance(provider, str):
126
+ provider = infer_provider(provider)
127
+ self.client = provider.client
128
+
129
+ async def request(
130
+ self,
131
+ messages: list[ModelMessage],
132
+ model_settings: ModelSettings | None,
133
+ model_request_parameters: ModelRequestParameters,
134
+ ) -> ModelResponse:
135
+ check_allow_model_requests()
136
+ response = await self._completions_create(
137
+ messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
138
+ )
139
+ model_response = self._process_response(response)
140
+ model_response.usage.requests = 1
141
+ return model_response
142
+
143
+ @asynccontextmanager
144
+ async def request_stream(
145
+ self,
146
+ messages: list[ModelMessage],
147
+ model_settings: ModelSettings | None,
148
+ model_request_parameters: ModelRequestParameters,
149
+ ) -> AsyncIterator[StreamedResponse]:
150
+ check_allow_model_requests()
151
+ response = await self._completions_create(
152
+ messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters
153
+ )
154
+ yield await self._process_streamed_response(response)
155
+
156
+ @property
157
+ def model_name(self) -> HuggingFaceModelName:
158
+ """The model name."""
159
+ return self._model_name
160
+
161
+ @property
162
+ def system(self) -> str:
163
+ """The system / model provider."""
164
+ return self._system
165
+
166
+ @overload
167
+ async def _completions_create(
168
+ self,
169
+ messages: list[ModelMessage],
170
+ stream: Literal[True],
171
+ model_settings: HuggingFaceModelSettings,
172
+ model_request_parameters: ModelRequestParameters,
173
+ ) -> AsyncIterable[ChatCompletionStreamOutput]: ...
174
+
175
+ @overload
176
+ async def _completions_create(
177
+ self,
178
+ messages: list[ModelMessage],
179
+ stream: Literal[False],
180
+ model_settings: HuggingFaceModelSettings,
181
+ model_request_parameters: ModelRequestParameters,
182
+ ) -> ChatCompletionOutput: ...
183
+
184
+ async def _completions_create(
185
+ self,
186
+ messages: list[ModelMessage],
187
+ stream: bool,
188
+ model_settings: HuggingFaceModelSettings,
189
+ model_request_parameters: ModelRequestParameters,
190
+ ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]:
191
+ tools = self._get_tools(model_request_parameters)
192
+
193
+ if not tools:
194
+ tool_choice: Literal['none', 'required', 'auto'] | None = None
195
+ elif not model_request_parameters.allow_text_output:
196
+ tool_choice = 'required'
197
+ else:
198
+ tool_choice = 'auto'
199
+
200
+ hf_messages = await self._map_messages(messages)
201
+
202
+ try:
203
+ return await self.client.chat.completions.create( # type: ignore
204
+ model=self._model_name,
205
+ messages=hf_messages, # type: ignore
206
+ tools=tools,
207
+ tool_choice=tool_choice or None,
208
+ stream=stream,
209
+ stop=model_settings.get('stop_sequences', None),
210
+ temperature=model_settings.get('temperature', None),
211
+ top_p=model_settings.get('top_p', None),
212
+ seed=model_settings.get('seed', None),
213
+ presence_penalty=model_settings.get('presence_penalty', None),
214
+ frequency_penalty=model_settings.get('frequency_penalty', None),
215
+ logit_bias=model_settings.get('logit_bias', None), # type: ignore
216
+ logprobs=model_settings.get('logprobs', None),
217
+ top_logprobs=model_settings.get('top_logprobs', None),
218
+ extra_body=model_settings.get('extra_body'), # type: ignore
219
+ )
220
+ except aiohttp.ClientResponseError as e:
221
+ raise ModelHTTPError(
222
+ status_code=e.status,
223
+ model_name=self.model_name,
224
+ body=e.response_error_payload, # type: ignore
225
+ ) from e
226
+ except HfHubHTTPError as e:
227
+ raise ModelHTTPError(
228
+ status_code=e.response.status_code,
229
+ model_name=self.model_name,
230
+ body=e.response.content,
231
+ ) from e
232
+
233
+ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
234
+ """Process a non-streamed response, and prepare a message to return."""
235
+ if response.created:
236
+ timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
237
+ else:
238
+ timestamp = _now_utc()
239
+
240
+ choice = response.choices[0]
241
+ content = choice.message.content
242
+ tool_calls = choice.message.tool_calls
243
+
244
+ items: list[ModelResponsePart] = []
245
+
246
+ if content is not None:
247
+ items.extend(split_content_into_text_and_thinking(content))
248
+ if tool_calls is not None:
249
+ for c in tool_calls:
250
+ items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
251
+ return ModelResponse(
252
+ items,
253
+ usage=_map_usage(response),
254
+ model_name=response.model,
255
+ timestamp=timestamp,
256
+ vendor_id=response.id,
257
+ )
258
+
259
+ async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
260
+ """Process a streamed response, and prepare a streaming response to return."""
261
+ peekable_response = _utils.PeekableAsyncStream(response)
262
+ first_chunk = await peekable_response.peek()
263
+ if isinstance(first_chunk, _utils.Unset):
264
+ raise UnexpectedModelBehavior( # pragma: no cover
265
+ 'Streamed response ended without content or tool calls'
266
+ )
267
+
268
+ return HuggingFaceStreamedResponse(
269
+ _model_name=self._model_name,
270
+ _response=peekable_response,
271
+ _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
272
+ )
273
+
274
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]:
275
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
276
+ if model_request_parameters.output_tools:
277
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
278
+ return tools
279
+
280
+ async def _map_messages(
281
+ self, messages: list[ModelMessage]
282
+ ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
283
+ """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`."""
284
+ hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = []
285
+ for message in messages:
286
+ if isinstance(message, ModelRequest):
287
+ async for item in self._map_user_message(message):
288
+ hf_messages.append(item)
289
+ elif isinstance(message, ModelResponse):
290
+ texts: list[str] = []
291
+ tool_calls: list[ChatCompletionInputToolCall] = []
292
+ for item in message.parts:
293
+ if isinstance(item, TextPart):
294
+ texts.append(item.content)
295
+ elif isinstance(item, ToolCallPart):
296
+ tool_calls.append(self._map_tool_call(item))
297
+ elif isinstance(item, ThinkingPart):
298
+ # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
299
+ # please open an issue. The below code is the code to send thinking to the provider.
300
+ # texts.append(f'<think>\n{item.content}\n</think>')
301
+ pass
302
+ else:
303
+ assert_never(item)
304
+ message_param = ChatCompletionInputMessage(role='assistant') # type: ignore
305
+ if texts:
306
+ # Note: model responses from this model should only have one text item, so the following
307
+ # shouldn't merge multiple texts into one unless you switch models between runs:
308
+ message_param['content'] = '\n\n'.join(texts)
309
+ if tool_calls:
310
+ message_param['tool_calls'] = tool_calls
311
+ hf_messages.append(message_param)
312
+ else:
313
+ assert_never(message)
314
+ if instructions := self._get_instructions(messages):
315
+ hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore
316
+ return hf_messages
317
+
318
+ @staticmethod
319
+ def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall:
320
+ return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore
321
+ {
322
+ 'id': _guard_tool_call_id(t=t),
323
+ 'type': 'function',
324
+ 'function': {
325
+ 'name': t.tool_name,
326
+ 'arguments': t.args_as_json_str(),
327
+ },
328
+ }
329
+ )
330
+
331
+ @staticmethod
332
+ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
333
+ tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
334
+ {
335
+ 'type': 'function',
336
+ 'function': {
337
+ 'name': f.name,
338
+ 'description': f.description,
339
+ 'parameters': f.parameters_json_schema,
340
+ },
341
+ }
342
+ )
343
+ if f.strict is not None:
344
+ tool_param['function']['strict'] = f.strict
345
+ return tool_param
346
+
347
+ async def _map_user_message(
348
+ self, message: ModelRequest
349
+ ) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]:
350
+ for part in message.parts:
351
+ if isinstance(part, SystemPromptPart):
352
+ yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore
353
+ elif isinstance(part, UserPromptPart):
354
+ yield await self._map_user_prompt(part)
355
+ elif isinstance(part, ToolReturnPart):
356
+ yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore
357
+ {
358
+ 'role': 'tool',
359
+ 'tool_call_id': _guard_tool_call_id(t=part),
360
+ 'content': part.model_response_str(),
361
+ }
362
+ )
363
+ elif isinstance(part, RetryPromptPart):
364
+ if part.tool_name is None:
365
+ yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
366
+ {'role': 'user', 'content': part.model_response()}
367
+ )
368
+ else:
369
+ yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore
370
+ {
371
+ 'role': 'tool',
372
+ 'tool_call_id': _guard_tool_call_id(t=part),
373
+ 'content': part.model_response(),
374
+ }
375
+ )
376
+ else:
377
+ assert_never(part)
378
+
379
+ @staticmethod
380
+ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage:
381
+ content: str | list[ChatCompletionInputMessage]
382
+ if isinstance(part.content, str):
383
+ content = part.content
384
+ else:
385
+ content = []
386
+ for item in part.content:
387
+ if isinstance(item, str):
388
+ content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore
389
+ elif isinstance(item, ImageUrl):
390
+ url = ChatCompletionInputURL(url=item.url) # type: ignore
391
+ content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
392
+ elif isinstance(item, BinaryContent):
393
+ base64_encoded = base64.b64encode(item.data).decode('utf-8')
394
+ if item.is_image:
395
+ url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore
396
+ content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore
397
+ else: # pragma: no cover
398
+ raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
399
+ elif isinstance(item, AudioUrl):
400
+ raise NotImplementedError('AudioUrl is not supported for Hugging Face')
401
+ elif isinstance(item, DocumentUrl):
402
+ raise NotImplementedError('DocumentUrl is not supported for Hugging Face')
403
+ elif isinstance(item, VideoUrl):
404
+ raise NotImplementedError('VideoUrl is not supported for Hugging Face')
405
+ else:
406
+ assert_never(item)
407
+ return ChatCompletionInputMessage(role='user', content=content) # type: ignore
408
+
409
+
410
+ @dataclass
411
+ class HuggingFaceStreamedResponse(StreamedResponse):
412
+ """Implementation of `StreamedResponse` for Hugging Face models."""
413
+
414
+ _model_name: str
415
+ _response: AsyncIterable[ChatCompletionStreamOutput]
416
+ _timestamp: datetime
417
+
418
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
419
+ async for chunk in self._response:
420
+ self._usage += _map_usage(chunk)
421
+
422
+ try:
423
+ choice = chunk.choices[0]
424
+ except IndexError:
425
+ continue
426
+
427
+ # Handle the text part of the response
428
+ content = choice.delta.content
429
+ if content is not None:
430
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
431
+
432
+ for dtc in choice.delta.tool_calls or []:
433
+ maybe_event = self._parts_manager.handle_tool_call_delta(
434
+ vendor_part_id=dtc.index,
435
+ tool_name=dtc.function and dtc.function.name, # type: ignore
436
+ args=dtc.function and dtc.function.arguments,
437
+ tool_call_id=dtc.id,
438
+ )
439
+ if maybe_event is not None:
440
+ yield maybe_event
441
+
442
+ @property
443
+ def model_name(self) -> str:
444
+ """Get the model name of the response."""
445
+ return self._model_name
446
+
447
+ @property
448
+ def timestamp(self) -> datetime:
449
+ """Get the timestamp of the response."""
450
+ return self._timestamp
451
+
452
+
453
+ def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage:
454
+ response_usage = response.usage
455
+ if response_usage is None:
456
+ return usage.Usage()
457
+
458
+ return usage.Usage(
459
+ request_tokens=response_usage.prompt_tokens,
460
+ response_tokens=response_usage.completion_tokens,
461
+ total_tokens=response_usage.total_tokens,
462
+ details=None,
463
+ )
@@ -138,7 +138,7 @@ class InstrumentationSettings:
138
138
  **tokens_histogram_kwargs,
139
139
  explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
140
140
  )
141
- except TypeError: # pragma: lax no cover
141
+ except TypeError:
142
142
  # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
143
143
  self.tokens_histogram = self.meter.create_histogram(
144
144
  **tokens_histogram_kwargs, # pyright: ignore
@@ -75,7 +75,7 @@ try:
75
75
  from mistralai.models.usermessage import UserMessage as MistralUserMessage
76
76
  from mistralai.types.basemodel import Unset as MistralUnset
77
77
  from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
78
- except ImportError as e: # pragma: lax no cover
78
+ except ImportError as e: # pragma: no cover
79
79
  raise ImportError(
80
80
  'Please install `mistral` to use the Mistral model, '
81
81
  'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
@@ -217,7 +217,7 @@ class MistralModel(Model):
217
217
  except SDKError as e:
218
218
  if (status_code := e.status_code) >= 400:
219
219
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
220
- raise # pragma: lax no cover
220
+ raise # pragma: no cover
221
221
 
222
222
  assert response, 'A unexpected empty response from Mistral.'
223
223
  return response
@@ -345,7 +345,7 @@ class OpenAIModel(Model):
345
345
  except APIStatusError as e:
346
346
  if (status_code := e.status_code) >= 400:
347
347
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
348
- raise # pragma: lax no cover
348
+ raise # pragma: no cover
349
349
 
350
350
  def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
351
351
  """Process a non-streamed response, and prepare a message to return."""
@@ -781,7 +781,7 @@ class OpenAIResponsesModel(Model):
781
781
  except APIStatusError as e:
782
782
  if (status_code := e.status_code) >= 400:
783
783
  raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
784
- raise # pragma: lax no cover
784
+ raise # pragma: no cover
785
785
 
786
786
  def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven:
787
787
  reasoning_effort = model_settings.get('openai_reasoning_effort', None)
@@ -111,6 +111,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
111
111
  from .heroku import HerokuProvider
112
112
 
113
113
  return HerokuProvider
114
+ elif provider == 'huggingface':
115
+ from .huggingface import HuggingFaceProvider
116
+
117
+ return HuggingFaceProvider
114
118
  elif provider == 'github':
115
119
  from .github import GitHubProvider
116
120
 
@@ -86,7 +86,7 @@ class GoogleProvider(Provider[genai.Client]):
86
86
  # NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
87
87
  api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
88
88
 
89
- if vertexai is None: # pragma: lax no cover
89
+ if vertexai is None:
90
90
  vertexai = bool(location or project or credentials)
91
91
 
92
92
  if not vertexai:
@@ -114,7 +114,7 @@ class GoogleProvider(Provider[genai.Client]):
114
114
  http_options={'headers': {'User-Agent': get_user_agent()}},
115
115
  )
116
116
  else:
117
- self._client = client # pragma: lax no cover
117
+ self._client = client
118
118
 
119
119
 
120
120
  VertexAILocation = Literal[
@@ -50,7 +50,7 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
50
50
  return self._client
51
51
 
52
52
  def model_profile(self, model_name: str) -> ModelProfile | None:
53
- return google_model_profile(model_name) # pragma: lax no cover
53
+ return google_model_profile(model_name)
54
54
 
55
55
  @overload
56
56
  def __init__(
@@ -116,6 +116,8 @@ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
116
116
  class _VertexAIAuth(httpx.Auth):
117
117
  """Auth class for Vertex AI API."""
118
118
 
119
+ _refresh_lock: anyio.Lock = anyio.Lock()
120
+
119
121
  credentials: BaseCredentials | ServiceAccountCredentials | None
120
122
 
121
123
  def __init__(
@@ -169,10 +171,13 @@ class _VertexAIAuth(httpx.Auth):
169
171
  return creds
170
172
 
171
173
  async def _refresh_token(self) -> str: # pragma: no cover
172
- assert self.credentials is not None
173
- await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
174
- assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
175
- return self.credentials.token
174
+ async with self._refresh_lock:
175
+ assert self.credentials is not None
176
+ await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
177
+ assert isinstance(self.credentials.token, str), ( # type: ignore[reportUnknownMemberType]
178
+ f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
179
+ )
180
+ return self.credentials.token
176
181
 
177
182
 
178
183
  async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient
7
+
8
+ from pydantic_ai.exceptions import UserError
9
+
10
+ try:
11
+ from huggingface_hub import AsyncInferenceClient
12
+ except ImportError as _import_error: # pragma: no cover
13
+ raise ImportError(
14
+ 'Please install the `huggingface_hub` package to use the HuggingFace provider, '
15
+ "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`"
16
+ ) from _import_error
17
+
18
+ from . import Provider
19
+
20
+
21
+ class HuggingFaceProvider(Provider[AsyncInferenceClient]):
22
+ """Provider for Hugging Face."""
23
+
24
+ @property
25
+ def name(self) -> str:
26
+ return 'huggingface'
27
+
28
+ @property
29
+ def base_url(self) -> str:
30
+ return self.client.model # type: ignore
31
+
32
+ @property
33
+ def client(self) -> AsyncInferenceClient:
34
+ return self._client
35
+
36
+ @overload
37
+ def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ...
38
+ @overload
39
+ def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ...
40
+ @overload
41
+ def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ...
42
+ @overload
43
+ def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ...
44
+ @overload
45
+ def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ...
46
+ @overload
47
+ def __init__(self, *, api_key: str | None = None) -> None: ...
48
+
49
+ def __init__(
50
+ self,
51
+ base_url: str | None = None,
52
+ api_key: str | None = None,
53
+ hf_client: AsyncInferenceClient | None = None,
54
+ http_client: AsyncClient | None = None,
55
+ provider_name: str | None = None,
56
+ ) -> None:
57
+ """Create a new Hugging Face provider.
58
+
59
+ Args:
60
+ base_url: The base url for the Hugging Face requests.
61
+ api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable
62
+ will be used if available.
63
+ hf_client: An existing
64
+ [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
65
+ client to use. If not provided, a new instance will be created.
66
+ http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests.
67
+ provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners).
68
+ defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
69
+ If `base_url` is passed, then `provider_name` is not used.
70
+ """
71
+ api_key = api_key or os.environ.get('HF_TOKEN')
72
+
73
+ if api_key is None:
74
+ raise UserError(
75
+ 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`'
76
+ 'to use the HuggingFace provider.'
77
+ )
78
+
79
+ if http_client is not None:
80
+ raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.')
81
+
82
+ if base_url is not None and provider_name is not None:
83
+ raise ValueError('Cannot provide both `base_url` and `provider_name`.')
84
+
85
+ if hf_client is None:
86
+ self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore
87
+ else:
88
+ self._client = hf_client