huggingface-hub 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (35) hide show
  1. huggingface_hub/__init__.py +19 -1
  2. huggingface_hub/_commit_api.py +49 -20
  3. huggingface_hub/_inference_endpoints.py +10 -0
  4. huggingface_hub/_login.py +2 -2
  5. huggingface_hub/commands/download.py +1 -1
  6. huggingface_hub/file_download.py +57 -21
  7. huggingface_hub/hf_api.py +269 -54
  8. huggingface_hub/hf_file_system.py +131 -8
  9. huggingface_hub/hub_mixin.py +204 -42
  10. huggingface_hub/inference/_client.py +56 -9
  11. huggingface_hub/inference/_common.py +4 -3
  12. huggingface_hub/inference/_generated/_async_client.py +57 -9
  13. huggingface_hub/inference/_text_generation.py +5 -0
  14. huggingface_hub/inference/_types.py +17 -0
  15. huggingface_hub/lfs.py +6 -3
  16. huggingface_hub/repocard.py +5 -3
  17. huggingface_hub/repocard_data.py +11 -3
  18. huggingface_hub/serialization/__init__.py +19 -0
  19. huggingface_hub/serialization/_base.py +168 -0
  20. huggingface_hub/serialization/_numpy.py +67 -0
  21. huggingface_hub/serialization/_tensorflow.py +93 -0
  22. huggingface_hub/serialization/_torch.py +199 -0
  23. huggingface_hub/templates/datasetcard_template.md +1 -1
  24. huggingface_hub/templates/modelcard_template.md +1 -4
  25. huggingface_hub/utils/__init__.py +14 -10
  26. huggingface_hub/utils/_datetime.py +4 -11
  27. huggingface_hub/utils/_errors.py +29 -0
  28. huggingface_hub/utils/_runtime.py +21 -15
  29. huggingface_hub/utils/endpoint_helpers.py +27 -1
  30. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/METADATA +7 -3
  31. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/RECORD +35 -30
  32. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/LICENSE +0 -0
  33. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/WHEEL +0 -0
  34. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/entry_points.txt +0 -0
  35. {huggingface_hub-0.20.3.dist-info → huggingface_hub-0.21.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,7 @@
33
33
  # - Images are parsed as PIL.Image for easier manipulation.
34
34
  # - Provides a "recommended model" for each task => suboptimal but user-wise quicker to get a first script running.
35
35
  # - Only the main parameters are publicly exposed. Power users can always read the docs for more options.
36
+ import base64
36
37
  import logging
37
38
  import time
38
39
  import warnings
@@ -78,6 +79,7 @@ from huggingface_hub.inference._text_generation import (
78
79
  raise_text_generation_error,
79
80
  )
80
81
  from huggingface_hub.inference._types import (
82
+ AudioToAudioOutput,
81
83
  ClassificationOutput,
82
84
  ConversationalOutput,
83
85
  FillMaskOutput,
@@ -299,6 +301,49 @@ class InferenceClient:
299
301
  response = self.post(data=audio, model=model, task="audio-classification")
300
302
  return _bytes_to_list(response)
301
303
 
304
+ def audio_to_audio(
305
+ self,
306
+ audio: ContentT,
307
+ *,
308
+ model: Optional[str] = None,
309
+ ) -> List[AudioToAudioOutput]:
310
+ """
311
+ Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
312
+
313
+ Args:
314
+ audio (Union[str, Path, bytes, BinaryIO]):
315
+ The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an
316
+ audio file.
317
+ model (`str`, *optional*):
318
+ The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub
319
+ or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for
320
+ audio_to_audio will be used.
321
+
322
+ Returns:
323
+ `List[Dict]`: A list of dictionary where each index contains audios label, content-type, and audio content in blob.
324
+
325
+ Raises:
326
+ `InferenceTimeoutError`:
327
+ If the model is unavailable or the request times out.
328
+ `HTTPError`:
329
+ If the request fails with an HTTP error status code other than HTTP 503.
330
+
331
+ Example:
332
+ ```py
333
+ >>> from huggingface_hub import InferenceClient
334
+ >>> client = InferenceClient()
335
+ >>> audio_output = client.audio_to_audio("audio.flac")
336
+ >>> for i, item in enumerate(audio_output):
337
+ >>> with open(f"output_{i}.flac", "wb") as f:
338
+ f.write(item["blob"])
339
+ ```
340
+ """
341
+ response = self.post(data=audio, model=model, task="audio-to-audio")
342
+ audio_output = _bytes_to_list(response)
343
+ for item in audio_output:
344
+ item["blob"] = base64.b64decode(item["blob"])
345
+ return audio_output
346
+
302
347
  def automatic_speech_recognition(
303
348
  self,
304
349
  audio: ContentT,
@@ -1063,16 +1108,17 @@ class InferenceClient:
1063
1108
  )
1064
1109
  return _bytes_to_dict(response) # type: ignore
1065
1110
 
1066
- def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
1111
+ def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
1067
1112
  """
1068
1113
  Classifying a target category (a group) based on a set of attributes.
1069
1114
 
1070
1115
  Args:
1071
1116
  table (`Dict[str, Any]`):
1072
1117
  Set of attributes to classify.
1073
- model (`str`):
1074
- The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1075
- a deployed Inference Endpoint.
1118
+ model (`str`, *optional*):
1119
+ The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1120
+ a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used.
1121
+ Defaults to None.
1076
1122
 
1077
1123
  Returns:
1078
1124
  `List`: a list of labels, one per row in the initial table.
@@ -1107,16 +1153,17 @@ class InferenceClient:
1107
1153
  response = self.post(json={"table": table}, model=model, task="tabular-classification")
1108
1154
  return _bytes_to_list(response)
1109
1155
 
1110
- def tabular_regression(self, table: Dict[str, Any], *, model: str) -> List[float]:
1156
+ def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
1111
1157
  """
1112
1158
  Predicting a numerical target value given a set of attributes/features in a table.
1113
1159
 
1114
1160
  Args:
1115
1161
  table (`Dict[str, Any]`):
1116
1162
  Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
1117
- model (`str`):
1118
- The model to use for the tabular-regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1119
- a deployed Inference Endpoint.
1163
+ model (`str`, *optional*):
1164
+ The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1165
+ a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used.
1166
+ Defaults to None.
1120
1167
 
1121
1168
  Returns:
1122
1169
  `List`: a list of predicted numerical target values.
@@ -1483,7 +1530,7 @@ class InferenceClient:
1483
1530
  # Remove some parameters if not a TGI server
1484
1531
  if not _is_tgi_server(model):
1485
1532
  ignored_parameters = []
1486
- for key in "watermark", "stop", "details", "decoder_input_details":
1533
+ for key in "watermark", "stop", "details", "decoder_input_details", "best_of":
1487
1534
  if payload["parameters"][key] is not None:
1488
1535
  ignored_parameters.append(key)
1489
1536
  del payload["parameters"][key]
@@ -84,8 +84,9 @@ class ModelStatus:
84
84
  backend. Loadable models are automatically loaded when the user first
85
85
  requests inference on the endpoint. This means it is transparent for the
86
86
  user to load a model, except that the first call takes longer to complete.
87
- compute_type (`str`):
88
- The type of compute resource the model is using or will use, such as 'gpu' or 'cpu'.
87
+ compute_type (`Dict`):
88
+ Information about the compute resource the model is using or will use, such as 'gpu' type and number of
89
+ replicas.
89
90
  framework (`str`):
90
91
  The name of the framework that the model was built with, such as 'transformers'
91
92
  or 'text-generation-inference'.
@@ -93,7 +94,7 @@ class ModelStatus:
93
94
 
94
95
  loaded: bool
95
96
  state: str
96
- compute_type: str
97
+ compute_type: Dict
97
98
  framework: str
98
99
 
99
100
 
@@ -19,6 +19,7 @@
19
19
  # To re-generate the code, run `make style` or `python ./utils/generate_async_inference_client.py --update`.
20
20
  # WARNING
21
21
  import asyncio
22
+ import base64
22
23
  import logging
23
24
  import time
24
25
  import warnings
@@ -63,6 +64,7 @@ from huggingface_hub.inference._text_generation import (
63
64
  raise_text_generation_error,
64
65
  )
65
66
  from huggingface_hub.inference._types import (
67
+ AudioToAudioOutput,
66
68
  ClassificationOutput,
67
69
  ConversationalOutput,
68
70
  FillMaskOutput,
@@ -295,6 +297,50 @@ class AsyncInferenceClient:
295
297
  response = await self.post(data=audio, model=model, task="audio-classification")
296
298
  return _bytes_to_list(response)
297
299
 
300
+ async def audio_to_audio(
301
+ self,
302
+ audio: ContentT,
303
+ *,
304
+ model: Optional[str] = None,
305
+ ) -> List[AudioToAudioOutput]:
306
+ """
307
+ Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation).
308
+
309
+ Args:
310
+ audio (Union[str, Path, bytes, BinaryIO]):
311
+ The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an
312
+ audio file.
313
+ model (`str`, *optional*):
314
+ The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub
315
+ or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for
316
+ audio_to_audio will be used.
317
+
318
+ Returns:
319
+ `List[Dict]`: A list of dictionary where each index contains audios label, content-type, and audio content in blob.
320
+
321
+ Raises:
322
+ `InferenceTimeoutError`:
323
+ If the model is unavailable or the request times out.
324
+ `aiohttp.ClientResponseError`:
325
+ If the request fails with an HTTP error status code other than HTTP 503.
326
+
327
+ Example:
328
+ ```py
329
+ # Must be run in an async context
330
+ >>> from huggingface_hub import AsyncInferenceClient
331
+ >>> client = AsyncInferenceClient()
332
+ >>> audio_output = await client.audio_to_audio("audio.flac")
333
+ >>> async for i, item in enumerate(audio_output):
334
+ >>> with open(f"output_{i}.flac", "wb") as f:
335
+ f.write(item["blob"])
336
+ ```
337
+ """
338
+ response = await self.post(data=audio, model=model, task="audio-to-audio")
339
+ audio_output = _bytes_to_list(response)
340
+ for item in audio_output:
341
+ item["blob"] = base64.b64decode(item["blob"])
342
+ return audio_output
343
+
298
344
  async def automatic_speech_recognition(
299
345
  self,
300
346
  audio: ContentT,
@@ -1080,16 +1126,17 @@ class AsyncInferenceClient:
1080
1126
  )
1081
1127
  return _bytes_to_dict(response) # type: ignore
1082
1128
 
1083
- async def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
1129
+ async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]:
1084
1130
  """
1085
1131
  Classifying a target category (a group) based on a set of attributes.
1086
1132
 
1087
1133
  Args:
1088
1134
  table (`Dict[str, Any]`):
1089
1135
  Set of attributes to classify.
1090
- model (`str`):
1091
- The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1092
- a deployed Inference Endpoint.
1136
+ model (`str`, *optional*):
1137
+ The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1138
+ a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used.
1139
+ Defaults to None.
1093
1140
 
1094
1141
  Returns:
1095
1142
  `List`: a list of labels, one per row in the initial table.
@@ -1125,16 +1172,17 @@ class AsyncInferenceClient:
1125
1172
  response = await self.post(json={"table": table}, model=model, task="tabular-classification")
1126
1173
  return _bytes_to_list(response)
1127
1174
 
1128
- async def tabular_regression(self, table: Dict[str, Any], *, model: str) -> List[float]:
1175
+ async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]:
1129
1176
  """
1130
1177
  Predicting a numerical target value given a set of attributes/features in a table.
1131
1178
 
1132
1179
  Args:
1133
1180
  table (`Dict[str, Any]`):
1134
1181
  Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical.
1135
- model (`str`):
1136
- The model to use for the tabular-regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1137
- a deployed Inference Endpoint.
1182
+ model (`str`, *optional*):
1183
+ The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to
1184
+ a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used.
1185
+ Defaults to None.
1138
1186
 
1139
1187
  Returns:
1140
1188
  `List`: a list of predicted numerical target values.
@@ -1504,7 +1552,7 @@ class AsyncInferenceClient:
1504
1552
  # Remove some parameters if not a TGI server
1505
1553
  if not _is_tgi_server(model):
1506
1554
  ignored_parameters = []
1507
- for key in "watermark", "stop", "details", "decoder_input_details":
1555
+ for key in "watermark", "stop", "details", "decoder_input_details", "best_of":
1508
1556
  if payload["parameters"][key] is not None:
1509
1557
  ignored_parameters.append(key)
1510
1558
  del payload["parameters"][key]
@@ -451,6 +451,8 @@ class TextGenerationStreamResponse:
451
451
  Args:
452
452
  token (`Token`):
453
453
  The generated token.
454
+ index (`Optional[int]`, *optional*):
455
+ The token index within the stream. Optional to support older clients that omit it.
454
456
  generated_text (`Optional[str]`, *optional*):
455
457
  The complete generated text. Only available when the generation is finished.
456
458
  details (`Optional[StreamDetails]`, *optional*):
@@ -459,6 +461,9 @@ class TextGenerationStreamResponse:
459
461
 
460
462
  # Generated token
461
463
  token: Token
464
+ # The token index within the stream
465
+ # Optional to support older clients that omit it.
466
+ index: Optional[int] = None
462
467
  # Complete generated text
463
468
  # Only available when the generation is finished
464
469
  generated_text: Optional[str] = None
@@ -19,6 +19,23 @@ if TYPE_CHECKING:
19
19
  from PIL import Image
20
20
 
21
21
 
22
+ class AudioToAudioOutput(TypedDict):
23
+ """Dictionary containing the output of a [`~InferenceClient.audio_to_audio`] task.
24
+
25
+ Args:
26
+ label (`str`):
27
+ The label of the audio file.
28
+ content-type (`str`):
29
+ The content type of audio file.
30
+ blob (`bytes`):
31
+ The audio file in byte format.
32
+ """
33
+
34
+ label: str
35
+ content_type: str
36
+ blob: bytes
37
+
38
+
22
39
  class ClassificationOutput(TypedDict):
23
40
  """Dictionary containing the output of a [`~InferenceClient.audio_classification`] and [`~InferenceClient.image_classification`] task.
24
41
 
huggingface_hub/lfs.py CHANGED
@@ -295,7 +295,7 @@ def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> Non
295
295
  """
296
296
  with operation.as_file(with_tqdm=True) as fileobj:
297
297
  # S3 might raise a transient 500 error -> let's retry if that happens
298
- response = http_backoff("PUT", upload_url, data=fileobj, retry_on_status_codes=(500, 503))
298
+ response = http_backoff("PUT", upload_url, data=fileobj, retry_on_status_codes=(500, 502, 503, 504))
299
299
  hf_raise_for_status(response)
300
300
 
301
301
 
@@ -380,7 +380,7 @@ def _upload_parts_iteratively(
380
380
  ) as fileobj_slice:
381
381
  # S3 might raise a transient 500 error -> let's retry if that happens
382
382
  part_upload_res = http_backoff(
383
- "PUT", part_upload_url, data=fileobj_slice, retry_on_status_codes=(500, 503)
383
+ "PUT", part_upload_url, data=fileobj_slice, retry_on_status_codes=(500, 502, 503, 504)
384
384
  )
385
385
  hf_raise_for_status(part_upload_res)
386
386
  headers.append(part_upload_res.headers)
@@ -409,7 +409,10 @@ def _upload_parts_hf_transfer(
409
409
  desc = operation.path_in_repo
410
410
  if len(desc) > 40:
411
411
  desc = f"(…){desc[-40:]}"
412
- disable = bool(logger.getEffectiveLevel() == logging.NOTSET)
412
+
413
+ # set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached
414
+ # see https://github.com/huggingface/huggingface_hub/pull/2000
415
+ disable = True if (logger.getEffectiveLevel() == logging.NOTSET) else None
413
416
 
414
417
  with tqdm(unit="B", unit_scale=True, total=total, initial=0, desc=desc, disable=disable) as progress:
415
418
  try:
@@ -1,6 +1,5 @@
1
1
  import os
2
2
  import re
3
- import warnings
4
3
  from pathlib import Path
5
4
  from typing import Any, Dict, Literal, Optional, Type, Union
6
5
 
@@ -21,7 +20,10 @@ from huggingface_hub.repocard_data import (
21
20
  from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump
22
21
 
23
22
  from .constants import REPOCARD_NAME
24
- from .utils import EntryNotFoundError, SoftTemporaryDirectory, validate_hf_hub_args
23
+ from .utils import EntryNotFoundError, SoftTemporaryDirectory, logging, validate_hf_hub_args
24
+
25
+
26
+ logger = logging.get_logger(__name__)
25
27
 
26
28
 
27
29
  TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md"
@@ -102,7 +104,7 @@ class RepoCard:
102
104
  raise ValueError("repo card metadata block should be a dict")
103
105
  else:
104
106
  # Model card without metadata... create empty metadata
105
- warnings.warn("Repo card metadata block was not found. Setting CardData to empty.")
107
+ logger.warning("Repo card metadata block was not found. Setting CardData to empty.")
106
108
  data_dict = {}
107
109
  self.text = content
108
110
 
@@ -1,10 +1,12 @@
1
1
  import copy
2
- import warnings
3
2
  from collections import defaultdict
4
3
  from dataclasses import dataclass
5
4
  from typing import Any, Dict, List, Optional, Tuple, Union
6
5
 
7
- from huggingface_hub.utils import yaml_dump
6
+ from huggingface_hub.utils import logging, yaml_dump
7
+
8
+
9
+ logger = logging.get_logger(__name__)
8
10
 
9
11
 
10
12
  @dataclass
@@ -253,6 +255,10 @@ class ModelCardData(CardData):
253
255
  tags (`List[str]`, *optional*):
254
256
  List of tags to add to your model that can be used when filtering on the Hugging
255
257
  Face Hub. Defaults to None.
258
+ base_model (`str` or `List[str]`, *optional*):
259
+ The identifier of the base model from which the model derives. This is applicable for example if your model is a
260
+ fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs
261
+ if your model derives from multiple models). Defaults to None.
256
262
  datasets (`List[str]`, *optional*):
257
263
  List of datasets that were used to train this model. Should be a dataset ID
258
264
  found on https://hf.co/datasets. Defaults to None.
@@ -295,6 +301,7 @@ class ModelCardData(CardData):
295
301
  license: Optional[str] = None,
296
302
  library_name: Optional[str] = None,
297
303
  tags: Optional[List[str]] = None,
304
+ base_model: Optional[Union[str, List[str]]] = None,
298
305
  datasets: Optional[List[str]] = None,
299
306
  metrics: Optional[List[str]] = None,
300
307
  eval_results: Optional[List[EvalResult]] = None,
@@ -306,6 +313,7 @@ class ModelCardData(CardData):
306
313
  self.license = license
307
314
  self.library_name = library_name
308
315
  self.tags = tags
316
+ self.base_model = base_model
309
317
  self.datasets = datasets
310
318
  self.metrics = metrics
311
319
  self.eval_results = eval_results
@@ -319,7 +327,7 @@ class ModelCardData(CardData):
319
327
  self.eval_results = eval_results
320
328
  except (KeyError, TypeError) as error:
321
329
  if ignore_metadata_errors:
322
- warnings.warn("Invalid model-index. Not loading eval results into CardData.")
330
+ logger.warning("Invalid model-index. Not loading eval results into CardData.")
323
331
  else:
324
332
  raise ValueError(
325
333
  f"Invalid `model_index` in metadata cannot be parsed: {error.__class__} {error}. Pass"
@@ -0,0 +1,19 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ruff: noqa: F401
15
+ """Contains helpers to serialize tensors."""
16
+ from ._base import StateDictSplit, split_state_dict_into_shards_factory
17
+ from ._numpy import split_numpy_state_dict_into_shards
18
+ from ._tensorflow import split_tf_state_dict_into_shards
19
+ from ._torch import split_torch_state_dict_into_shards
@@ -0,0 +1,168 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains helpers to split tensors into shards."""
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, Callable, Dict, List, Optional, TypeVar
17
+
18
+ from .. import logging
19
+
20
+
21
+ TensorT = TypeVar("TensorT")
22
+ TensorSizeFn_T = Callable[[TensorT], int]
23
+ StorageIDFn_T = Callable[[TensorT], Optional[Any]]
24
+
25
+ MAX_SHARD_SIZE = 5_000_000_000 # 5GB
26
+ FILENAME_PATTERN = "model{suffix}.safetensors"
27
+
28
+ logger = logging.get_logger(__file__)
29
+
30
+
31
+ @dataclass
32
+ class StateDictSplit:
33
+ is_sharded: bool = field(init=False)
34
+ metadata: Dict[str, Any]
35
+ filename_to_tensors: Dict[str, List[str]]
36
+ tensor_to_filename: Dict[str, str]
37
+
38
+ def __post_init__(self):
39
+ self.is_sharded = len(self.filename_to_tensors) > 1
40
+
41
+
42
+ def split_state_dict_into_shards_factory(
43
+ state_dict: Dict[str, TensorT],
44
+ *,
45
+ get_tensor_size: TensorSizeFn_T,
46
+ get_storage_id: StorageIDFn_T = lambda tensor: None,
47
+ filename_pattern: str = FILENAME_PATTERN,
48
+ max_shard_size: int = MAX_SHARD_SIZE,
49
+ ) -> StateDictSplit:
50
+ """
51
+ Split a model state dictionary in shards so that each shard is smaller than a given size.
52
+
53
+ The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
54
+ made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
55
+ have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
56
+ [6+2+2GB], [6+2GB], [6GB].
57
+
58
+ <Tip warning={true}>
59
+
60
+ If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
61
+ size greater than `max_shard_size`.
62
+
63
+ </Tip>
64
+
65
+ Args:
66
+ state_dict (`Dict[str, Tensor]`):
67
+ The state dictionary to save.
68
+ get_tensor_size (`Callable[[Tensor], int]`):
69
+ A function that returns the size of a tensor in bytes.
70
+ get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
71
+ A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
72
+ same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
73
+ during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id.
74
+ filename_pattern (`str`, *optional*):
75
+ The pattern to generate the files names in which the model will be saved. Pattern must be a string that
76
+ can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
77
+ Defaults to `"model{suffix}.safetensors"`.
78
+ max_shard_size (`int` or `str`, *optional*):
79
+ The maximum size of each shard, in bytes. Defaults to 5GB.
80
+
81
+ Returns:
82
+ [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
83
+ """
84
+ storage_id_to_tensors: Dict[Any, List[str]] = {}
85
+
86
+ shard_list: List[Dict[str, TensorT]] = []
87
+ current_shard: Dict[str, TensorT] = {}
88
+ current_shard_size = 0
89
+ total_size = 0
90
+
91
+ for key, tensor in state_dict.items():
92
+ # when bnb serialization is used the weights in the state dict can be strings
93
+ # check: https://github.com/huggingface/transformers/pull/24416 for more details
94
+ if isinstance(tensor, str):
95
+ logger.info("Skipping tensor %s as it is a string (bnb serialization)", key)
96
+ continue
97
+
98
+ # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
99
+ storage_id = get_storage_id(tensor)
100
+ if storage_id is not None:
101
+ if storage_id in storage_id_to_tensors:
102
+ # We skip this tensor for now and will reassign to correct shard later
103
+ storage_id_to_tensors[storage_id].append(key)
104
+ continue
105
+ else:
106
+ # This is the first tensor with this storage_id, we create a new entry
107
+ # in the storage_id_to_tensors dict => we will assign the shard id later
108
+ storage_id_to_tensors[storage_id] = [key]
109
+
110
+ # Compute tensor size
111
+ tensor_size = get_tensor_size(tensor)
112
+
113
+ # If this tensor is bigger than the maximal size, we put it in its own shard
114
+ if tensor_size > max_shard_size:
115
+ total_size += tensor_size
116
+ shard_list.append({key: tensor})
117
+ continue
118
+
119
+ # If this tensor is going to tip up over the maximal size, we split.
120
+ # Current shard already has some tensors, we add it to the list of shards and create a new one.
121
+ if current_shard_size + tensor_size > max_shard_size:
122
+ shard_list.append(current_shard)
123
+ current_shard = {}
124
+ current_shard_size = 0
125
+
126
+ # Add the tensor to the current shard
127
+ current_shard[key] = tensor
128
+ current_shard_size += tensor_size
129
+ total_size += tensor_size
130
+
131
+ # Add the last shard
132
+ if len(current_shard) > 0:
133
+ shard_list.append(current_shard)
134
+ nb_shards = len(shard_list)
135
+
136
+ # Loop over the tensors that share the same storage and assign them together
137
+ for storage_id, keys in storage_id_to_tensors.items():
138
+ # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard
139
+ for shard in shard_list:
140
+ if keys[0] in shard:
141
+ for key in keys:
142
+ shard[key] = state_dict[key]
143
+ break
144
+
145
+ # If we only have one shard, we return it => no need to build the index
146
+ if nb_shards == 1:
147
+ filename = filename_pattern.format(suffix="")
148
+ return StateDictSplit(
149
+ metadata={"total_size": total_size},
150
+ filename_to_tensors={filename: list(state_dict.keys())},
151
+ tensor_to_filename={key: filename for key in state_dict.keys()},
152
+ )
153
+
154
+ # Now that each tensor is assigned to a shard, let's assign a filename to each shard
155
+ tensor_name_to_filename = {}
156
+ filename_to_tensors = {}
157
+ for idx, shard in enumerate(shard_list):
158
+ filename = filename_pattern.format(suffix=f"-{idx+1:05d}-of-{nb_shards:05d}")
159
+ for key in shard:
160
+ tensor_name_to_filename[key] = filename
161
+ filename_to_tensors[filename] = list(shard.keys())
162
+
163
+ # Build the index and return
164
+ return StateDictSplit(
165
+ metadata={"total_size": total_size},
166
+ filename_to_tensors=filename_to_tensors,
167
+ tensor_to_filename=tensor_name_to_filename,
168
+ )
@@ -0,0 +1,67 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Contains numpy-specific helpers."""
15
+ from typing import TYPE_CHECKING, Dict
16
+
17
+ from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ import numpy as np
22
+
23
+
24
+ def split_numpy_state_dict_into_shards(
25
+ state_dict: Dict[str, "np.ndarray"],
26
+ *,
27
+ filename_pattern: str = FILENAME_PATTERN,
28
+ max_shard_size: int = MAX_SHARD_SIZE,
29
+ ) -> StateDictSplit:
30
+ """
31
+ Split a model state dictionary in shards so that each shard is smaller than a given size.
32
+
33
+ The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
34
+ made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
35
+ have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
36
+ [6+2+2GB], [6+2GB], [6GB].
37
+
38
+ <Tip warning={true}>
39
+
40
+ If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
41
+ size greater than `max_shard_size`.
42
+
43
+ </Tip>
44
+
45
+ Args:
46
+ state_dict (`Dict[str, np.ndarray]`):
47
+ The state dictionary to save.
48
+ filename_pattern (`str`, *optional*):
49
+ The pattern to generate the files names in which the model will be saved. Pattern must be a string that
50
+ can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
51
+ Defaults to `"model{suffix}.safetensors"`.
52
+ max_shard_size (`int` or `str`, *optional*):
53
+ The maximum size of each shard, in bytes. Defaults to 5GB.
54
+
55
+ Returns:
56
+ [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
57
+ """
58
+ return split_state_dict_into_shards_factory(
59
+ state_dict,
60
+ max_shard_size=max_shard_size,
61
+ filename_pattern=filename_pattern,
62
+ get_tensor_size=get_tensor_size,
63
+ )
64
+
65
+
66
+ def get_tensor_size(tensor: "np.ndarray") -> int:
67
+ return tensor.nbytes