huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +217 -1
- huggingface_hub/_commit_api.py +14 -15
- huggingface_hub/_inference_endpoints.py +12 -11
- huggingface_hub/_login.py +1 -0
- huggingface_hub/_multi_commits.py +1 -0
- huggingface_hub/_snapshot_download.py +9 -1
- huggingface_hub/_tensorboard_logger.py +1 -0
- huggingface_hub/_webhooks_payload.py +1 -0
- huggingface_hub/_webhooks_server.py +1 -0
- huggingface_hub/commands/_cli_utils.py +1 -0
- huggingface_hub/commands/delete_cache.py +1 -0
- huggingface_hub/commands/download.py +1 -0
- huggingface_hub/commands/env.py +1 -0
- huggingface_hub/commands/scan_cache.py +1 -0
- huggingface_hub/commands/upload.py +1 -0
- huggingface_hub/community.py +1 -0
- huggingface_hub/constants.py +3 -1
- huggingface_hub/errors.py +38 -0
- huggingface_hub/file_download.py +102 -95
- huggingface_hub/hf_api.py +47 -35
- huggingface_hub/hf_file_system.py +77 -3
- huggingface_hub/hub_mixin.py +215 -54
- huggingface_hub/inference/_client.py +554 -239
- huggingface_hub/inference/_common.py +195 -41
- huggingface_hub/inference/_generated/_async_client.py +558 -239
- huggingface_hub/inference/_generated/types/__init__.py +115 -0
- huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
- huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
- huggingface_hub/inference/_generated/types/base.py +149 -0
- huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
- huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
- huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
- huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
- huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
- huggingface_hub/inference/_generated/types/image_classification.py +43 -0
- huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
- huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
- huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
- huggingface_hub/inference/_generated/types/object_detection.py +55 -0
- huggingface_hub/inference/_generated/types/question_answering.py +77 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
- huggingface_hub/inference/_generated/types/summarization.py +46 -0
- huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
- huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
- huggingface_hub/inference/_generated/types/text_classification.py +43 -0
- huggingface_hub/inference/_generated/types/text_generation.py +161 -0
- huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
- huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
- huggingface_hub/inference/_generated/types/token_classification.py +53 -0
- huggingface_hub/inference/_generated/types/translation.py +46 -0
- huggingface_hub/inference/_generated/types/video_classification.py +47 -0
- huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
- huggingface_hub/inference/_templating.py +105 -0
- huggingface_hub/inference/_types.py +4 -152
- huggingface_hub/keras_mixin.py +39 -17
- huggingface_hub/lfs.py +20 -8
- huggingface_hub/repocard.py +11 -3
- huggingface_hub/repocard_data.py +12 -2
- huggingface_hub/serialization/__init__.py +1 -0
- huggingface_hub/serialization/_base.py +1 -0
- huggingface_hub/serialization/_numpy.py +1 -0
- huggingface_hub/serialization/_tensorflow.py +1 -0
- huggingface_hub/serialization/_torch.py +1 -0
- huggingface_hub/utils/__init__.py +4 -1
- huggingface_hub/utils/_cache_manager.py +7 -0
- huggingface_hub/utils/_chunk_utils.py +1 -0
- huggingface_hub/utils/_datetime.py +1 -0
- huggingface_hub/utils/_errors.py +10 -1
- huggingface_hub/utils/_experimental.py +1 -0
- huggingface_hub/utils/_fixes.py +19 -3
- huggingface_hub/utils/_git_credential.py +1 -0
- huggingface_hub/utils/_headers.py +10 -3
- huggingface_hub/utils/_hf_folder.py +1 -0
- huggingface_hub/utils/_http.py +1 -0
- huggingface_hub/utils/_pagination.py +1 -0
- huggingface_hub/utils/_paths.py +1 -0
- huggingface_hub/utils/_runtime.py +22 -0
- huggingface_hub/utils/_subprocess.py +1 -0
- huggingface_hub/utils/_token.py +1 -0
- huggingface_hub/utils/_typing.py +29 -1
- huggingface_hub/utils/_validators.py +1 -0
- huggingface_hub/utils/endpoint_helpers.py +1 -0
- huggingface_hub/utils/logging.py +1 -1
- huggingface_hub/utils/sha.py +1 -0
- huggingface_hub/utils/tqdm.py +1 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/METADATA +14 -15
- huggingface_hub-0.22.0.dist-info/RECORD +113 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/WHEEL +1 -1
- huggingface_hub/inference/_text_generation.py +0 -551
- huggingface_hub-0.21.4.dist-info/RECORD +0 -81
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Inference code generated from the JSON schema spec in @huggingface/tasks.
|
|
2
|
+
#
|
|
3
|
+
# See:
|
|
4
|
+
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
|
5
|
+
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from .base import BaseInferenceType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ZeroShotObjectDetectionInputData(BaseInferenceType):
|
|
14
|
+
"""The input image data, with candidate labels"""
|
|
15
|
+
|
|
16
|
+
candidate_labels: List[str]
|
|
17
|
+
"""The candidate labels for this image"""
|
|
18
|
+
image: Any
|
|
19
|
+
"""The image data to generate bounding boxes from"""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class ZeroShotObjectDetectionInput(BaseInferenceType):
|
|
24
|
+
"""Inputs for Zero Shot Object Detection inference"""
|
|
25
|
+
|
|
26
|
+
inputs: ZeroShotObjectDetectionInputData
|
|
27
|
+
"""The input image data, with candidate labels"""
|
|
28
|
+
parameters: Optional[Dict[str, Any]] = None
|
|
29
|
+
"""Additional inference parameters"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
|
|
34
|
+
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
|
35
|
+
image.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
xmax: int
|
|
39
|
+
xmin: int
|
|
40
|
+
ymax: int
|
|
41
|
+
ymin: int
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ZeroShotObjectDetectionOutputElement(BaseInferenceType):
|
|
46
|
+
"""Outputs of inference for the Zero Shot Object Detection task"""
|
|
47
|
+
|
|
48
|
+
box: ZeroShotObjectDetectionBoundingBox
|
|
49
|
+
"""The predicted bounding box. Coordinates are relative to the top left corner of the input
|
|
50
|
+
image.
|
|
51
|
+
"""
|
|
52
|
+
label: str
|
|
53
|
+
"""A candidate label"""
|
|
54
|
+
score: float
|
|
55
|
+
"""The associated score / probability"""
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from typing import Callable, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from ..utils import HfHubHTTPError, RepositoryNotFoundError, is_minijinja_available
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TemplateError(Exception):
|
|
8
|
+
"""Any error raised while trying to fetch or render a chat template."""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _import_minijinja():
|
|
12
|
+
if not is_minijinja_available():
|
|
13
|
+
raise ImportError("Cannot render template. Please install minijinja using `pip install minijinja`.")
|
|
14
|
+
import minijinja # noqa: F401
|
|
15
|
+
|
|
16
|
+
return minijinja
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def render_chat_prompt(
|
|
20
|
+
*,
|
|
21
|
+
model_id: str,
|
|
22
|
+
messages: List[Dict[str, str]],
|
|
23
|
+
token: Union[str, bool, None] = None,
|
|
24
|
+
add_generation_prompt: bool = True,
|
|
25
|
+
**kwargs,
|
|
26
|
+
) -> str:
|
|
27
|
+
"""Render a chat prompt using a model's chat template.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model_id (`str`):
|
|
31
|
+
The model id.
|
|
32
|
+
messages (`List[Dict[str, str]]`):
|
|
33
|
+
The list of messages to render.
|
|
34
|
+
token (`str` or `bool`, *optional*):
|
|
35
|
+
Hugging Face token. Will default to the locally saved token if not provided.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
`str`: The rendered chat prompt.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
`TemplateError`: If there's any issue while fetching, compiling or rendering the chat template.
|
|
42
|
+
"""
|
|
43
|
+
minijinja = _import_minijinja()
|
|
44
|
+
template = _fetch_and_compile_template(model_id=model_id, token=token)
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
return template(messages=messages, add_generation_prompt=add_generation_prompt, **kwargs)
|
|
48
|
+
except minijinja.TemplateError as e:
|
|
49
|
+
raise TemplateError(f"Error while trying to render chat prompt for model '{model_id}': {e}") from e
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@lru_cache # TODO: lru_cache for raised exceptions
|
|
53
|
+
def _fetch_and_compile_template(*, model_id: str, token: Union[str, None]) -> Callable:
|
|
54
|
+
"""Fetch and compile a model's chat template.
|
|
55
|
+
|
|
56
|
+
Method is cached to avoid fetching the same model's config multiple times.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
model_id (`str`):
|
|
60
|
+
The model id.
|
|
61
|
+
token (`str` or `bool`, *optional*):
|
|
62
|
+
Hugging Face token. Will default to the locally saved token if not provided.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
`Callable`: A callable that takes a list of messages and returns the rendered chat prompt.
|
|
66
|
+
"""
|
|
67
|
+
from huggingface_hub.hf_api import HfApi
|
|
68
|
+
|
|
69
|
+
minijinja = _import_minijinja()
|
|
70
|
+
|
|
71
|
+
# 1. fetch config from API
|
|
72
|
+
try:
|
|
73
|
+
config = HfApi(token=token).model_info(model_id).config
|
|
74
|
+
except RepositoryNotFoundError as e:
|
|
75
|
+
raise TemplateError(f"Cannot render chat template: model '{model_id}' not found.") from e
|
|
76
|
+
except HfHubHTTPError as e:
|
|
77
|
+
raise TemplateError(f"Error while trying to fetch chat template for model '{model_id}': {e}") from e
|
|
78
|
+
|
|
79
|
+
# 2. check config validity
|
|
80
|
+
if config is None:
|
|
81
|
+
raise TemplateError(f"Config not found for model '{model_id}'.")
|
|
82
|
+
tokenizer_config = config.get("tokenizer_config")
|
|
83
|
+
if tokenizer_config is None:
|
|
84
|
+
raise TemplateError(f"Tokenizer config not found for model '{model_id}'.")
|
|
85
|
+
if tokenizer_config.get("chat_template") is None:
|
|
86
|
+
raise TemplateError(f"Chat template not found in tokenizer_config for model '{model_id}'.")
|
|
87
|
+
chat_template = tokenizer_config["chat_template"]
|
|
88
|
+
if not isinstance(chat_template, str):
|
|
89
|
+
raise TemplateError(f"Chat template must be a string, not '{type(chat_template)}' (model: {model_id}).")
|
|
90
|
+
|
|
91
|
+
special_tokens: Dict[str, Optional[str]] = {}
|
|
92
|
+
for key, value in tokenizer_config.items():
|
|
93
|
+
if "token" in key:
|
|
94
|
+
if isinstance(value, str):
|
|
95
|
+
special_tokens[key] = value
|
|
96
|
+
elif isinstance(value, dict) and value.get("__type") == "AddedToken":
|
|
97
|
+
special_tokens[key] = value.get("content")
|
|
98
|
+
|
|
99
|
+
# 3. compile template and return
|
|
100
|
+
env = minijinja.Environment()
|
|
101
|
+
try:
|
|
102
|
+
env.add_template("chat_template", chat_template)
|
|
103
|
+
except minijinja.TemplateError as e:
|
|
104
|
+
raise TemplateError(f"Error while trying to compile chat template for model '{model_id}': {e}") from e
|
|
105
|
+
return lambda **kwargs: env.render_template("chat_template", **kwargs, **special_tokens)
|
|
@@ -12,42 +12,13 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
from typing import TYPE_CHECKING, List, TypedDict
|
|
16
15
|
|
|
16
|
+
from typing import List, TypedDict
|
|
17
17
|
|
|
18
|
-
if TYPE_CHECKING:
|
|
19
|
-
from PIL import Image
|
|
20
18
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
label (`str`):
|
|
27
|
-
The label of the audio file.
|
|
28
|
-
content-type (`str`):
|
|
29
|
-
The content type of audio file.
|
|
30
|
-
blob (`bytes`):
|
|
31
|
-
The audio file in byte format.
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
label: str
|
|
35
|
-
content_type: str
|
|
36
|
-
blob: bytes
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class ClassificationOutput(TypedDict):
|
|
40
|
-
"""Dictionary containing the output of a [`~InferenceClient.audio_classification`] and [`~InferenceClient.image_classification`] task.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
label (`str`):
|
|
44
|
-
The label predicted by the model.
|
|
45
|
-
score (`float`):
|
|
46
|
-
The score of the label predicted by the model.
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
|
-
label: str
|
|
50
|
-
score: float
|
|
19
|
+
# Legacy types
|
|
20
|
+
# Types are now generated from the JSON schema spec in @huggingface/tasks.
|
|
21
|
+
# See ./src/huggingface_hub/inference/_generated/types
|
|
51
22
|
|
|
52
23
|
|
|
53
24
|
class ConversationalOutputConversation(TypedDict):
|
|
@@ -79,122 +50,3 @@ class ConversationalOutput(TypedDict):
|
|
|
79
50
|
conversation: ConversationalOutputConversation
|
|
80
51
|
generated_text: str
|
|
81
52
|
warnings: List[str]
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class FillMaskOutput(TypedDict):
|
|
85
|
-
"""Dictionary containing information about a [`~InferenceClient.fill_mask`] task.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
score (`float`):
|
|
89
|
-
The probability of the token.
|
|
90
|
-
token (`int`):
|
|
91
|
-
The id of the token.
|
|
92
|
-
token_str (`str`):
|
|
93
|
-
The string representation of the token.
|
|
94
|
-
sequence (`str`):
|
|
95
|
-
The actual sequence of tokens that ran against the model (may contain special tokens).
|
|
96
|
-
"""
|
|
97
|
-
|
|
98
|
-
score: float
|
|
99
|
-
token: int
|
|
100
|
-
token_str: str
|
|
101
|
-
sequence: str
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
class ImageSegmentationOutput(TypedDict):
|
|
105
|
-
"""Dictionary containing information about a [`~InferenceClient.image_segmentation`] task. In practice, image segmentation returns a
|
|
106
|
-
list of `ImageSegmentationOutput` with 1 item per mask.
|
|
107
|
-
|
|
108
|
-
Args:
|
|
109
|
-
label (`str`):
|
|
110
|
-
The label corresponding to the mask.
|
|
111
|
-
mask (`Image`):
|
|
112
|
-
An Image object representing the mask predicted by the model.
|
|
113
|
-
score (`float`):
|
|
114
|
-
The score associated with the label for this mask.
|
|
115
|
-
"""
|
|
116
|
-
|
|
117
|
-
label: str
|
|
118
|
-
mask: "Image"
|
|
119
|
-
score: float
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
class ObjectDetectionOutput(TypedDict):
|
|
123
|
-
"""Dictionary containing information about a [`~InferenceClient.object_detection`] task.
|
|
124
|
-
|
|
125
|
-
Args:
|
|
126
|
-
label (`str`):
|
|
127
|
-
The label corresponding to the detected object.
|
|
128
|
-
box (`dict`):
|
|
129
|
-
A dict response of bounding box coordinates of
|
|
130
|
-
the detected object: xmin, ymin, xmax, ymax
|
|
131
|
-
score (`float`):
|
|
132
|
-
The score corresponding to the detected object.
|
|
133
|
-
"""
|
|
134
|
-
|
|
135
|
-
label: str
|
|
136
|
-
box: dict
|
|
137
|
-
score: float
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class QuestionAnsweringOutput(TypedDict):
|
|
141
|
-
"""Dictionary containing information about a [`~InferenceClient.question_answering`] task.
|
|
142
|
-
|
|
143
|
-
Args:
|
|
144
|
-
score (`float`):
|
|
145
|
-
A float that represents how likely that the answer is correct.
|
|
146
|
-
start (`int`):
|
|
147
|
-
The index (string wise) of the start of the answer within context.
|
|
148
|
-
end (`int`):
|
|
149
|
-
The index (string wise) of the end of the answer within context.
|
|
150
|
-
answer (`str`):
|
|
151
|
-
A string that is the answer within the text.
|
|
152
|
-
"""
|
|
153
|
-
|
|
154
|
-
score: float
|
|
155
|
-
start: int
|
|
156
|
-
end: int
|
|
157
|
-
answer: str
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
class TableQuestionAnsweringOutput(TypedDict):
|
|
161
|
-
"""Dictionary containing information about a [`~InferenceClient.table_question_answering`] task.
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
answer (`str`):
|
|
165
|
-
The plaintext answer.
|
|
166
|
-
coordinates (`List[List[int]]`):
|
|
167
|
-
A list of coordinates of the cells referenced in the answer.
|
|
168
|
-
cells (`List[int]`):
|
|
169
|
-
A list of coordinates of the cells contents.
|
|
170
|
-
aggregator (`str`):
|
|
171
|
-
The aggregator used to get the answer.
|
|
172
|
-
"""
|
|
173
|
-
|
|
174
|
-
answer: str
|
|
175
|
-
coordinates: List[List[int]]
|
|
176
|
-
cells: List[List[int]]
|
|
177
|
-
aggregator: str
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
class TokenClassificationOutput(TypedDict):
|
|
181
|
-
"""Dictionary containing the output of a [`~InferenceClient.token_classification`] task.
|
|
182
|
-
|
|
183
|
-
Args:
|
|
184
|
-
entity_group (`str`):
|
|
185
|
-
The type for the entity being recognized (model specific).
|
|
186
|
-
score (`float`):
|
|
187
|
-
The score of the label predicted by the model.
|
|
188
|
-
word (`str`):
|
|
189
|
-
The string that was captured.
|
|
190
|
-
start (`int`):
|
|
191
|
-
The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
|
|
192
|
-
end (`int`):
|
|
193
|
-
The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
|
|
194
|
-
"""
|
|
195
|
-
|
|
196
|
-
entity_group: str
|
|
197
|
-
score: float
|
|
198
|
-
word: str
|
|
199
|
-
start: int
|
|
200
|
-
end: int
|
huggingface_hub/keras_mixin.py
CHANGED
|
@@ -2,6 +2,7 @@ import collections.abc as collections
|
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
4
|
import warnings
|
|
5
|
+
from functools import wraps
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from shutil import copytree
|
|
7
8
|
from typing import Any, Dict, List, Optional, Union
|
|
@@ -18,12 +19,37 @@ from huggingface_hub.utils import (
|
|
|
18
19
|
from .constants import CONFIG_NAME
|
|
19
20
|
from .hf_api import HfApi
|
|
20
21
|
from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
|
|
22
|
+
from .utils._typing import CallableT
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
logger = logging.get_logger(__name__)
|
|
24
26
|
|
|
27
|
+
keras = None
|
|
25
28
|
if is_tf_available():
|
|
26
|
-
|
|
29
|
+
# Depending on which version of TensorFlow is installed, we need to import
|
|
30
|
+
# keras from the correct location.
|
|
31
|
+
# See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1.
|
|
32
|
+
# Note: saving a keras model only works with Keras<3.0.
|
|
33
|
+
try:
|
|
34
|
+
import tf_keras as keras # type: ignore
|
|
35
|
+
except ImportError:
|
|
36
|
+
import tensorflow as tf # type: ignore
|
|
37
|
+
|
|
38
|
+
keras = tf.keras
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _requires_keras_2_model(fn: CallableT) -> CallableT:
|
|
42
|
+
# Wrapper to raise if user tries to save a Keras 3.x model
|
|
43
|
+
@wraps(fn)
|
|
44
|
+
def _inner(model, *args, **kwargs):
|
|
45
|
+
if not hasattr(model, "history"): # hacky way to check if model is Keras 2.x
|
|
46
|
+
raise NotImplementedError(
|
|
47
|
+
f"Cannot use '{fn.__name__}': Keras 3.x is not supported."
|
|
48
|
+
" Please save models manually and upload them using `upload_folder` or `huggingface-cli upload`."
|
|
49
|
+
)
|
|
50
|
+
return fn(model, *args, **kwargs)
|
|
51
|
+
|
|
52
|
+
return _inner # type: ignore [return-value]
|
|
27
53
|
|
|
28
54
|
|
|
29
55
|
def _flatten_dict(dictionary, parent_key=""):
|
|
@@ -57,21 +83,20 @@ def _flatten_dict(dictionary, parent_key=""):
|
|
|
57
83
|
|
|
58
84
|
def _create_hyperparameter_table(model):
|
|
59
85
|
"""Parse hyperparameter dictionary into a markdown table."""
|
|
86
|
+
table = None
|
|
60
87
|
if model.optimizer is not None:
|
|
61
88
|
optimizer_params = model.optimizer.get_config()
|
|
62
89
|
# flatten the configuration
|
|
63
90
|
optimizer_params = _flatten_dict(optimizer_params)
|
|
64
|
-
optimizer_params["training_precision"] =
|
|
91
|
+
optimizer_params["training_precision"] = keras.mixed_precision.global_policy().name
|
|
65
92
|
table = "| Hyperparameters | Value |\n| :-- | :-- |\n"
|
|
66
93
|
for key, value in optimizer_params.items():
|
|
67
94
|
table += f"| {key} | {value} |\n"
|
|
68
|
-
else:
|
|
69
|
-
table = None
|
|
70
95
|
return table
|
|
71
96
|
|
|
72
97
|
|
|
73
98
|
def _plot_network(model, save_directory):
|
|
74
|
-
|
|
99
|
+
keras.utils.plot_model(
|
|
75
100
|
model,
|
|
76
101
|
to_file=f"{save_directory}/model.png",
|
|
77
102
|
show_shapes=False,
|
|
@@ -128,6 +153,7 @@ def _create_model_card(
|
|
|
128
153
|
readme_path.write_text(model_card)
|
|
129
154
|
|
|
130
155
|
|
|
156
|
+
@_requires_keras_2_model
|
|
131
157
|
def save_pretrained_keras(
|
|
132
158
|
model,
|
|
133
159
|
save_directory: Union[str, Path],
|
|
@@ -162,9 +188,7 @@ def save_pretrained_keras(
|
|
|
162
188
|
model_save_kwargs will be passed to
|
|
163
189
|
[`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
|
|
164
190
|
"""
|
|
165
|
-
if
|
|
166
|
-
import tensorflow as tf
|
|
167
|
-
else:
|
|
191
|
+
if keras is None:
|
|
168
192
|
raise ImportError("Called a Tensorflow-specific function but could not import it.")
|
|
169
193
|
|
|
170
194
|
if not model.built:
|
|
@@ -210,7 +234,7 @@ def save_pretrained_keras(
|
|
|
210
234
|
json.dump(model.history.history, f, indent=2, sort_keys=True)
|
|
211
235
|
|
|
212
236
|
_create_model_card(model, save_directory, plot_model, metadata)
|
|
213
|
-
|
|
237
|
+
keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs)
|
|
214
238
|
|
|
215
239
|
|
|
216
240
|
def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
|
|
@@ -273,6 +297,7 @@ def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
|
|
|
273
297
|
|
|
274
298
|
|
|
275
299
|
@validate_hf_hub_args
|
|
300
|
+
@_requires_keras_2_model
|
|
276
301
|
def push_to_hub_keras(
|
|
277
302
|
model,
|
|
278
303
|
repo_id: str,
|
|
@@ -444,6 +469,7 @@ class KerasModelHubMixin(ModelHubMixin):
|
|
|
444
469
|
resume_download,
|
|
445
470
|
local_files_only,
|
|
446
471
|
token,
|
|
472
|
+
config: Optional[Dict[str, Any]] = None,
|
|
447
473
|
**model_kwargs,
|
|
448
474
|
):
|
|
449
475
|
"""Here we just call [`from_pretrained_keras`] function so both the mixin and
|
|
@@ -452,14 +478,9 @@ class KerasModelHubMixin(ModelHubMixin):
|
|
|
452
478
|
TODO - Some args above aren't used since we are calling
|
|
453
479
|
snapshot_download instead of hf_hub_download.
|
|
454
480
|
"""
|
|
455
|
-
if
|
|
456
|
-
import tensorflow as tf
|
|
457
|
-
else:
|
|
481
|
+
if keras is None:
|
|
458
482
|
raise ImportError("Called a TensorFlow-specific function but could not import it.")
|
|
459
483
|
|
|
460
|
-
# TODO - Figure out what to do about these config values. Config is not going to be needed to load model
|
|
461
|
-
cfg = model_kwargs.pop("config", None)
|
|
462
|
-
|
|
463
484
|
# Root is either a local filepath matching model_id or a cached snapshot
|
|
464
485
|
if not os.path.isdir(model_id):
|
|
465
486
|
storage_folder = snapshot_download(
|
|
@@ -472,9 +493,10 @@ class KerasModelHubMixin(ModelHubMixin):
|
|
|
472
493
|
else:
|
|
473
494
|
storage_folder = model_id
|
|
474
495
|
|
|
475
|
-
|
|
496
|
+
# TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here...
|
|
497
|
+
model = keras.models.load_model(storage_folder)
|
|
476
498
|
|
|
477
499
|
# For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
|
|
478
|
-
model.config =
|
|
500
|
+
model.config = config
|
|
479
501
|
|
|
480
502
|
return model
|
huggingface_hub/lfs.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Git LFS related type definitions and utilities"""
|
|
16
|
+
|
|
16
17
|
import inspect
|
|
17
18
|
import io
|
|
18
19
|
import os
|
|
@@ -27,10 +28,10 @@ from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tupl
|
|
|
27
28
|
from urllib.parse import unquote
|
|
28
29
|
|
|
29
30
|
from huggingface_hub.constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER, REPO_TYPES_URL_PREFIXES
|
|
30
|
-
from huggingface_hub.utils import get_session
|
|
31
31
|
|
|
32
32
|
from .utils import (
|
|
33
33
|
build_hf_headers,
|
|
34
|
+
get_session,
|
|
34
35
|
hf_raise_for_status,
|
|
35
36
|
http_backoff,
|
|
36
37
|
logging,
|
|
@@ -105,6 +106,7 @@ def post_lfs_batch_info(
|
|
|
105
106
|
repo_id: str,
|
|
106
107
|
revision: Optional[str] = None,
|
|
107
108
|
endpoint: Optional[str] = None,
|
|
109
|
+
headers: Optional[Dict[str, str]] = None,
|
|
108
110
|
) -> Tuple[List[dict], List[dict]]:
|
|
109
111
|
"""
|
|
110
112
|
Requests the LFS batch endpoint to retrieve upload instructions
|
|
@@ -120,10 +122,10 @@ def post_lfs_batch_info(
|
|
|
120
122
|
repo_id (`str`):
|
|
121
123
|
A namespace (user or an organization) and a repo name separated
|
|
122
124
|
by a `/`.
|
|
123
|
-
token (`str`, *optional*):
|
|
124
|
-
An authentication token ( See https://huggingface.co/settings/tokens )
|
|
125
125
|
revision (`str`, *optional*):
|
|
126
126
|
The git revision to upload to.
|
|
127
|
+
headers (`dict`, *optional*):
|
|
128
|
+
Additional headers to include in the request
|
|
127
129
|
|
|
128
130
|
Returns:
|
|
129
131
|
`LfsBatchInfo`: 2-tuple:
|
|
@@ -154,7 +156,12 @@ def post_lfs_batch_info(
|
|
|
154
156
|
}
|
|
155
157
|
if revision is not None:
|
|
156
158
|
payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted'
|
|
157
|
-
|
|
159
|
+
|
|
160
|
+
headers = {
|
|
161
|
+
**LFS_HEADERS,
|
|
162
|
+
**build_hf_headers(token=token),
|
|
163
|
+
**(headers or {}),
|
|
164
|
+
}
|
|
158
165
|
resp = get_session().post(batch_url, headers=headers, json=payload)
|
|
159
166
|
hf_raise_for_status(resp)
|
|
160
167
|
batch_info = resp.json()
|
|
@@ -181,7 +188,12 @@ class CompletionPayloadT(TypedDict):
|
|
|
181
188
|
parts: List[PayloadPartT]
|
|
182
189
|
|
|
183
190
|
|
|
184
|
-
def lfs_upload(
|
|
191
|
+
def lfs_upload(
|
|
192
|
+
operation: "CommitOperationAdd",
|
|
193
|
+
lfs_batch_action: Dict,
|
|
194
|
+
token: Optional[str] = None,
|
|
195
|
+
headers: Optional[Dict[str, str]] = None,
|
|
196
|
+
) -> None:
|
|
185
197
|
"""
|
|
186
198
|
Handles uploading a given object to the Hub with the LFS protocol.
|
|
187
199
|
|
|
@@ -193,8 +205,8 @@ def lfs_upload(operation: "CommitOperationAdd", lfs_batch_action: Dict, token: O
|
|
|
193
205
|
lfs_batch_action (`dict`):
|
|
194
206
|
Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for
|
|
195
207
|
more details.
|
|
196
|
-
|
|
197
|
-
|
|
208
|
+
headers (`dict`, *optional*):
|
|
209
|
+
Headers to include in the request, including authentication and user agent headers.
|
|
198
210
|
|
|
199
211
|
Raises:
|
|
200
212
|
- `ValueError` if `lfs_batch_action` is improperly formatted
|
|
@@ -234,7 +246,7 @@ def lfs_upload(operation: "CommitOperationAdd", lfs_batch_action: Dict, token: O
|
|
|
234
246
|
_validate_lfs_action(verify_action)
|
|
235
247
|
verify_resp = get_session().post(
|
|
236
248
|
verify_action["href"],
|
|
237
|
-
headers=build_hf_headers(token=token
|
|
249
|
+
headers=build_hf_headers(token=token, headers=headers),
|
|
238
250
|
json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size},
|
|
239
251
|
)
|
|
240
252
|
hf_raise_for_status(verify_resp)
|
huggingface_hub/repocard.py
CHANGED
|
@@ -294,6 +294,7 @@ class RepoCard:
|
|
|
294
294
|
cls,
|
|
295
295
|
card_data: CardData,
|
|
296
296
|
template_path: Optional[str] = None,
|
|
297
|
+
template_str: Optional[str] = None,
|
|
297
298
|
**template_kwargs,
|
|
298
299
|
):
|
|
299
300
|
"""Initialize a RepoCard from a template. By default, it uses the default template.
|
|
@@ -322,7 +323,12 @@ class RepoCard:
|
|
|
322
323
|
|
|
323
324
|
kwargs = card_data.to_dict().copy()
|
|
324
325
|
kwargs.update(template_kwargs) # Template_kwargs have priority
|
|
325
|
-
|
|
326
|
+
|
|
327
|
+
if template_path is not None:
|
|
328
|
+
template_str = Path(template_path).read_text()
|
|
329
|
+
if template_str is None:
|
|
330
|
+
template_str = Path(cls.default_template_path).read_text()
|
|
331
|
+
template = jinja2.Template(template_str)
|
|
326
332
|
content = template.render(card_data=card_data.to_yaml(), **kwargs)
|
|
327
333
|
return cls(content)
|
|
328
334
|
|
|
@@ -337,6 +343,7 @@ class ModelCard(RepoCard):
|
|
|
337
343
|
cls,
|
|
338
344
|
card_data: ModelCardData,
|
|
339
345
|
template_path: Optional[str] = None,
|
|
346
|
+
template_str: Optional[str] = None,
|
|
340
347
|
**template_kwargs,
|
|
341
348
|
):
|
|
342
349
|
"""Initialize a ModelCard from a template. By default, it uses the default template, which can be found here:
|
|
@@ -404,7 +411,7 @@ class ModelCard(RepoCard):
|
|
|
404
411
|
|
|
405
412
|
```
|
|
406
413
|
"""
|
|
407
|
-
return super().from_template(card_data, template_path, **template_kwargs)
|
|
414
|
+
return super().from_template(card_data, template_path, template_str, **template_kwargs)
|
|
408
415
|
|
|
409
416
|
|
|
410
417
|
class DatasetCard(RepoCard):
|
|
@@ -417,6 +424,7 @@ class DatasetCard(RepoCard):
|
|
|
417
424
|
cls,
|
|
418
425
|
card_data: DatasetCardData,
|
|
419
426
|
template_path: Optional[str] = None,
|
|
427
|
+
template_str: Optional[str] = None,
|
|
420
428
|
**template_kwargs,
|
|
421
429
|
):
|
|
422
430
|
"""Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here:
|
|
@@ -468,7 +476,7 @@ class DatasetCard(RepoCard):
|
|
|
468
476
|
|
|
469
477
|
```
|
|
470
478
|
"""
|
|
471
|
-
return super().from_template(card_data, template_path, **template_kwargs)
|
|
479
|
+
return super().from_template(card_data, template_path, template_str, **template_kwargs)
|
|
472
480
|
|
|
473
481
|
|
|
474
482
|
class SpaceCard(RepoCard):
|
huggingface_hub/repocard_data.py
CHANGED
|
@@ -312,7 +312,7 @@ class ModelCardData(CardData):
|
|
|
312
312
|
self.language = language
|
|
313
313
|
self.license = license
|
|
314
314
|
self.library_name = library_name
|
|
315
|
-
self.tags = tags
|
|
315
|
+
self.tags = _to_unique_list(tags)
|
|
316
316
|
self.base_model = base_model
|
|
317
317
|
self.datasets = datasets
|
|
318
318
|
self.metrics = metrics
|
|
@@ -507,7 +507,7 @@ class SpaceCardData(CardData):
|
|
|
507
507
|
self.duplicated_from = duplicated_from
|
|
508
508
|
self.models = models
|
|
509
509
|
self.datasets = datasets
|
|
510
|
-
self.tags = tags
|
|
510
|
+
self.tags = _to_unique_list(tags)
|
|
511
511
|
super().__init__(**kwargs)
|
|
512
512
|
|
|
513
513
|
|
|
@@ -717,3 +717,13 @@ def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult])
|
|
|
717
717
|
}
|
|
718
718
|
]
|
|
719
719
|
return _remove_none(model_index)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _to_unique_list(tags: Optional[List[str]]) -> Optional[List[str]]:
|
|
723
|
+
if tags is None:
|
|
724
|
+
return tags
|
|
725
|
+
unique_tags = [] # make tags unique + keep order explicitly
|
|
726
|
+
for tag in tags:
|
|
727
|
+
if tag not in unique_tags:
|
|
728
|
+
unique_tags.append(tag)
|
|
729
|
+
return unique_tags
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ruff: noqa: F401
|
|
15
15
|
"""Contains helpers to serialize tensors."""
|
|
16
|
+
|
|
16
17
|
from ._base import StateDictSplit, split_state_dict_into_shards_factory
|
|
17
18
|
from ._numpy import split_numpy_state_dict_into_shards
|
|
18
19
|
from ._tensorflow import split_tf_state_dict_into_shards
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Contains helpers to split tensors into shards."""
|
|
15
|
+
|
|
15
16
|
from dataclasses import dataclass, field
|
|
16
17
|
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
|
17
18
|
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Contains numpy-specific helpers."""
|
|
15
|
+
|
|
15
16
|
from typing import TYPE_CHECKING, Dict
|
|
16
17
|
|
|
17
18
|
from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
|