huggingface-hub 0.28.1__py3-none-any.whl → 0.29.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (63) hide show
  1. huggingface_hub/__init__.py +1 -4
  2. huggingface_hub/constants.py +16 -10
  3. huggingface_hub/file_download.py +10 -6
  4. huggingface_hub/hf_api.py +53 -23
  5. huggingface_hub/inference/_client.py +151 -84
  6. huggingface_hub/inference/_common.py +3 -27
  7. huggingface_hub/inference/_generated/_async_client.py +147 -83
  8. huggingface_hub/inference/_generated/types/__init__.py +1 -1
  9. huggingface_hub/inference/_generated/types/audio_classification.py +4 -5
  10. huggingface_hub/inference/_generated/types/audio_to_audio.py +3 -4
  11. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +7 -8
  12. huggingface_hub/inference/_generated/types/base.py +21 -0
  13. huggingface_hub/inference/_generated/types/chat_completion.py +29 -30
  14. huggingface_hub/inference/_generated/types/depth_estimation.py +3 -4
  15. huggingface_hub/inference/_generated/types/document_question_answering.py +5 -6
  16. huggingface_hub/inference/_generated/types/feature_extraction.py +5 -6
  17. huggingface_hub/inference/_generated/types/fill_mask.py +4 -5
  18. huggingface_hub/inference/_generated/types/image_classification.py +4 -5
  19. huggingface_hub/inference/_generated/types/image_segmentation.py +4 -5
  20. huggingface_hub/inference/_generated/types/image_to_image.py +5 -6
  21. huggingface_hub/inference/_generated/types/image_to_text.py +5 -6
  22. huggingface_hub/inference/_generated/types/object_detection.py +5 -6
  23. huggingface_hub/inference/_generated/types/question_answering.py +5 -6
  24. huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -4
  25. huggingface_hub/inference/_generated/types/summarization.py +4 -5
  26. huggingface_hub/inference/_generated/types/table_question_answering.py +5 -6
  27. huggingface_hub/inference/_generated/types/text2text_generation.py +4 -5
  28. huggingface_hub/inference/_generated/types/text_classification.py +4 -5
  29. huggingface_hub/inference/_generated/types/text_generation.py +12 -13
  30. huggingface_hub/inference/_generated/types/text_to_audio.py +5 -6
  31. huggingface_hub/inference/_generated/types/text_to_image.py +8 -15
  32. huggingface_hub/inference/_generated/types/text_to_speech.py +5 -6
  33. huggingface_hub/inference/_generated/types/text_to_video.py +4 -5
  34. huggingface_hub/inference/_generated/types/token_classification.py +4 -5
  35. huggingface_hub/inference/_generated/types/translation.py +4 -5
  36. huggingface_hub/inference/_generated/types/video_classification.py +4 -5
  37. huggingface_hub/inference/_generated/types/visual_question_answering.py +5 -6
  38. huggingface_hub/inference/_generated/types/zero_shot_classification.py +4 -5
  39. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +4 -5
  40. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +5 -6
  41. huggingface_hub/inference/_providers/__init__.py +44 -8
  42. huggingface_hub/inference/_providers/_common.py +239 -0
  43. huggingface_hub/inference/_providers/black_forest_labs.py +66 -0
  44. huggingface_hub/inference/_providers/fal_ai.py +31 -100
  45. huggingface_hub/inference/_providers/fireworks_ai.py +6 -0
  46. huggingface_hub/inference/_providers/hf_inference.py +58 -142
  47. huggingface_hub/inference/_providers/hyperbolic.py +43 -0
  48. huggingface_hub/inference/_providers/nebius.py +41 -0
  49. huggingface_hub/inference/_providers/novita.py +26 -0
  50. huggingface_hub/inference/_providers/replicate.py +24 -119
  51. huggingface_hub/inference/_providers/sambanova.py +3 -86
  52. huggingface_hub/inference/_providers/together.py +36 -130
  53. huggingface_hub/utils/_headers.py +5 -0
  54. huggingface_hub/utils/_hf_folder.py +4 -32
  55. huggingface_hub/utils/_http.py +85 -2
  56. huggingface_hub/utils/_typing.py +1 -1
  57. huggingface_hub/utils/logging.py +6 -0
  58. {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/METADATA +1 -1
  59. {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/RECORD +63 -57
  60. {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/LICENSE +0 -0
  61. {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/WHEEL +0 -0
  62. {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/entry_points.txt +0 -0
  63. {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/top_level.txt +0 -0
@@ -3,13 +3,12 @@
3
3
  # See:
4
4
  # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
5
  # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
- from dataclasses import dataclass
7
6
  from typing import List, Optional
8
7
 
9
- from .base import BaseInferenceType
8
+ from .base import BaseInferenceType, dataclass_with_extra
10
9
 
11
10
 
12
- @dataclass
11
+ @dataclass_with_extra
13
12
  class ZeroShotClassificationParameters(BaseInferenceType):
14
13
  """Additional inference parameters for Zero Shot Classification"""
15
14
 
@@ -26,7 +25,7 @@ class ZeroShotClassificationParameters(BaseInferenceType):
26
25
  """
27
26
 
28
27
 
29
- @dataclass
28
+ @dataclass_with_extra
30
29
  class ZeroShotClassificationInput(BaseInferenceType):
31
30
  """Inputs for Zero Shot Classification inference"""
32
31
 
@@ -36,7 +35,7 @@ class ZeroShotClassificationInput(BaseInferenceType):
36
35
  """Additional inference parameters for Zero Shot Classification"""
37
36
 
38
37
 
39
- @dataclass
38
+ @dataclass_with_extra
40
39
  class ZeroShotClassificationOutputElement(BaseInferenceType):
41
40
  """Outputs of inference for the Zero Shot Classification task"""
42
41
 
@@ -3,13 +3,12 @@
3
3
  # See:
4
4
  # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
5
  # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
- from dataclasses import dataclass
7
6
  from typing import List, Optional
8
7
 
9
- from .base import BaseInferenceType
8
+ from .base import BaseInferenceType, dataclass_with_extra
10
9
 
11
10
 
12
- @dataclass
11
+ @dataclass_with_extra
13
12
  class ZeroShotImageClassificationParameters(BaseInferenceType):
14
13
  """Additional inference parameters for Zero Shot Image Classification"""
15
14
 
@@ -21,7 +20,7 @@ class ZeroShotImageClassificationParameters(BaseInferenceType):
21
20
  """
22
21
 
23
22
 
24
- @dataclass
23
+ @dataclass_with_extra
25
24
  class ZeroShotImageClassificationInput(BaseInferenceType):
26
25
  """Inputs for Zero Shot Image Classification inference"""
27
26
 
@@ -31,7 +30,7 @@ class ZeroShotImageClassificationInput(BaseInferenceType):
31
30
  """Additional inference parameters for Zero Shot Image Classification"""
32
31
 
33
32
 
34
- @dataclass
33
+ @dataclass_with_extra
35
34
  class ZeroShotImageClassificationOutputElement(BaseInferenceType):
36
35
  """Outputs of inference for the Zero Shot Image Classification task"""
37
36
 
@@ -3,13 +3,12 @@
3
3
  # See:
4
4
  # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
5
  # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
- from dataclasses import dataclass
7
6
  from typing import List
8
7
 
9
- from .base import BaseInferenceType
8
+ from .base import BaseInferenceType, dataclass_with_extra
10
9
 
11
10
 
12
- @dataclass
11
+ @dataclass_with_extra
13
12
  class ZeroShotObjectDetectionParameters(BaseInferenceType):
14
13
  """Additional inference parameters for Zero Shot Object Detection"""
15
14
 
@@ -17,7 +16,7 @@ class ZeroShotObjectDetectionParameters(BaseInferenceType):
17
16
  """The candidate labels for this image"""
18
17
 
19
18
 
20
- @dataclass
19
+ @dataclass_with_extra
21
20
  class ZeroShotObjectDetectionInput(BaseInferenceType):
22
21
  """Inputs for Zero Shot Object Detection inference"""
23
22
 
@@ -27,7 +26,7 @@ class ZeroShotObjectDetectionInput(BaseInferenceType):
27
26
  """Additional inference parameters for Zero Shot Object Detection"""
28
27
 
29
28
 
30
- @dataclass
29
+ @dataclass_with_extra
31
30
  class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
32
31
  """The predicted bounding box. Coordinates are relative to the top left corner of the input
33
32
  image.
@@ -39,7 +38,7 @@ class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
39
38
  ymin: int
40
39
 
41
40
 
42
- @dataclass
41
+ @dataclass_with_extra
43
42
  class ZeroShotObjectDetectionOutputElement(BaseInferenceType):
44
43
  """Outputs of inference for the Zero Shot Object Detection task"""
45
44
 
@@ -1,27 +1,49 @@
1
1
  from typing import Dict, Literal
2
2
 
3
- from .._common import TaskProviderHelper
4
- from .fal_ai import FalAIAutomaticSpeechRecognitionTask, FalAITextToImageTask, FalAITextToVideoTask
3
+ from ._common import TaskProviderHelper
4
+ from .black_forest_labs import BlackForestLabsTextToImageTask
5
+ from .fal_ai import (
6
+ FalAIAutomaticSpeechRecognitionTask,
7
+ FalAITextToImageTask,
8
+ FalAITextToSpeechTask,
9
+ FalAITextToVideoTask,
10
+ )
11
+ from .fireworks_ai import FireworksAIConversationalTask
5
12
  from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
13
+ from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
14
+ from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
15
+ from .novita import NovitaConversationalTask, NovitaTextGenerationTask
6
16
  from .replicate import ReplicateTask, ReplicateTextToSpeechTask
7
17
  from .sambanova import SambanovaConversationalTask
8
- from .together import TogetherTextGenerationTask, TogetherTextToImageTask
18
+ from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
9
19
 
10
20
 
11
21
  PROVIDER_T = Literal[
22
+ "black-forest-labs",
12
23
  "fal-ai",
24
+ "fireworks-ai",
13
25
  "hf-inference",
26
+ "hyperbolic",
27
+ "nebius",
28
+ "novita",
14
29
  "replicate",
15
30
  "sambanova",
16
31
  "together",
17
32
  ]
18
33
 
19
34
  PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
35
+ "black-forest-labs": {
36
+ "text-to-image": BlackForestLabsTextToImageTask(),
37
+ },
20
38
  "fal-ai": {
21
- "text-to-image": FalAITextToImageTask(),
22
39
  "automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
40
+ "text-to-image": FalAITextToImageTask(),
41
+ "text-to-speech": FalAITextToSpeechTask(),
23
42
  "text-to-video": FalAITextToVideoTask(),
24
43
  },
44
+ "fireworks-ai": {
45
+ "conversational": FireworksAIConversationalTask(),
46
+ },
25
47
  "hf-inference": {
26
48
  "text-to-image": HFInferenceTask("text-to-image"),
27
49
  "conversational": HFInferenceConversational(),
@@ -35,9 +57,9 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
35
57
  "image-classification": HFInferenceBinaryInputTask("image-classification"),
36
58
  "image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
37
59
  "document-question-answering": HFInferenceTask("document-question-answering"),
38
- "image-to-text": HFInferenceTask("image-to-text"),
60
+ "image-to-text": HFInferenceBinaryInputTask("image-to-text"),
39
61
  "object-detection": HFInferenceBinaryInputTask("object-detection"),
40
- "audio-to-audio": HFInferenceTask("audio-to-audio"),
62
+ "audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
41
63
  "zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
42
64
  "zero-shot-classification": HFInferenceTask("zero-shot-classification"),
43
65
  "image-to-image": HFInferenceBinaryInputTask("image-to-image"),
@@ -50,6 +72,20 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
50
72
  "summarization": HFInferenceTask("summarization"),
51
73
  "visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
52
74
  },
75
+ "hyperbolic": {
76
+ "text-to-image": HyperbolicTextToImageTask(),
77
+ "conversational": HyperbolicTextGenerationTask("conversational"),
78
+ "text-generation": HyperbolicTextGenerationTask("text-generation"),
79
+ },
80
+ "nebius": {
81
+ "text-to-image": NebiusTextToImageTask(),
82
+ "conversational": NebiusConversationalTask(),
83
+ "text-generation": NebiusTextGenerationTask(),
84
+ },
85
+ "novita": {
86
+ "text-generation": NovitaTextGenerationTask(),
87
+ "conversational": NovitaConversationalTask(),
88
+ },
53
89
  "replicate": {
54
90
  "text-to-image": ReplicateTask("text-to-image"),
55
91
  "text-to-speech": ReplicateTextToSpeechTask(),
@@ -60,8 +96,8 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
60
96
  },
61
97
  "together": {
62
98
  "text-to-image": TogetherTextToImageTask(),
63
- "conversational": TogetherTextGenerationTask("conversational"),
64
- "text-generation": TogetherTextGenerationTask("text-generation"),
99
+ "conversational": TogetherConversationalTask(),
100
+ "text-generation": TogetherTextGenerationTask(),
65
101
  },
66
102
  }
67
103
 
@@ -0,0 +1,239 @@
1
+ from functools import lru_cache
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ from huggingface_hub import constants
5
+ from huggingface_hub.inference._common import RequestParameters
6
+ from huggingface_hub.utils import build_hf_headers, get_token, logging
7
+
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ # Dev purposes only.
13
+ # If you want to try to run inference for a new model locally before it's registered on huggingface.co
14
+ # for a given Inference Provider, you can add it to the following dictionary.
15
+ HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = {
16
+ # "HF model ID" => "Model ID on Inference Provider's side"
17
+ #
18
+ # Example:
19
+ # "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
20
+ "fal-ai": {},
21
+ "fireworks-ai": {},
22
+ "hf-inference": {},
23
+ "hyperbolic": {},
24
+ "nebius": {},
25
+ "replicate": {},
26
+ "sambanova": {},
27
+ "together": {},
28
+ }
29
+
30
+
31
+ def filter_none(d: Dict[str, Any]) -> Dict[str, Any]:
32
+ return {k: v for k, v in d.items() if v is not None}
33
+
34
+
35
+ class TaskProviderHelper:
36
+ """Base class for task-specific provider helpers."""
37
+
38
+ def __init__(self, provider: str, base_url: str, task: str) -> None:
39
+ self.provider = provider
40
+ self.task = task
41
+ self.base_url = base_url
42
+
43
+ def prepare_request(
44
+ self,
45
+ *,
46
+ inputs: Any,
47
+ parameters: Dict[str, Any],
48
+ headers: Dict,
49
+ model: Optional[str],
50
+ api_key: Optional[str],
51
+ extra_payload: Optional[Dict[str, Any]] = None,
52
+ ) -> RequestParameters:
53
+ """
54
+ Prepare the request to be sent to the provider.
55
+
56
+ Each step (api_key, model, headers, url, payload) can be customized in subclasses.
57
+ """
58
+ # api_key from user, or local token, or raise error
59
+ api_key = self._prepare_api_key(api_key)
60
+
61
+ # mapped model from HF model ID
62
+ mapped_model = self._prepare_mapped_model(model)
63
+
64
+ # default HF headers + user headers (to customize in subclasses)
65
+ headers = self._prepare_headers(headers, api_key)
66
+
67
+ # routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
68
+ url = self._prepare_url(api_key, mapped_model)
69
+
70
+ # prepare payload (to customize in subclasses)
71
+ payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
72
+ if payload is not None:
73
+ payload = recursive_merge(payload, extra_payload or {})
74
+
75
+ # body data (to customize in subclasses)
76
+ data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
77
+
78
+ # check if both payload and data are set and return
79
+ if payload is not None and data is not None:
80
+ raise ValueError("Both payload and data cannot be set in the same request.")
81
+ if payload is None and data is None:
82
+ raise ValueError("Either payload or data must be set in the request.")
83
+ return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
84
+
85
+ def get_response(self, response: Union[bytes, Dict]) -> Any:
86
+ """
87
+ Return the response in the expected format.
88
+
89
+ Override this method in subclasses for customized response handling."""
90
+ return response
91
+
92
+ def _prepare_api_key(self, api_key: Optional[str]) -> str:
93
+ """Return the API key to use for the request.
94
+
95
+ Usually not overwritten in subclasses."""
96
+ if api_key is None:
97
+ api_key = get_token()
98
+ if api_key is None:
99
+ raise ValueError(
100
+ f"You must provide an api_key to work with {self.provider} API or log in with `huggingface-cli login`."
101
+ )
102
+ return api_key
103
+
104
+ def _prepare_mapped_model(self, model: Optional[str]) -> str:
105
+ """Return the mapped model ID to use for the request.
106
+
107
+ Usually not overwritten in subclasses."""
108
+ if model is None:
109
+ raise ValueError(f"Please provide an HF model ID supported by {self.provider}.")
110
+
111
+ # hardcoded mapping for local testing
112
+ if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model):
113
+ return HARDCODED_MODEL_ID_MAPPING[self.provider][model]
114
+
115
+ provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
116
+ if provider_mapping is None:
117
+ raise ValueError(f"Model {model} is not supported by provider {self.provider}.")
118
+
119
+ if provider_mapping.task != self.task:
120
+ raise ValueError(
121
+ f"Model {model} is not supported for task {self.task} and provider {self.provider}. "
122
+ f"Supported task: {provider_mapping.task}."
123
+ )
124
+
125
+ if provider_mapping.status == "staging":
126
+ logger.warning(
127
+ f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only."
128
+ )
129
+ return provider_mapping.provider_id
130
+
131
+ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
132
+ """Return the headers to use for the request.
133
+
134
+ Override this method in subclasses for customized headers.
135
+ """
136
+ return {**build_hf_headers(token=api_key), **headers}
137
+
138
+ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
139
+ """Return the URL to use for the request.
140
+
141
+ Usually not overwritten in subclasses."""
142
+ base_url = self._prepare_base_url(api_key)
143
+ route = self._prepare_route(mapped_model)
144
+ return f"{base_url.rstrip('/')}/{route.lstrip('/')}"
145
+
146
+ def _prepare_base_url(self, api_key: str) -> str:
147
+ """Return the base URL to use for the request.
148
+
149
+ Usually not overwritten in subclasses."""
150
+ # Route to the proxy if the api_key is a HF TOKEN
151
+ if api_key.startswith("hf_"):
152
+ logger.info(f"Calling '{self.provider}' provider through Hugging Face router.")
153
+ return constants.INFERENCE_PROXY_TEMPLATE.format(provider=self.provider)
154
+ else:
155
+ logger.info(f"Calling '{self.provider}' provider directly.")
156
+ return self.base_url
157
+
158
+ def _prepare_route(self, mapped_model: str) -> str:
159
+ """Return the route to use for the request.
160
+
161
+ Override this method in subclasses for customized routes.
162
+ """
163
+ return ""
164
+
165
+ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
166
+ """Return the payload to use for the request, as a dict.
167
+
168
+ Override this method in subclasses for customized payloads.
169
+ Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
170
+ """
171
+ return None
172
+
173
+ def _prepare_payload_as_bytes(
174
+ self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
175
+ ) -> Optional[bytes]:
176
+ """Return the body to use for the request, as bytes.
177
+
178
+ Override this method in subclasses for customized body data.
179
+ Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
180
+ """
181
+ return None
182
+
183
+
184
+ class BaseConversationalTask(TaskProviderHelper):
185
+ """
186
+ Base class for conversational (chat completion) tasks.
187
+ The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat
188
+ """
189
+
190
+ def __init__(self, provider: str, base_url: str):
191
+ super().__init__(provider=provider, base_url=base_url, task="conversational")
192
+
193
+ def _prepare_route(self, mapped_model: str) -> str:
194
+ return "/v1/chat/completions"
195
+
196
+ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
197
+ return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
198
+
199
+
200
+ class BaseTextGenerationTask(TaskProviderHelper):
201
+ """
202
+ Base class for text-generation (completion) tasks.
203
+ The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions
204
+ """
205
+
206
+ def __init__(self, provider: str, base_url: str):
207
+ super().__init__(provider=provider, base_url=base_url, task="text-generation")
208
+
209
+ def _prepare_route(self, mapped_model: str) -> str:
210
+ return "/v1/completions"
211
+
212
+ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
213
+ return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
214
+
215
+
216
+ @lru_cache(maxsize=None)
217
+ def _fetch_inference_provider_mapping(model: str) -> Dict:
218
+ """
219
+ Fetch provider mappings for a model from the Hub.
220
+ """
221
+ from huggingface_hub.hf_api import HfApi
222
+
223
+ info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
224
+ provider_mapping = info.inference_provider_mapping
225
+ if provider_mapping is None:
226
+ raise ValueError(f"No provider mapping found for model {model}")
227
+ return provider_mapping
228
+
229
+
230
+ def recursive_merge(dict1: Dict, dict2: Dict) -> Dict:
231
+ return {
232
+ **dict1,
233
+ **{
234
+ key: recursive_merge(dict1[key], value)
235
+ if (key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict))
236
+ else value
237
+ for key, value in dict2.items()
238
+ },
239
+ }
@@ -0,0 +1,66 @@
1
+ import time
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ from huggingface_hub.inference._common import _as_dict
5
+ from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
6
+ from huggingface_hub.utils import logging
7
+ from huggingface_hub.utils._http import get_session
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ MAX_POLLING_ATTEMPTS = 6
13
+ POLLING_INTERVAL = 1.0
14
+
15
+
16
+ class BlackForestLabsTextToImageTask(TaskProviderHelper):
17
+ def __init__(self):
18
+ super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai/v1", task="text-to-image")
19
+
20
+ def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
21
+ headers = super()._prepare_headers(headers, api_key)
22
+ if not api_key.startswith("hf_"):
23
+ _ = headers.pop("authorization")
24
+ headers["X-Key"] = api_key
25
+ return headers
26
+
27
+ def _prepare_route(self, mapped_model: str) -> str:
28
+ return mapped_model
29
+
30
+ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
31
+ parameters = filter_none(parameters)
32
+ if "num_inference_steps" in parameters:
33
+ parameters["steps"] = parameters.pop("num_inference_steps")
34
+ if "guidance_scale" in parameters:
35
+ parameters["guidance"] = parameters.pop("guidance_scale")
36
+
37
+ return {"prompt": inputs, **parameters}
38
+
39
+ def get_response(self, response: Union[bytes, Dict]) -> Any:
40
+ """
41
+ Polling mechanism for Black Forest Labs since the API is asynchronous.
42
+ """
43
+ url = _as_dict(response).get("polling_url")
44
+ session = get_session()
45
+ for _ in range(MAX_POLLING_ATTEMPTS):
46
+ time.sleep(POLLING_INTERVAL)
47
+
48
+ response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
49
+ response.raise_for_status() # type: ignore
50
+ response_json: Dict = response.json() # type: ignore
51
+ status = response_json.get("status")
52
+ logger.info(
53
+ f"Polling generation result from {url}. Current status: {status}. "
54
+ f"Will retry after {POLLING_INTERVAL} seconds if not ready."
55
+ )
56
+
57
+ if (
58
+ status == "Ready"
59
+ and isinstance(response_json.get("result"), dict)
60
+ and (sample_url := response_json["result"].get("sample"))
61
+ ):
62
+ image_resp = session.get(sample_url)
63
+ image_resp.raise_for_status()
64
+ return image_resp.content
65
+
66
+ raise TimeoutError(f"Failed to get the image URL after {MAX_POLLING_ATTEMPTS} attempts.")