huggingface-hub 0.21.4__py3-none-any.whl → 0.22.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (96) hide show
  1. huggingface_hub/__init__.py +217 -1
  2. huggingface_hub/_commit_api.py +14 -15
  3. huggingface_hub/_inference_endpoints.py +12 -11
  4. huggingface_hub/_login.py +1 -0
  5. huggingface_hub/_multi_commits.py +1 -0
  6. huggingface_hub/_snapshot_download.py +9 -1
  7. huggingface_hub/_tensorboard_logger.py +1 -0
  8. huggingface_hub/_webhooks_payload.py +1 -0
  9. huggingface_hub/_webhooks_server.py +1 -0
  10. huggingface_hub/commands/_cli_utils.py +1 -0
  11. huggingface_hub/commands/delete_cache.py +1 -0
  12. huggingface_hub/commands/download.py +1 -0
  13. huggingface_hub/commands/env.py +1 -0
  14. huggingface_hub/commands/scan_cache.py +1 -0
  15. huggingface_hub/commands/upload.py +1 -0
  16. huggingface_hub/community.py +1 -0
  17. huggingface_hub/constants.py +3 -1
  18. huggingface_hub/errors.py +38 -0
  19. huggingface_hub/file_download.py +24 -24
  20. huggingface_hub/hf_api.py +47 -35
  21. huggingface_hub/hub_mixin.py +210 -54
  22. huggingface_hub/inference/_client.py +554 -239
  23. huggingface_hub/inference/_common.py +195 -41
  24. huggingface_hub/inference/_generated/_async_client.py +558 -239
  25. huggingface_hub/inference/_generated/types/__init__.py +115 -0
  26. huggingface_hub/inference/_generated/types/audio_classification.py +43 -0
  27. huggingface_hub/inference/_generated/types/audio_to_audio.py +31 -0
  28. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +116 -0
  29. huggingface_hub/inference/_generated/types/base.py +149 -0
  30. huggingface_hub/inference/_generated/types/chat_completion.py +106 -0
  31. huggingface_hub/inference/_generated/types/depth_estimation.py +29 -0
  32. huggingface_hub/inference/_generated/types/document_question_answering.py +85 -0
  33. huggingface_hub/inference/_generated/types/feature_extraction.py +19 -0
  34. huggingface_hub/inference/_generated/types/fill_mask.py +50 -0
  35. huggingface_hub/inference/_generated/types/image_classification.py +43 -0
  36. huggingface_hub/inference/_generated/types/image_segmentation.py +52 -0
  37. huggingface_hub/inference/_generated/types/image_to_image.py +55 -0
  38. huggingface_hub/inference/_generated/types/image_to_text.py +105 -0
  39. huggingface_hub/inference/_generated/types/object_detection.py +55 -0
  40. huggingface_hub/inference/_generated/types/question_answering.py +77 -0
  41. huggingface_hub/inference/_generated/types/sentence_similarity.py +28 -0
  42. huggingface_hub/inference/_generated/types/summarization.py +46 -0
  43. huggingface_hub/inference/_generated/types/table_question_answering.py +45 -0
  44. huggingface_hub/inference/_generated/types/text2text_generation.py +45 -0
  45. huggingface_hub/inference/_generated/types/text_classification.py +43 -0
  46. huggingface_hub/inference/_generated/types/text_generation.py +161 -0
  47. huggingface_hub/inference/_generated/types/text_to_audio.py +105 -0
  48. huggingface_hub/inference/_generated/types/text_to_image.py +57 -0
  49. huggingface_hub/inference/_generated/types/token_classification.py +53 -0
  50. huggingface_hub/inference/_generated/types/translation.py +46 -0
  51. huggingface_hub/inference/_generated/types/video_classification.py +47 -0
  52. huggingface_hub/inference/_generated/types/visual_question_answering.py +53 -0
  53. huggingface_hub/inference/_generated/types/zero_shot_classification.py +56 -0
  54. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +51 -0
  55. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +55 -0
  56. huggingface_hub/inference/_templating.py +105 -0
  57. huggingface_hub/inference/_types.py +4 -152
  58. huggingface_hub/keras_mixin.py +39 -17
  59. huggingface_hub/lfs.py +20 -8
  60. huggingface_hub/repocard.py +11 -3
  61. huggingface_hub/repocard_data.py +12 -2
  62. huggingface_hub/serialization/__init__.py +1 -0
  63. huggingface_hub/serialization/_base.py +1 -0
  64. huggingface_hub/serialization/_numpy.py +1 -0
  65. huggingface_hub/serialization/_tensorflow.py +1 -0
  66. huggingface_hub/serialization/_torch.py +1 -0
  67. huggingface_hub/utils/__init__.py +4 -1
  68. huggingface_hub/utils/_cache_manager.py +7 -0
  69. huggingface_hub/utils/_chunk_utils.py +1 -0
  70. huggingface_hub/utils/_datetime.py +1 -0
  71. huggingface_hub/utils/_errors.py +10 -1
  72. huggingface_hub/utils/_experimental.py +1 -0
  73. huggingface_hub/utils/_fixes.py +19 -3
  74. huggingface_hub/utils/_git_credential.py +1 -0
  75. huggingface_hub/utils/_headers.py +10 -3
  76. huggingface_hub/utils/_hf_folder.py +1 -0
  77. huggingface_hub/utils/_http.py +1 -0
  78. huggingface_hub/utils/_pagination.py +1 -0
  79. huggingface_hub/utils/_paths.py +1 -0
  80. huggingface_hub/utils/_runtime.py +22 -0
  81. huggingface_hub/utils/_subprocess.py +1 -0
  82. huggingface_hub/utils/_token.py +1 -0
  83. huggingface_hub/utils/_typing.py +29 -1
  84. huggingface_hub/utils/_validators.py +1 -0
  85. huggingface_hub/utils/endpoint_helpers.py +1 -0
  86. huggingface_hub/utils/logging.py +1 -1
  87. huggingface_hub/utils/sha.py +1 -0
  88. huggingface_hub/utils/tqdm.py +1 -0
  89. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/METADATA +14 -15
  90. huggingface_hub-0.22.0rc0.dist-info/RECORD +113 -0
  91. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/WHEEL +1 -1
  92. huggingface_hub/inference/_text_generation.py +0 -551
  93. huggingface_hub-0.21.4.dist-info/RECORD +0 -81
  94. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/LICENSE +0 -0
  95. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/entry_points.txt +0 -0
  96. {huggingface_hub-0.21.4.dist-info → huggingface_hub-0.22.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,55 @@
1
+ # Inference code generated from the JSON schema spec in @huggingface/tasks.
2
+ #
3
+ # See:
4
+ # - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
5
+ # - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ from .base import BaseInferenceType
10
+
11
+
12
+ @dataclass
13
+ class ZeroShotObjectDetectionInputData(BaseInferenceType):
14
+ """The input image data, with candidate labels"""
15
+
16
+ candidate_labels: List[str]
17
+ """The candidate labels for this image"""
18
+ image: Any
19
+ """The image data to generate bounding boxes from"""
20
+
21
+
22
+ @dataclass
23
+ class ZeroShotObjectDetectionInput(BaseInferenceType):
24
+ """Inputs for Zero Shot Object Detection inference"""
25
+
26
+ inputs: ZeroShotObjectDetectionInputData
27
+ """The input image data, with candidate labels"""
28
+ parameters: Optional[Dict[str, Any]] = None
29
+ """Additional inference parameters"""
30
+
31
+
32
+ @dataclass
33
+ class ZeroShotObjectDetectionBoundingBox(BaseInferenceType):
34
+ """The predicted bounding box. Coordinates are relative to the top left corner of the input
35
+ image.
36
+ """
37
+
38
+ xmax: int
39
+ xmin: int
40
+ ymax: int
41
+ ymin: int
42
+
43
+
44
+ @dataclass
45
+ class ZeroShotObjectDetectionOutputElement(BaseInferenceType):
46
+ """Outputs of inference for the Zero Shot Object Detection task"""
47
+
48
+ box: ZeroShotObjectDetectionBoundingBox
49
+ """The predicted bounding box. Coordinates are relative to the top left corner of the input
50
+ image.
51
+ """
52
+ label: str
53
+ """A candidate label"""
54
+ score: float
55
+ """The associated score / probability"""
@@ -0,0 +1,105 @@
1
+ from functools import lru_cache
2
+ from typing import Callable, Dict, List, Optional, Union
3
+
4
+ from ..utils import HfHubHTTPError, RepositoryNotFoundError, is_minijinja_available
5
+
6
+
7
+ class TemplateError(Exception):
8
+ """Any error raised while trying to fetch or render a chat template."""
9
+
10
+
11
+ def _import_minijinja():
12
+ if not is_minijinja_available():
13
+ raise ImportError("Cannot render template. Please install minijinja using `pip install minijinja`.")
14
+ import minijinja # noqa: F401
15
+
16
+ return minijinja
17
+
18
+
19
+ def render_chat_prompt(
20
+ *,
21
+ model_id: str,
22
+ messages: List[Dict[str, str]],
23
+ token: Union[str, bool, None] = None,
24
+ add_generation_prompt: bool = True,
25
+ **kwargs,
26
+ ) -> str:
27
+ """Render a chat prompt using a model's chat template.
28
+
29
+ Args:
30
+ model_id (`str`):
31
+ The model id.
32
+ messages (`List[Dict[str, str]]`):
33
+ The list of messages to render.
34
+ token (`str` or `bool`, *optional*):
35
+ Hugging Face token. Will default to the locally saved token if not provided.
36
+
37
+ Returns:
38
+ `str`: The rendered chat prompt.
39
+
40
+ Raises:
41
+ `TemplateError`: If there's any issue while fetching, compiling or rendering the chat template.
42
+ """
43
+ minijinja = _import_minijinja()
44
+ template = _fetch_and_compile_template(model_id=model_id, token=token)
45
+
46
+ try:
47
+ return template(messages=messages, add_generation_prompt=add_generation_prompt, **kwargs)
48
+ except minijinja.TemplateError as e:
49
+ raise TemplateError(f"Error while trying to render chat prompt for model '{model_id}': {e}") from e
50
+
51
+
52
+ @lru_cache # TODO: lru_cache for raised exceptions
53
+ def _fetch_and_compile_template(*, model_id: str, token: Union[str, None]) -> Callable:
54
+ """Fetch and compile a model's chat template.
55
+
56
+ Method is cached to avoid fetching the same model's config multiple times.
57
+
58
+ Args:
59
+ model_id (`str`):
60
+ The model id.
61
+ token (`str` or `bool`, *optional*):
62
+ Hugging Face token. Will default to the locally saved token if not provided.
63
+
64
+ Returns:
65
+ `Callable`: A callable that takes a list of messages and returns the rendered chat prompt.
66
+ """
67
+ from huggingface_hub.hf_api import HfApi
68
+
69
+ minijinja = _import_minijinja()
70
+
71
+ # 1. fetch config from API
72
+ try:
73
+ config = HfApi(token=token).model_info(model_id).config
74
+ except RepositoryNotFoundError as e:
75
+ raise TemplateError(f"Cannot render chat template: model '{model_id}' not found.") from e
76
+ except HfHubHTTPError as e:
77
+ raise TemplateError(f"Error while trying to fetch chat template for model '{model_id}': {e}") from e
78
+
79
+ # 2. check config validity
80
+ if config is None:
81
+ raise TemplateError(f"Config not found for model '{model_id}'.")
82
+ tokenizer_config = config.get("tokenizer_config")
83
+ if tokenizer_config is None:
84
+ raise TemplateError(f"Tokenizer config not found for model '{model_id}'.")
85
+ if tokenizer_config.get("chat_template") is None:
86
+ raise TemplateError(f"Chat template not found in tokenizer_config for model '{model_id}'.")
87
+ chat_template = tokenizer_config["chat_template"]
88
+ if not isinstance(chat_template, str):
89
+ raise TemplateError(f"Chat template must be a string, not '{type(chat_template)}' (model: {model_id}).")
90
+
91
+ special_tokens: Dict[str, Optional[str]] = {}
92
+ for key, value in tokenizer_config.items():
93
+ if "token" in key:
94
+ if isinstance(value, str):
95
+ special_tokens[key] = value
96
+ elif isinstance(value, dict) and value.get("__type") == "AddedToken":
97
+ special_tokens[key] = value.get("content")
98
+
99
+ # 3. compile template and return
100
+ env = minijinja.Environment()
101
+ try:
102
+ env.add_template("chat_template", chat_template)
103
+ except minijinja.TemplateError as e:
104
+ raise TemplateError(f"Error while trying to compile chat template for model '{model_id}': {e}") from e
105
+ return lambda **kwargs: env.render_template("chat_template", **kwargs, **special_tokens)
@@ -12,42 +12,13 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- from typing import TYPE_CHECKING, List, TypedDict
16
15
 
16
+ from typing import List, TypedDict
17
17
 
18
- if TYPE_CHECKING:
19
- from PIL import Image
20
18
 
21
-
22
- class AudioToAudioOutput(TypedDict):
23
- """Dictionary containing the output of a [`~InferenceClient.audio_to_audio`] task.
24
-
25
- Args:
26
- label (`str`):
27
- The label of the audio file.
28
- content-type (`str`):
29
- The content type of audio file.
30
- blob (`bytes`):
31
- The audio file in byte format.
32
- """
33
-
34
- label: str
35
- content_type: str
36
- blob: bytes
37
-
38
-
39
- class ClassificationOutput(TypedDict):
40
- """Dictionary containing the output of a [`~InferenceClient.audio_classification`] and [`~InferenceClient.image_classification`] task.
41
-
42
- Args:
43
- label (`str`):
44
- The label predicted by the model.
45
- score (`float`):
46
- The score of the label predicted by the model.
47
- """
48
-
49
- label: str
50
- score: float
19
+ # Legacy types
20
+ # Types are now generated from the JSON schema spec in @huggingface/tasks.
21
+ # See ./src/huggingface_hub/inference/_generated/types
51
22
 
52
23
 
53
24
  class ConversationalOutputConversation(TypedDict):
@@ -79,122 +50,3 @@ class ConversationalOutput(TypedDict):
79
50
  conversation: ConversationalOutputConversation
80
51
  generated_text: str
81
52
  warnings: List[str]
82
-
83
-
84
- class FillMaskOutput(TypedDict):
85
- """Dictionary containing information about a [`~InferenceClient.fill_mask`] task.
86
-
87
- Args:
88
- score (`float`):
89
- The probability of the token.
90
- token (`int`):
91
- The id of the token.
92
- token_str (`str`):
93
- The string representation of the token.
94
- sequence (`str`):
95
- The actual sequence of tokens that ran against the model (may contain special tokens).
96
- """
97
-
98
- score: float
99
- token: int
100
- token_str: str
101
- sequence: str
102
-
103
-
104
- class ImageSegmentationOutput(TypedDict):
105
- """Dictionary containing information about a [`~InferenceClient.image_segmentation`] task. In practice, image segmentation returns a
106
- list of `ImageSegmentationOutput` with 1 item per mask.
107
-
108
- Args:
109
- label (`str`):
110
- The label corresponding to the mask.
111
- mask (`Image`):
112
- An Image object representing the mask predicted by the model.
113
- score (`float`):
114
- The score associated with the label for this mask.
115
- """
116
-
117
- label: str
118
- mask: "Image"
119
- score: float
120
-
121
-
122
- class ObjectDetectionOutput(TypedDict):
123
- """Dictionary containing information about a [`~InferenceClient.object_detection`] task.
124
-
125
- Args:
126
- label (`str`):
127
- The label corresponding to the detected object.
128
- box (`dict`):
129
- A dict response of bounding box coordinates of
130
- the detected object: xmin, ymin, xmax, ymax
131
- score (`float`):
132
- The score corresponding to the detected object.
133
- """
134
-
135
- label: str
136
- box: dict
137
- score: float
138
-
139
-
140
- class QuestionAnsweringOutput(TypedDict):
141
- """Dictionary containing information about a [`~InferenceClient.question_answering`] task.
142
-
143
- Args:
144
- score (`float`):
145
- A float that represents how likely that the answer is correct.
146
- start (`int`):
147
- The index (string wise) of the start of the answer within context.
148
- end (`int`):
149
- The index (string wise) of the end of the answer within context.
150
- answer (`str`):
151
- A string that is the answer within the text.
152
- """
153
-
154
- score: float
155
- start: int
156
- end: int
157
- answer: str
158
-
159
-
160
- class TableQuestionAnsweringOutput(TypedDict):
161
- """Dictionary containing information about a [`~InferenceClient.table_question_answering`] task.
162
-
163
- Args:
164
- answer (`str`):
165
- The plaintext answer.
166
- coordinates (`List[List[int]]`):
167
- A list of coordinates of the cells referenced in the answer.
168
- cells (`List[int]`):
169
- A list of coordinates of the cells contents.
170
- aggregator (`str`):
171
- The aggregator used to get the answer.
172
- """
173
-
174
- answer: str
175
- coordinates: List[List[int]]
176
- cells: List[List[int]]
177
- aggregator: str
178
-
179
-
180
- class TokenClassificationOutput(TypedDict):
181
- """Dictionary containing the output of a [`~InferenceClient.token_classification`] task.
182
-
183
- Args:
184
- entity_group (`str`):
185
- The type for the entity being recognized (model specific).
186
- score (`float`):
187
- The score of the label predicted by the model.
188
- word (`str`):
189
- The string that was captured.
190
- start (`int`):
191
- The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
192
- end (`int`):
193
- The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
194
- """
195
-
196
- entity_group: str
197
- score: float
198
- word: str
199
- start: int
200
- end: int
@@ -2,6 +2,7 @@ import collections.abc as collections
2
2
  import json
3
3
  import os
4
4
  import warnings
5
+ from functools import wraps
5
6
  from pathlib import Path
6
7
  from shutil import copytree
7
8
  from typing import Any, Dict, List, Optional, Union
@@ -18,12 +19,37 @@ from huggingface_hub.utils import (
18
19
  from .constants import CONFIG_NAME
19
20
  from .hf_api import HfApi
20
21
  from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
22
+ from .utils._typing import CallableT
21
23
 
22
24
 
23
25
  logger = logging.get_logger(__name__)
24
26
 
27
+ keras = None
25
28
  if is_tf_available():
26
- import tensorflow as tf # type: ignore
29
+ # Depending on which version of TensorFlow is installed, we need to import
30
+ # keras from the correct location.
31
+ # See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1.
32
+ # Note: saving a keras model only works with Keras<3.0.
33
+ try:
34
+ import tf_keras as keras # type: ignore
35
+ except ImportError:
36
+ import tensorflow as tf # type: ignore
37
+
38
+ keras = tf.keras
39
+
40
+
41
+ def _requires_keras_2_model(fn: CallableT) -> CallableT:
42
+ # Wrapper to raise if user tries to save a Keras 3.x model
43
+ @wraps(fn)
44
+ def _inner(model, *args, **kwargs):
45
+ if not hasattr(model, "history"): # hacky way to check if model is Keras 2.x
46
+ raise NotImplementedError(
47
+ f"Cannot use '{fn.__name__}': Keras 3.x is not supported."
48
+ " Please save models manually and upload them using `upload_folder` or `huggingface-cli upload`."
49
+ )
50
+ return fn(model, *args, **kwargs)
51
+
52
+ return _inner # type: ignore [return-value]
27
53
 
28
54
 
29
55
  def _flatten_dict(dictionary, parent_key=""):
@@ -57,21 +83,20 @@ def _flatten_dict(dictionary, parent_key=""):
57
83
 
58
84
  def _create_hyperparameter_table(model):
59
85
  """Parse hyperparameter dictionary into a markdown table."""
86
+ table = None
60
87
  if model.optimizer is not None:
61
88
  optimizer_params = model.optimizer.get_config()
62
89
  # flatten the configuration
63
90
  optimizer_params = _flatten_dict(optimizer_params)
64
- optimizer_params["training_precision"] = tf.keras.mixed_precision.global_policy().name
91
+ optimizer_params["training_precision"] = keras.mixed_precision.global_policy().name
65
92
  table = "| Hyperparameters | Value |\n| :-- | :-- |\n"
66
93
  for key, value in optimizer_params.items():
67
94
  table += f"| {key} | {value} |\n"
68
- else:
69
- table = None
70
95
  return table
71
96
 
72
97
 
73
98
  def _plot_network(model, save_directory):
74
- tf.keras.utils.plot_model(
99
+ keras.utils.plot_model(
75
100
  model,
76
101
  to_file=f"{save_directory}/model.png",
77
102
  show_shapes=False,
@@ -128,6 +153,7 @@ def _create_model_card(
128
153
  readme_path.write_text(model_card)
129
154
 
130
155
 
156
+ @_requires_keras_2_model
131
157
  def save_pretrained_keras(
132
158
  model,
133
159
  save_directory: Union[str, Path],
@@ -162,9 +188,7 @@ def save_pretrained_keras(
162
188
  model_save_kwargs will be passed to
163
189
  [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
164
190
  """
165
- if is_tf_available():
166
- import tensorflow as tf
167
- else:
191
+ if keras is None:
168
192
  raise ImportError("Called a Tensorflow-specific function but could not import it.")
169
193
 
170
194
  if not model.built:
@@ -210,7 +234,7 @@ def save_pretrained_keras(
210
234
  json.dump(model.history.history, f, indent=2, sort_keys=True)
211
235
 
212
236
  _create_model_card(model, save_directory, plot_model, metadata)
213
- tf.keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs)
237
+ keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs)
214
238
 
215
239
 
216
240
  def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
@@ -273,6 +297,7 @@ def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
273
297
 
274
298
 
275
299
  @validate_hf_hub_args
300
+ @_requires_keras_2_model
276
301
  def push_to_hub_keras(
277
302
  model,
278
303
  repo_id: str,
@@ -444,6 +469,7 @@ class KerasModelHubMixin(ModelHubMixin):
444
469
  resume_download,
445
470
  local_files_only,
446
471
  token,
472
+ config: Optional[Dict[str, Any]] = None,
447
473
  **model_kwargs,
448
474
  ):
449
475
  """Here we just call [`from_pretrained_keras`] function so both the mixin and
@@ -452,14 +478,9 @@ class KerasModelHubMixin(ModelHubMixin):
452
478
  TODO - Some args above aren't used since we are calling
453
479
  snapshot_download instead of hf_hub_download.
454
480
  """
455
- if is_tf_available():
456
- import tensorflow as tf
457
- else:
481
+ if keras is None:
458
482
  raise ImportError("Called a TensorFlow-specific function but could not import it.")
459
483
 
460
- # TODO - Figure out what to do about these config values. Config is not going to be needed to load model
461
- cfg = model_kwargs.pop("config", None)
462
-
463
484
  # Root is either a local filepath matching model_id or a cached snapshot
464
485
  if not os.path.isdir(model_id):
465
486
  storage_folder = snapshot_download(
@@ -472,9 +493,10 @@ class KerasModelHubMixin(ModelHubMixin):
472
493
  else:
473
494
  storage_folder = model_id
474
495
 
475
- model = tf.keras.models.load_model(storage_folder, **model_kwargs)
496
+ # TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here...
497
+ model = keras.models.load_model(storage_folder)
476
498
 
477
499
  # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
478
- model.config = cfg
500
+ model.config = config
479
501
 
480
502
  return model
huggingface_hub/lfs.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  """Git LFS related type definitions and utilities"""
16
+
16
17
  import inspect
17
18
  import io
18
19
  import os
@@ -27,10 +28,10 @@ from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tupl
27
28
  from urllib.parse import unquote
28
29
 
29
30
  from huggingface_hub.constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER, REPO_TYPES_URL_PREFIXES
30
- from huggingface_hub.utils import get_session
31
31
 
32
32
  from .utils import (
33
33
  build_hf_headers,
34
+ get_session,
34
35
  hf_raise_for_status,
35
36
  http_backoff,
36
37
  logging,
@@ -105,6 +106,7 @@ def post_lfs_batch_info(
105
106
  repo_id: str,
106
107
  revision: Optional[str] = None,
107
108
  endpoint: Optional[str] = None,
109
+ headers: Optional[Dict[str, str]] = None,
108
110
  ) -> Tuple[List[dict], List[dict]]:
109
111
  """
110
112
  Requests the LFS batch endpoint to retrieve upload instructions
@@ -120,10 +122,10 @@ def post_lfs_batch_info(
120
122
  repo_id (`str`):
121
123
  A namespace (user or an organization) and a repo name separated
122
124
  by a `/`.
123
- token (`str`, *optional*):
124
- An authentication token ( See https://huggingface.co/settings/tokens )
125
125
  revision (`str`, *optional*):
126
126
  The git revision to upload to.
127
+ headers (`dict`, *optional*):
128
+ Additional headers to include in the request
127
129
 
128
130
  Returns:
129
131
  `LfsBatchInfo`: 2-tuple:
@@ -154,7 +156,12 @@ def post_lfs_batch_info(
154
156
  }
155
157
  if revision is not None:
156
158
  payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted'
157
- headers = {**LFS_HEADERS, **build_hf_headers(token=token or True)} # Token must be provided or retrieved
159
+
160
+ headers = {
161
+ **LFS_HEADERS,
162
+ **build_hf_headers(token=token),
163
+ **(headers or {}),
164
+ }
158
165
  resp = get_session().post(batch_url, headers=headers, json=payload)
159
166
  hf_raise_for_status(resp)
160
167
  batch_info = resp.json()
@@ -181,7 +188,12 @@ class CompletionPayloadT(TypedDict):
181
188
  parts: List[PayloadPartT]
182
189
 
183
190
 
184
- def lfs_upload(operation: "CommitOperationAdd", lfs_batch_action: Dict, token: Optional[str]) -> None:
191
+ def lfs_upload(
192
+ operation: "CommitOperationAdd",
193
+ lfs_batch_action: Dict,
194
+ token: Optional[str] = None,
195
+ headers: Optional[Dict[str, str]] = None,
196
+ ) -> None:
185
197
  """
186
198
  Handles uploading a given object to the Hub with the LFS protocol.
187
199
 
@@ -193,8 +205,8 @@ def lfs_upload(operation: "CommitOperationAdd", lfs_batch_action: Dict, token: O
193
205
  lfs_batch_action (`dict`):
194
206
  Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for
195
207
  more details.
196
- token (`str`, *optional*):
197
- A [user access token](https://hf.co/settings/tokens) to authenticate requests against the Hub
208
+ headers (`dict`, *optional*):
209
+ Headers to include in the request, including authentication and user agent headers.
198
210
 
199
211
  Raises:
200
212
  - `ValueError` if `lfs_batch_action` is improperly formatted
@@ -234,7 +246,7 @@ def lfs_upload(operation: "CommitOperationAdd", lfs_batch_action: Dict, token: O
234
246
  _validate_lfs_action(verify_action)
235
247
  verify_resp = get_session().post(
236
248
  verify_action["href"],
237
- headers=build_hf_headers(token=token or True),
249
+ headers=build_hf_headers(token=token, headers=headers),
238
250
  json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size},
239
251
  )
240
252
  hf_raise_for_status(verify_resp)
@@ -294,6 +294,7 @@ class RepoCard:
294
294
  cls,
295
295
  card_data: CardData,
296
296
  template_path: Optional[str] = None,
297
+ template_str: Optional[str] = None,
297
298
  **template_kwargs,
298
299
  ):
299
300
  """Initialize a RepoCard from a template. By default, it uses the default template.
@@ -322,7 +323,12 @@ class RepoCard:
322
323
 
323
324
  kwargs = card_data.to_dict().copy()
324
325
  kwargs.update(template_kwargs) # Template_kwargs have priority
325
- template = jinja2.Template(Path(template_path or cls.default_template_path).read_text())
326
+
327
+ if template_path is not None:
328
+ template_str = Path(template_path).read_text()
329
+ if template_str is None:
330
+ template_str = Path(cls.default_template_path).read_text()
331
+ template = jinja2.Template(template_str)
326
332
  content = template.render(card_data=card_data.to_yaml(), **kwargs)
327
333
  return cls(content)
328
334
 
@@ -337,6 +343,7 @@ class ModelCard(RepoCard):
337
343
  cls,
338
344
  card_data: ModelCardData,
339
345
  template_path: Optional[str] = None,
346
+ template_str: Optional[str] = None,
340
347
  **template_kwargs,
341
348
  ):
342
349
  """Initialize a ModelCard from a template. By default, it uses the default template, which can be found here:
@@ -404,7 +411,7 @@ class ModelCard(RepoCard):
404
411
 
405
412
  ```
406
413
  """
407
- return super().from_template(card_data, template_path, **template_kwargs)
414
+ return super().from_template(card_data, template_path, template_str, **template_kwargs)
408
415
 
409
416
 
410
417
  class DatasetCard(RepoCard):
@@ -417,6 +424,7 @@ class DatasetCard(RepoCard):
417
424
  cls,
418
425
  card_data: DatasetCardData,
419
426
  template_path: Optional[str] = None,
427
+ template_str: Optional[str] = None,
420
428
  **template_kwargs,
421
429
  ):
422
430
  """Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here:
@@ -468,7 +476,7 @@ class DatasetCard(RepoCard):
468
476
 
469
477
  ```
470
478
  """
471
- return super().from_template(card_data, template_path, **template_kwargs)
479
+ return super().from_template(card_data, template_path, template_str, **template_kwargs)
472
480
 
473
481
 
474
482
  class SpaceCard(RepoCard):
@@ -312,7 +312,7 @@ class ModelCardData(CardData):
312
312
  self.language = language
313
313
  self.license = license
314
314
  self.library_name = library_name
315
- self.tags = tags
315
+ self.tags = _to_unique_list(tags)
316
316
  self.base_model = base_model
317
317
  self.datasets = datasets
318
318
  self.metrics = metrics
@@ -507,7 +507,7 @@ class SpaceCardData(CardData):
507
507
  self.duplicated_from = duplicated_from
508
508
  self.models = models
509
509
  self.datasets = datasets
510
- self.tags = tags
510
+ self.tags = _to_unique_list(tags)
511
511
  super().__init__(**kwargs)
512
512
 
513
513
 
@@ -717,3 +717,13 @@ def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult])
717
717
  }
718
718
  ]
719
719
  return _remove_none(model_index)
720
+
721
+
722
+ def _to_unique_list(tags: Optional[List[str]]) -> Optional[List[str]]:
723
+ if tags is None:
724
+ return tags
725
+ unique_tags = [] # make tags unique + keep order explicitly
726
+ for tag in tags:
727
+ if tag not in unique_tags:
728
+ unique_tags.append(tag)
729
+ return unique_tags
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ruff: noqa: F401
15
15
  """Contains helpers to serialize tensors."""
16
+
16
17
  from ._base import StateDictSplit, split_state_dict_into_shards_factory
17
18
  from ._numpy import split_numpy_state_dict_into_shards
18
19
  from ._tensorflow import split_tf_state_dict_into_shards
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Contains helpers to split tensors into shards."""
15
+
15
16
  from dataclasses import dataclass, field
16
17
  from typing import Any, Callable, Dict, List, Optional, TypeVar
17
18
 
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Contains numpy-specific helpers."""
15
+
15
16
  from typing import TYPE_CHECKING, Dict
16
17
 
17
18
  from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Contains tensorflow-specific helpers."""
15
+
15
16
  import math
16
17
  import re
17
18
  from typing import TYPE_CHECKING, Dict