huggingface-hub 0.28.1__py3-none-any.whl → 0.29.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +1 -4
- huggingface_hub/constants.py +16 -10
- huggingface_hub/file_download.py +10 -6
- huggingface_hub/hf_api.py +53 -23
- huggingface_hub/inference/_client.py +151 -84
- huggingface_hub/inference/_common.py +3 -27
- huggingface_hub/inference/_generated/_async_client.py +147 -83
- huggingface_hub/inference/_generated/types/__init__.py +1 -1
- huggingface_hub/inference/_generated/types/audio_classification.py +4 -5
- huggingface_hub/inference/_generated/types/audio_to_audio.py +3 -4
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +7 -8
- huggingface_hub/inference/_generated/types/base.py +21 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +29 -30
- huggingface_hub/inference/_generated/types/depth_estimation.py +3 -4
- huggingface_hub/inference/_generated/types/document_question_answering.py +5 -6
- huggingface_hub/inference/_generated/types/feature_extraction.py +5 -6
- huggingface_hub/inference/_generated/types/fill_mask.py +4 -5
- huggingface_hub/inference/_generated/types/image_classification.py +4 -5
- huggingface_hub/inference/_generated/types/image_segmentation.py +4 -5
- huggingface_hub/inference/_generated/types/image_to_image.py +5 -6
- huggingface_hub/inference/_generated/types/image_to_text.py +5 -6
- huggingface_hub/inference/_generated/types/object_detection.py +5 -6
- huggingface_hub/inference/_generated/types/question_answering.py +5 -6
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -4
- huggingface_hub/inference/_generated/types/summarization.py +4 -5
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -6
- huggingface_hub/inference/_generated/types/text2text_generation.py +4 -5
- huggingface_hub/inference/_generated/types/text_classification.py +4 -5
- huggingface_hub/inference/_generated/types/text_generation.py +12 -13
- huggingface_hub/inference/_generated/types/text_to_audio.py +5 -6
- huggingface_hub/inference/_generated/types/text_to_image.py +8 -15
- huggingface_hub/inference/_generated/types/text_to_speech.py +5 -6
- huggingface_hub/inference/_generated/types/text_to_video.py +4 -5
- huggingface_hub/inference/_generated/types/token_classification.py +4 -5
- huggingface_hub/inference/_generated/types/translation.py +4 -5
- huggingface_hub/inference/_generated/types/video_classification.py +4 -5
- huggingface_hub/inference/_generated/types/visual_question_answering.py +5 -6
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +4 -5
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +4 -5
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +5 -6
- huggingface_hub/inference/_providers/__init__.py +44 -8
- huggingface_hub/inference/_providers/_common.py +239 -0
- huggingface_hub/inference/_providers/black_forest_labs.py +66 -0
- huggingface_hub/inference/_providers/fal_ai.py +31 -100
- huggingface_hub/inference/_providers/fireworks_ai.py +6 -0
- huggingface_hub/inference/_providers/hf_inference.py +58 -142
- huggingface_hub/inference/_providers/hyperbolic.py +43 -0
- huggingface_hub/inference/_providers/nebius.py +41 -0
- huggingface_hub/inference/_providers/novita.py +26 -0
- huggingface_hub/inference/_providers/replicate.py +24 -119
- huggingface_hub/inference/_providers/sambanova.py +3 -86
- huggingface_hub/inference/_providers/together.py +36 -130
- huggingface_hub/utils/_headers.py +5 -0
- huggingface_hub/utils/_hf_folder.py +4 -32
- huggingface_hub/utils/_http.py +85 -2
- huggingface_hub/utils/_typing.py +1 -1
- huggingface_hub/utils/logging.py +6 -0
- {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/METADATA +1 -1
- {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/RECORD +63 -57
- {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.28.1.dist-info → huggingface_hub-0.29.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -3,13 +3,12 @@
|
|
|
3
3
|
# See:
|
|
4
4
|
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
|
5
5
|
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
|
6
|
-
from dataclasses import dataclass
|
|
7
6
|
from typing import List, Optional
|
|
8
7
|
|
|
9
|
-
from .base import BaseInferenceType
|
|
8
|
+
from .base import BaseInferenceType, dataclass_with_extra
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
@
|
|
11
|
+
@dataclass_with_extra
|
|
13
12
|
class ZeroShotClassificationParameters(BaseInferenceType):
|
|
14
13
|
"""Additional inference parameters for Zero Shot Classification"""
|
|
15
14
|
|
|
@@ -26,7 +25,7 @@ class ZeroShotClassificationParameters(BaseInferenceType):
|
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
27
|
|
|
29
|
-
@
|
|
28
|
+
@dataclass_with_extra
|
|
30
29
|
class ZeroShotClassificationInput(BaseInferenceType):
|
|
31
30
|
"""Inputs for Zero Shot Classification inference"""
|
|
32
31
|
|
|
@@ -36,7 +35,7 @@ class ZeroShotClassificationInput(BaseInferenceType):
|
|
|
36
35
|
"""Additional inference parameters for Zero Shot Classification"""
|
|
37
36
|
|
|
38
37
|
|
|
39
|
-
@
|
|
38
|
+
@dataclass_with_extra
|
|
40
39
|
class ZeroShotClassificationOutputElement(BaseInferenceType):
|
|
41
40
|
"""Outputs of inference for the Zero Shot Classification task"""
|
|
42
41
|
|
|
@@ -3,13 +3,12 @@
|
|
|
3
3
|
# See:
|
|
4
4
|
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
|
5
5
|
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
|
6
|
-
from dataclasses import dataclass
|
|
7
6
|
from typing import List, Optional
|
|
8
7
|
|
|
9
|
-
from .base import BaseInferenceType
|
|
8
|
+
from .base import BaseInferenceType, dataclass_with_extra
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
@
|
|
11
|
+
@dataclass_with_extra
|
|
13
12
|
class ZeroShotImageClassificationParameters(BaseInferenceType):
|
|
14
13
|
"""Additional inference parameters for Zero Shot Image Classification"""
|
|
15
14
|
|
|
@@ -21,7 +20,7 @@ class ZeroShotImageClassificationParameters(BaseInferenceType):
|
|
|
21
20
|
"""
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
@
|
|
23
|
+
@dataclass_with_extra
|
|
25
24
|
class ZeroShotImageClassificationInput(BaseInferenceType):
|
|
26
25
|
"""Inputs for Zero Shot Image Classification inference"""
|
|
27
26
|
|
|
@@ -31,7 +30,7 @@ class ZeroShotImageClassificationInput(BaseInferenceType):
|
|
|
31
30
|
"""Additional inference parameters for Zero Shot Image Classification"""
|
|
32
31
|
|
|
33
32
|
|
|
34
|
-
@
|
|
33
|
+
@dataclass_with_extra
|
|
35
34
|
class ZeroShotImageClassificationOutputElement(BaseInferenceType):
|
|
36
35
|
"""Outputs of inference for the Zero Shot Image Classification task"""
|
|
37
36
|
|
|
@@ -3,13 +3,12 @@
|
|
|
3
3
|
# See:
|
|
4
4
|
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
|
5
5
|
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
|
6
|
-
from dataclasses import dataclass
|
|
7
6
|
from typing import List
|
|
8
7
|
|
|
9
|
-
from .base import BaseInferenceType
|
|
8
|
+
from .base import BaseInferenceType, dataclass_with_extra
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
@
|
|
11
|
+
@dataclass_with_extra
|
|
13
12
|
class ZeroShotObjectDetectionParameters(BaseInferenceType):
|
|
14
13
|
"""Additional inference parameters for Zero Shot Object Detection"""
|
|
15
14
|
|
|
@@ -17,7 +16,7 @@ class ZeroShotObjectDetectionParameters(BaseInferenceType):
|
|
|
17
16
|
"""The candidate labels for this image"""
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
@
|
|
19
|
+
@dataclass_with_extra
|
|
21
20
|
class ZeroShotObjectDetectionInput(BaseInferenceType):
|
|
22
21
|
"""Inputs for Zero Shot Object Detection inference"""
|
|
23
22
|
|
|
@@ -27,7 +26,7 @@ class ZeroShotObjectDetectionInput(BaseInferenceType):
|
|
|
27
26
|
"""Additional inference parameters for Zero Shot Object Detection"""
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
@
|
|
29
|
+
@dataclass_with_extra
|
|
31
30
|
class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
|
|
32
31
|
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
|
33
32
|
image.
|
|
@@ -39,7 +38,7 @@ class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
|
|
|
39
38
|
ymin: int
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
@
|
|
41
|
+
@dataclass_with_extra
|
|
43
42
|
class ZeroShotObjectDetectionOutputElement(BaseInferenceType):
|
|
44
43
|
"""Outputs of inference for the Zero Shot Object Detection task"""
|
|
45
44
|
|
|
@@ -1,27 +1,49 @@
|
|
|
1
1
|
from typing import Dict, Literal
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from .
|
|
3
|
+
from ._common import TaskProviderHelper
|
|
4
|
+
from .black_forest_labs import BlackForestLabsTextToImageTask
|
|
5
|
+
from .fal_ai import (
|
|
6
|
+
FalAIAutomaticSpeechRecognitionTask,
|
|
7
|
+
FalAITextToImageTask,
|
|
8
|
+
FalAITextToSpeechTask,
|
|
9
|
+
FalAITextToVideoTask,
|
|
10
|
+
)
|
|
11
|
+
from .fireworks_ai import FireworksAIConversationalTask
|
|
5
12
|
from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask
|
|
13
|
+
from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
|
|
14
|
+
from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask
|
|
15
|
+
from .novita import NovitaConversationalTask, NovitaTextGenerationTask
|
|
6
16
|
from .replicate import ReplicateTask, ReplicateTextToSpeechTask
|
|
7
17
|
from .sambanova import SambanovaConversationalTask
|
|
8
|
-
from .together import TogetherTextGenerationTask, TogetherTextToImageTask
|
|
18
|
+
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
|
|
9
19
|
|
|
10
20
|
|
|
11
21
|
PROVIDER_T = Literal[
|
|
22
|
+
"black-forest-labs",
|
|
12
23
|
"fal-ai",
|
|
24
|
+
"fireworks-ai",
|
|
13
25
|
"hf-inference",
|
|
26
|
+
"hyperbolic",
|
|
27
|
+
"nebius",
|
|
28
|
+
"novita",
|
|
14
29
|
"replicate",
|
|
15
30
|
"sambanova",
|
|
16
31
|
"together",
|
|
17
32
|
]
|
|
18
33
|
|
|
19
34
|
PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
|
35
|
+
"black-forest-labs": {
|
|
36
|
+
"text-to-image": BlackForestLabsTextToImageTask(),
|
|
37
|
+
},
|
|
20
38
|
"fal-ai": {
|
|
21
|
-
"text-to-image": FalAITextToImageTask(),
|
|
22
39
|
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
|
|
40
|
+
"text-to-image": FalAITextToImageTask(),
|
|
41
|
+
"text-to-speech": FalAITextToSpeechTask(),
|
|
23
42
|
"text-to-video": FalAITextToVideoTask(),
|
|
24
43
|
},
|
|
44
|
+
"fireworks-ai": {
|
|
45
|
+
"conversational": FireworksAIConversationalTask(),
|
|
46
|
+
},
|
|
25
47
|
"hf-inference": {
|
|
26
48
|
"text-to-image": HFInferenceTask("text-to-image"),
|
|
27
49
|
"conversational": HFInferenceConversational(),
|
|
@@ -35,9 +57,9 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
|
|
35
57
|
"image-classification": HFInferenceBinaryInputTask("image-classification"),
|
|
36
58
|
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
|
|
37
59
|
"document-question-answering": HFInferenceTask("document-question-answering"),
|
|
38
|
-
"image-to-text":
|
|
60
|
+
"image-to-text": HFInferenceBinaryInputTask("image-to-text"),
|
|
39
61
|
"object-detection": HFInferenceBinaryInputTask("object-detection"),
|
|
40
|
-
"audio-to-audio":
|
|
62
|
+
"audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
|
|
41
63
|
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
|
|
42
64
|
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
|
|
43
65
|
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
|
|
@@ -50,6 +72,20 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
|
|
50
72
|
"summarization": HFInferenceTask("summarization"),
|
|
51
73
|
"visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"),
|
|
52
74
|
},
|
|
75
|
+
"hyperbolic": {
|
|
76
|
+
"text-to-image": HyperbolicTextToImageTask(),
|
|
77
|
+
"conversational": HyperbolicTextGenerationTask("conversational"),
|
|
78
|
+
"text-generation": HyperbolicTextGenerationTask("text-generation"),
|
|
79
|
+
},
|
|
80
|
+
"nebius": {
|
|
81
|
+
"text-to-image": NebiusTextToImageTask(),
|
|
82
|
+
"conversational": NebiusConversationalTask(),
|
|
83
|
+
"text-generation": NebiusTextGenerationTask(),
|
|
84
|
+
},
|
|
85
|
+
"novita": {
|
|
86
|
+
"text-generation": NovitaTextGenerationTask(),
|
|
87
|
+
"conversational": NovitaConversationalTask(),
|
|
88
|
+
},
|
|
53
89
|
"replicate": {
|
|
54
90
|
"text-to-image": ReplicateTask("text-to-image"),
|
|
55
91
|
"text-to-speech": ReplicateTextToSpeechTask(),
|
|
@@ -60,8 +96,8 @@ PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
|
|
|
60
96
|
},
|
|
61
97
|
"together": {
|
|
62
98
|
"text-to-image": TogetherTextToImageTask(),
|
|
63
|
-
"conversational":
|
|
64
|
-
"text-generation": TogetherTextGenerationTask(
|
|
99
|
+
"conversational": TogetherConversationalTask(),
|
|
100
|
+
"text-generation": TogetherTextGenerationTask(),
|
|
65
101
|
},
|
|
66
102
|
}
|
|
67
103
|
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from typing import Any, Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from huggingface_hub import constants
|
|
5
|
+
from huggingface_hub.inference._common import RequestParameters
|
|
6
|
+
from huggingface_hub.utils import build_hf_headers, get_token, logging
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = logging.get_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Dev purposes only.
|
|
13
|
+
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
|
|
14
|
+
# for a given Inference Provider, you can add it to the following dictionary.
|
|
15
|
+
HARDCODED_MODEL_ID_MAPPING: Dict[str, Dict[str, str]] = {
|
|
16
|
+
# "HF model ID" => "Model ID on Inference Provider's side"
|
|
17
|
+
#
|
|
18
|
+
# Example:
|
|
19
|
+
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
|
|
20
|
+
"fal-ai": {},
|
|
21
|
+
"fireworks-ai": {},
|
|
22
|
+
"hf-inference": {},
|
|
23
|
+
"hyperbolic": {},
|
|
24
|
+
"nebius": {},
|
|
25
|
+
"replicate": {},
|
|
26
|
+
"sambanova": {},
|
|
27
|
+
"together": {},
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def filter_none(d: Dict[str, Any]) -> Dict[str, Any]:
|
|
32
|
+
return {k: v for k, v in d.items() if v is not None}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TaskProviderHelper:
|
|
36
|
+
"""Base class for task-specific provider helpers."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, provider: str, base_url: str, task: str) -> None:
|
|
39
|
+
self.provider = provider
|
|
40
|
+
self.task = task
|
|
41
|
+
self.base_url = base_url
|
|
42
|
+
|
|
43
|
+
def prepare_request(
|
|
44
|
+
self,
|
|
45
|
+
*,
|
|
46
|
+
inputs: Any,
|
|
47
|
+
parameters: Dict[str, Any],
|
|
48
|
+
headers: Dict,
|
|
49
|
+
model: Optional[str],
|
|
50
|
+
api_key: Optional[str],
|
|
51
|
+
extra_payload: Optional[Dict[str, Any]] = None,
|
|
52
|
+
) -> RequestParameters:
|
|
53
|
+
"""
|
|
54
|
+
Prepare the request to be sent to the provider.
|
|
55
|
+
|
|
56
|
+
Each step (api_key, model, headers, url, payload) can be customized in subclasses.
|
|
57
|
+
"""
|
|
58
|
+
# api_key from user, or local token, or raise error
|
|
59
|
+
api_key = self._prepare_api_key(api_key)
|
|
60
|
+
|
|
61
|
+
# mapped model from HF model ID
|
|
62
|
+
mapped_model = self._prepare_mapped_model(model)
|
|
63
|
+
|
|
64
|
+
# default HF headers + user headers (to customize in subclasses)
|
|
65
|
+
headers = self._prepare_headers(headers, api_key)
|
|
66
|
+
|
|
67
|
+
# routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses)
|
|
68
|
+
url = self._prepare_url(api_key, mapped_model)
|
|
69
|
+
|
|
70
|
+
# prepare payload (to customize in subclasses)
|
|
71
|
+
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
|
|
72
|
+
if payload is not None:
|
|
73
|
+
payload = recursive_merge(payload, extra_payload or {})
|
|
74
|
+
|
|
75
|
+
# body data (to customize in subclasses)
|
|
76
|
+
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)
|
|
77
|
+
|
|
78
|
+
# check if both payload and data are set and return
|
|
79
|
+
if payload is not None and data is not None:
|
|
80
|
+
raise ValueError("Both payload and data cannot be set in the same request.")
|
|
81
|
+
if payload is None and data is None:
|
|
82
|
+
raise ValueError("Either payload or data must be set in the request.")
|
|
83
|
+
return RequestParameters(url=url, task=self.task, model=mapped_model, json=payload, data=data, headers=headers)
|
|
84
|
+
|
|
85
|
+
def get_response(self, response: Union[bytes, Dict]) -> Any:
|
|
86
|
+
"""
|
|
87
|
+
Return the response in the expected format.
|
|
88
|
+
|
|
89
|
+
Override this method in subclasses for customized response handling."""
|
|
90
|
+
return response
|
|
91
|
+
|
|
92
|
+
def _prepare_api_key(self, api_key: Optional[str]) -> str:
|
|
93
|
+
"""Return the API key to use for the request.
|
|
94
|
+
|
|
95
|
+
Usually not overwritten in subclasses."""
|
|
96
|
+
if api_key is None:
|
|
97
|
+
api_key = get_token()
|
|
98
|
+
if api_key is None:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"You must provide an api_key to work with {self.provider} API or log in with `huggingface-cli login`."
|
|
101
|
+
)
|
|
102
|
+
return api_key
|
|
103
|
+
|
|
104
|
+
def _prepare_mapped_model(self, model: Optional[str]) -> str:
|
|
105
|
+
"""Return the mapped model ID to use for the request.
|
|
106
|
+
|
|
107
|
+
Usually not overwritten in subclasses."""
|
|
108
|
+
if model is None:
|
|
109
|
+
raise ValueError(f"Please provide an HF model ID supported by {self.provider}.")
|
|
110
|
+
|
|
111
|
+
# hardcoded mapping for local testing
|
|
112
|
+
if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model):
|
|
113
|
+
return HARDCODED_MODEL_ID_MAPPING[self.provider][model]
|
|
114
|
+
|
|
115
|
+
provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
|
|
116
|
+
if provider_mapping is None:
|
|
117
|
+
raise ValueError(f"Model {model} is not supported by provider {self.provider}.")
|
|
118
|
+
|
|
119
|
+
if provider_mapping.task != self.task:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Model {model} is not supported for task {self.task} and provider {self.provider}. "
|
|
122
|
+
f"Supported task: {provider_mapping.task}."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if provider_mapping.status == "staging":
|
|
126
|
+
logger.warning(
|
|
127
|
+
f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only."
|
|
128
|
+
)
|
|
129
|
+
return provider_mapping.provider_id
|
|
130
|
+
|
|
131
|
+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
|
132
|
+
"""Return the headers to use for the request.
|
|
133
|
+
|
|
134
|
+
Override this method in subclasses for customized headers.
|
|
135
|
+
"""
|
|
136
|
+
return {**build_hf_headers(token=api_key), **headers}
|
|
137
|
+
|
|
138
|
+
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
|
|
139
|
+
"""Return the URL to use for the request.
|
|
140
|
+
|
|
141
|
+
Usually not overwritten in subclasses."""
|
|
142
|
+
base_url = self._prepare_base_url(api_key)
|
|
143
|
+
route = self._prepare_route(mapped_model)
|
|
144
|
+
return f"{base_url.rstrip('/')}/{route.lstrip('/')}"
|
|
145
|
+
|
|
146
|
+
def _prepare_base_url(self, api_key: str) -> str:
|
|
147
|
+
"""Return the base URL to use for the request.
|
|
148
|
+
|
|
149
|
+
Usually not overwritten in subclasses."""
|
|
150
|
+
# Route to the proxy if the api_key is a HF TOKEN
|
|
151
|
+
if api_key.startswith("hf_"):
|
|
152
|
+
logger.info(f"Calling '{self.provider}' provider through Hugging Face router.")
|
|
153
|
+
return constants.INFERENCE_PROXY_TEMPLATE.format(provider=self.provider)
|
|
154
|
+
else:
|
|
155
|
+
logger.info(f"Calling '{self.provider}' provider directly.")
|
|
156
|
+
return self.base_url
|
|
157
|
+
|
|
158
|
+
def _prepare_route(self, mapped_model: str) -> str:
|
|
159
|
+
"""Return the route to use for the request.
|
|
160
|
+
|
|
161
|
+
Override this method in subclasses for customized routes.
|
|
162
|
+
"""
|
|
163
|
+
return ""
|
|
164
|
+
|
|
165
|
+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
|
166
|
+
"""Return the payload to use for the request, as a dict.
|
|
167
|
+
|
|
168
|
+
Override this method in subclasses for customized payloads.
|
|
169
|
+
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
|
|
170
|
+
"""
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
def _prepare_payload_as_bytes(
|
|
174
|
+
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
|
|
175
|
+
) -> Optional[bytes]:
|
|
176
|
+
"""Return the body to use for the request, as bytes.
|
|
177
|
+
|
|
178
|
+
Override this method in subclasses for customized body data.
|
|
179
|
+
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
|
|
180
|
+
"""
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class BaseConversationalTask(TaskProviderHelper):
|
|
185
|
+
"""
|
|
186
|
+
Base class for conversational (chat completion) tasks.
|
|
187
|
+
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self, provider: str, base_url: str):
|
|
191
|
+
super().__init__(provider=provider, base_url=base_url, task="conversational")
|
|
192
|
+
|
|
193
|
+
def _prepare_route(self, mapped_model: str) -> str:
|
|
194
|
+
return "/v1/chat/completions"
|
|
195
|
+
|
|
196
|
+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
|
197
|
+
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class BaseTextGenerationTask(TaskProviderHelper):
|
|
201
|
+
"""
|
|
202
|
+
Base class for text-generation (completion) tasks.
|
|
203
|
+
The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(self, provider: str, base_url: str):
|
|
207
|
+
super().__init__(provider=provider, base_url=base_url, task="text-generation")
|
|
208
|
+
|
|
209
|
+
def _prepare_route(self, mapped_model: str) -> str:
|
|
210
|
+
return "/v1/completions"
|
|
211
|
+
|
|
212
|
+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
|
213
|
+
return {"prompt": inputs, **filter_none(parameters), "model": mapped_model}
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@lru_cache(maxsize=None)
|
|
217
|
+
def _fetch_inference_provider_mapping(model: str) -> Dict:
|
|
218
|
+
"""
|
|
219
|
+
Fetch provider mappings for a model from the Hub.
|
|
220
|
+
"""
|
|
221
|
+
from huggingface_hub.hf_api import HfApi
|
|
222
|
+
|
|
223
|
+
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
|
|
224
|
+
provider_mapping = info.inference_provider_mapping
|
|
225
|
+
if provider_mapping is None:
|
|
226
|
+
raise ValueError(f"No provider mapping found for model {model}")
|
|
227
|
+
return provider_mapping
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def recursive_merge(dict1: Dict, dict2: Dict) -> Dict:
|
|
231
|
+
return {
|
|
232
|
+
**dict1,
|
|
233
|
+
**{
|
|
234
|
+
key: recursive_merge(dict1[key], value)
|
|
235
|
+
if (key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict))
|
|
236
|
+
else value
|
|
237
|
+
for key, value in dict2.items()
|
|
238
|
+
},
|
|
239
|
+
}
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Any, Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from huggingface_hub.inference._common import _as_dict
|
|
5
|
+
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
|
|
6
|
+
from huggingface_hub.utils import logging
|
|
7
|
+
from huggingface_hub.utils._http import get_session
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = logging.get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
MAX_POLLING_ATTEMPTS = 6
|
|
13
|
+
POLLING_INTERVAL = 1.0
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BlackForestLabsTextToImageTask(TaskProviderHelper):
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai/v1", task="text-to-image")
|
|
19
|
+
|
|
20
|
+
def _prepare_headers(self, headers: Dict, api_key: str) -> Dict:
|
|
21
|
+
headers = super()._prepare_headers(headers, api_key)
|
|
22
|
+
if not api_key.startswith("hf_"):
|
|
23
|
+
_ = headers.pop("authorization")
|
|
24
|
+
headers["X-Key"] = api_key
|
|
25
|
+
return headers
|
|
26
|
+
|
|
27
|
+
def _prepare_route(self, mapped_model: str) -> str:
|
|
28
|
+
return mapped_model
|
|
29
|
+
|
|
30
|
+
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
|
|
31
|
+
parameters = filter_none(parameters)
|
|
32
|
+
if "num_inference_steps" in parameters:
|
|
33
|
+
parameters["steps"] = parameters.pop("num_inference_steps")
|
|
34
|
+
if "guidance_scale" in parameters:
|
|
35
|
+
parameters["guidance"] = parameters.pop("guidance_scale")
|
|
36
|
+
|
|
37
|
+
return {"prompt": inputs, **parameters}
|
|
38
|
+
|
|
39
|
+
def get_response(self, response: Union[bytes, Dict]) -> Any:
|
|
40
|
+
"""
|
|
41
|
+
Polling mechanism for Black Forest Labs since the API is asynchronous.
|
|
42
|
+
"""
|
|
43
|
+
url = _as_dict(response).get("polling_url")
|
|
44
|
+
session = get_session()
|
|
45
|
+
for _ in range(MAX_POLLING_ATTEMPTS):
|
|
46
|
+
time.sleep(POLLING_INTERVAL)
|
|
47
|
+
|
|
48
|
+
response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore
|
|
49
|
+
response.raise_for_status() # type: ignore
|
|
50
|
+
response_json: Dict = response.json() # type: ignore
|
|
51
|
+
status = response_json.get("status")
|
|
52
|
+
logger.info(
|
|
53
|
+
f"Polling generation result from {url}. Current status: {status}. "
|
|
54
|
+
f"Will retry after {POLLING_INTERVAL} seconds if not ready."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if (
|
|
58
|
+
status == "Ready"
|
|
59
|
+
and isinstance(response_json.get("result"), dict)
|
|
60
|
+
and (sample_url := response_json["result"].get("sample"))
|
|
61
|
+
):
|
|
62
|
+
image_resp = session.get(sample_url)
|
|
63
|
+
image_resp.raise_for_status()
|
|
64
|
+
return image_resp.content
|
|
65
|
+
|
|
66
|
+
raise TimeoutError(f"Failed to get the image URL after {MAX_POLLING_ATTEMPTS} attempts.")
|