huggingface-hub 0.23.5__py3-none-any.whl → 0.24.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (42) hide show
  1. huggingface_hub/__init__.py +47 -15
  2. huggingface_hub/_commit_api.py +38 -8
  3. huggingface_hub/_inference_endpoints.py +11 -4
  4. huggingface_hub/_local_folder.py +22 -13
  5. huggingface_hub/_snapshot_download.py +12 -7
  6. huggingface_hub/_webhooks_server.py +3 -1
  7. huggingface_hub/commands/huggingface_cli.py +4 -3
  8. huggingface_hub/commands/repo_files.py +128 -0
  9. huggingface_hub/constants.py +12 -0
  10. huggingface_hub/file_download.py +127 -91
  11. huggingface_hub/hf_api.py +979 -341
  12. huggingface_hub/hf_file_system.py +30 -3
  13. huggingface_hub/inference/_client.py +373 -42
  14. huggingface_hub/inference/_common.py +0 -2
  15. huggingface_hub/inference/_generated/_async_client.py +390 -48
  16. huggingface_hub/inference/_generated/types/__init__.py +4 -1
  17. huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
  18. huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
  19. huggingface_hub/inference/_generated/types/text_generation.py +29 -0
  20. huggingface_hub/lfs.py +11 -6
  21. huggingface_hub/repocard_data.py +3 -3
  22. huggingface_hub/repository.py +6 -6
  23. huggingface_hub/serialization/__init__.py +8 -3
  24. huggingface_hub/serialization/_base.py +13 -16
  25. huggingface_hub/serialization/_tensorflow.py +4 -3
  26. huggingface_hub/serialization/_torch.py +399 -22
  27. huggingface_hub/utils/__init__.py +0 -1
  28. huggingface_hub/utils/_errors.py +1 -1
  29. huggingface_hub/utils/_fixes.py +14 -3
  30. huggingface_hub/utils/_paths.py +17 -6
  31. huggingface_hub/utils/_subprocess.py +0 -1
  32. huggingface_hub/utils/_telemetry.py +9 -1
  33. huggingface_hub/utils/endpoint_helpers.py +2 -186
  34. huggingface_hub/utils/sha.py +36 -1
  35. huggingface_hub/utils/tqdm.py +0 -1
  36. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/METADATA +12 -9
  37. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/RECORD +41 -41
  38. huggingface_hub/serialization/_numpy.py +0 -68
  39. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/LICENSE +0 -0
  40. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/WHEEL +0 -0
  41. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/entry_points.txt +0 -0
  42. {huggingface_hub-0.23.5.dist-info → huggingface_hub-0.24.0rc0.dist-info}/top_level.txt +0 -0
@@ -20,10 +20,13 @@ from .base import BaseInferenceType
20
20
  from .chat_completion import (
21
21
  ChatCompletionInput,
22
22
  ChatCompletionInputFunctionDefinition,
23
+ ChatCompletionInputFunctionName,
24
+ ChatCompletionInputGrammarType,
23
25
  ChatCompletionInputMessage,
26
+ ChatCompletionInputMessageChunk,
24
27
  ChatCompletionInputTool,
25
- ChatCompletionInputToolCall,
26
28
  ChatCompletionInputToolTypeClass,
29
+ ChatCompletionInputURL,
27
30
  ChatCompletionOutput,
28
31
  ChatCompletionOutputComplete,
29
32
  ChatCompletionOutputFunctionDefinition,
@@ -10,33 +10,55 @@ from .base import BaseInferenceType
10
10
 
11
11
 
12
12
  @dataclass
13
- class ChatCompletionInputFunctionDefinition(BaseInferenceType):
14
- arguments: Any
15
- name: str
16
- description: Optional[str] = None
13
+ class ChatCompletionInputURL(BaseInferenceType):
14
+ url: str
15
+
16
+
17
+ ChatCompletionInputMessageChunkType = Literal["text", "image_url"]
17
18
 
18
19
 
19
20
  @dataclass
20
- class ChatCompletionInputToolCall(BaseInferenceType):
21
- function: ChatCompletionInputFunctionDefinition
22
- id: int
23
- type: str
21
+ class ChatCompletionInputMessageChunk(BaseInferenceType):
22
+ type: "ChatCompletionInputMessageChunkType"
23
+ image_url: Optional[ChatCompletionInputURL] = None
24
+ text: Optional[str] = None
24
25
 
25
26
 
26
27
  @dataclass
27
28
  class ChatCompletionInputMessage(BaseInferenceType):
29
+ content: Union[List[ChatCompletionInputMessageChunk], str]
28
30
  role: str
29
- content: Optional[str] = None
30
31
  name: Optional[str] = None
31
- tool_calls: Optional[List[ChatCompletionInputToolCall]] = None
32
+
33
+
34
+ ChatCompletionInputGrammarTypeType = Literal["json", "regex"]
35
+
36
+
37
+ @dataclass
38
+ class ChatCompletionInputGrammarType(BaseInferenceType):
39
+ type: "ChatCompletionInputGrammarTypeType"
40
+ value: Any
41
+ """A string that represents a [JSON Schema](https://json-schema.org/).
42
+ JSON Schema is a declarative language that allows to annotate JSON documents
43
+ with types and descriptions.
44
+ """
45
+
46
+
47
+ @dataclass
48
+ class ChatCompletionInputFunctionName(BaseInferenceType):
49
+ name: str
32
50
 
33
51
 
34
52
  @dataclass
35
53
  class ChatCompletionInputToolTypeClass(BaseInferenceType):
36
- function_name: str
54
+ function: Optional[ChatCompletionInputFunctionName] = None
37
55
 
38
56
 
39
- ChatCompletionInputToolTypeEnum = Literal["OneOf"]
57
+ @dataclass
58
+ class ChatCompletionInputFunctionDefinition(BaseInferenceType):
59
+ arguments: Any
60
+ name: str
61
+ description: Optional[str] = None
40
62
 
41
63
 
42
64
  @dataclass
@@ -55,10 +77,6 @@ class ChatCompletionInput(BaseInferenceType):
55
77
 
56
78
  messages: List[ChatCompletionInputMessage]
57
79
  """A list of messages comprising the conversation so far."""
58
- model: str
59
- """[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
60
- on which models work with the Chat API.
61
- """
62
80
  frequency_penalty: Optional[float] = None
63
81
  """Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
64
82
  frequency in the text so far,
@@ -83,6 +101,10 @@ class ChatCompletionInput(BaseInferenceType):
83
101
  """
84
102
  max_tokens: Optional[int] = None
85
103
  """The maximum number of tokens that can be generated in the chat completion."""
104
+ model: Optional[str] = None
105
+ """[UNUSED] ID of the model to use. See the model endpoint compatibility table for details
106
+ on which models work with the Chat API.
107
+ """
86
108
  n: Optional[int] = None
87
109
  """UNUSED
88
110
  How many chat completion choices to generate for each input message. Note that you will
@@ -94,6 +116,7 @@ class ChatCompletionInput(BaseInferenceType):
94
116
  appear in the text so far,
95
117
  increasing the model's likelihood to talk about new topics
96
118
  """
119
+ response_format: Optional[ChatCompletionInputGrammarType] = None
97
120
  seed: Optional[int] = None
98
121
  stop: Optional[List[str]] = None
99
122
  """Up to 4 sequences where the API will stop generating further tokens."""
@@ -104,7 +127,7 @@ class ChatCompletionInput(BaseInferenceType):
104
127
  lower values like 0.2 will make it more focused and deterministic.
105
128
  We generally recommend altering this or `top_p` but not both.
106
129
  """
107
- tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, "ChatCompletionInputToolTypeEnum"]] = None
130
+ tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None
108
131
  tool_prompt: Optional[str] = None
109
132
  """A prompt to be appended before the tools"""
110
133
  tools: Optional[List[ChatCompletionInputTool]] = None
@@ -153,7 +176,7 @@ class ChatCompletionOutputFunctionDefinition(BaseInferenceType):
153
176
  @dataclass
154
177
  class ChatCompletionOutputToolCall(BaseInferenceType):
155
178
  function: ChatCompletionOutputFunctionDefinition
156
- id: int
179
+ id: str
157
180
  type: str
158
181
 
159
182
 
@@ -161,7 +184,6 @@ class ChatCompletionOutputToolCall(BaseInferenceType):
161
184
  class ChatCompletionOutputMessage(BaseInferenceType):
162
185
  role: str
163
186
  content: Optional[str] = None
164
- name: Optional[str] = None
165
187
  tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
166
188
 
167
189
 
@@ -192,7 +214,6 @@ class ChatCompletionOutput(BaseInferenceType):
192
214
  created: int
193
215
  id: str
194
216
  model: str
195
- object: str
196
217
  system_fingerprint: str
197
218
  usage: ChatCompletionOutputUsage
198
219
 
@@ -256,5 +277,4 @@ class ChatCompletionStreamOutput(BaseInferenceType):
256
277
  created: int
257
278
  id: str
258
279
  model: str
259
- object: str
260
280
  system_fingerprint: str
@@ -4,16 +4,34 @@
4
4
  # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
5
  # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
6
  from dataclasses import dataclass
7
- from typing import Any, Dict, Optional
7
+ from typing import Literal, Optional
8
8
 
9
9
  from .base import BaseInferenceType
10
10
 
11
11
 
12
+ FeatureExtractionInputTruncationDirection = Literal["Left", "Right"]
13
+
14
+
12
15
  @dataclass
13
16
  class FeatureExtractionInput(BaseInferenceType):
14
- """Inputs for Text Embedding inference"""
17
+ """Feature Extraction Input.
18
+ Auto-generated from TEI specs.
19
+ For more details, check out
20
+ https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
21
+ """
15
22
 
16
23
  inputs: str
17
- """The text to get the embeddings of"""
18
- parameters: Optional[Dict[str, Any]] = None
19
- """Additional inference parameters"""
24
+ """The text to embed."""
25
+ normalize: Optional[bool] = None
26
+ prompt_name: Optional[str] = None
27
+ """The name of the prompt that should be used by for encoding. If not set, no prompt
28
+ will be applied.
29
+ Must be a key in the `Sentence Transformers` configuration `prompts` dictionary.
30
+ For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",
31
+ ...},
32
+ then the sentence "What is the capital of France?" will be encoded as
33
+ "query: What is the capital of France?" because the prompt text will be prepended before
34
+ any text to encode.
35
+ """
36
+ truncate: Optional[bool] = None
37
+ truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None
@@ -24,24 +24,53 @@ class TextGenerationInputGrammarType(BaseInferenceType):
24
24
 
25
25
  @dataclass
26
26
  class TextGenerationInputGenerateParameters(BaseInferenceType):
27
+ adapter_id: Optional[str] = None
28
+ """Lora adapter id"""
27
29
  best_of: Optional[int] = None
30
+ """Generate best_of sequences and return the one if the highest token logprobs."""
28
31
  decoder_input_details: Optional[bool] = None
32
+ """Whether to return decoder input token logprobs and ids."""
29
33
  details: Optional[bool] = None
34
+ """Whether to return generation details."""
30
35
  do_sample: Optional[bool] = None
36
+ """Activate logits sampling."""
31
37
  frequency_penalty: Optional[float] = None
38
+ """The parameter for frequency penalty. 1.0 means no penalty
39
+ Penalize new tokens based on their existing frequency in the text so far,
40
+ decreasing the model's likelihood to repeat the same line verbatim.
41
+ """
32
42
  grammar: Optional[TextGenerationInputGrammarType] = None
33
43
  max_new_tokens: Optional[int] = None
44
+ """Maximum number of tokens to generate."""
34
45
  repetition_penalty: Optional[float] = None
46
+ """The parameter for repetition penalty. 1.0 means no penalty.
47
+ See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
48
+ """
35
49
  return_full_text: Optional[bool] = None
50
+ """Whether to prepend the prompt to the generated text"""
36
51
  seed: Optional[int] = None
52
+ """Random sampling seed."""
37
53
  stop: Optional[List[str]] = None
54
+ """Stop generating tokens if a member of `stop` is generated."""
38
55
  temperature: Optional[float] = None
56
+ """The value used to module the logits distribution."""
39
57
  top_k: Optional[int] = None
58
+ """The number of highest probability vocabulary tokens to keep for top-k-filtering."""
40
59
  top_n_tokens: Optional[int] = None
60
+ """The number of highest probability vocabulary tokens to keep for top-n-filtering."""
41
61
  top_p: Optional[float] = None
62
+ """Top-p value for nucleus sampling."""
42
63
  truncate: Optional[int] = None
64
+ """Truncate inputs tokens to the given size."""
43
65
  typical_p: Optional[float] = None
66
+ """Typical Decoding mass
67
+ See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666)
68
+ for more information.
69
+ """
44
70
  watermark: Optional[bool] = None
71
+ """Watermarking with [A Watermark for Large Language
72
+ Models](https://arxiv.org/abs/2301.10226).
73
+ """
45
74
 
46
75
 
47
76
  @dataclass
huggingface_hub/lfs.py CHANGED
@@ -134,9 +134,10 @@ def post_lfs_batch_info(
134
134
  - Second element is an list of errors, if any
135
135
 
136
136
  Raises:
137
- `ValueError`: If an argument is invalid or the server response is malformed
138
-
139
- `HTTPError`: If the server returned an error
137
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
138
+ If an argument is invalid or the server response is malformed.
139
+ [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
140
+ If the server returned an error.
140
141
  """
141
142
  endpoint = endpoint if endpoint is not None else ENDPOINT
142
143
  url_prefix = ""
@@ -211,8 +212,10 @@ def lfs_upload(
211
212
  Headers to include in the request, including authentication and user agent headers.
212
213
 
213
214
  Raises:
214
- - `ValueError` if `lfs_batch_action` is improperly formatted
215
- - `HTTPError` if the upload resulted in an error
215
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
216
+ If `lfs_batch_action` is improperly formatted
217
+ [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
218
+ If the upload resulted in an error
216
219
  """
217
220
  # 0. If LFS file is already present, skip upload
218
221
  _validate_batch_actions(lfs_batch_action)
@@ -307,7 +310,9 @@ def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> Non
307
310
 
308
311
  Returns: `requests.Response`
309
312
 
310
- Raises: `requests.HTTPError` if the upload resulted in an error
313
+ Raises:
314
+ [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
315
+ If the upload resulted in an error.
311
316
  """
312
317
  with operation.as_file(with_tqdm=True) as fileobj:
313
318
  # S3 might raise a transient 500 error -> let's retry if that happens
@@ -55,7 +55,7 @@ class EvalResult:
55
55
  source_name (`str`, *optional*):
56
56
  The name of the source of the evaluation result. Example: "Open LLM Leaderboard".
57
57
  source_url (`str`, *optional*):
58
- The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard".
58
+ The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard".
59
59
  """
60
60
 
61
61
  # Required
@@ -128,7 +128,7 @@ class EvalResult:
128
128
  source_name: Optional[str] = None
129
129
 
130
130
  # The URL of the source of the evaluation result.
131
- # Example: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard
131
+ # Example: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard
132
132
  source_url: Optional[str] = None
133
133
 
134
134
  @property
@@ -350,7 +350,7 @@ class ModelCardData(CardData):
350
350
  super().__init__(**kwargs)
351
351
 
352
352
  if self.eval_results:
353
- if type(self.eval_results) == EvalResult:
353
+ if isinstance(self.eval_results, EvalResult):
354
354
  self.eval_results = [self.eval_results]
355
355
  if self.model_name is None:
356
356
  raise ValueError("Passing `eval_results` requires `model_name` to be set.")
@@ -507,8 +507,8 @@ class Repository:
507
507
  instance will be created if this is left to `None`.
508
508
 
509
509
  Raises:
510
- - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
511
- if the remote repository set in `clone_from` does not exist.
510
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
511
+ If the remote repository set in `clone_from` does not exist.
512
512
  """
513
513
  if isinstance(local_dir, Path):
514
514
  local_dir = str(local_dir)
@@ -542,10 +542,10 @@ class Repository:
542
542
  user = self.client.whoami(self.huggingface_token)
543
543
 
544
544
  if git_email is None:
545
- git_email = user["email"]
545
+ git_email = user.get("email")
546
546
 
547
547
  if git_user is None:
548
- git_user = user["fullname"]
548
+ git_user = user.get("fullname")
549
549
 
550
550
  if git_user is not None or git_email is not None:
551
551
  self.git_config_username_and_email(git_user, git_email)
@@ -580,8 +580,8 @@ class Repository:
580
580
  Checks that `git` and `git-lfs` can be run.
581
581
 
582
582
  Raises:
583
- - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
584
- if `git` or `git-lfs` are not installed.
583
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
584
+ If `git` or `git-lfs` are not installed.
585
585
  """
586
586
  try:
587
587
  git_version = run_subprocess("git --version", self.local_dir).stdout.strip()
@@ -15,6 +15,11 @@
15
15
  """Contains helpers to serialize tensors."""
16
16
 
17
17
  from ._base import StateDictSplit, split_state_dict_into_shards_factory
18
- from ._numpy import split_numpy_state_dict_into_shards
19
- from ._tensorflow import split_tf_state_dict_into_shards
20
- from ._torch import split_torch_state_dict_into_shards
18
+ from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards
19
+ from ._torch import (
20
+ get_torch_storage_id,
21
+ get_torch_storage_size,
22
+ save_torch_model,
23
+ save_torch_state_dict,
24
+ split_torch_state_dict_into_shards,
25
+ )
@@ -23,8 +23,14 @@ TensorT = TypeVar("TensorT")
23
23
  TensorSizeFn_T = Callable[[TensorT], int]
24
24
  StorageIDFn_T = Callable[[TensorT], Optional[Any]]
25
25
 
26
- MAX_SHARD_SIZE = 5_000_000_000 # 5GB
27
- FILENAME_PATTERN = "model{suffix}.safetensors"
26
+ MAX_SHARD_SIZE = "5GB"
27
+ SIZE_UNITS = {
28
+ "TB": 10**12,
29
+ "GB": 10**9,
30
+ "MB": 10**6,
31
+ "KB": 10**3,
32
+ }
33
+
28
34
 
29
35
  logger = logging.get_logger(__file__)
30
36
 
@@ -43,9 +49,9 @@ class StateDictSplit:
43
49
  def split_state_dict_into_shards_factory(
44
50
  state_dict: Dict[str, TensorT],
45
51
  *,
46
- get_tensor_size: TensorSizeFn_T,
52
+ get_storage_size: TensorSizeFn_T,
53
+ filename_pattern: str,
47
54
  get_storage_id: StorageIDFn_T = lambda tensor: None,
48
- filename_pattern: str = FILENAME_PATTERN,
49
55
  max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
50
56
  ) -> StateDictSplit:
51
57
  """
@@ -66,8 +72,8 @@ def split_state_dict_into_shards_factory(
66
72
  Args:
67
73
  state_dict (`Dict[str, Tensor]`):
68
74
  The state dictionary to save.
69
- get_tensor_size (`Callable[[Tensor], int]`):
70
- A function that returns the size of a tensor in bytes.
75
+ get_storage_size (`Callable[[Tensor], int]`):
76
+ A function that returns the size of a tensor when saved on disk in bytes.
71
77
  get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
72
78
  A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
73
79
  same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
@@ -75,7 +81,6 @@ def split_state_dict_into_shards_factory(
75
81
  filename_pattern (`str`, *optional*):
76
82
  The pattern to generate the files names in which the model will be saved. Pattern must be a string that
77
83
  can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
78
- Defaults to `"model{suffix}.safetensors"`.
79
84
  max_shard_size (`int` or `str`, *optional*):
80
85
  The maximum size of each shard, in bytes. Defaults to 5GB.
81
86
 
@@ -112,7 +117,7 @@ def split_state_dict_into_shards_factory(
112
117
  storage_id_to_tensors[storage_id] = [key]
113
118
 
114
119
  # Compute tensor size
115
- tensor_size = get_tensor_size(tensor)
120
+ tensor_size = get_storage_size(tensor)
116
121
 
117
122
  # If this tensor is bigger than the maximal size, we put it in its own shard
118
123
  if tensor_size > max_shard_size:
@@ -172,14 +177,6 @@ def split_state_dict_into_shards_factory(
172
177
  )
173
178
 
174
179
 
175
- SIZE_UNITS = {
176
- "TB": 10**12,
177
- "GB": 10**9,
178
- "MB": 10**6,
179
- "KB": 10**3,
180
- }
181
-
182
-
183
180
  def parse_size_to_int(size_as_str: str) -> int:
184
181
  """
185
182
  Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
@@ -17,6 +17,7 @@ import math
17
17
  import re
18
18
  from typing import TYPE_CHECKING, Dict, Union
19
19
 
20
+ from .. import constants
20
21
  from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
21
22
 
22
23
 
@@ -27,7 +28,7 @@ if TYPE_CHECKING:
27
28
  def split_tf_state_dict_into_shards(
28
29
  state_dict: Dict[str, "tf.Tensor"],
29
30
  *,
30
- filename_pattern: str = "tf_model{suffix}.h5",
31
+ filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN,
31
32
  max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
32
33
  ) -> StateDictSplit:
33
34
  """
@@ -62,11 +63,11 @@ def split_tf_state_dict_into_shards(
62
63
  state_dict,
63
64
  max_shard_size=max_shard_size,
64
65
  filename_pattern=filename_pattern,
65
- get_tensor_size=get_tensor_size,
66
+ get_storage_size=get_tf_storage_size,
66
67
  )
67
68
 
68
69
 
69
- def get_tensor_size(tensor: "tf.Tensor") -> int:
70
+ def get_tf_storage_size(tensor: "tf.Tensor") -> int:
70
71
  # Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool).
71
72
  # Better to overestimate than underestimate.
72
73
  return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype))