huggingface-hub 0.23.3__py3-none-any.whl → 0.24.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (44) hide show
  1. huggingface_hub/__init__.py +47 -15
  2. huggingface_hub/_commit_api.py +38 -8
  3. huggingface_hub/_inference_endpoints.py +11 -4
  4. huggingface_hub/_local_folder.py +22 -13
  5. huggingface_hub/_snapshot_download.py +12 -7
  6. huggingface_hub/_webhooks_server.py +3 -1
  7. huggingface_hub/commands/huggingface_cli.py +4 -3
  8. huggingface_hub/commands/repo_files.py +128 -0
  9. huggingface_hub/constants.py +12 -0
  10. huggingface_hub/file_download.py +127 -91
  11. huggingface_hub/hf_api.py +979 -341
  12. huggingface_hub/hf_file_system.py +30 -3
  13. huggingface_hub/hub_mixin.py +103 -41
  14. huggingface_hub/inference/_client.py +373 -42
  15. huggingface_hub/inference/_common.py +0 -2
  16. huggingface_hub/inference/_generated/_async_client.py +390 -48
  17. huggingface_hub/inference/_generated/types/__init__.py +4 -1
  18. huggingface_hub/inference/_generated/types/chat_completion.py +41 -21
  19. huggingface_hub/inference/_generated/types/feature_extraction.py +23 -5
  20. huggingface_hub/inference/_generated/types/text_generation.py +29 -0
  21. huggingface_hub/lfs.py +11 -6
  22. huggingface_hub/repocard_data.py +41 -29
  23. huggingface_hub/repository.py +6 -6
  24. huggingface_hub/serialization/__init__.py +8 -3
  25. huggingface_hub/serialization/_base.py +13 -16
  26. huggingface_hub/serialization/_tensorflow.py +4 -3
  27. huggingface_hub/serialization/_torch.py +399 -22
  28. huggingface_hub/utils/__init__.py +1 -2
  29. huggingface_hub/utils/_errors.py +1 -1
  30. huggingface_hub/utils/_fixes.py +14 -3
  31. huggingface_hub/utils/_paths.py +17 -6
  32. huggingface_hub/utils/_subprocess.py +0 -1
  33. huggingface_hub/utils/_telemetry.py +9 -1
  34. huggingface_hub/utils/_typing.py +26 -1
  35. huggingface_hub/utils/endpoint_helpers.py +2 -186
  36. huggingface_hub/utils/sha.py +36 -1
  37. huggingface_hub/utils/tqdm.py +0 -1
  38. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/METADATA +12 -9
  39. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/RECORD +43 -43
  40. huggingface_hub/serialization/_numpy.py +0 -68
  41. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/LICENSE +0 -0
  42. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/WHEEL +0 -0
  43. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/entry_points.txt +0 -0
  44. {huggingface_hub-0.23.3.dist-info → huggingface_hub-0.24.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import os
2
3
  import re
3
4
  import tempfile
@@ -74,8 +75,11 @@ class HfFileSystem(fsspec.AbstractFileSystem):
74
75
  Access a remote Hugging Face Hub repository as if were a local file system.
75
76
 
76
77
  Args:
77
- token (`str`, *optional*):
78
- Authentication token, obtained with [`HfApi.login`] method. Will default to the stored token.
78
+ token (`str` or `bool`, *optional*):
79
+ A valid user access token (string). Defaults to the locally saved
80
+ token, which is the recommended method for authentication (see
81
+ https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
82
+ To disable authentication, pass `False`.
79
83
 
80
84
  Usage:
81
85
 
@@ -105,7 +109,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
105
109
  self,
106
110
  *args,
107
111
  endpoint: Optional[str] = None,
108
- token: Optional[str] = None,
112
+ token: Union[bool, str, None] = None,
109
113
  **storage_options,
110
114
  ):
111
115
  super().__init__(*args, **storage_options)
@@ -400,6 +404,12 @@ class HfFileSystem(fsspec.AbstractFileSystem):
400
404
  out.append(cache_path_info)
401
405
  return out
402
406
 
407
+ def walk(self, path, *args, **kwargs):
408
+ # Set expand_info=False by default to get a x10 speed boost
409
+ kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
410
+ path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
411
+ yield from super().walk(path, *args, **kwargs)
412
+
403
413
  def glob(self, path, **kwargs):
404
414
  # Set expand_info=False by default to get a x10 speed boost
405
415
  kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
@@ -880,3 +890,20 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
880
890
 
881
891
  def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
882
892
  return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
893
+
894
+
895
+ # Add docstrings to the methods of HfFileSystem from fsspec.AbstractFileSystem
896
+ for name, function in inspect.getmembers(HfFileSystem, predicate=inspect.isfunction):
897
+ parent = getattr(fsspec.AbstractFileSystem, name, None)
898
+ if parent is not None and parent.__doc__ is not None:
899
+ parent_doc = parent.__doc__
900
+ parent_doc = parent_doc.replace("Parameters\n ----------\n", "Args:\n")
901
+ parent_doc = parent_doc.replace("Returns\n -------\n", "Return:\n")
902
+ function.__doc__ = (
903
+ (
904
+ "\n_Docstring taken from "
905
+ f"[fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.{name})._"
906
+ )
907
+ + "\n\n"
908
+ + parent_doc
909
+ )
@@ -1,9 +1,21 @@
1
1
  import inspect
2
2
  import json
3
3
  import os
4
+ import warnings
4
5
  from dataclasses import asdict, dataclass, is_dataclass
5
6
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Callable,
11
+ Dict,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ Type,
16
+ TypeVar,
17
+ Union,
18
+ )
7
19
 
8
20
  from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
9
21
  from .file_download import hf_hub_download
@@ -15,8 +27,10 @@ from .utils import (
15
27
  SoftTemporaryDirectory,
16
28
  is_jsonable,
17
29
  is_safetensors_available,
30
+ is_simple_optional_type,
18
31
  is_torch_available,
19
32
  logging,
33
+ unwrap_simple_optional_type,
20
34
  validate_hf_hub_args,
21
35
  )
22
36
 
@@ -85,8 +99,8 @@ class ModelHubMixin:
85
99
  URL of the library documentation. Used to generate model card.
86
100
  model_card_template (`str`, *optional*):
87
101
  Template of the model card. Used to generate model card. Defaults to a generic template.
88
- languages (`List[str]`, *optional*):
89
- Languages supported by the library. Used to generate model card.
102
+ language (`str` or `List[str]`, *optional*):
103
+ Language supported by the library. Used to generate model card.
90
104
  library_name (`str`, *optional*):
91
105
  Name of the library integrating ModelHubMixin. Used to generate model card.
92
106
  license (`str`, *optional*):
@@ -191,7 +205,7 @@ class ModelHubMixin:
191
205
  # Model card template
192
206
  model_card_template: str = DEFAULT_MODEL_CARD,
193
207
  # Model card metadata
194
- languages: Optional[List[str]] = None,
208
+ language: Optional[List[str]] = None,
195
209
  library_name: Optional[str] = None,
196
210
  license: Optional[str] = None,
197
211
  license_name: Optional[str] = None,
@@ -205,6 +219,8 @@ class ModelHubMixin:
205
219
  # Value is a tuple (encoder, decoder).
206
220
  # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
207
221
  ] = None,
222
+ # Deprecated arguments
223
+ languages: Optional[List[str]] = None,
208
224
  ) -> None:
209
225
  """Inspect __init__ signature only once when subclassing + handle modelcard."""
210
226
  super().__init_subclass__()
@@ -212,20 +228,57 @@ class ModelHubMixin:
212
228
  # Will be reused when creating modelcard
213
229
  tags = tags or []
214
230
  tags.append("model_hub_mixin")
215
- cls._hub_mixin_info = MixinInfo(
216
- model_card_template=model_card_template,
217
- repo_url=repo_url,
218
- docs_url=docs_url,
219
- model_card_data=ModelCardData(
220
- languages=languages,
221
- library_name=library_name,
222
- license=license,
223
- license_name=license_name,
224
- license_link=license_link,
225
- pipeline_tag=pipeline_tag,
226
- tags=tags,
227
- ),
228
- )
231
+
232
+ # Initialize MixinInfo if not existent
233
+ info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())
234
+
235
+ # If parent class has a MixinInfo, inherit from it as a copy
236
+ if hasattr(cls, "_hub_mixin_info"):
237
+ # Inherit model card template from parent class if not explicitly set
238
+ if model_card_template == DEFAULT_MODEL_CARD:
239
+ info.model_card_template = cls._hub_mixin_info.model_card_template
240
+
241
+ # Inherit from parent model card data
242
+ info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())
243
+
244
+ # Inherit other info
245
+ info.docs_url = cls._hub_mixin_info.docs_url
246
+ info.repo_url = cls._hub_mixin_info.repo_url
247
+ cls._hub_mixin_info = info
248
+
249
+ if languages is not None:
250
+ warnings.warn(
251
+ "The `languages` argument is deprecated. Use `language` instead. This will be removed in `huggingface_hub>=0.27.0`.",
252
+ DeprecationWarning,
253
+ )
254
+ language = languages
255
+
256
+ # Update MixinInfo with metadata
257
+ if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
258
+ info.model_card_template = model_card_template
259
+ if repo_url is not None:
260
+ info.repo_url = repo_url
261
+ if docs_url is not None:
262
+ info.docs_url = docs_url
263
+ if language is not None:
264
+ info.model_card_data.language = language
265
+ if library_name is not None:
266
+ info.model_card_data.library_name = library_name
267
+ if license is not None:
268
+ info.model_card_data.license = license
269
+ if license_name is not None:
270
+ info.model_card_data.license_name = license_name
271
+ if license_link is not None:
272
+ info.model_card_data.license_link = license_link
273
+ if pipeline_tag is not None:
274
+ info.model_card_data.pipeline_tag = pipeline_tag
275
+ if tags is not None:
276
+ if info.model_card_data.tags is not None:
277
+ info.model_card_data.tags.extend(tags)
278
+ else:
279
+ info.model_card_data.tags = tags
280
+
281
+ info.model_card_data.tags = sorted(set(info.model_card_data.tags))
229
282
 
230
283
  # Handle encoders/decoders for args
231
284
  cls._hub_mixin_coders = coders or {}
@@ -283,12 +336,11 @@ class ModelHubMixin:
283
336
  if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
284
337
  },
285
338
  }
286
- init_config.pop("config", {})
339
+ passed_config = init_config.pop("config", {})
287
340
 
288
341
  # Populate `init_config` with provided config
289
- provided_config = passed_values.get("config")
290
- if isinstance(provided_config, dict):
291
- init_config.update(provided_config)
342
+ if isinstance(passed_config, dict):
343
+ init_config.update(passed_config)
292
344
 
293
345
  # Set `config` attribute and return
294
346
  if init_config != {}:
@@ -307,15 +359,26 @@ class ModelHubMixin:
307
359
  """Encode an argument into a JSON serializable format."""
308
360
  for type_, (encoder, _) in cls._hub_mixin_coders.items():
309
361
  if isinstance(arg, type_):
362
+ if arg is None:
363
+ return None
310
364
  return encoder(arg)
311
365
  return arg
312
366
 
313
367
  @classmethod
314
- def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> ARGS_T:
368
+ def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
315
369
  """Decode a JSON serializable value into an argument."""
370
+ if is_simple_optional_type(expected_type):
371
+ if value is None:
372
+ return None
373
+ expected_type = unwrap_simple_optional_type(expected_type)
374
+ # Dataclass => handle it
375
+ if is_dataclass(expected_type):
376
+ return _load_dataclass(expected_type, value) # type: ignore[return-value]
377
+ # Otherwise => check custom decoders
316
378
  for type_, (_, decoder) in cls._hub_mixin_coders.items():
317
- if issubclass(expected_type, type_):
379
+ if inspect.isclass(expected_type) and issubclass(expected_type, type_):
318
380
  return decoder(value)
381
+ # Otherwise => don't decode
319
382
  return value
320
383
 
321
384
  def save_pretrained(
@@ -325,6 +388,7 @@ class ModelHubMixin:
325
388
  config: Optional[Union[dict, "DataclassInstance"]] = None,
326
389
  repo_id: Optional[str] = None,
327
390
  push_to_hub: bool = False,
391
+ model_card_kwargs: Optional[Dict[str, Any]] = None,
328
392
  **push_to_hub_kwargs,
329
393
  ) -> Optional[str]:
330
394
  """
@@ -340,7 +404,9 @@ class ModelHubMixin:
340
404
  repo_id (`str`, *optional*):
341
405
  ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
342
406
  not provided.
343
- kwargs:
407
+ model_card_kwargs (`Dict[str, Any]`, *optional*):
408
+ Additional arguments passed to the model card template to customize the model card.
409
+ push_to_hub_kwargs:
344
410
  Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
345
411
  Returns:
346
412
  `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
@@ -369,8 +435,9 @@ class ModelHubMixin:
369
435
 
370
436
  # save model card
371
437
  model_card_path = save_directory / "README.md"
438
+ model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
372
439
  if not model_card_path.exists(): # do not overwrite if already exists
373
- self.generate_model_card().save(save_directory / "README.md")
440
+ self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")
374
441
 
375
442
  # push to the Hub if required
376
443
  if push_to_hub:
@@ -379,7 +446,7 @@ class ModelHubMixin:
379
446
  kwargs["config"] = config
380
447
  if repo_id is None:
381
448
  repo_id = save_directory.name # Defaults to `save_directory` name
382
- return self.push_to_hub(repo_id=repo_id, **kwargs)
449
+ return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
383
450
  return None
384
451
 
385
452
  def _save_pretrained(self, save_directory: Path) -> None:
@@ -477,19 +544,10 @@ class ModelHubMixin:
477
544
  model_kwargs[param.name] = config[param.name]
478
545
 
479
546
  # Check if `config` argument was passed at init
480
- if "config" in cls._hub_mixin_init_parameters:
481
- # Check if `config` argument is a dataclass
547
+ if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
548
+ # Decode `config` argument if it was passed
482
549
  config_annotation = cls._hub_mixin_init_parameters["config"].annotation
483
- if config_annotation is inspect.Parameter.empty:
484
- pass # no annotation
485
- elif is_dataclass(config_annotation):
486
- config = _load_dataclass(config_annotation, config)
487
- else:
488
- # if Optional/Union annotation => check if a dataclass is in the Union
489
- for _sub_annotation in get_args(config_annotation):
490
- if is_dataclass(_sub_annotation):
491
- config = _load_dataclass(_sub_annotation, config)
492
- break
550
+ config = cls._decode_arg(config_annotation, config)
493
551
 
494
552
  # Forward config to model initialization
495
553
  model_kwargs["config"] = config
@@ -505,7 +563,7 @@ class ModelHubMixin:
505
563
  model_kwargs[key] = value
506
564
 
507
565
  # Finally, also inject if `_from_pretrained` expects it
508
- if cls._hub_mixin_inject_config:
566
+ if cls._hub_mixin_inject_config and "config" not in model_kwargs:
509
567
  model_kwargs["config"] = config
510
568
 
511
569
  instance = cls._from_pretrained(
@@ -588,6 +646,7 @@ class ModelHubMixin:
588
646
  allow_patterns: Optional[Union[List[str], str]] = None,
589
647
  ignore_patterns: Optional[Union[List[str], str]] = None,
590
648
  delete_patterns: Optional[Union[List[str], str]] = None,
649
+ model_card_kwargs: Optional[Dict[str, Any]] = None,
591
650
  ) -> str:
592
651
  """
593
652
  Upload model checkpoint to the Hub.
@@ -618,6 +677,8 @@ class ModelHubMixin:
618
677
  If provided, files matching any of the patterns are not pushed.
619
678
  delete_patterns (`List[str]` or `str`, *optional*):
620
679
  If provided, remote files matching any of the patterns will be deleted from the repo.
680
+ model_card_kwargs (`Dict[str, Any]`, *optional*):
681
+ Additional arguments passed to the model card template to customize the model card.
621
682
 
622
683
  Returns:
623
684
  The url of the commit of your model in the given repository.
@@ -628,7 +689,7 @@ class ModelHubMixin:
628
689
  # Push the files to the repo in a single commit
629
690
  with SoftTemporaryDirectory() as tmp:
630
691
  saved_path = Path(tmp) / repo_id
631
- self.save_pretrained(saved_path, config=config)
692
+ self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
632
693
  return api.upload_folder(
633
694
  repo_id=repo_id,
634
695
  repo_type="model",
@@ -647,6 +708,7 @@ class ModelHubMixin:
647
708
  template_str=self._hub_mixin_info.model_card_template,
648
709
  repo_url=self._hub_mixin_info.repo_url,
649
710
  docs_url=self._hub_mixin_info.docs_url,
711
+ **kwargs,
650
712
  )
651
713
  return card
652
714