huggingface-hub 0.29.0rc2__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +160 -46
- huggingface_hub/_commit_api.py +277 -71
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +33 -22
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +241 -81
- huggingface_hub/_space_api.py +18 -10
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +196 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +15 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +83 -59
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +99 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +606 -346
- huggingface_hub/hf_api.py +2445 -1132
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +61 -66
- huggingface_hub/inference/_client.py +501 -630
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +536 -722
- huggingface_hub/inference/_generated/types/__init__.py +6 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +5 -6
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +77 -31
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +8 -2
- huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +11 -11
- huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +149 -20
- huggingface_hub/inference/_providers/_common.py +160 -37
- huggingface_hub/inference/_providers/black_forest_labs.py +12 -9
- huggingface_hub/inference/_providers/cerebras.py +6 -0
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +32 -0
- huggingface_hub/inference/_providers/fal_ai.py +231 -22
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +22 -1
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +143 -33
- huggingface_hub/inference/_providers/hyperbolic.py +9 -5
- huggingface_hub/inference/_providers/nebius.py +47 -5
- huggingface_hub/inference/_providers/novita.py +48 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +25 -0
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +46 -9
- huggingface_hub/inference/_providers/sambanova.py +37 -1
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +34 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +79 -59
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +27 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +399 -237
- huggingface_hub/utils/_pagination.py +6 -6
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +74 -22
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +13 -11
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +235 -0
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +33 -4
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -82
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -428
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -299
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.29.0rc2.dist-info/RECORD +0 -131
- huggingface_hub-0.29.0rc2.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -1,49 +1,120 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Literal, Optional, Union
|
|
2
2
|
|
|
3
|
-
from .
|
|
3
|
+
from huggingface_hub.inference._providers.featherless_ai import (
|
|
4
|
+
FeatherlessConversationalTask,
|
|
5
|
+
FeatherlessTextGenerationTask,
|
|
6
|
+
)
|
|
7
|
+
from huggingface_hub.utils import logging
|
|
8
|
+
|
|
9
|
+
from ._common import AutoRouterConversationalTask, TaskProviderHelper, _fetch_inference_provider_mapping
|
|
4
10
|
from .black_forest_labs import BlackForestLabsTextToImageTask
|
|
11
|
+
from .cerebras import CerebrasConversationalTask
|
|
12
|
+
from .clarifai import ClarifaiConversationalTask
|
|
13
|
+
from .cohere import CohereConversationalTask
|
|
5
14
|
from .fal_ai import (
|
|
6
15
|
FalAIAutomaticSpeechRecognitionTask,
|
|
16
|
+
FalAIImageSegmentationTask,
|
|
17
|
+
FalAIImageToImageTask,
|
|
18
|
+
FalAIImageToVideoTask,
|
|
7
19
|
FalAITextToImageTask,
|
|
8
20
|
FalAITextToSpeechTask,
|
|
9
21
|
FalAITextToVideoTask,
|
|
10
22
|
)
|
|
11
23
|
from .fireworks_ai import FireworksAIConversationalTask
|
|
12
|
-
from .
|
|
24
|
+
from .groq import GroqConversationalTask
|
|
25
|
+
from .hf_inference import (
|
|
26
|
+
HFInferenceBinaryInputTask,
|
|
27
|
+
HFInferenceConversational,
|
|
28
|
+
HFInferenceFeatureExtractionTask,
|
|
29
|
+
HFInferenceTask,
|
|
30
|
+
)
|
|
13
31
|
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
|
|
14
|
-
from .nebius import
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
32
|
+
from .nebius import (
|
|
33
|
+
NebiusConversationalTask,
|
|
34
|
+
NebiusFeatureExtractionTask,
|
|
35
|
+
NebiusTextGenerationTask,
|
|
36
|
+
NebiusTextToImageTask,
|
|
37
|
+
)
|
|
38
|
+
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
|
|
39
|
+
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
|
|
40
|
+
from .openai import OpenAIConversationalTask
|
|
41
|
+
from .publicai import PublicAIConversationalTask
|
|
42
|
+
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
|
|
43
|
+
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
|
|
44
|
+
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
|
|
18
45
|
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
|
|
46
|
+
from .wavespeed import (
|
|
47
|
+
WavespeedAIImageToImageTask,
|
|
48
|
+
WavespeedAIImageToVideoTask,
|
|
49
|
+
WavespeedAITextToImageTask,
|
|
50
|
+
WavespeedAITextToVideoTask,
|
|
51
|
+
)
|
|
52
|
+
from .zai_org import ZaiConversationalTask
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
logger = logging.get_logger(__name__)
|
|
19
56
|
|
|
20
57
|
|
|
21
58
|
PROVIDER_T = Literal[
|
|
22
59
|
"black-forest-labs",
|
|
60
|
+
"cerebras",
|
|
61
|
+
"clarifai",
|
|
62
|
+
"cohere",
|
|
23
63
|
"fal-ai",
|
|
64
|
+
"featherless-ai",
|
|
24
65
|
"fireworks-ai",
|
|
66
|
+
"groq",
|
|
25
67
|
"hf-inference",
|
|
26
68
|
"hyperbolic",
|
|
27
69
|
"nebius",
|
|
28
70
|
"novita",
|
|
71
|
+
"nscale",
|
|
72
|
+
"openai",
|
|
73
|
+
"publicai",
|
|
29
74
|
"replicate",
|
|
30
75
|
"sambanova",
|
|
76
|
+
"scaleway",
|
|
31
77
|
"together",
|
|
78
|
+
"wavespeed",
|
|
79
|
+
"zai-org",
|
|
32
80
|
]
|
|
33
81
|
|
|
34
|
-
|
|
82
|
+
PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]
|
|
83
|
+
|
|
84
|
+
CONVERSATIONAL_AUTO_ROUTER = AutoRouterConversationalTask()
|
|
85
|
+
|
|
86
|
+
PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
|
|
35
87
|
"black-forest-labs": {
|
|
36
88
|
"text-to-image": BlackForestLabsTextToImageTask(),
|
|
37
89
|
},
|
|
90
|
+
"cerebras": {
|
|
91
|
+
"conversational": CerebrasConversationalTask(),
|
|
92
|
+
},
|
|
93
|
+
"clarifai": {
|
|
94
|
+
"conversational": ClarifaiConversationalTask(),
|
|
95
|
+
},
|
|
96
|
+
"cohere": {
|
|
97
|
+
"conversational": CohereConversationalTask(),
|
|
98
|
+
},
|
|
38
99
|
"fal-ai": {
|
|
39
100
|
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
|
|
40
101
|
"text-to-image": FalAITextToImageTask(),
|
|
41
102
|
"text-to-speech": FalAITextToSpeechTask(),
|
|
42
103
|
"text-to-video": FalAITextToVideoTask(),
|
|
104
|
+
"image-to-video": FalAIImageToVideoTask(),
|
|
105
|
+
"image-to-image": FalAIImageToImageTask(),
|
|
106
|
+
"image-segmentation": FalAIImageSegmentationTask(),
|
|
107
|
+
},
|
|
108
|
+
"featherless-ai": {
|
|
109
|
+
"conversational": FeatherlessConversationalTask(),
|
|
110
|
+
"text-generation": FeatherlessTextGenerationTask(),
|
|
43
111
|
},
|
|
44
112
|
"fireworks-ai": {
|
|
45
113
|
"conversational": FireworksAIConversationalTask(),
|
|
46
114
|
},
|
|
115
|
+
"groq": {
|
|
116
|
+
"conversational": GroqConversationalTask(),
|
|
117
|
+
},
|
|
47
118
|
"hf-inference": {
|
|
48
119
|
"text-to-image": HFInferenceTask("text-to-image"),
|
|
49
120
|
"conversational": HFInferenceConversational(),
|
|
@@ -53,7 +124,7 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
|
|
53
124
|
"audio-classification": HFInferenceBinaryInputTask("audio-classification"),
|
|
54
125
|
"automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"),
|
|
55
126
|
"fill-mask": HFInferenceTask("fill-mask"),
|
|
56
|
-
"feature-extraction":
|
|
127
|
+
"feature-extraction": HFInferenceFeatureExtractionTask(),
|
|
57
128
|
"image-classification": HFInferenceBinaryInputTask("image-classification"),
|
|
58
129
|
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
|
|
59
130
|
"document-question-answering": HFInferenceTask("document-question-answering"),
|
|
@@ -81,45 +152,103 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
|
|
81
152
|
"text-to-image": NebiusTextToImageTask(),
|
|
82
153
|
"conversational": NebiusConversationalTask(),
|
|
83
154
|
"text-generation": NebiusTextGenerationTask(),
|
|
155
|
+
"feature-extraction": NebiusFeatureExtractionTask(),
|
|
84
156
|
},
|
|
85
157
|
"novita": {
|
|
86
158
|
"text-generation": NovitaTextGenerationTask(),
|
|
87
159
|
"conversational": NovitaConversationalTask(),
|
|
160
|
+
"text-to-video": NovitaTextToVideoTask(),
|
|
161
|
+
},
|
|
162
|
+
"nscale": {
|
|
163
|
+
"conversational": NscaleConversationalTask(),
|
|
164
|
+
"text-to-image": NscaleTextToImageTask(),
|
|
165
|
+
},
|
|
166
|
+
"openai": {
|
|
167
|
+
"conversational": OpenAIConversationalTask(),
|
|
168
|
+
},
|
|
169
|
+
"publicai": {
|
|
170
|
+
"conversational": PublicAIConversationalTask(),
|
|
88
171
|
},
|
|
89
172
|
"replicate": {
|
|
90
|
-
"
|
|
173
|
+
"image-to-image": ReplicateImageToImageTask(),
|
|
174
|
+
"text-to-image": ReplicateTextToImageTask(),
|
|
91
175
|
"text-to-speech": ReplicateTextToSpeechTask(),
|
|
92
176
|
"text-to-video": ReplicateTask("text-to-video"),
|
|
93
177
|
},
|
|
94
178
|
"sambanova": {
|
|
95
179
|
"conversational": SambanovaConversationalTask(),
|
|
180
|
+
"feature-extraction": SambanovaFeatureExtractionTask(),
|
|
181
|
+
},
|
|
182
|
+
"scaleway": {
|
|
183
|
+
"conversational": ScalewayConversationalTask(),
|
|
184
|
+
"feature-extraction": ScalewayFeatureExtractionTask(),
|
|
96
185
|
},
|
|
97
186
|
"together": {
|
|
98
187
|
"text-to-image": TogetherTextToImageTask(),
|
|
99
188
|
"conversational": TogetherConversationalTask(),
|
|
100
189
|
"text-generation": TogetherTextGenerationTask(),
|
|
101
190
|
},
|
|
191
|
+
"wavespeed": {
|
|
192
|
+
"text-to-image": WavespeedAITextToImageTask(),
|
|
193
|
+
"text-to-video": WavespeedAITextToVideoTask(),
|
|
194
|
+
"image-to-image": WavespeedAIImageToImageTask(),
|
|
195
|
+
"image-to-video": WavespeedAIImageToVideoTask(),
|
|
196
|
+
},
|
|
197
|
+
"zai-org": {
|
|
198
|
+
"conversational": ZaiConversationalTask(),
|
|
199
|
+
},
|
|
102
200
|
}
|
|
103
201
|
|
|
104
202
|
|
|
105
|
-
def get_provider_helper(
|
|
203
|
+
def get_provider_helper(
|
|
204
|
+
provider: Optional[PROVIDER_OR_POLICY_T], task: str, model: Optional[str]
|
|
205
|
+
) -> TaskProviderHelper:
|
|
106
206
|
"""Get provider helper instance by name and task.
|
|
107
207
|
|
|
108
208
|
Args:
|
|
109
|
-
provider (str):
|
|
110
|
-
task (str): Name of the task
|
|
111
|
-
|
|
209
|
+
provider (`str`, *optional*): name of the provider, or "auto" to automatically select the provider for the model.
|
|
210
|
+
task (`str`): Name of the task
|
|
211
|
+
model (`str`, *optional*): Name of the model
|
|
112
212
|
Returns:
|
|
113
213
|
TaskProviderHelper: Helper instance for the specified provider and task
|
|
114
214
|
|
|
115
215
|
Raises:
|
|
116
216
|
ValueError: If provider or task is not supported
|
|
117
217
|
"""
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
218
|
+
|
|
219
|
+
if (model is None and provider in (None, "auto")) or (
|
|
220
|
+
model is not None and model.startswith(("http://", "https://"))
|
|
221
|
+
):
|
|
222
|
+
provider = "hf-inference"
|
|
223
|
+
|
|
224
|
+
if provider is None:
|
|
225
|
+
logger.info(
|
|
226
|
+
"No provider specified for task `conversational`. Defaulting to server-side auto routing."
|
|
227
|
+
if task == "conversational"
|
|
228
|
+
else "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
|
|
229
|
+
)
|
|
230
|
+
provider = "auto"
|
|
231
|
+
|
|
232
|
+
if provider == "auto":
|
|
233
|
+
if model is None:
|
|
234
|
+
raise ValueError("Specifying a model is required when provider is 'auto'")
|
|
235
|
+
if task == "conversational":
|
|
236
|
+
# Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
|
|
237
|
+
return CONVERSATIONAL_AUTO_ROUTER
|
|
238
|
+
|
|
239
|
+
provider_mapping = _fetch_inference_provider_mapping(model)
|
|
240
|
+
provider = next(iter(provider_mapping)).provider
|
|
241
|
+
|
|
242
|
+
provider_tasks = PROVIDERS.get(provider) # type: ignore
|
|
243
|
+
if provider_tasks is None:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"Provider '{provider}' not supported. Available values: 'auto' or any provider from {list(PROVIDERS.keys())}."
|
|
246
|
+
"Passing 'auto' (default value) will automatically select the first provider available for the model, sorted "
|
|
247
|
+
"by the user's order in https://hf.co/settings/inference-providers."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if task not in provider_tasks:
|
|
121
251
|
raise ValueError(
|
|
122
|
-
f"Task '{task}' not supported for provider '{provider}'. "
|
|
123
|
-
f"Available tasks: {list(PROVIDERS[provider].keys())}"
|
|
252
|
+
f"Task '{task}' not supported for provider '{provider}'. Available tasks: {list(provider_tasks.keys())}"
|
|
124
253
|
)
|
|
125
|
-
return
|
|
254
|
+
return provider_tasks[task]
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from functools import lru_cache
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional, Union, overload
|
|
3
3
|
|
|
4
4
|
from huggingface_hub import constants
|
|
5
|
-
from huggingface_hub.
|
|
5
|
+
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
6
|
+
from huggingface_hub.inference._common import MimeBytes, RequestParameters
|
|
7
|
+
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage
|
|
6
8
|
from huggingface_hub.utils import build_hf_headers, get_token, logging
|
|
7
9
|
|
|
8
10
|
|
|
@@ -12,24 +14,54 @@ logger = logging.get_logger(__name__)
|
|
|
12
14
|
# Dev purposes only.
|
|
13
15
|
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
|
|
14
16
|
# for a given Inference Provider, you can add it to the following dictionary.
|
|
15
|
-
|
|
16
|
-
# "HF model ID" => "Model ID on Inference Provider's side"
|
|
17
|
+
HARDCODED_MODEL_INFERENCE_MAPPING: dict[str, dict[str, InferenceProviderMapping]] = {
|
|
18
|
+
# "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side"
|
|
17
19
|
#
|
|
18
20
|
# Example:
|
|
19
|
-
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
21
|
+
# "Qwen/Qwen2.5-Coder-32B-Instruct": InferenceProviderMapping(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
22
|
+
# provider_id="Qwen2.5-Coder-32B-Instruct",
|
|
23
|
+
# task="conversational",
|
|
24
|
+
# status="live")
|
|
25
|
+
"cerebras": {},
|
|
26
|
+
"cohere": {},
|
|
27
|
+
"clarifai": {},
|
|
20
28
|
"fal-ai": {},
|
|
21
29
|
"fireworks-ai": {},
|
|
30
|
+
"groq": {},
|
|
22
31
|
"hf-inference": {},
|
|
23
32
|
"hyperbolic": {},
|
|
24
33
|
"nebius": {},
|
|
34
|
+
"nscale": {},
|
|
25
35
|
"replicate": {},
|
|
26
36
|
"sambanova": {},
|
|
37
|
+
"scaleway": {},
|
|
27
38
|
"together": {},
|
|
39
|
+
"wavespeed": {},
|
|
40
|
+
"zai-org": {},
|
|
28
41
|
}
|
|
29
42
|
|
|
30
43
|
|
|
31
|
-
|
|
32
|
-
|
|
44
|
+
@overload
|
|
45
|
+
def filter_none(obj: dict[str, Any]) -> dict[str, Any]: ...
|
|
46
|
+
@overload
|
|
47
|
+
def filter_none(obj: list[Any]) -> list[Any]: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def filter_none(obj: Union[dict[str, Any], list[Any]]) -> Union[dict[str, Any], list[Any]]:
|
|
51
|
+
if isinstance(obj, dict):
|
|
52
|
+
cleaned: dict[str, Any] = {}
|
|
53
|
+
for k, v in obj.items():
|
|
54
|
+
if v is None:
|
|
55
|
+
continue
|
|
56
|
+
if isinstance(v, (dict, list)):
|
|
57
|
+
v = filter_none(v)
|
|
58
|
+
cleaned[k] = v
|
|
59
|
+
return cleaned
|
|
60
|
+
|
|
61
|
+
if isinstance(obj, list):
|
|
62
|
+
return [filter_none(v) if isinstance(v, (dict, list)) else v for v in obj]
|
|
63
|
+
|
|
64
|
+
raise ValueError(f"Expected dict or list, got {type(obj)}")
|
|
33
65
|
|
|
34
66
|
|
|
35
67
|
class TaskProviderHelper:
|
|
@@ -44,11 +76,11 @@ class TaskProviderHelper:
|
|
|
44
76
|
self,
|
|
45
77
|
*,
|
|
46
78
|
inputs: Any,
|
|
47
|
-
parameters:
|
|
48
|
-
headers:
|
|
79
|
+
parameters: dict[str, Any],
|
|
80
|
+
headers: dict,
|
|
49
81
|
model: Optional[str],
|
|
50
82
|
api_key: Optional[str],
|
|
51
|
-
extra_payload: Optional[
|
|
83
|
+
extra_payload: Optional[dict[str, Any]] = None,
|
|
52
84
|
) -> RequestParameters:
|
|
53
85
|
"""
|
|
54
86
|
Prepare the request to be sent to the provider.
|
|
@@ -59,30 +91,45 @@ class TaskProviderHelper:
|
|
|
59
91
|
api_key = self._prepare_api_key(api_key)
|
|
60
92
|
|
|
61
93
|
# mapped model from HF model ID
|
|
62
|
-
|
|
94
|
+
provider_mapping_info = self._prepare_mapping_info(model)
|
|
63
95
|
|
|
64
96
|
# default HF headers + user headers (to customize in subclasses)
|
|
65
97
|
headers = self._prepare_headers(headers, api_key)
|
|
66
98
|
|
|
67
99
|
# routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
|
|
68
|
-
url = self._prepare_url(api_key,
|
|
100
|
+
url = self._prepare_url(api_key, provider_mapping_info.provider_id)
|
|
69
101
|
|
|
70
102
|
# prepare payload (to customize in subclasses)
|
|
71
|
-
payload = self._prepare_payload_as_dict(inputs, parameters,
|
|
103
|
+
payload = self._prepare_payload_as_dict(inputs, parameters, provider_mapping_info=provider_mapping_info)
|
|
72
104
|
if payload is not None:
|
|
73
|
-
payload = recursive_merge(payload, extra_payload or {})
|
|
105
|
+
payload = recursive_merge(payload, filter_none(extra_payload or {}))
|
|
74
106
|
|
|
75
107
|
# body data (to customize in subclasses)
|
|
76
|
-
data = self._prepare_payload_as_bytes(inputs, parameters,
|
|
108
|
+
data = self._prepare_payload_as_bytes(inputs, parameters, provider_mapping_info, extra_payload)
|
|
77
109
|
|
|
78
110
|
# check if both payload and data are set and return
|
|
79
111
|
if payload is not None and data is not None:
|
|
80
112
|
raise ValueError("Both payload and data cannot be set in the same request.")
|
|
81
113
|
if payload is None and data is None:
|
|
82
114
|
raise ValueError("Either payload or data must be set in the request.")
|
|
83
|
-
return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
|
|
84
115
|
|
|
85
|
-
|
|
116
|
+
# normalize headers to lowercase and add content-type if not present
|
|
117
|
+
normalized_headers = self._normalize_headers(headers, payload, data)
|
|
118
|
+
|
|
119
|
+
return RequestParameters(
|
|
120
|
+
url=url,
|
|
121
|
+
task=self.task,
|
|
122
|
+
model=provider_mapping_info.provider_id,
|
|
123
|
+
json=payload,
|
|
124
|
+
data=data,
|
|
125
|
+
headers=normalized_headers,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def get_response(
|
|
129
|
+
self,
|
|
130
|
+
response: Union[bytes, dict],
|
|
131
|
+
request_params: Optional[RequestParameters] = None,
|
|
132
|
+
) -> Any:
|
|
86
133
|
"""
|
|
87
134
|
Return the response in the expected format.
|
|
88
135
|
|
|
@@ -97,11 +144,11 @@ class TaskProviderHelper:
|
|
|
97
144
|
api_key = get_token()
|
|
98
145
|
if api_key is None:
|
|
99
146
|
raise ValueError(
|
|
100
|
-
f"You must provide an api_key to work with {self.provider} API or log in with `
|
|
147
|
+
f"You must provide an api_key to work with {self.provider} API or log in with `hf auth login`."
|
|
101
148
|
)
|
|
102
149
|
return api_key
|
|
103
150
|
|
|
104
|
-
def
|
|
151
|
+
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
|
|
105
152
|
"""Return the mapped model ID to use for the request.
|
|
106
153
|
|
|
107
154
|
Usually not overwritten in subclasses."""
|
|
@@ -109,10 +156,15 @@ class TaskProviderHelper:
|
|
|
109
156
|
raise ValueError(f"Please provide an HF model ID supported by {self.provider}.")
|
|
110
157
|
|
|
111
158
|
# hardcoded mapping for local testing
|
|
112
|
-
if
|
|
113
|
-
return
|
|
159
|
+
if HARDCODED_MODEL_INFERENCE_MAPPING.get(self.provider, {}).get(model):
|
|
160
|
+
return HARDCODED_MODEL_INFERENCE_MAPPING[self.provider][model]
|
|
161
|
+
|
|
162
|
+
provider_mapping = None
|
|
163
|
+
for mapping in _fetch_inference_provider_mapping(model):
|
|
164
|
+
if mapping.provider == self.provider:
|
|
165
|
+
provider_mapping = mapping
|
|
166
|
+
break
|
|
114
167
|
|
|
115
|
-
provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
|
|
116
168
|
if provider_mapping is None:
|
|
117
169
|
raise ValueError(f"Model {model} is not supported by provider {self.provider}.")
|
|
118
170
|
|
|
@@ -126,9 +178,29 @@ class TaskProviderHelper:
|
|
|
126
178
|
logger.warning(
|
|
127
179
|
f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only."
|
|
128
180
|
)
|
|
129
|
-
|
|
181
|
+
if provider_mapping.status == "error":
|
|
182
|
+
logger.warning(
|
|
183
|
+
f"Our latest automated health check on model '{model}' for provider '{self.provider}' did not complete successfully. "
|
|
184
|
+
"Inference call might fail."
|
|
185
|
+
)
|
|
186
|
+
return provider_mapping
|
|
187
|
+
|
|
188
|
+
def _normalize_headers(
|
|
189
|
+
self, headers: dict[str, Any], payload: Optional[dict[str, Any]], data: Optional[MimeBytes]
|
|
190
|
+
) -> dict[str, Any]:
|
|
191
|
+
"""Normalize the headers to use for the request.
|
|
130
192
|
|
|
131
|
-
|
|
193
|
+
Override this method in subclasses for customized headers.
|
|
194
|
+
"""
|
|
195
|
+
normalized_headers = {key.lower(): value for key, value in headers.items() if value is not None}
|
|
196
|
+
if normalized_headers.get("content-type") is None:
|
|
197
|
+
if data is not None and data.mime_type is not None:
|
|
198
|
+
normalized_headers["content-type"] = data.mime_type
|
|
199
|
+
elif payload is not None:
|
|
200
|
+
normalized_headers["content-type"] = "application/json"
|
|
201
|
+
return normalized_headers
|
|
202
|
+
|
|
203
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
132
204
|
"""Return the headers to use for the request.
|
|
133
205
|
|
|
134
206
|
Override this method in subclasses for customized headers.
|
|
@@ -140,7 +212,7 @@ class TaskProviderHelper:
|
|
|
140
212
|
|
|
141
213
|
Usually not overwritten in subclasses."""
|
|
142
214
|
base_url = self._prepare_base_url(api_key)
|
|
143
|
-
route = self._prepare_route(mapped_model)
|
|
215
|
+
route = self._prepare_route(mapped_model, api_key)
|
|
144
216
|
return f"{base_url.rstrip('/')}/{route.lstrip('/')}"
|
|
145
217
|
|
|
146
218
|
def _prepare_base_url(self, api_key: str) -> str:
|
|
@@ -155,14 +227,16 @@ class TaskProviderHelper:
|
|
|
155
227
|
logger.info(f"Calling '{self.provider}' provider directly.")
|
|
156
228
|
return self.base_url
|
|
157
229
|
|
|
158
|
-
def _prepare_route(self, mapped_model: str) -> str:
|
|
230
|
+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
|
159
231
|
"""Return the route to use for the request.
|
|
160
232
|
|
|
161
233
|
Override this method in subclasses for customized routes.
|
|
162
234
|
"""
|
|
163
235
|
return ""
|
|
164
236
|
|
|
165
|
-
def _prepare_payload_as_dict(
|
|
237
|
+
def _prepare_payload_as_dict(
|
|
238
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
239
|
+
) -> Optional[dict]:
|
|
166
240
|
"""Return the payload to use for the request, as a dict.
|
|
167
241
|
|
|
168
242
|
Override this method in subclasses for customized payloads.
|
|
@@ -171,8 +245,12 @@ class TaskProviderHelper:
|
|
|
171
245
|
return None
|
|
172
246
|
|
|
173
247
|
def _prepare_payload_as_bytes(
|
|
174
|
-
self,
|
|
175
|
-
|
|
248
|
+
self,
|
|
249
|
+
inputs: Any,
|
|
250
|
+
parameters: dict,
|
|
251
|
+
provider_mapping_info: InferenceProviderMapping,
|
|
252
|
+
extra_payload: Optional[dict],
|
|
253
|
+
) -> Optional[MimeBytes]:
|
|
176
254
|
"""Return the body to use for the request, as bytes.
|
|
177
255
|
|
|
178
256
|
Override this method in subclasses for customized body data.
|
|
@@ -190,11 +268,54 @@ class BaseConversationalTask(TaskProviderHelper):
|
|
|
190
268
|
def __init__(self, provider: str, base_url: str):
|
|
191
269
|
super().__init__(provider=provider, base_url=base_url, task="conversational")
|
|
192
270
|
|
|
193
|
-
def _prepare_route(self, mapped_model: str) -> str:
|
|
271
|
+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
|
194
272
|
return "/v1/chat/completions"
|
|
195
273
|
|
|
196
|
-
def _prepare_payload_as_dict(
|
|
197
|
-
|
|
274
|
+
def _prepare_payload_as_dict(
|
|
275
|
+
self,
|
|
276
|
+
inputs: list[Union[dict, ChatCompletionInputMessage]],
|
|
277
|
+
parameters: dict,
|
|
278
|
+
provider_mapping_info: InferenceProviderMapping,
|
|
279
|
+
) -> Optional[dict]:
|
|
280
|
+
return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id})
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class AutoRouterConversationalTask(BaseConversationalTask):
|
|
284
|
+
"""
|
|
285
|
+
Auto-router for conversational tasks.
|
|
286
|
+
|
|
287
|
+
We let the Hugging Face router select the best provider for the model, based on availability and user preferences.
|
|
288
|
+
This is a special case since the selection is done server-side (avoid 1 API call to fetch provider mapping).
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
def __init__(self):
|
|
292
|
+
super().__init__(provider="auto", base_url="https://router.huggingface.co")
|
|
293
|
+
|
|
294
|
+
def _prepare_base_url(self, api_key: str) -> str:
|
|
295
|
+
"""Return the base URL to use for the request.
|
|
296
|
+
|
|
297
|
+
Usually not overwritten in subclasses."""
|
|
298
|
+
# Route to the proxy if the api_key is a HF TOKEN
|
|
299
|
+
if not api_key.startswith("hf_"):
|
|
300
|
+
raise ValueError("Cannot select auto-router when using non-Hugging Face API key.")
|
|
301
|
+
else:
|
|
302
|
+
return self.base_url # No `/auto` suffix in the URL
|
|
303
|
+
|
|
304
|
+
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
|
|
305
|
+
"""
|
|
306
|
+
In auto-router, we don't need to fetch provider mapping info.
|
|
307
|
+
We just return a dummy mapping info with provider_id set to the HF model ID.
|
|
308
|
+
"""
|
|
309
|
+
if model is None:
|
|
310
|
+
raise ValueError("Please provide an HF model ID.")
|
|
311
|
+
|
|
312
|
+
return InferenceProviderMapping(
|
|
313
|
+
provider="auto",
|
|
314
|
+
hf_model_id=model,
|
|
315
|
+
providerId=model,
|
|
316
|
+
status="live",
|
|
317
|
+
task="conversational",
|
|
318
|
+
)
|
|
198
319
|
|
|
199
320
|
|
|
200
321
|
class BaseTextGenerationTask(TaskProviderHelper):
|
|
@@ -206,15 +327,17 @@ class BaseTextGenerationTask(TaskProviderHelper):
|
|
|
206
327
|
def __init__(self, provider: str, base_url: str):
|
|
207
328
|
super().__init__(provider=provider, base_url=base_url, task="text-generation")
|
|
208
329
|
|
|
209
|
-
def _prepare_route(self, mapped_model: str) -> str:
|
|
330
|
+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
|
210
331
|
return "/v1/completions"
|
|
211
332
|
|
|
212
|
-
def _prepare_payload_as_dict(
|
|
213
|
-
|
|
333
|
+
def _prepare_payload_as_dict(
|
|
334
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
335
|
+
) -> Optional[dict]:
|
|
336
|
+
return filter_none({"prompt": inputs, **parameters, "model": provider_mapping_info.provider_id})
|
|
214
337
|
|
|
215
338
|
|
|
216
339
|
@lru_cache(maxsize=None)
|
|
217
|
-
def _fetch_inference_provider_mapping(model: str) ->
|
|
340
|
+
def _fetch_inference_provider_mapping(model: str) -> list["InferenceProviderMapping"]:
|
|
218
341
|
"""
|
|
219
342
|
Fetch provider mappings for a model from the Hub.
|
|
220
343
|
"""
|
|
@@ -227,7 +350,7 @@ def _fetch_inference_provider_mapping(model: str) -> Dict:
|
|
|
227
350
|
return provider_mapping
|
|
228
351
|
|
|
229
352
|
|
|
230
|
-
def recursive_merge(dict1:
|
|
353
|
+
def recursive_merge(dict1: dict, dict2: dict) -> dict:
|
|
231
354
|
return {
|
|
232
355
|
**dict1,
|
|
233
356
|
**{
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
|
-
from huggingface_hub.
|
|
4
|
+
from huggingface_hub.hf_api import InferenceProviderMapping
|
|
5
|
+
from huggingface_hub.inference._common import RequestParameters, _as_dict
|
|
5
6
|
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
|
6
7
|
from huggingface_hub.utils import logging
|
|
7
8
|
from huggingface_hub.utils._http import get_session
|
|
@@ -15,19 +16,21 @@ POLLING_INTERVAL = 1.0
|
|
|
15
16
|
|
|
16
17
|
class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
17
18
|
def __init__(self):
|
|
18
|
-
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai
|
|
19
|
+
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image")
|
|
19
20
|
|
|
20
|
-
def _prepare_headers(self, headers:
|
|
21
|
+
def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
|
|
21
22
|
headers = super()._prepare_headers(headers, api_key)
|
|
22
23
|
if not api_key.startswith("hf_"):
|
|
23
24
|
_ = headers.pop("authorization")
|
|
24
25
|
headers["X-Key"] = api_key
|
|
25
26
|
return headers
|
|
26
27
|
|
|
27
|
-
def _prepare_route(self, mapped_model: str) -> str:
|
|
28
|
-
return mapped_model
|
|
28
|
+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
|
29
|
+
return f"/v1/{mapped_model}"
|
|
29
30
|
|
|
30
|
-
def _prepare_payload_as_dict(
|
|
31
|
+
def _prepare_payload_as_dict(
|
|
32
|
+
self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
|
|
33
|
+
) -> Optional[dict]:
|
|
31
34
|
parameters = filter_none(parameters)
|
|
32
35
|
if "num_inference_steps" in parameters:
|
|
33
36
|
parameters["steps"] = parameters.pop("num_inference_steps")
|
|
@@ -36,7 +39,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
36
39
|
|
|
37
40
|
return {"prompt": inputs, **parameters}
|
|
38
41
|
|
|
39
|
-
def get_response(self, response: Union[bytes,
|
|
42
|
+
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
|
|
40
43
|
"""
|
|
41
44
|
Polling mechanism for Black Forest Labs since the API is asynchronous.
|
|
42
45
|
"""
|
|
@@ -47,7 +50,7 @@ class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
|
47
50
|
|
|
48
51
|
response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
|
|
49
52
|
response.raise_for_status() # type: ignore
|
|
50
|
-
response_json:
|
|
53
|
+
response_json: dict = response.json() # type: ignore
|
|
51
54
|
status = response_json.get("status")
|
|
52
55
|
logger.info(
|
|
53
56
|
f"Polling generation result from {url}. Current status: {status}. "
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from ._common import BaseConversationalTask
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
_PROVIDER = "clarifai"
|
|
5
|
+
_BASE_URL = "https://api.clarifai.com"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ClarifaiConversationalTask(BaseConversationalTask):
|
|
9
|
+
def __init__(self):
|
|
10
|
+
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
|
|
11
|
+
|
|
12
|
+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
|
|
13
|
+
return "/v2/ext/openai/v1/chat/completions"
|