huggingface-hub 0.23.5__py3-none-any.whl → 0.24.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +47 -15
- huggingface_hub/_commit_api.py +38 -8
- huggingface_hub/_inference_endpoints.py +11 -4
- huggingface_hub/_local_folder.py +22 -13
- huggingface_hub/_snapshot_download.py +12 -7
- huggingface_hub/_webhooks_server.py +3 -1
- huggingface_hub/commands/huggingface_cli.py +4 -3
- huggingface_hub/commands/repo_files.py +128 -0
- huggingface_hub/constants.py +12 -0
- huggingface_hub/file_download.py +127 -91
- huggingface_hub/hf_api.py +979 -341
- huggingface_hub/hf_file_system.py +30 -3
- huggingface_hub/inference/_client.py +373 -42
- huggingface_hub/inference/_common.py +0 -2
- huggingface_hub/inference/_generated/_async_client.py +390 -48
- huggingface_hub/inference/_generated/types/__init__.py +4 -1
- huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
- huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
- huggingface_hub/inference/_generated/types/text_generation.py +29 -0
- huggingface_hub/lfs.py +11 -6
- huggingface_hub/repocard_data.py +3 -3
- huggingface_hub/repository.py +6 -6
- huggingface_hub/serialization/__init__.py +8 -3
- huggingface_hub/serialization/_base.py +13 -16
- huggingface_hub/serialization/_tensorflow.py +4 -3
- huggingface_hub/serialization/_torch.py +399 -22
- huggingface_hub/utils/__init__.py +0 -1
- huggingface_hub/utils/_errors.py +1 -1
- huggingface_hub/utils/_fixes.py +14 -3
- huggingface_hub/utils/_paths.py +17 -6
- huggingface_hub/utils/_subprocess.py +0 -1
- huggingface_hub/utils/_telemetry.py +9 -1
- huggingface_hub/utils/endpoint_helpers.py +2 -186
- huggingface_hub/utils/sha.py +36 -1
- huggingface_hub/utils/tqdm.py +0 -1
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/METADATA +12 -9
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/RECORD +41 -41
- huggingface_hub/serialization/_numpy.py +0 -68
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -20,10 +20,13 @@ from .base import BaseInferenceType
|
|
|
20
20
|
from .chat_completion import (
|
|
21
21
|
ChatCompletionInput,
|
|
22
22
|
ChatCompletionInputFunctionDefinition,
|
|
23
|
+
ChatCompletionInputFunctionName,
|
|
24
|
+
ChatCompletionInputGrammarType,
|
|
23
25
|
ChatCompletionInputMessage,
|
|
26
|
+
ChatCompletionInputMessageChunk,
|
|
24
27
|
ChatCompletionInputTool,
|
|
25
|
-
ChatCompletionInputToolCall,
|
|
26
28
|
ChatCompletionInputToolTypeClass,
|
|
29
|
+
ChatCompletionInputURL,
|
|
27
30
|
ChatCompletionOutput,
|
|
28
31
|
ChatCompletionOutputComplete,
|
|
29
32
|
ChatCompletionOutputFunctionDefinition,
|
|
@@ -10,33 +10,55 @@ from .base import BaseInferenceType
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@dataclass
|
|
13
|
-
class
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
13
|
+
class ChatCompletionInputURL(BaseInferenceType):
|
|
14
|
+
url: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
ChatCompletionInputMessageChunkType = Literal["text", "image_url"]
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
@dataclass
|
|
20
|
-
class
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
class ChatCompletionInputMessageChunk(BaseInferenceType):
|
|
22
|
+
type: "ChatCompletionInputMessageChunkType"
|
|
23
|
+
image_url: Optional[ChatCompletionInputURL] = None
|
|
24
|
+
text: Optional[str] = None
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
@dataclass
|
|
27
28
|
class ChatCompletionInputMessage(BaseInferenceType):
|
|
29
|
+
content: Union[List[ChatCompletionInputMessageChunk], str]
|
|
28
30
|
role: str
|
|
29
|
-
content: Optional[str] = None
|
|
30
31
|
name: Optional[str] = None
|
|
31
|
-
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
ChatCompletionInputGrammarTypeType = Literal["json", "regex"]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ChatCompletionInputGrammarType(BaseInferenceType):
|
|
39
|
+
type: "ChatCompletionInputGrammarTypeType"
|
|
40
|
+
value: Any
|
|
41
|
+
"""A string that represents a [JSON Schema](https://json-schema.org/).
|
|
42
|
+
JSON Schema is a declarative language that allows to annotate JSON documents
|
|
43
|
+
with types and descriptions.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ChatCompletionInputFunctionName(BaseInferenceType):
|
|
49
|
+
name: str
|
|
32
50
|
|
|
33
51
|
|
|
34
52
|
@dataclass
|
|
35
53
|
class ChatCompletionInputToolTypeClass(BaseInferenceType):
|
|
36
|
-
|
|
54
|
+
function: Optional[ChatCompletionInputFunctionName] = None
|
|
37
55
|
|
|
38
56
|
|
|
39
|
-
|
|
57
|
+
@dataclass
|
|
58
|
+
class ChatCompletionInputFunctionDefinition(BaseInferenceType):
|
|
59
|
+
arguments: Any
|
|
60
|
+
name: str
|
|
61
|
+
description: Optional[str] = None
|
|
40
62
|
|
|
41
63
|
|
|
42
64
|
@dataclass
|
|
@@ -55,10 +77,6 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
55
77
|
|
|
56
78
|
messages: List[ChatCompletionInputMessage]
|
|
57
79
|
"""A list of messages comprising the conversation so far."""
|
|
58
|
-
model: str
|
|
59
|
-
"""[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
|
|
60
|
-
on which models work with the Chat API.
|
|
61
|
-
"""
|
|
62
80
|
frequency_penalty: Optional[float] = None
|
|
63
81
|
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
|
|
64
82
|
frequency in the text so far,
|
|
@@ -83,6 +101,10 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
83
101
|
"""
|
|
84
102
|
max_tokens: Optional[int] = None
|
|
85
103
|
"""The maximum number of tokens that can be generated in the chat completion."""
|
|
104
|
+
model: Optional[str] = None
|
|
105
|
+
"""[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
|
|
106
|
+
on which models work with the Chat API.
|
|
107
|
+
"""
|
|
86
108
|
n: Optional[int] = None
|
|
87
109
|
"""UNUSED
|
|
88
110
|
How many chat completion choices to generate for each input message. Note that you will
|
|
@@ -94,6 +116,7 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
94
116
|
appear in the text so far,
|
|
95
117
|
increasing the model's likelihood to talk about new topics
|
|
96
118
|
"""
|
|
119
|
+
response_format: Optional[ChatCompletionInputGrammarType] = None
|
|
97
120
|
seed: Optional[int] = None
|
|
98
121
|
stop: Optional[List[str]] = None
|
|
99
122
|
"""Up to 4 sequences where the API will stop generating further tokens."""
|
|
@@ -104,7 +127,7 @@ class ChatCompletionInput(BaseInferenceType):
|
|
|
104
127
|
lower values like 0.2 will make it more focused and deterministic.
|
|
105
128
|
We generally recommend altering this or `top_p` but not both.
|
|
106
129
|
"""
|
|
107
|
-
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass,
|
|
130
|
+
tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None
|
|
108
131
|
tool_prompt: Optional[str] = None
|
|
109
132
|
"""A prompt to be appended before the tools"""
|
|
110
133
|
tools: Optional[List[ChatCompletionInputTool]] = None
|
|
@@ -153,7 +176,7 @@ class ChatCompletionOutputFunctionDefinition(BaseInferenceType):
|
|
|
153
176
|
@dataclass
|
|
154
177
|
class ChatCompletionOutputToolCall(BaseInferenceType):
|
|
155
178
|
function: ChatCompletionOutputFunctionDefinition
|
|
156
|
-
id:
|
|
179
|
+
id: str
|
|
157
180
|
type: str
|
|
158
181
|
|
|
159
182
|
|
|
@@ -161,7 +184,6 @@ class ChatCompletionOutputToolCall(BaseInferenceType):
|
|
|
161
184
|
class ChatCompletionOutputMessage(BaseInferenceType):
|
|
162
185
|
role: str
|
|
163
186
|
content: Optional[str] = None
|
|
164
|
-
name: Optional[str] = None
|
|
165
187
|
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
|
|
166
188
|
|
|
167
189
|
|
|
@@ -192,7 +214,6 @@ class ChatCompletionOutput(BaseInferenceType):
|
|
|
192
214
|
created: int
|
|
193
215
|
id: str
|
|
194
216
|
model: str
|
|
195
|
-
object: str
|
|
196
217
|
system_fingerprint: str
|
|
197
218
|
usage: ChatCompletionOutputUsage
|
|
198
219
|
|
|
@@ -256,5 +277,4 @@ class ChatCompletionStreamOutput(BaseInferenceType):
|
|
|
256
277
|
created: int
|
|
257
278
|
id: str
|
|
258
279
|
model: str
|
|
259
|
-
object: str
|
|
260
280
|
system_fingerprint: str
|
|
@@ -4,16 +4,34 @@
|
|
|
4
4
|
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
|
|
5
5
|
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Literal, Optional
|
|
8
8
|
|
|
9
9
|
from .base import BaseInferenceType
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]
|
|
13
|
+
|
|
14
|
+
|
|
12
15
|
@dataclass
|
|
13
16
|
class FeatureExtractionInput(BaseInferenceType):
|
|
14
|
-
"""
|
|
17
|
+
"""Feature Extraction Input.
|
|
18
|
+
Auto-generated from TEI specs.
|
|
19
|
+
For more details, check out
|
|
20
|
+
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
|
|
21
|
+
"""
|
|
15
22
|
|
|
16
23
|
inputs: str
|
|
17
|
-
"""The text to
|
|
18
|
-
|
|
19
|
-
|
|
24
|
+
"""The text to embed."""
|
|
25
|
+
normalize: Optional[bool] = None
|
|
26
|
+
prompt_name: Optional[str] = None
|
|
27
|
+
"""The name of the prompt that should be used by for encoding. If not set, no prompt
|
|
28
|
+
will be applied.
|
|
29
|
+
Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
|
|
30
|
+
For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",
|
|
31
|
+
...},
|
|
32
|
+
then the sentence "What is the capital of France?" will be encoded as
|
|
33
|
+
"query: What is the capital of France?" because the prompt text will be prepended before
|
|
34
|
+
any text to encode.
|
|
35
|
+
"""
|
|
36
|
+
truncate: Optional[bool] = None
|
|
37
|
+
truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None
|
|
@@ -24,24 +24,53 @@ class TextGenerationInputGrammarType(BaseInferenceType):
|
|
|
24
24
|
|
|
25
25
|
@dataclass
|
|
26
26
|
class TextGenerationInputGenerateParameters(BaseInferenceType):
|
|
27
|
+
adapter_id: Optional[str] = None
|
|
28
|
+
"""Lora adapter id"""
|
|
27
29
|
best_of: Optional[int] = None
|
|
30
|
+
"""Generate best_of sequences and return the one if the highest token logprobs."""
|
|
28
31
|
decoder_input_details: Optional[bool] = None
|
|
32
|
+
"""Whether to return decoder input token logprobs and ids."""
|
|
29
33
|
details: Optional[bool] = None
|
|
34
|
+
"""Whether to return generation details."""
|
|
30
35
|
do_sample: Optional[bool] = None
|
|
36
|
+
"""Activate logits sampling."""
|
|
31
37
|
frequency_penalty: Optional[float] = None
|
|
38
|
+
"""The parameter for frequency penalty. 1.0 means no penalty
|
|
39
|
+
Penalize new tokens based on their existing frequency in the text so far,
|
|
40
|
+
decreasing the model's likelihood to repeat the same line verbatim.
|
|
41
|
+
"""
|
|
32
42
|
grammar: Optional[TextGenerationInputGrammarType] = None
|
|
33
43
|
max_new_tokens: Optional[int] = None
|
|
44
|
+
"""Maximum number of tokens to generate."""
|
|
34
45
|
repetition_penalty: Optional[float] = None
|
|
46
|
+
"""The parameter for repetition penalty. 1.0 means no penalty.
|
|
47
|
+
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
|
48
|
+
"""
|
|
35
49
|
return_full_text: Optional[bool] = None
|
|
50
|
+
"""Whether to prepend the prompt to the generated text"""
|
|
36
51
|
seed: Optional[int] = None
|
|
52
|
+
"""Random sampling seed."""
|
|
37
53
|
stop: Optional[List[str]] = None
|
|
54
|
+
"""Stop generating tokens if a member of `stop` is generated."""
|
|
38
55
|
temperature: Optional[float] = None
|
|
56
|
+
"""The value used to module the logits distribution."""
|
|
39
57
|
top_k: Optional[int] = None
|
|
58
|
+
"""The number of highest probability vocabulary tokens to keep for top-k-filtering."""
|
|
40
59
|
top_n_tokens: Optional[int] = None
|
|
60
|
+
"""The number of highest probability vocabulary tokens to keep for top-n-filtering."""
|
|
41
61
|
top_p: Optional[float] = None
|
|
62
|
+
"""Top-p value for nucleus sampling."""
|
|
42
63
|
truncate: Optional[int] = None
|
|
64
|
+
"""Truncate inputs tokens to the given size."""
|
|
43
65
|
typical_p: Optional[float] = None
|
|
66
|
+
"""Typical Decoding mass
|
|
67
|
+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666)
|
|
68
|
+
for more information.
|
|
69
|
+
"""
|
|
44
70
|
watermark: Optional[bool] = None
|
|
71
|
+
"""Watermarking with [A Watermark for Large Language
|
|
72
|
+
Models](https://arxiv.org/abs/2301.10226).
|
|
73
|
+
"""
|
|
45
74
|
|
|
46
75
|
|
|
47
76
|
@dataclass
|
huggingface_hub/lfs.py
CHANGED
|
@@ -134,9 +134,10 @@ def post_lfs_batch_info(
|
|
|
134
134
|
- Second element is an list of errors, if any
|
|
135
135
|
|
|
136
136
|
Raises:
|
|
137
|
-
`ValueError
|
|
138
|
-
|
|
139
|
-
`HTTPError
|
|
137
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
138
|
+
If an argument is invalid or the server response is malformed.
|
|
139
|
+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
|
140
|
+
If the server returned an error.
|
|
140
141
|
"""
|
|
141
142
|
endpoint = endpoint if endpoint is not None else ENDPOINT
|
|
142
143
|
url_prefix = ""
|
|
@@ -211,8 +212,10 @@ def lfs_upload(
|
|
|
211
212
|
Headers to include in the request, including authentication and user agent headers.
|
|
212
213
|
|
|
213
214
|
Raises:
|
|
214
|
-
|
|
215
|
-
|
|
215
|
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
|
216
|
+
If `lfs_batch_action` is improperly formatted
|
|
217
|
+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
|
218
|
+
If the upload resulted in an error
|
|
216
219
|
"""
|
|
217
220
|
# 0. If LFS file is already present, skip upload
|
|
218
221
|
_validate_batch_actions(lfs_batch_action)
|
|
@@ -307,7 +310,9 @@ def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> Non
|
|
|
307
310
|
|
|
308
311
|
Returns: `requests.Response`
|
|
309
312
|
|
|
310
|
-
Raises:
|
|
313
|
+
Raises:
|
|
314
|
+
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
|
|
315
|
+
If the upload resulted in an error.
|
|
311
316
|
"""
|
|
312
317
|
with operation.as_file(with_tqdm=True) as fileobj:
|
|
313
318
|
# S3 might raise a transient 500 error -> let's retry if that happens
|
huggingface_hub/repocard_data.py
CHANGED
|
@@ -55,7 +55,7 @@ class EvalResult:
|
|
|
55
55
|
source_name (`str`, *optional*):
|
|
56
56
|
The name of the source of the evaluation result. Example: "Open LLM Leaderboard".
|
|
57
57
|
source_url (`str`, *optional*):
|
|
58
|
-
The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/
|
|
58
|
+
The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard".
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
61
|
# Required
|
|
@@ -128,7 +128,7 @@ class EvalResult:
|
|
|
128
128
|
source_name: Optional[str] = None
|
|
129
129
|
|
|
130
130
|
# The URL of the source of the evaluation result.
|
|
131
|
-
# Example: https://huggingface.co/spaces/
|
|
131
|
+
# Example: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard
|
|
132
132
|
source_url: Optional[str] = None
|
|
133
133
|
|
|
134
134
|
@property
|
|
@@ -350,7 +350,7 @@ class ModelCardData(CardData):
|
|
|
350
350
|
super().__init__(**kwargs)
|
|
351
351
|
|
|
352
352
|
if self.eval_results:
|
|
353
|
-
if
|
|
353
|
+
if isinstance(self.eval_results, EvalResult):
|
|
354
354
|
self.eval_results = [self.eval_results]
|
|
355
355
|
if self.model_name is None:
|
|
356
356
|
raise ValueError("Passing `eval_results` requires `model_name` to be set.")
|
huggingface_hub/repository.py
CHANGED
|
@@ -507,8 +507,8 @@ class Repository:
|
|
|
507
507
|
instance will be created if this is left to `None`.
|
|
508
508
|
|
|
509
509
|
Raises:
|
|
510
|
-
|
|
511
|
-
|
|
510
|
+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
|
511
|
+
If the remote repository set in `clone_from` does not exist.
|
|
512
512
|
"""
|
|
513
513
|
if isinstance(local_dir, Path):
|
|
514
514
|
local_dir = str(local_dir)
|
|
@@ -542,10 +542,10 @@ class Repository:
|
|
|
542
542
|
user = self.client.whoami(self.huggingface_token)
|
|
543
543
|
|
|
544
544
|
if git_email is None:
|
|
545
|
-
git_email = user
|
|
545
|
+
git_email = user.get("email")
|
|
546
546
|
|
|
547
547
|
if git_user is None:
|
|
548
|
-
git_user = user
|
|
548
|
+
git_user = user.get("fullname")
|
|
549
549
|
|
|
550
550
|
if git_user is not None or git_email is not None:
|
|
551
551
|
self.git_config_username_and_email(git_user, git_email)
|
|
@@ -580,8 +580,8 @@ class Repository:
|
|
|
580
580
|
Checks that `git` and `git-lfs` can be run.
|
|
581
581
|
|
|
582
582
|
Raises:
|
|
583
|
-
|
|
584
|
-
|
|
583
|
+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
|
584
|
+
If `git` or `git-lfs` are not installed.
|
|
585
585
|
"""
|
|
586
586
|
try:
|
|
587
587
|
git_version = run_subprocess("git --version", self.local_dir).stdout.strip()
|
|
@@ -15,6 +15,11 @@
|
|
|
15
15
|
"""Contains helpers to serialize tensors."""
|
|
16
16
|
|
|
17
17
|
from ._base import StateDictSplit, split_state_dict_into_shards_factory
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
|
|
18
|
+
from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
|
|
19
|
+
from ._torch import (
|
|
20
|
+
get_torch_storage_id,
|
|
21
|
+
get_torch_storage_size,
|
|
22
|
+
save_torch_model,
|
|
23
|
+
save_torch_state_dict,
|
|
24
|
+
split_torch_state_dict_into_shards,
|
|
25
|
+
)
|
|
@@ -23,8 +23,14 @@ TensorT = TypeVar("TensorT")
|
|
|
23
23
|
TensorSizeFn_T = Callable[[TensorT], int]
|
|
24
24
|
StorageIDFn_T = Callable[[TensorT], Optional[Any]]
|
|
25
25
|
|
|
26
|
-
MAX_SHARD_SIZE =
|
|
27
|
-
|
|
26
|
+
MAX_SHARD_SIZE = "5GB"
|
|
27
|
+
SIZE_UNITS = {
|
|
28
|
+
"TB": 10**12,
|
|
29
|
+
"GB": 10**9,
|
|
30
|
+
"MB": 10**6,
|
|
31
|
+
"KB": 10**3,
|
|
32
|
+
}
|
|
33
|
+
|
|
28
34
|
|
|
29
35
|
logger = logging.get_logger(__file__)
|
|
30
36
|
|
|
@@ -43,9 +49,9 @@ class StateDictSplit:
|
|
|
43
49
|
def split_state_dict_into_shards_factory(
|
|
44
50
|
state_dict: Dict[str, TensorT],
|
|
45
51
|
*,
|
|
46
|
-
|
|
52
|
+
get_storage_size: TensorSizeFn_T,
|
|
53
|
+
filename_pattern: str,
|
|
47
54
|
get_storage_id: StorageIDFn_T = lambda tensor: None,
|
|
48
|
-
filename_pattern: str = FILENAME_PATTERN,
|
|
49
55
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
50
56
|
) -> StateDictSplit:
|
|
51
57
|
"""
|
|
@@ -66,8 +72,8 @@ def split_state_dict_into_shards_factory(
|
|
|
66
72
|
Args:
|
|
67
73
|
state_dict (`Dict[str, Tensor]`):
|
|
68
74
|
The state dictionary to save.
|
|
69
|
-
|
|
70
|
-
A function that returns the size of a tensor in bytes.
|
|
75
|
+
get_storage_size (`Callable[[Tensor], int]`):
|
|
76
|
+
A function that returns the size of a tensor when saved on disk in bytes.
|
|
71
77
|
get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
|
|
72
78
|
A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
|
|
73
79
|
same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
|
|
@@ -75,7 +81,6 @@ def split_state_dict_into_shards_factory(
|
|
|
75
81
|
filename_pattern (`str`, *optional*):
|
|
76
82
|
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
|
|
77
83
|
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
|
|
78
|
-
Defaults to `"model{suffix}.safetensors"`.
|
|
79
84
|
max_shard_size (`int` or `str`, *optional*):
|
|
80
85
|
The maximum size of each shard, in bytes. Defaults to 5GB.
|
|
81
86
|
|
|
@@ -112,7 +117,7 @@ def split_state_dict_into_shards_factory(
|
|
|
112
117
|
storage_id_to_tensors[storage_id] = [key]
|
|
113
118
|
|
|
114
119
|
# Compute tensor size
|
|
115
|
-
tensor_size =
|
|
120
|
+
tensor_size = get_storage_size(tensor)
|
|
116
121
|
|
|
117
122
|
# If this tensor is bigger than the maximal size, we put it in its own shard
|
|
118
123
|
if tensor_size > max_shard_size:
|
|
@@ -172,14 +177,6 @@ def split_state_dict_into_shards_factory(
|
|
|
172
177
|
)
|
|
173
178
|
|
|
174
179
|
|
|
175
|
-
SIZE_UNITS = {
|
|
176
|
-
"TB": 10**12,
|
|
177
|
-
"GB": 10**9,
|
|
178
|
-
"MB": 10**6,
|
|
179
|
-
"KB": 10**3,
|
|
180
|
-
}
|
|
181
|
-
|
|
182
|
-
|
|
183
180
|
def parse_size_to_int(size_as_str: str) -> int:
|
|
184
181
|
"""
|
|
185
182
|
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
|
|
@@ -17,6 +17,7 @@ import math
|
|
|
17
17
|
import re
|
|
18
18
|
from typing import TYPE_CHECKING, Dict, Union
|
|
19
19
|
|
|
20
|
+
from .. import constants
|
|
20
21
|
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
|
|
21
22
|
|
|
22
23
|
|
|
@@ -27,7 +28,7 @@ if TYPE_CHECKING:
|
|
|
27
28
|
def split_tf_state_dict_into_shards(
|
|
28
29
|
state_dict: Dict[str, "tf.Tensor"],
|
|
29
30
|
*,
|
|
30
|
-
filename_pattern: str =
|
|
31
|
+
filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN,
|
|
31
32
|
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
|
|
32
33
|
) -> StateDictSplit:
|
|
33
34
|
"""
|
|
@@ -62,11 +63,11 @@ def split_tf_state_dict_into_shards(
|
|
|
62
63
|
state_dict,
|
|
63
64
|
max_shard_size=max_shard_size,
|
|
64
65
|
filename_pattern=filename_pattern,
|
|
65
|
-
|
|
66
|
+
get_storage_size=get_tf_storage_size,
|
|
66
67
|
)
|
|
67
68
|
|
|
68
69
|
|
|
69
|
-
def
|
|
70
|
+
def get_tf_storage_size(tensor: "tf.Tensor") -> int:
|
|
70
71
|
# Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
|
|
71
72
|
# Better to overestimate than underestimate.
|
|
72
73
|
return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))
|