huggingface-hub 0.35.1__py3-none-any.whl → 1.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (127) hide show
  1. huggingface_hub/__init__.py +28 -45
  2. huggingface_hub/_commit_api.py +28 -28
  3. huggingface_hub/_commit_scheduler.py +11 -8
  4. huggingface_hub/_inference_endpoints.py +8 -8
  5. huggingface_hub/_jobs_api.py +20 -20
  6. huggingface_hub/_login.py +13 -39
  7. huggingface_hub/_oauth.py +8 -8
  8. huggingface_hub/_snapshot_download.py +14 -28
  9. huggingface_hub/_space_api.py +4 -4
  10. huggingface_hub/_tensorboard_logger.py +5 -5
  11. huggingface_hub/_upload_large_folder.py +15 -15
  12. huggingface_hub/_webhooks_payload.py +3 -3
  13. huggingface_hub/_webhooks_server.py +2 -2
  14. huggingface_hub/cli/__init__.py +0 -14
  15. huggingface_hub/cli/_cli_utils.py +80 -3
  16. huggingface_hub/cli/auth.py +104 -150
  17. huggingface_hub/cli/cache.py +102 -126
  18. huggingface_hub/cli/download.py +93 -110
  19. huggingface_hub/cli/hf.py +37 -41
  20. huggingface_hub/cli/jobs.py +689 -1017
  21. huggingface_hub/cli/lfs.py +120 -143
  22. huggingface_hub/cli/repo.py +158 -216
  23. huggingface_hub/cli/repo_files.py +50 -84
  24. huggingface_hub/cli/system.py +6 -25
  25. huggingface_hub/cli/upload.py +198 -212
  26. huggingface_hub/cli/upload_large_folder.py +90 -105
  27. huggingface_hub/commands/_cli_utils.py +2 -2
  28. huggingface_hub/commands/delete_cache.py +11 -11
  29. huggingface_hub/commands/download.py +4 -13
  30. huggingface_hub/commands/lfs.py +4 -4
  31. huggingface_hub/commands/repo_files.py +2 -2
  32. huggingface_hub/commands/tag.py +1 -3
  33. huggingface_hub/commands/upload.py +4 -4
  34. huggingface_hub/commands/upload_large_folder.py +3 -3
  35. huggingface_hub/commands/user.py +4 -5
  36. huggingface_hub/community.py +5 -5
  37. huggingface_hub/constants.py +3 -41
  38. huggingface_hub/dataclasses.py +16 -22
  39. huggingface_hub/errors.py +43 -30
  40. huggingface_hub/fastai_utils.py +8 -9
  41. huggingface_hub/file_download.py +154 -253
  42. huggingface_hub/hf_api.py +329 -558
  43. huggingface_hub/hf_file_system.py +104 -62
  44. huggingface_hub/hub_mixin.py +32 -54
  45. huggingface_hub/inference/_client.py +178 -163
  46. huggingface_hub/inference/_common.py +38 -54
  47. huggingface_hub/inference/_generated/_async_client.py +219 -259
  48. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
  49. huggingface_hub/inference/_generated/types/base.py +10 -7
  50. huggingface_hub/inference/_generated/types/chat_completion.py +16 -16
  51. huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
  52. huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
  53. huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
  54. huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
  55. huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
  56. huggingface_hub/inference/_generated/types/summarization.py +2 -2
  57. huggingface_hub/inference/_generated/types/table_question_answering.py +4 -4
  58. huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
  59. huggingface_hub/inference/_generated/types/text_generation.py +10 -10
  60. huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
  61. huggingface_hub/inference/_generated/types/token_classification.py +2 -2
  62. huggingface_hub/inference/_generated/types/translation.py +2 -2
  63. huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
  64. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
  65. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
  66. huggingface_hub/inference/_mcp/agent.py +3 -3
  67. huggingface_hub/inference/_mcp/constants.py +1 -2
  68. huggingface_hub/inference/_mcp/mcp_client.py +33 -22
  69. huggingface_hub/inference/_mcp/types.py +10 -10
  70. huggingface_hub/inference/_mcp/utils.py +4 -4
  71. huggingface_hub/inference/_providers/__init__.py +2 -13
  72. huggingface_hub/inference/_providers/_common.py +24 -25
  73. huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
  74. huggingface_hub/inference/_providers/cohere.py +3 -3
  75. huggingface_hub/inference/_providers/fal_ai.py +25 -25
  76. huggingface_hub/inference/_providers/featherless_ai.py +4 -4
  77. huggingface_hub/inference/_providers/fireworks_ai.py +3 -3
  78. huggingface_hub/inference/_providers/hf_inference.py +13 -13
  79. huggingface_hub/inference/_providers/hyperbolic.py +4 -4
  80. huggingface_hub/inference/_providers/nebius.py +10 -10
  81. huggingface_hub/inference/_providers/novita.py +5 -5
  82. huggingface_hub/inference/_providers/nscale.py +4 -4
  83. huggingface_hub/inference/_providers/replicate.py +15 -15
  84. huggingface_hub/inference/_providers/sambanova.py +6 -6
  85. huggingface_hub/inference/_providers/together.py +7 -7
  86. huggingface_hub/lfs.py +24 -33
  87. huggingface_hub/repocard.py +16 -17
  88. huggingface_hub/repocard_data.py +56 -56
  89. huggingface_hub/serialization/__init__.py +0 -1
  90. huggingface_hub/serialization/_base.py +9 -9
  91. huggingface_hub/serialization/_dduf.py +7 -7
  92. huggingface_hub/serialization/_torch.py +28 -28
  93. huggingface_hub/utils/__init__.py +10 -4
  94. huggingface_hub/utils/_auth.py +5 -5
  95. huggingface_hub/utils/_cache_manager.py +31 -31
  96. huggingface_hub/utils/_deprecation.py +1 -1
  97. huggingface_hub/utils/_dotenv.py +3 -3
  98. huggingface_hub/utils/_fixes.py +0 -10
  99. huggingface_hub/utils/_git_credential.py +3 -3
  100. huggingface_hub/utils/_headers.py +7 -29
  101. huggingface_hub/utils/_http.py +369 -209
  102. huggingface_hub/utils/_pagination.py +4 -4
  103. huggingface_hub/utils/_paths.py +5 -5
  104. huggingface_hub/utils/_runtime.py +15 -13
  105. huggingface_hub/utils/_safetensors.py +21 -21
  106. huggingface_hub/utils/_subprocess.py +9 -9
  107. huggingface_hub/utils/_telemetry.py +3 -3
  108. huggingface_hub/utils/_typing.py +3 -3
  109. huggingface_hub/utils/_validators.py +53 -72
  110. huggingface_hub/utils/_xet.py +16 -16
  111. huggingface_hub/utils/_xet_progress_reporting.py +1 -1
  112. huggingface_hub/utils/insecure_hashlib.py +3 -9
  113. huggingface_hub/utils/tqdm.py +3 -3
  114. {huggingface_hub-0.35.1.dist-info → huggingface_hub-1.0.0rc1.dist-info}/METADATA +17 -26
  115. huggingface_hub-1.0.0rc1.dist-info/RECORD +161 -0
  116. huggingface_hub/inference/_providers/publicai.py +0 -6
  117. huggingface_hub/inference/_providers/scaleway.py +0 -28
  118. huggingface_hub/inference_api.py +0 -217
  119. huggingface_hub/keras_mixin.py +0 -500
  120. huggingface_hub/repository.py +0 -1477
  121. huggingface_hub/serialization/_tensorflow.py +0 -95
  122. huggingface_hub/utils/_hf_folder.py +0 -68
  123. huggingface_hub-0.35.1.dist-info/RECORD +0 -168
  124. {huggingface_hub-0.35.1.dist-info → huggingface_hub-1.0.0rc1.dist-info}/LICENSE +0 -0
  125. {huggingface_hub-0.35.1.dist-info → huggingface_hub-1.0.0rc1.dist-info}/WHEEL +0 -0
  126. {huggingface_hub-0.35.1.dist-info → huggingface_hub-1.0.0rc1.dist-info}/entry_points.txt +0 -0
  127. {huggingface_hub-0.35.1.dist-info → huggingface_hub-1.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -2,24 +2,25 @@ import os
2
2
  import re
3
3
  import tempfile
4
4
  from collections import deque
5
+ from contextlib import ExitStack
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime
7
8
  from itertools import chain
8
9
  from pathlib import Path
9
- from typing import Any, Dict, Iterator, List, NoReturn, Optional, Tuple, Union
10
+ from typing import Any, Iterator, NoReturn, Optional, Union
10
11
  from urllib.parse import quote, unquote
11
12
 
12
13
  import fsspec
14
+ import httpx
13
15
  from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
14
16
  from fsspec.utils import isfilelike
15
- from requests import Response
16
17
 
17
18
  from . import constants
18
19
  from ._commit_api import CommitOperationCopy, CommitOperationDelete
19
- from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
20
+ from .errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
20
21
  from .file_download import hf_hub_url, http_get
21
22
  from .hf_api import HfApi, LastCommitInfo, RepoFile
22
- from .utils import HFValidationError, hf_raise_for_status, http_backoff
23
+ from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff
23
24
 
24
25
 
25
26
  # Regex used to match special revisions with "/" in them (see #1710)
@@ -113,13 +114,13 @@ class HfFileSystem(fsspec.AbstractFileSystem):
113
114
  # Maps (repo_type, repo_id, revision) to a 2-tuple with:
114
115
  # * the 1st element indicating whether the repositoy and the revision exist
115
116
  # * the 2nd element being the exception raised if the repository or revision doesn't exist
116
- self._repo_and_revision_exists_cache: Dict[
117
- Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]]
117
+ self._repo_and_revision_exists_cache: dict[
118
+ tuple[str, str, Optional[str]], tuple[bool, Optional[Exception]]
118
119
  ] = {}
119
120
 
120
121
  def _repo_and_revision_exist(
121
122
  self, repo_type: str, repo_id: str, revision: Optional[str]
122
- ) -> Tuple[bool, Optional[Exception]]:
123
+ ) -> tuple[bool, Optional[Exception]]:
123
124
  if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
124
125
  try:
125
126
  self._api.repo_info(
@@ -338,7 +339,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
338
339
 
339
340
  def ls(
340
341
  self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
341
- ) -> List[Union[str, Dict[str, Any]]]:
342
+ ) -> list[Union[str, dict[str, Any]]]:
342
343
  """
343
344
  List the contents of a directory.
344
345
 
@@ -362,7 +363,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
362
363
  The git revision to list from.
363
364
 
364
365
  Returns:
365
- `List[Union[str, Dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
366
+ `list[Union[str, dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
366
367
  dictionaries (if detail=True).
367
368
  """
368
369
  resolved_path = self.resolve_path(path, revision=revision)
@@ -483,7 +484,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
483
484
  out.append(cache_path_info)
484
485
  return out
485
486
 
486
- def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]:
487
+ def walk(self, path: str, *args, **kwargs) -> Iterator[tuple[str, list[str], list[str]]]:
487
488
  """
488
489
  Return all files below the given path.
489
490
 
@@ -494,12 +495,12 @@ class HfFileSystem(fsspec.AbstractFileSystem):
494
495
  Root path to list files from.
495
496
 
496
497
  Returns:
497
- `Iterator[Tuple[str, List[str], List[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
498
+ `Iterator[tuple[str, list[str], list[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
498
499
  """
499
500
  path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
500
501
  yield from super().walk(path, *args, **kwargs)
501
502
 
502
- def glob(self, path: str, **kwargs) -> List[str]:
503
+ def glob(self, path: str, **kwargs) -> list[str]:
503
504
  """
504
505
  Find files by glob-matching.
505
506
 
@@ -510,7 +511,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
510
511
  Path pattern to match.
511
512
 
512
513
  Returns:
513
- `List[str]`: List of paths matching the pattern.
514
+ `list[str]`: List of paths matching the pattern.
514
515
  """
515
516
  path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
516
517
  return super().glob(path, **kwargs)
@@ -524,7 +525,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
524
525
  refresh: bool = False,
525
526
  revision: Optional[str] = None,
526
527
  **kwargs,
527
- ) -> Union[List[str], Dict[str, Dict[str, Any]]]:
528
+ ) -> Union[list[str], dict[str, dict[str, Any]]]:
528
529
  """
529
530
  List all files below path.
530
531
 
@@ -545,7 +546,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
545
546
  The git revision to list from.
546
547
 
547
548
  Returns:
548
- `Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information.
549
+ `Union[list[str], dict[str, dict[str, Any]]]`: List of paths or dict of file information.
549
550
  """
550
551
  if maxdepth:
551
552
  return super().find(
@@ -650,7 +651,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
650
651
  info = self.info(path, **{**kwargs, "expand_info": True})
651
652
  return info["last_commit"]["date"]
652
653
 
653
- def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]:
654
+ def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> dict[str, Any]:
654
655
  """
655
656
  Get information about a file or directory.
656
657
 
@@ -671,7 +672,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
671
672
  The git revision to get info from.
672
673
 
673
674
  Returns:
674
- `Dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
675
+ `dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
675
676
 
676
677
  """
677
678
  resolved_path = self.resolve_path(path, revision=revision)
@@ -958,7 +959,13 @@ class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
958
959
  repo_type=self.resolved_path.repo_type,
959
960
  endpoint=self.fs.endpoint,
960
961
  )
961
- r = http_backoff("GET", url, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT)
962
+ r = http_backoff(
963
+ "GET",
964
+ url,
965
+ headers=headers,
966
+ retry_on_status_codes=(500, 502, 503, 504),
967
+ timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
968
+ )
962
969
  hf_raise_for_status(r)
963
970
  return r.content
964
971
 
@@ -1033,8 +1040,9 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
1033
1040
  super().__init__(
1034
1041
  fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
1035
1042
  )
1036
- self.response: Optional[Response] = None
1043
+ self.response: Optional[httpx.Response] = None
1037
1044
  self.fs: HfFileSystem
1045
+ self._exit_stack = ExitStack()
1038
1046
 
1039
1047
  def seek(self, loc: int, whence: int = 0):
1040
1048
  if loc == 0 and whence == 1:
@@ -1044,53 +1052,32 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
1044
1052
  raise ValueError("Cannot seek streaming HF file")
1045
1053
 
1046
1054
  def read(self, length: int = -1):
1047
- read_args = (length,) if length >= 0 else ()
1055
+ """Read the remote file.
1056
+
1057
+ If the file is already open, we reuse the connection.
1058
+ Otherwise, open a new connection and read from it.
1059
+
1060
+ If reading the stream fails, we retry with a new connection.
1061
+ """
1048
1062
  if self.response is None:
1049
- url = hf_hub_url(
1050
- repo_id=self.resolved_path.repo_id,
1051
- revision=self.resolved_path.revision,
1052
- filename=self.resolved_path.path_in_repo,
1053
- repo_type=self.resolved_path.repo_type,
1054
- endpoint=self.fs.endpoint,
1055
- )
1056
- self.response = http_backoff(
1057
- "GET",
1058
- url,
1059
- headers=self.fs._api._build_hf_headers(),
1060
- stream=True,
1061
- timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1062
- )
1063
- hf_raise_for_status(self.response)
1064
- try:
1065
- self.response.raw.decode_content = True
1066
- out = self.response.raw.read(*read_args)
1067
- except Exception:
1068
- self.response.close()
1063
+ self._open_connection()
1069
1064
 
1070
- # Retry by recreating the connection
1071
- url = hf_hub_url(
1072
- repo_id=self.resolved_path.repo_id,
1073
- revision=self.resolved_path.revision,
1074
- filename=self.resolved_path.path_in_repo,
1075
- repo_type=self.resolved_path.repo_type,
1076
- endpoint=self.fs.endpoint,
1077
- )
1078
- self.response = http_backoff(
1079
- "GET",
1080
- url,
1081
- headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
1082
- stream=True,
1083
- timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1084
- )
1085
- hf_raise_for_status(self.response)
1065
+ retried_once = False
1066
+ while True:
1086
1067
  try:
1087
- self.response.raw.decode_content = True
1088
- out = self.response.raw.read(*read_args)
1068
+ if self.response is None:
1069
+ return b"" # Already read the entire file
1070
+ out = _partial_read(self.response, length)
1071
+ self.loc += len(out)
1072
+ return out
1089
1073
  except Exception:
1090
- self.response.close()
1091
- raise
1092
- self.loc += len(out)
1093
- return out
1074
+ if self.response is not None:
1075
+ self.response.close()
1076
+ if retried_once: # Already retried once, give up
1077
+ raise
1078
+ # First failure, retry with range header
1079
+ self._open_connection()
1080
+ retried_once = True
1094
1081
 
1095
1082
  def url(self) -> str:
1096
1083
  return self.fs.url(self.path)
@@ -1099,11 +1086,43 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
1099
1086
  if not hasattr(self, "resolved_path"):
1100
1087
  # Means that the constructor failed. Nothing to do.
1101
1088
  return
1089
+ self._exit_stack.close()
1102
1090
  return super().__del__()
1103
1091
 
1104
1092
  def __reduce__(self):
1105
1093
  return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
1106
1094
 
1095
+ def _open_connection(self):
1096
+ """Open a connection to the remote file."""
1097
+ url = hf_hub_url(
1098
+ repo_id=self.resolved_path.repo_id,
1099
+ revision=self.resolved_path.revision,
1100
+ filename=self.resolved_path.path_in_repo,
1101
+ repo_type=self.resolved_path.repo_type,
1102
+ endpoint=self.fs.endpoint,
1103
+ )
1104
+ headers = self.fs._api._build_hf_headers()
1105
+ if self.loc > 0:
1106
+ headers["Range"] = f"bytes={self.loc}-"
1107
+ self.response = self._exit_stack.enter_context(
1108
+ http_stream_backoff(
1109
+ "GET",
1110
+ url,
1111
+ headers=headers,
1112
+ retry_on_status_codes=(500, 502, 503, 504),
1113
+ timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1114
+ )
1115
+ )
1116
+
1117
+ try:
1118
+ hf_raise_for_status(self.response)
1119
+ except HfHubHTTPError as e:
1120
+ if e.response.status_code == 416:
1121
+ # Range not satisfiable => means that we have already read the entire file
1122
+ self.response = None
1123
+ return
1124
+ raise
1125
+
1107
1126
 
1108
1127
  def safe_revision(revision: str) -> str:
1109
1128
  return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
@@ -1126,3 +1145,26 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
1126
1145
 
1127
1146
  def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
1128
1147
  return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
1148
+
1149
+
1150
+ def _partial_read(response: httpx.Response, length: int = -1) -> bytes:
1151
+ """
1152
+ Read up to `length` bytes from a streamed response.
1153
+ If length == -1, read until EOF.
1154
+ """
1155
+ buf = bytearray()
1156
+ if length < -1:
1157
+ raise ValueError("length must be -1 or >= 0")
1158
+ if length == 0:
1159
+ return b""
1160
+ if length == -1:
1161
+ for chunk in response.iter_bytes():
1162
+ buf.extend(chunk)
1163
+ return bytes(buf)
1164
+
1165
+ for chunk in response.iter_bytes(chunk_size=length):
1166
+ buf.extend(chunk)
1167
+ if len(buf) >= length:
1168
+ return bytes(buf[:length])
1169
+
1170
+ return bytes(buf) # may be < length if response ended
@@ -3,7 +3,7 @@ import json
3
3
  import os
4
4
  from dataclasses import Field, asdict, dataclass, is_dataclass
5
5
  from pathlib import Path
6
- from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
6
+ from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union
7
7
 
8
8
  import packaging.version
9
9
 
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
38
38
 
39
39
  # Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
40
40
  class DataclassInstance(Protocol):
41
- __dataclass_fields__: ClassVar[Dict[str, Field]]
41
+ __dataclass_fields__: ClassVar[dict[str, Field]]
42
42
 
43
43
 
44
44
  # Generic variable that is either ModelHubMixin or a subclass thereof
@@ -47,7 +47,7 @@ T = TypeVar("T", bound="ModelHubMixin")
47
47
  ARGS_T = TypeVar("ARGS_T")
48
48
  ENCODER_T = Callable[[ARGS_T], Any]
49
49
  DECODER_T = Callable[[Any], ARGS_T]
50
- CODER_T = Tuple[ENCODER_T, DECODER_T]
50
+ CODER_T = tuple[ENCODER_T, DECODER_T]
51
51
 
52
52
 
53
53
  DEFAULT_MODEL_CARD = """
@@ -96,7 +96,7 @@ class ModelHubMixin:
96
96
  URL of the library documentation. Used to generate model card.
97
97
  model_card_template (`str`, *optional*):
98
98
  Template of the model card. Used to generate model card. Defaults to a generic template.
99
- language (`str` or `List[str]`, *optional*):
99
+ language (`str` or `list[str]`, *optional*):
100
100
  Language supported by the library. Used to generate model card.
101
101
  library_name (`str`, *optional*):
102
102
  Name of the library integrating ModelHubMixin. Used to generate model card.
@@ -113,9 +113,9 @@ class ModelHubMixin:
113
113
  E.g: "https://coqui.ai/cpml".
114
114
  pipeline_tag (`str`, *optional*):
115
115
  Tag of the pipeline. Used to generate model card. E.g. "text-classification".
116
- tags (`List[str]`, *optional*):
116
+ tags (`list[str]`, *optional*):
117
117
  Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
118
- coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
118
+ coders (`dict[Type, tuple[Callable, Callable]]`, *optional*):
119
119
  Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
120
120
  jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
121
121
 
@@ -145,12 +145,10 @@ class ModelHubMixin:
145
145
  ...
146
146
  ... @classmethod
147
147
  ... def from_pretrained(
148
- ... cls: Type[T],
148
+ ... cls: type[T],
149
149
  ... pretrained_model_name_or_path: Union[str, Path],
150
150
  ... *,
151
151
  ... force_download: bool = False,
152
- ... resume_download: Optional[bool] = None,
153
- ... proxies: Optional[Dict] = None,
154
152
  ... token: Optional[Union[str, bool]] = None,
155
153
  ... cache_dir: Optional[Union[str, Path]] = None,
156
154
  ... local_files_only: bool = False,
@@ -188,10 +186,10 @@ class ModelHubMixin:
188
186
  _hub_mixin_info: MixinInfo
189
187
  # ^ information about the library integrating ModelHubMixin (used to generate model card)
190
188
  _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
191
- _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters
192
- _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters
193
- _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded
194
- _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types
189
+ _hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters
190
+ _hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters
191
+ _hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded
192
+ _hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types
195
193
  # ^ internal values to handle config
196
194
 
197
195
  def __init_subclass__(
@@ -204,16 +202,16 @@ class ModelHubMixin:
204
202
  # Model card template
205
203
  model_card_template: str = DEFAULT_MODEL_CARD,
206
204
  # Model card metadata
207
- language: Optional[List[str]] = None,
205
+ language: Optional[list[str]] = None,
208
206
  library_name: Optional[str] = None,
209
207
  license: Optional[str] = None,
210
208
  license_name: Optional[str] = None,
211
209
  license_link: Optional[str] = None,
212
210
  pipeline_tag: Optional[str] = None,
213
- tags: Optional[List[str]] = None,
211
+ tags: Optional[list[str]] = None,
214
212
  # How to encode/decode arguments with custom type into a JSON config?
215
213
  coders: Optional[
216
- Dict[Type, CODER_T]
214
+ dict[Type, CODER_T]
217
215
  # Key is a type.
218
216
  # Value is a tuple (encoder, decoder).
219
217
  # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
@@ -288,7 +286,7 @@ class ModelHubMixin:
288
286
  }
289
287
  cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
290
288
 
291
- def __new__(cls: Type[T], *args, **kwargs) -> T:
289
+ def __new__(cls: type[T], *args, **kwargs) -> T:
292
290
  """Create a new instance of the class and handle config.
293
291
 
294
292
  3 cases:
@@ -364,7 +362,7 @@ class ModelHubMixin:
364
362
  return arg
365
363
 
366
364
  @classmethod
367
- def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
365
+ def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]:
368
366
  """Decode a JSON serializable value into an argument."""
369
367
  if is_simple_optional_type(expected_type):
370
368
  if value is None:
@@ -387,7 +385,7 @@ class ModelHubMixin:
387
385
  config: Optional[Union[dict, DataclassInstance]] = None,
388
386
  repo_id: Optional[str] = None,
389
387
  push_to_hub: bool = False,
390
- model_card_kwargs: Optional[Dict[str, Any]] = None,
388
+ model_card_kwargs: Optional[dict[str, Any]] = None,
391
389
  **push_to_hub_kwargs,
392
390
  ) -> Optional[str]:
393
391
  """
@@ -403,7 +401,7 @@ class ModelHubMixin:
403
401
  repo_id (`str`, *optional*):
404
402
  ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
405
403
  not provided.
406
- model_card_kwargs (`Dict[str, Any]`, *optional*):
404
+ model_card_kwargs (`dict[str, Any]`, *optional*):
407
405
  Additional arguments passed to the model card template to customize the model card.
408
406
  push_to_hub_kwargs:
409
407
  Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
@@ -462,12 +460,10 @@ class ModelHubMixin:
462
460
  @classmethod
463
461
  @validate_hf_hub_args
464
462
  def from_pretrained(
465
- cls: Type[T],
463
+ cls: type[T],
466
464
  pretrained_model_name_or_path: Union[str, Path],
467
465
  *,
468
466
  force_download: bool = False,
469
- resume_download: Optional[bool] = None,
470
- proxies: Optional[Dict] = None,
471
467
  token: Optional[Union[str, bool]] = None,
472
468
  cache_dir: Optional[Union[str, Path]] = None,
473
469
  local_files_only: bool = False,
@@ -488,9 +484,6 @@ class ModelHubMixin:
488
484
  force_download (`bool`, *optional*, defaults to `False`):
489
485
  Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
490
486
  the existing cache.
491
- proxies (`Dict[str, str]`, *optional*):
492
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
493
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
494
487
  token (`str` or `bool`, *optional*):
495
488
  The token to use as HTTP bearer authorization for remote files. By default, it will use the token
496
489
  cached when running `hf auth login`.
@@ -498,7 +491,7 @@ class ModelHubMixin:
498
491
  Path to the folder where cached files are stored.
499
492
  local_files_only (`bool`, *optional*, defaults to `False`):
500
493
  If `True`, avoid downloading the file and return the path to the local cached file if it exists.
501
- model_kwargs (`Dict`, *optional*):
494
+ model_kwargs (`dict`, *optional*):
502
495
  Additional kwargs to pass to the model during initialization.
503
496
  """
504
497
  model_id = str(pretrained_model_name_or_path)
@@ -516,8 +509,6 @@ class ModelHubMixin:
516
509
  revision=revision,
517
510
  cache_dir=cache_dir,
518
511
  force_download=force_download,
519
- proxies=proxies,
520
- resume_download=resume_download,
521
512
  token=token,
522
513
  local_files_only=local_files_only,
523
514
  )
@@ -557,7 +548,7 @@ class ModelHubMixin:
557
548
  if key not in model_kwargs and key in config:
558
549
  model_kwargs[key] = config[key]
559
550
  elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
560
- for key, value in config.items():
551
+ for key, value in config.items(): # type: ignore[union-attr]
561
552
  if key not in model_kwargs:
562
553
  model_kwargs[key] = value
563
554
 
@@ -570,8 +561,6 @@ class ModelHubMixin:
570
561
  revision=revision,
571
562
  cache_dir=cache_dir,
572
563
  force_download=force_download,
573
- proxies=proxies,
574
- resume_download=resume_download,
575
564
  local_files_only=local_files_only,
576
565
  token=token,
577
566
  **model_kwargs,
@@ -586,14 +575,12 @@ class ModelHubMixin:
586
575
 
587
576
  @classmethod
588
577
  def _from_pretrained(
589
- cls: Type[T],
578
+ cls: type[T],
590
579
  *,
591
580
  model_id: str,
592
581
  revision: Optional[str],
593
582
  cache_dir: Optional[Union[str, Path]],
594
583
  force_download: bool,
595
- proxies: Optional[Dict],
596
- resume_download: Optional[bool],
597
584
  local_files_only: bool,
598
585
  token: Optional[Union[str, bool]],
599
586
  **model_kwargs,
@@ -616,9 +603,6 @@ class ModelHubMixin:
616
603
  force_download (`bool`, *optional*, defaults to `False`):
617
604
  Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
618
605
  the existing cache.
619
- proxies (`Dict[str, str]`, *optional*):
620
- A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
621
- 'http://hostname': 'foo.bar:4012'}`).
622
606
  token (`str` or `bool`, *optional*):
623
607
  The token to use as HTTP bearer authorization for remote files. By default, it will use the token
624
608
  cached when running `hf auth login`.
@@ -642,10 +626,10 @@ class ModelHubMixin:
642
626
  token: Optional[str] = None,
643
627
  branch: Optional[str] = None,
644
628
  create_pr: Optional[bool] = None,
645
- allow_patterns: Optional[Union[List[str], str]] = None,
646
- ignore_patterns: Optional[Union[List[str], str]] = None,
647
- delete_patterns: Optional[Union[List[str], str]] = None,
648
- model_card_kwargs: Optional[Dict[str, Any]] = None,
629
+ allow_patterns: Optional[Union[list[str], str]] = None,
630
+ ignore_patterns: Optional[Union[list[str], str]] = None,
631
+ delete_patterns: Optional[Union[list[str], str]] = None,
632
+ model_card_kwargs: Optional[dict[str, Any]] = None,
649
633
  ) -> str:
650
634
  """
651
635
  Upload model checkpoint to the Hub.
@@ -671,13 +655,13 @@ class ModelHubMixin:
671
655
  The git branch on which to push the model. This defaults to `"main"`.
672
656
  create_pr (`boolean`, *optional*):
673
657
  Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
674
- allow_patterns (`List[str]` or `str`, *optional*):
658
+ allow_patterns (`list[str]` or `str`, *optional*):
675
659
  If provided, only files matching at least one pattern are pushed.
676
- ignore_patterns (`List[str]` or `str`, *optional*):
660
+ ignore_patterns (`list[str]` or `str`, *optional*):
677
661
  If provided, files matching any of the patterns are not pushed.
678
- delete_patterns (`List[str]` or `str`, *optional*):
662
+ delete_patterns (`list[str]` or `str`, *optional*):
679
663
  If provided, remote files matching any of the patterns will be deleted from the repo.
680
- model_card_kwargs (`Dict[str, Any]`, *optional*):
664
+ model_card_kwargs (`dict[str, Any]`, *optional*):
681
665
  Additional arguments passed to the model card template to customize the model card.
682
666
 
683
667
  Returns:
@@ -760,7 +744,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
760
744
  ```
761
745
  """
762
746
 
763
- def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
747
+ def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None:
764
748
  tags = tags or []
765
749
  tags.append("pytorch_model_hub_mixin")
766
750
  kwargs["tags"] = tags
@@ -779,8 +763,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
779
763
  revision: Optional[str],
780
764
  cache_dir: Optional[Union[str, Path]],
781
765
  force_download: bool,
782
- proxies: Optional[Dict],
783
- resume_download: Optional[bool],
784
766
  local_files_only: bool,
785
767
  token: Union[str, bool, None],
786
768
  map_location: str = "cpu",
@@ -801,8 +783,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
801
783
  revision=revision,
802
784
  cache_dir=cache_dir,
803
785
  force_download=force_download,
804
- proxies=proxies,
805
- resume_download=resume_download,
806
786
  token=token,
807
787
  local_files_only=local_files_only,
808
788
  )
@@ -814,8 +794,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
814
794
  revision=revision,
815
795
  cache_dir=cache_dir,
816
796
  force_download=force_download,
817
- proxies=proxies,
818
- resume_download=resume_download,
819
797
  token=token,
820
798
  local_files_only=local_files_only,
821
799
  )
@@ -845,7 +823,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
845
823
  return model
846
824
 
847
825
 
848
- def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
826
+ def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance:
849
827
  """Load a dataclass instance from a dictionary.
850
828
 
851
829
  Fields not expected by the dataclass are ignored.