huggingface-hub 0.23.3__py3-none-any.whl → 0.24.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +47 -15
- huggingface_hub/_commit_api.py +38 -8
- huggingface_hub/_inference_endpoints.py +11 -4
- huggingface_hub/_local_folder.py +22 -13
- huggingface_hub/_snapshot_download.py +12 -7
- huggingface_hub/_webhooks_server.py +3 -1
- huggingface_hub/commands/huggingface_cli.py +4 -3
- huggingface_hub/commands/repo_files.py +128 -0
- huggingface_hub/constants.py +12 -0
- huggingface_hub/file_download.py +127 -91
- huggingface_hub/hf_api.py +979 -341
- huggingface_hub/hf_file_system.py +30 -3
- huggingface_hub/hub_mixin.py +103 -41
- huggingface_hub/inference/_client.py +373 -42
- huggingface_hub/inference/_common.py +0 -2
- huggingface_hub/inference/_generated/_async_client.py +390 -48
- huggingface_hub/inference/_generated/types/__init__.py +4 -1
- huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
- huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
- huggingface_hub/inference/_generated/types/text_generation.py +29 -0
- huggingface_hub/lfs.py +11 -6
- huggingface_hub/repocard_data.py +41 -29
- huggingface_hub/repository.py +6 -6
- huggingface_hub/serialization/__init__.py +8 -3
- huggingface_hub/serialization/_base.py +13 -16
- huggingface_hub/serialization/_tensorflow.py +4 -3
- huggingface_hub/serialization/_torch.py +399 -22
- huggingface_hub/utils/__init__.py +1 -2
- huggingface_hub/utils/_errors.py +1 -1
- huggingface_hub/utils/_fixes.py +14 -3
- huggingface_hub/utils/_paths.py +17 -6
- huggingface_hub/utils/_subprocess.py +0 -1
- huggingface_hub/utils/_telemetry.py +9 -1
- huggingface_hub/utils/_typing.py +26 -1
- huggingface_hub/utils/endpoint_helpers.py +2 -186
- huggingface_hub/utils/sha.py +36 -1
- huggingface_hub/utils/tqdm.py +0 -1
- {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/METADATA +12 -9
- {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/RECORD +43 -43
- huggingface_hub/serialization/_numpy.py +0 -68
- {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -20,10 +20,13 @@ from .base import BaseInferenceType
|
|
|
20
20
|
from .chat_completion import (
|
|
21
21
|
ChatCompletionInput,
|
|
22
22
|
ChatCompletionInputFunctionDefinition,
|
|
23
|
+
ChatCompletionInputFunctionName,
|
|
24
|
+
ChatCompletionInputGrammarType,
|
|
23
25
|
ChatCompletionInputMessage,
|
|
26
|
+
ChatCompletionInputMessageChunk,
|
|
24
27
|
ChatCompletionInputTool,
|
|
25
|
-
ChatCompletionInputToolCall,
|
|
26
28
|
ChatCompletionInputToolTypeClass,
|
|
29
|
+
ChatCompletionInputURL,
|
|
27
30
|
ChatCompletionOutput,
|
|
28
31
|
ChatCompletionOutputComplete,
|
|
29
32
|
ChatCompletionOutputFunctionDefinition,
|
|
@@ -10,33 +10,55 @@ from .base import BaseInferenceType
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@dataclass
|
|
13
|
-
class
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
13
|
+
class ChatCompletionInputURL(BaseInferenceType):
|
|
14
|
+
url: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
ChatCompletionInputMessageChunkType = Literal["text", "image_url"]
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
@dataclass
|
|
20
|
-
class
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
class ChatCompletionInputMessageChunk(BaseInferenceType):
|
|
22
|
+
type: "ChatCompletionInputMessageChunkType"
|
|
23
|
+
image_url: Optional[ChatCompletionInputURL] = None
|
|
24
|
+
text: Optional[str] = None
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
@dataclass
|
|
27
28
|
class ChatCompletionInputMessage(BaseInferenceType):
|
|
29
|
+
content: Union[List[ChatCompletionInputMessageChunk], str]
|
|
28
30
|
role: str
|
|
29
|
-
content: Optional[str] = None
|
|
30
31
|
name: Optional[str] = None
|
|
31
|
-
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
ChatCompletionInputGrammarTypeType = Literal["json", "regex"]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ChatCompletionInputGrammarType(BaseInferenceType):
|
|
39
|
+
type: "ChatCompletionInputGrammarTypeType"
|
|
40
|
+
value: Any
|
|
41
|
+
"""A string that represents a [JSON Schema](https://json-schema.org/).
|
|
42
|
+
JSON Schema is a declarative language that allows to annotate JSON documents
|
|
43
|
+
with types and descriptions.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ChatCompletionInputFunctionName(BaseInferenceType):
|
|
49
|
+
name: str
|
|
32
50
|
|
|
33
51
|
|
|
34
52
|
@dataclass
|
|
35
53
|
class ChatCompletionInputToolTypeClass(BaseInferenceType):
|
|
36
|
-
|
|
54
|
+
function: Optional[ChatCompletionInputFunctionName] = None
|
|
37
55
|
|
|
38
56
|
|
|
39
|
-
|
|
57
|
+
@dataclass
|
|
58
|
+
class ChatCompletionInputFunctionDefinition(BaseInferenceType):
|
|
59
|
+
arguments: Any
|
|
60
|
+
name: str
|
|
61
|
+
description: Optional[str] = None
|
|
40
62
|
|
|
41
63
|
|
|
42
64
|
@dataclass
|
|
@@ -55,10 +77,6 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
55
77
|
|
|
56
78
|
messages: List[ChatCompletionInputMessage]
|
|
57
79
|
"""A list of messages comprising the conversation so far."""
|
|
58
|
-
model: str
|
|
59
|
-
"""[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
|
|
60
|
-
on which models work with the Chat API.
|
|
61
|
-
"""
|
|
62
80
|
frequency_penalty: Optional[float] = None
|
|
63
81
|
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
|
|
64
82
|
frequency in the text so far,
|
|
@@ -83,6 +101,10 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
83
101
|
"""
|
|
84
102
|
max_tokens: Optional[int] = None
|
|
85
103
|
"""The maximum number of tokens that can be generated in the chat completion."""
|
|
104
|
+
model: Optional[str] = None
|
|
105
|
+
"""[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
|
|
106
|
+
on which models work with the Chat API.
|
|
107
|
+
"""
|
|
86
108
|
n: Optional[int] = None
|
|
87
109
|
"""UNUSED
|
|
88
110
|
How many chat completion choices to generate for each input message. Note that you will
|
|
@@ -94,6 +116,7 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
94
116
|
appear in the text so far,
|
|
95
117
|
increasing the model's likelihood to talk about new topics
|
|
96
118
|
"""
|
|
119
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None
|
|
97
120
|
seed: Optional[int] = None
|
|
98
121
|
stop: Optional[List[str]] = None
|
|
99
122
|
"""Up to 4 sequences where the API will stop generating further tokens."""
|
|
@@ -104,7 +127,7 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
104
127
|
lower values like 0.2 will make it more focused and deterministic.
|
|
105
128
|
We generally recommend altering this or `top_p` but not both.
|
|
106
129
|
"""
|
|
107
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
130
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None
|
|
108
131
|
tool_prompt: Optional[str] = None
|
|
109
132
|
"""A prompt to be appended before the tools"""
|
|
110
133
|
tools: Optional[List[ChatCompletionInputTool]] = None
|
|
@@ -153,7 +176,7 @@ class ChatCompletionOutputFunctionDefinition(BaseInferenceType):
|
|
|
153
176
|
@dataclass
|
|
154
177
|
class ChatCompletionOutputToolCall(BaseInferenceType):
|
|
155
178
|
function: ChatCompletionOutputFunctionDefinition
|
|
156
|
-
id:
|
|
179
|
+
id: str
|
|
157
180
|
type: str
|
|
158
181
|
|
|
159
182
|
|
|
@@ -161,7 +184,6 @@ class ChatCompletionOutputToolCall(BaseInferenceType):
|
|
|
161
184
|
class ChatCompletionOutputMessage(BaseInferenceType):
|
|
162
185
|
role: str
|
|
163
186
|
content: Optional[str] = None
|
|
164
|
-
name: Optional[str] = None
|
|
165
187
|
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
|
|
166
188
|
|
|
167
189
|
|
|
@@ -192,7 +214,6 @@ class ChatCompletionOutput(BaseInferenceType):
|
|
|
192
214
|
created: int
|
|
193
215
|
id: str
|
|
194
216
|
model: str
|
|
195
|
-
object: str
|
|
196
217
|
system_fingerprint: str
|
|
197
218
|
usage: ChatCompletionOutputUsage
|
|
198
219
|
|
|
@@ -256,5 +277,4 @@ class ChatCompletionStreamOutput(BaseInferenceType):
|
|
|
256
277
|
created: int
|
|
257
278
|
id: str
|
|
258
279
|
model: str
|
|
259
|
-
object: str
|
|
260
280
|
system_fingerprint: str
|
|
@@ -4,16 +4,34 @@
|
|
|
4
4
|
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
|
5
5
|
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Literal, Optional
|
|
8
8
|
|
|
9
9
|
from .base import BaseInferenceType
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]
|
|
13
|
+
|
|
14
|
+
|
|
12
15
|
@dataclass
|
|
13
16
|
class FeatureExtractionInput(BaseInferenceType):
|
|
14
|
-
"""
|
|
17
|
+
"""Feature Extraction Input.
|
|
18
|
+
Auto-generated from TEI specs.
|
|
19
|
+
For more details, check out
|
|
20
|
+
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
|
|
21
|
+
"""
|
|
15
22
|
|
|
16
23
|
inputs: str
|
|
17
|
-
"""The text to
|
|
18
|
-
|
|
19
|
-
|
|
24
|
+
"""The text to embed."""
|
|
25
|
+
normalize: Optional[bool] = None
|
|
26
|
+
prompt_name: Optional[str] = None
|
|
27
|
+
"""The name of the prompt that should be used by for encoding. If not set, no prompt
|
|
28
|
+
will be applied.
|
|
29
|
+
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
|
|
30
|
+
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",
|
|
31
|
+
...},
|
|
32
|
+
then the sentence "What is the capital of France?" will be encoded as
|
|
33
|
+
"query: What is the capital of France?" because the prompt text will be prepended before
|
|
34
|
+
any text to encode.
|
|
35
|
+
"""
|
|
36
|
+
truncate: Optional[bool] = None
|
|
37
|
+
truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None
|
|
@@ -24,24 +24,53 @@ class TextGenerationInputGrammarType(BaseInferenceType):
|
|
|
24
24
|
|
|
25
25
|
@dataclass
|
|
26
26
|
class TextGenerationInputGenerateParameters(BaseInferenceType):
|
|
27
|
+
adapter_id: Optional[str] = None
|
|
28
|
+
"""Lora adapter id"""
|
|
27
29
|
best_of: Optional[int] = None
|
|
30
|
+
"""Generate best_of sequences and return the one if the highest token logprobs."""
|
|
28
31
|
decoder_input_details: Optional[bool] = None
|
|
32
|
+
"""Whether to return decoder input token logprobs and ids."""
|
|
29
33
|
details: Optional[bool] = None
|
|
34
|
+
"""Whether to return generation details."""
|
|
30
35
|
do_sample: Optional[bool] = None
|
|
36
|
+
"""Activate logits sampling."""
|
|
31
37
|
frequency_penalty: Optional[float] = None
|
|
38
|
+
"""The parameter for frequency penalty. 1.0 means no penalty
|
|
39
|
+
Penalize new tokens based on their existing frequency in the text so far,
|
|
40
|
+
decreasing the model's likelihood to repeat the same line verbatim.
|
|
41
|
+
"""
|
|
32
42
|
grammar: Optional[TextGenerationInputGrammarType] = None
|
|
33
43
|
max_new_tokens: Optional[int] = None
|
|
44
|
+
"""Maximum number of tokens to generate."""
|
|
34
45
|
repetition_penalty: Optional[float] = None
|
|
46
|
+
"""The parameter for repetition penalty. 1.0 means no penalty.
|
|
47
|
+
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
|
48
|
+
"""
|
|
35
49
|
return_full_text: Optional[bool] = None
|
|
50
|
+
"""Whether to prepend the prompt to the generated text"""
|
|
36
51
|
seed: Optional[int] = None
|
|
52
|
+
"""Random sampling seed."""
|
|
37
53
|
stop: Optional[List[str]] = None
|
|
54
|
+
"""Stop generating tokens if a member of `stop` is generated."""
|
|
38
55
|
temperature: Optional[float] = None
|
|
56
|
+
"""The value used to module the logits distribution."""
|
|
39
57
|
top_k: Optional[int] = None
|
|
58
|
+
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
|
40
59
|
top_n_tokens: Optional[int] = None
|
|
60
|
+
"""The number of highest probability vocabulary tokens to keep for top-n-filtering."""
|
|
41
61
|
top_p: Optional[float] = None
|
|
62
|
+
"""Top-p value for nucleus sampling."""
|
|
42
63
|
truncate: Optional[int] = None
|
|
64
|
+
"""Truncate inputs tokens to the given size."""
|
|
43
65
|
typical_p: Optional[float] = None
|
|
66
|
+
"""Typical Decoding mass
|
|
67
|
+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666)
|
|
68
|
+
for more information.
|
|
69
|
+
"""
|
|
44
70
|
watermark: Optional[bool] = None
|
|
71
|
+
"""Watermarking with [A Watermark for Large Language
|
|
72
|
+
Models](https://arxiv.org/abs/2301.10226).
|
|
73
|
+
"""
|
|
45
74
|
|
|
46
75
|
|
|
47
76
|
@dataclass
|
huggingface_hub/lfs.py
CHANGED
|
@@ -134,9 +134,10 @@ def post_lfs_batch_info(
|
|
|
134
134
|
- Second element is an list of errors, if any
|
|
135
135
|
|
|
136
136
|
Raises:
|
|
137
|
-
`ValueError
|
|
138
|
-
|
|
139
|
-
`HTTPError
|
|
137
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
138
|
+
If an argument is invalid or the server response is malformed.
|
|
139
|
+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
|
140
|
+
If the server returned an error.
|
|
140
141
|
"""
|
|
141
142
|
endpoint = endpoint if endpoint is not None else ENDPOINT
|
|
142
143
|
url_prefix = ""
|
|
@@ -211,8 +212,10 @@ def lfs_upload(
|
|
|
211
212
|
Headers to include in the request, including authentication and user agent headers.
|
|
212
213
|
|
|
213
214
|
Raises:
|
|
214
|
-
|
|
215
|
-
|
|
215
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
216
|
+
If `lfs_batch_action` is improperly formatted
|
|
217
|
+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
|
218
|
+
If the upload resulted in an error
|
|
216
219
|
"""
|
|
217
220
|
# 0. If LFS file is already present, skip upload
|
|
218
221
|
_validate_batch_actions(lfs_batch_action)
|
|
@@ -307,7 +310,9 @@ def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> Non
|
|
|
307
310
|
|
|
308
311
|
Returns: `requests.Response`
|
|
309
312
|
|
|
310
|
-
Raises:
|
|
313
|
+
Raises:
|
|
314
|
+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
|
315
|
+
If the upload resulted in an error.
|
|
311
316
|
"""
|
|
312
317
|
with operation.as_file(with_tqdm=True) as fileobj:
|
|
313
318
|
# S3 might raise a transient 500 error -> let's retry if that happens
|
huggingface_hub/repocard_data.py
CHANGED
|
@@ -55,7 +55,7 @@ class EvalResult:
|
|
|
55
55
|
source_name (`str`, *optional*):
|
|
56
56
|
The name of the source of the evaluation result. Example: "Open LLM Leaderboard".
|
|
57
57
|
source_url (`str`, *optional*):
|
|
58
|
-
The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/
|
|
58
|
+
The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard".
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
61
|
# Required
|
|
@@ -128,7 +128,7 @@ class EvalResult:
|
|
|
128
128
|
source_name: Optional[str] = None
|
|
129
129
|
|
|
130
130
|
# The URL of the source of the evaluation result.
|
|
131
|
-
# Example: https://huggingface.co/spaces/
|
|
131
|
+
# Example: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard
|
|
132
132
|
source_url: Optional[str] = None
|
|
133
133
|
|
|
134
134
|
@property
|
|
@@ -242,19 +242,6 @@ class ModelCardData(CardData):
|
|
|
242
242
|
"""Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
|
|
243
243
|
|
|
244
244
|
Args:
|
|
245
|
-
language (`Union[str, List[str]]`, *optional*):
|
|
246
|
-
Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or
|
|
247
|
-
639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`.
|
|
248
|
-
license (`str`, *optional*):
|
|
249
|
-
License of this model. Example: apache-2.0 or any license from
|
|
250
|
-
https://huggingface.co/docs/hub/repositories-licenses. Defaults to None.
|
|
251
|
-
library_name (`str`, *optional*):
|
|
252
|
-
Name of library used by this model. Example: keras or any library from
|
|
253
|
-
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts.
|
|
254
|
-
Defaults to None.
|
|
255
|
-
tags (`List[str]`, *optional*):
|
|
256
|
-
List of tags to add to your model that can be used when filtering on the Hugging
|
|
257
|
-
Face Hub. Defaults to None.
|
|
258
245
|
base_model (`str` or `List[str]`, *optional*):
|
|
259
246
|
The identifier of the base model from which the model derives. This is applicable for example if your model is a
|
|
260
247
|
fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs
|
|
@@ -262,17 +249,36 @@ class ModelCardData(CardData):
|
|
|
262
249
|
datasets (`List[str]`, *optional*):
|
|
263
250
|
List of datasets that were used to train this model. Should be a dataset ID
|
|
264
251
|
found on https://hf.co/datasets. Defaults to None.
|
|
265
|
-
metrics (`List[str]`, *optional*):
|
|
266
|
-
List of metrics used to evaluate this model. Should be a metric name that can be found
|
|
267
|
-
at https://hf.co/metrics. Example: 'accuracy'. Defaults to None.
|
|
268
252
|
eval_results (`Union[List[EvalResult], EvalResult]`, *optional*):
|
|
269
253
|
List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided,
|
|
270
254
|
`model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`.
|
|
255
|
+
language (`Union[str, List[str]]`, *optional*):
|
|
256
|
+
Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or
|
|
257
|
+
639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`.
|
|
258
|
+
library_name (`str`, *optional*):
|
|
259
|
+
Name of library used by this model. Example: keras or any library from
|
|
260
|
+
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts.
|
|
261
|
+
Defaults to None.
|
|
262
|
+
license (`str`, *optional*):
|
|
263
|
+
License of this model. Example: apache-2.0 or any license from
|
|
264
|
+
https://huggingface.co/docs/hub/repositories-licenses. Defaults to None.
|
|
265
|
+
license_name (`str`, *optional*):
|
|
266
|
+
Name of the license of this model. Defaults to None. To be used in conjunction with `license_link`.
|
|
267
|
+
Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a name. In that case, use `license` instead.
|
|
268
|
+
license_link (`str`, *optional*):
|
|
269
|
+
Link to the license of this model. Defaults to None. To be used in conjunction with `license_name`.
|
|
270
|
+
Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a link. In that case, use `license` instead.
|
|
271
|
+
metrics (`List[str]`, *optional*):
|
|
272
|
+
List of metrics used to evaluate this model. Should be a metric name that can be found
|
|
273
|
+
at https://hf.co/metrics. Example: 'accuracy'. Defaults to None.
|
|
271
274
|
model_name (`str`, *optional*):
|
|
272
275
|
A name for this model. It is used along with
|
|
273
276
|
`eval_results` to construct the `model-index` within the card's metadata. The name
|
|
274
277
|
you supply here is what will be used on PapersWithCode's leaderboards. If None is provided
|
|
275
278
|
then the repo name is used as a default. Defaults to None.
|
|
279
|
+
tags (`List[str]`, *optional*):
|
|
280
|
+
List of tags to add to your model that can be used when filtering on the Hugging
|
|
281
|
+
Face Hub. Defaults to None.
|
|
276
282
|
ignore_metadata_errors (`str`):
|
|
277
283
|
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
|
|
278
284
|
the process. Use it at your own risk.
|
|
@@ -297,27 +303,33 @@ class ModelCardData(CardData):
|
|
|
297
303
|
def __init__(
|
|
298
304
|
self,
|
|
299
305
|
*,
|
|
300
|
-
language: Optional[Union[str, List[str]]] = None,
|
|
301
|
-
license: Optional[str] = None,
|
|
302
|
-
library_name: Optional[str] = None,
|
|
303
|
-
tags: Optional[List[str]] = None,
|
|
304
306
|
base_model: Optional[Union[str, List[str]]] = None,
|
|
305
307
|
datasets: Optional[List[str]] = None,
|
|
306
|
-
metrics: Optional[List[str]] = None,
|
|
307
308
|
eval_results: Optional[List[EvalResult]] = None,
|
|
309
|
+
language: Optional[Union[str, List[str]]] = None,
|
|
310
|
+
library_name: Optional[str] = None,
|
|
311
|
+
license: Optional[str] = None,
|
|
312
|
+
license_name: Optional[str] = None,
|
|
313
|
+
license_link: Optional[str] = None,
|
|
314
|
+
metrics: Optional[List[str]] = None,
|
|
308
315
|
model_name: Optional[str] = None,
|
|
316
|
+
pipeline_tag: Optional[str] = None,
|
|
317
|
+
tags: Optional[List[str]] = None,
|
|
309
318
|
ignore_metadata_errors: bool = False,
|
|
310
319
|
**kwargs,
|
|
311
320
|
):
|
|
312
|
-
self.language = language
|
|
313
|
-
self.license = license
|
|
314
|
-
self.library_name = library_name
|
|
315
|
-
self.tags = _to_unique_list(tags)
|
|
316
321
|
self.base_model = base_model
|
|
317
322
|
self.datasets = datasets
|
|
318
|
-
self.metrics = metrics
|
|
319
323
|
self.eval_results = eval_results
|
|
324
|
+
self.language = language
|
|
325
|
+
self.library_name = library_name
|
|
326
|
+
self.license = license
|
|
327
|
+
self.license_name = license_name
|
|
328
|
+
self.license_link = license_link
|
|
329
|
+
self.metrics = metrics
|
|
320
330
|
self.model_name = model_name
|
|
331
|
+
self.pipeline_tag = pipeline_tag
|
|
332
|
+
self.tags = _to_unique_list(tags)
|
|
321
333
|
|
|
322
334
|
model_index = kwargs.pop("model-index", None)
|
|
323
335
|
if model_index:
|
|
@@ -338,7 +350,7 @@ class ModelCardData(CardData):
|
|
|
338
350
|
super().__init__(**kwargs)
|
|
339
351
|
|
|
340
352
|
if self.eval_results:
|
|
341
|
-
if
|
|
353
|
+
if isinstance(self.eval_results, EvalResult):
|
|
342
354
|
self.eval_results = [self.eval_results]
|
|
343
355
|
if self.model_name is None:
|
|
344
356
|
raise ValueError("Passing `eval_results` requires `model_name` to be set.")
|
huggingface_hub/repository.py
CHANGED
|
@@ -507,8 +507,8 @@ class Repository:
|
|
|
507
507
|
instance will be created if this is left to `None`.
|
|
508
508
|
|
|
509
509
|
Raises:
|
|
510
|
-
|
|
511
|
-
|
|
510
|
+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
|
511
|
+
If the remote repository set in `clone_from` does not exist.
|
|
512
512
|
"""
|
|
513
513
|
if isinstance(local_dir, Path):
|
|
514
514
|
local_dir = str(local_dir)
|
|
@@ -542,10 +542,10 @@ class Repository:
|
|
|
542
542
|
user = self.client.whoami(self.huggingface_token)
|
|
543
543
|
|
|
544
544
|
if git_email is None:
|
|
545
|
-
git_email = user
|
|
545
|
+
git_email = user.get("email")
|
|
546
546
|
|
|
547
547
|
if git_user is None:
|
|
548
|
-
git_user = user
|
|
548
|
+
git_user = user.get("fullname")
|
|
549
549
|
|
|
550
550
|
if git_user is not None or git_email is not None:
|
|
551
551
|
self.git_config_username_and_email(git_user, git_email)
|
|
@@ -580,8 +580,8 @@ class Repository:
|
|
|
580
580
|
Checks that `git` and `git-lfs` can be run.
|
|
581
581
|
|
|
582
582
|
Raises:
|
|
583
|
-
|
|
584
|
-
|
|
583
|
+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
|
584
|
+
If `git` or `git-lfs` are not installed.
|
|
585
585
|
"""
|
|
586
586
|
try:
|
|
587
587
|
git_version = run_subprocess("git --version", self.local_dir).stdout.strip()
|
|
@@ -15,6 +15,11 @@
|
|
|
15
15
|
"""Contains helpers to serialize tensors."""
|
|
16
16
|
|
|
17
17
|
from ._base import StateDictSplit, split_state_dict_into_shards_factory
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
|
|
18
|
+
from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
|
|
19
|
+
from ._torch import (
|
|
20
|
+
get_torch_storage_id,
|
|
21
|
+
get_torch_storage_size,
|
|
22
|
+
save_torch_model,
|
|
23
|
+
save_torch_state_dict,
|
|
24
|
+
split_torch_state_dict_into_shards,
|
|
25
|
+
)
|
|
@@ -23,8 +23,14 @@ TensorT = TypeVar("TensorT")
|
|
|
23
23
|
TensorSizeFn_T = Callable[[TensorT], int]
|
|
24
24
|
StorageIDFn_T = Callable[[TensorT], Optional[Any]]
|
|
25
25
|
|
|
26
|
-
MAX_SHARD_SIZE =
|
|
27
|
-
|
|
26
|
+
MAX_SHARD_SIZE = "5GB"
|
|
27
|
+
SIZE_UNITS = {
|
|
28
|
+
"TB": 10**12,
|
|
29
|
+
"GB": 10**9,
|
|
30
|
+
"MB": 10**6,
|
|
31
|
+
"KB": 10**3,
|
|
32
|
+
}
|
|
33
|
+
|
|
28
34
|
|
|
29
35
|
logger = logging.get_logger(__file__)
|
|
30
36
|
|
|
@@ -43,9 +49,9 @@ class StateDictSplit:
|
|
|
43
49
|
def split_state_dict_into_shards_factory(
|
|
44
50
|
state_dict: Dict[str, TensorT],
|
|
45
51
|
*,
|
|
46
|
-
|
|
52
|
+
get_storage_size: TensorSizeFn_T,
|
|
53
|
+
filename_pattern: str,
|
|
47
54
|
get_storage_id: StorageIDFn_T = lambda tensor: None,
|
|
48
|
-
filename_pattern: str = FILENAME_PATTERN,
|
|
49
55
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
50
56
|
) -> StateDictSplit:
|
|
51
57
|
"""
|
|
@@ -66,8 +72,8 @@ def split_state_dict_into_shards_factory(
|
|
|
66
72
|
Args:
|
|
67
73
|
state_dict (`Dict[str, Tensor]`):
|
|
68
74
|
The state dictionary to save.
|
|
69
|
-
|
|
70
|
-
A function that returns the size of a tensor in bytes.
|
|
75
|
+
get_storage_size (`Callable[[Tensor], int]`):
|
|
76
|
+
A function that returns the size of a tensor when saved on disk in bytes.
|
|
71
77
|
get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
|
|
72
78
|
A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
|
|
73
79
|
same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
|
|
@@ -75,7 +81,6 @@ def split_state_dict_into_shards_factory(
|
|
|
75
81
|
filename_pattern (`str`, *optional*):
|
|
76
82
|
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
|
|
77
83
|
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
|
78
|
-
Defaults to `"model{suffix}.safetensors"`.
|
|
79
84
|
max_shard_size (`int` or `str`, *optional*):
|
|
80
85
|
The maximum size of each shard, in bytes. Defaults to 5GB.
|
|
81
86
|
|
|
@@ -112,7 +117,7 @@ def split_state_dict_into_shards_factory(
|
|
|
112
117
|
storage_id_to_tensors[storage_id] = [key]
|
|
113
118
|
|
|
114
119
|
# Compute tensor size
|
|
115
|
-
tensor_size =
|
|
120
|
+
tensor_size = get_storage_size(tensor)
|
|
116
121
|
|
|
117
122
|
# If this tensor is bigger than the maximal size, we put it in its own shard
|
|
118
123
|
if tensor_size > max_shard_size:
|
|
@@ -172,14 +177,6 @@ def split_state_dict_into_shards_factory(
|
|
|
172
177
|
)
|
|
173
178
|
|
|
174
179
|
|
|
175
|
-
SIZE_UNITS = {
|
|
176
|
-
"TB": 10**12,
|
|
177
|
-
"GB": 10**9,
|
|
178
|
-
"MB": 10**6,
|
|
179
|
-
"KB": 10**3,
|
|
180
|
-
}
|
|
181
|
-
|
|
182
|
-
|
|
183
180
|
def parse_size_to_int(size_as_str: str) -> int:
|
|
184
181
|
"""
|
|
185
182
|
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
|
|
@@ -17,6 +17,7 @@ import math
|
|
|
17
17
|
import re
|
|
18
18
|
from typing import TYPE_CHECKING, Dict, Union
|
|
19
19
|
|
|
20
|
+
from .. import constants
|
|
20
21
|
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
|
|
21
22
|
|
|
22
23
|
|
|
@@ -27,7 +28,7 @@ if TYPE_CHECKING:
|
|
|
27
28
|
def split_tf_state_dict_into_shards(
|
|
28
29
|
state_dict: Dict[str, "tf.Tensor"],
|
|
29
30
|
*,
|
|
30
|
-
filename_pattern: str =
|
|
31
|
+
filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN,
|
|
31
32
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
32
33
|
) -> StateDictSplit:
|
|
33
34
|
"""
|
|
@@ -62,11 +63,11 @@ def split_tf_state_dict_into_shards(
|
|
|
62
63
|
state_dict,
|
|
63
64
|
max_shard_size=max_shard_size,
|
|
64
65
|
filename_pattern=filename_pattern,
|
|
65
|
-
|
|
66
|
+
get_storage_size=get_tf_storage_size,
|
|
66
67
|
)
|
|
67
68
|
|
|
68
69
|
|
|
69
|
-
def
|
|
70
|
+
def get_tf_storage_size(tensor: "tf.Tensor") -> int:
|
|
70
71
|
# Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
|
|
71
72
|
# Better to overestimate than underestimate.
|
|
72
73
|
return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))
|