huggingface-hub 0.35.0rc0__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (127) hide show
  1. huggingface_hub/__init__.py +46 -45
  2. huggingface_hub/_commit_api.py +28 -28
  3. huggingface_hub/_commit_scheduler.py +11 -8
  4. huggingface_hub/_inference_endpoints.py +8 -8
  5. huggingface_hub/_jobs_api.py +176 -20
  6. huggingface_hub/_local_folder.py +1 -1
  7. huggingface_hub/_login.py +13 -39
  8. huggingface_hub/_oauth.py +10 -14
  9. huggingface_hub/_snapshot_download.py +14 -28
  10. huggingface_hub/_space_api.py +4 -4
  11. huggingface_hub/_tensorboard_logger.py +13 -14
  12. huggingface_hub/_upload_large_folder.py +120 -13
  13. huggingface_hub/_webhooks_payload.py +3 -3
  14. huggingface_hub/_webhooks_server.py +2 -2
  15. huggingface_hub/cli/_cli_utils.py +2 -2
  16. huggingface_hub/cli/auth.py +8 -6
  17. huggingface_hub/cli/cache.py +18 -20
  18. huggingface_hub/cli/download.py +4 -4
  19. huggingface_hub/cli/hf.py +2 -5
  20. huggingface_hub/cli/jobs.py +599 -22
  21. huggingface_hub/cli/lfs.py +4 -4
  22. huggingface_hub/cli/repo.py +11 -7
  23. huggingface_hub/cli/repo_files.py +2 -2
  24. huggingface_hub/cli/upload.py +4 -4
  25. huggingface_hub/cli/upload_large_folder.py +3 -3
  26. huggingface_hub/commands/_cli_utils.py +2 -2
  27. huggingface_hub/commands/delete_cache.py +13 -13
  28. huggingface_hub/commands/download.py +4 -13
  29. huggingface_hub/commands/lfs.py +4 -4
  30. huggingface_hub/commands/repo_files.py +2 -2
  31. huggingface_hub/commands/scan_cache.py +1 -1
  32. huggingface_hub/commands/tag.py +1 -3
  33. huggingface_hub/commands/upload.py +4 -4
  34. huggingface_hub/commands/upload_large_folder.py +3 -3
  35. huggingface_hub/commands/user.py +4 -5
  36. huggingface_hub/community.py +5 -5
  37. huggingface_hub/constants.py +3 -41
  38. huggingface_hub/dataclasses.py +16 -19
  39. huggingface_hub/errors.py +42 -29
  40. huggingface_hub/fastai_utils.py +8 -9
  41. huggingface_hub/file_download.py +162 -259
  42. huggingface_hub/hf_api.py +841 -616
  43. huggingface_hub/hf_file_system.py +98 -62
  44. huggingface_hub/hub_mixin.py +37 -57
  45. huggingface_hub/inference/_client.py +257 -325
  46. huggingface_hub/inference/_common.py +110 -124
  47. huggingface_hub/inference/_generated/_async_client.py +307 -432
  48. huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
  49. huggingface_hub/inference/_generated/types/base.py +10 -7
  50. huggingface_hub/inference/_generated/types/chat_completion.py +18 -16
  51. huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
  52. huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
  53. huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
  54. huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
  55. huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
  56. huggingface_hub/inference/_generated/types/summarization.py +2 -2
  57. huggingface_hub/inference/_generated/types/table_question_answering.py +4 -4
  58. huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
  59. huggingface_hub/inference/_generated/types/text_generation.py +10 -10
  60. huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
  61. huggingface_hub/inference/_generated/types/token_classification.py +2 -2
  62. huggingface_hub/inference/_generated/types/translation.py +2 -2
  63. huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
  64. huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
  65. huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
  66. huggingface_hub/inference/_mcp/_cli_hacks.py +3 -3
  67. huggingface_hub/inference/_mcp/agent.py +3 -3
  68. huggingface_hub/inference/_mcp/cli.py +1 -1
  69. huggingface_hub/inference/_mcp/constants.py +2 -3
  70. huggingface_hub/inference/_mcp/mcp_client.py +58 -30
  71. huggingface_hub/inference/_mcp/types.py +10 -7
  72. huggingface_hub/inference/_mcp/utils.py +11 -7
  73. huggingface_hub/inference/_providers/__init__.py +4 -2
  74. huggingface_hub/inference/_providers/_common.py +49 -25
  75. huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
  76. huggingface_hub/inference/_providers/cohere.py +3 -3
  77. huggingface_hub/inference/_providers/fal_ai.py +52 -21
  78. huggingface_hub/inference/_providers/featherless_ai.py +4 -4
  79. huggingface_hub/inference/_providers/fireworks_ai.py +3 -3
  80. huggingface_hub/inference/_providers/hf_inference.py +28 -20
  81. huggingface_hub/inference/_providers/hyperbolic.py +4 -4
  82. huggingface_hub/inference/_providers/nebius.py +10 -10
  83. huggingface_hub/inference/_providers/novita.py +5 -5
  84. huggingface_hub/inference/_providers/nscale.py +4 -4
  85. huggingface_hub/inference/_providers/replicate.py +15 -15
  86. huggingface_hub/inference/_providers/sambanova.py +6 -6
  87. huggingface_hub/inference/_providers/together.py +7 -7
  88. huggingface_hub/lfs.py +20 -31
  89. huggingface_hub/repocard.py +18 -18
  90. huggingface_hub/repocard_data.py +56 -56
  91. huggingface_hub/serialization/__init__.py +0 -1
  92. huggingface_hub/serialization/_base.py +9 -9
  93. huggingface_hub/serialization/_dduf.py +7 -7
  94. huggingface_hub/serialization/_torch.py +28 -28
  95. huggingface_hub/utils/__init__.py +10 -4
  96. huggingface_hub/utils/_auth.py +5 -5
  97. huggingface_hub/utils/_cache_manager.py +31 -31
  98. huggingface_hub/utils/_deprecation.py +1 -1
  99. huggingface_hub/utils/_dotenv.py +25 -21
  100. huggingface_hub/utils/_fixes.py +0 -10
  101. huggingface_hub/utils/_git_credential.py +4 -4
  102. huggingface_hub/utils/_headers.py +7 -29
  103. huggingface_hub/utils/_http.py +366 -208
  104. huggingface_hub/utils/_pagination.py +4 -4
  105. huggingface_hub/utils/_paths.py +5 -5
  106. huggingface_hub/utils/_runtime.py +16 -13
  107. huggingface_hub/utils/_safetensors.py +21 -21
  108. huggingface_hub/utils/_subprocess.py +9 -9
  109. huggingface_hub/utils/_telemetry.py +3 -3
  110. huggingface_hub/utils/_typing.py +25 -5
  111. huggingface_hub/utils/_validators.py +53 -72
  112. huggingface_hub/utils/_xet.py +16 -16
  113. huggingface_hub/utils/_xet_progress_reporting.py +32 -11
  114. huggingface_hub/utils/insecure_hashlib.py +3 -9
  115. huggingface_hub/utils/tqdm.py +3 -3
  116. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/METADATA +18 -29
  117. huggingface_hub-1.0.0rc0.dist-info/RECORD +161 -0
  118. huggingface_hub/inference_api.py +0 -217
  119. huggingface_hub/keras_mixin.py +0 -500
  120. huggingface_hub/repository.py +0 -1477
  121. huggingface_hub/serialization/_tensorflow.py +0 -95
  122. huggingface_hub/utils/_hf_folder.py +0 -68
  123. huggingface_hub-0.35.0rc0.dist-info/RECORD +0 -166
  124. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/LICENSE +0 -0
  125. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/WHEEL +0 -0
  126. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/entry_points.txt +0 -0
  127. {huggingface_hub-0.35.0rc0.dist-info → huggingface_hub-1.0.0rc0.dist-info}/top_level.txt +0 -0
@@ -2,24 +2,25 @@ import os
2
2
  import re
3
3
  import tempfile
4
4
  from collections import deque
5
+ from contextlib import ExitStack
5
6
  from dataclasses import dataclass, field
6
7
  from datetime import datetime
7
8
  from itertools import chain
8
9
  from pathlib import Path
9
- from typing import Any, Dict, Iterator, List, NoReturn, Optional, Tuple, Union
10
+ from typing import Any, Iterator, NoReturn, Optional, Union
10
11
  from urllib.parse import quote, unquote
11
12
 
12
13
  import fsspec
14
+ import httpx
13
15
  from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback
14
16
  from fsspec.utils import isfilelike
15
- from requests import Response
16
17
 
17
18
  from . import constants
18
19
  from ._commit_api import CommitOperationCopy, CommitOperationDelete
19
- from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
20
+ from .errors import EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
20
21
  from .file_download import hf_hub_url, http_get
21
22
  from .hf_api import HfApi, LastCommitInfo, RepoFile
22
- from .utils import HFValidationError, hf_raise_for_status, http_backoff
23
+ from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff
23
24
 
24
25
 
25
26
  # Regex used to match special revisions with "/" in them (see #1710)
@@ -113,13 +114,13 @@ class HfFileSystem(fsspec.AbstractFileSystem):
113
114
  # Maps (repo_type, repo_id, revision) to a 2-tuple with:
114
115
  # * the 1st element indicating whether the repositoy and the revision exist
115
116
  # * the 2nd element being the exception raised if the repository or revision doesn't exist
116
- self._repo_and_revision_exists_cache: Dict[
117
- Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]]
117
+ self._repo_and_revision_exists_cache: dict[
118
+ tuple[str, str, Optional[str]], tuple[bool, Optional[Exception]]
118
119
  ] = {}
119
120
 
120
121
  def _repo_and_revision_exist(
121
122
  self, repo_type: str, repo_id: str, revision: Optional[str]
122
- ) -> Tuple[bool, Optional[Exception]]:
123
+ ) -> tuple[bool, Optional[Exception]]:
123
124
  if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache:
124
125
  try:
125
126
  self._api.repo_info(
@@ -338,7 +339,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
338
339
 
339
340
  def ls(
340
341
  self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs
341
- ) -> List[Union[str, Dict[str, Any]]]:
342
+ ) -> list[Union[str, dict[str, Any]]]:
342
343
  """
343
344
  List the contents of a directory.
344
345
 
@@ -362,7 +363,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
362
363
  The git revision to list from.
363
364
 
364
365
  Returns:
365
- `List[Union[str, Dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
366
+ `list[Union[str, dict[str, Any]]]`: List of file paths (if detail=False) or list of file information
366
367
  dictionaries (if detail=True).
367
368
  """
368
369
  resolved_path = self.resolve_path(path, revision=revision)
@@ -483,7 +484,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
483
484
  out.append(cache_path_info)
484
485
  return out
485
486
 
486
- def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]:
487
+ def walk(self, path: str, *args, **kwargs) -> Iterator[tuple[str, list[str], list[str]]]:
487
488
  """
488
489
  Return all files below the given path.
489
490
 
@@ -494,12 +495,12 @@ class HfFileSystem(fsspec.AbstractFileSystem):
494
495
  Root path to list files from.
495
496
 
496
497
  Returns:
497
- `Iterator[Tuple[str, List[str], List[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
498
+ `Iterator[tuple[str, list[str], list[str]]]`: An iterator of (path, list of directory names, list of file names) tuples.
498
499
  """
499
500
  path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
500
501
  yield from super().walk(path, *args, **kwargs)
501
502
 
502
- def glob(self, path: str, **kwargs) -> List[str]:
503
+ def glob(self, path: str, **kwargs) -> list[str]:
503
504
  """
504
505
  Find files by glob-matching.
505
506
 
@@ -510,7 +511,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
510
511
  Path pattern to match.
511
512
 
512
513
  Returns:
513
- `List[str]`: List of paths matching the pattern.
514
+ `list[str]`: List of paths matching the pattern.
514
515
  """
515
516
  path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
516
517
  return super().glob(path, **kwargs)
@@ -524,7 +525,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
524
525
  refresh: bool = False,
525
526
  revision: Optional[str] = None,
526
527
  **kwargs,
527
- ) -> Union[List[str], Dict[str, Dict[str, Any]]]:
528
+ ) -> Union[list[str], dict[str, dict[str, Any]]]:
528
529
  """
529
530
  List all files below path.
530
531
 
@@ -545,7 +546,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
545
546
  The git revision to list from.
546
547
 
547
548
  Returns:
548
- `Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information.
549
+ `Union[list[str], dict[str, dict[str, Any]]]`: List of paths or dict of file information.
549
550
  """
550
551
  if maxdepth:
551
552
  return super().find(
@@ -650,7 +651,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
650
651
  info = self.info(path, **{**kwargs, "expand_info": True})
651
652
  return info["last_commit"]["date"]
652
653
 
653
- def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]:
654
+ def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> dict[str, Any]:
654
655
  """
655
656
  Get information about a file or directory.
656
657
 
@@ -671,7 +672,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
671
672
  The git revision to get info from.
672
673
 
673
674
  Returns:
674
- `Dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
675
+ `dict[str, Any]`: Dictionary containing file information (type, size, commit info, etc.).
675
676
 
676
677
  """
677
678
  resolved_path = self.resolve_path(path, revision=revision)
@@ -896,7 +897,7 @@ class HfFileSystem(fsspec.AbstractFileSystem):
896
897
  repo_type=resolve_remote_path.repo_type,
897
898
  endpoint=self.endpoint,
898
899
  ),
899
- temp_file=outfile,
900
+ temp_file=outfile, # type: ignore[arg-type]
900
901
  displayed_filename=rpath,
901
902
  expected_size=expected_size,
902
903
  resume_size=0,
@@ -1039,8 +1040,9 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
1039
1040
  super().__init__(
1040
1041
  fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs
1041
1042
  )
1042
- self.response: Optional[Response] = None
1043
+ self.response: Optional[httpx.Response] = None
1043
1044
  self.fs: HfFileSystem
1045
+ self._exit_stack = ExitStack()
1044
1046
 
1045
1047
  def seek(self, loc: int, whence: int = 0):
1046
1048
  if loc == 0 and whence == 1:
@@ -1050,53 +1052,32 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
1050
1052
  raise ValueError("Cannot seek streaming HF file")
1051
1053
 
1052
1054
  def read(self, length: int = -1):
1053
- read_args = (length,) if length >= 0 else ()
1055
+ """Read the remote file.
1056
+
1057
+ If the file is already open, we reuse the connection.
1058
+ Otherwise, open a new connection and read from it.
1059
+
1060
+ If reading the stream fails, we retry with a new connection.
1061
+ """
1054
1062
  if self.response is None:
1055
- url = hf_hub_url(
1056
- repo_id=self.resolved_path.repo_id,
1057
- revision=self.resolved_path.revision,
1058
- filename=self.resolved_path.path_in_repo,
1059
- repo_type=self.resolved_path.repo_type,
1060
- endpoint=self.fs.endpoint,
1061
- )
1062
- self.response = http_backoff(
1063
- "GET",
1064
- url,
1065
- headers=self.fs._api._build_hf_headers(),
1066
- retry_on_status_codes=(500, 502, 503, 504),
1067
- stream=True,
1068
- timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1069
- )
1070
- hf_raise_for_status(self.response)
1071
- try:
1072
- out = self.response.raw.read(*read_args)
1073
- except Exception:
1074
- self.response.close()
1063
+ self._open_connection()
1075
1064
 
1076
- # Retry by recreating the connection
1077
- url = hf_hub_url(
1078
- repo_id=self.resolved_path.repo_id,
1079
- revision=self.resolved_path.revision,
1080
- filename=self.resolved_path.path_in_repo,
1081
- repo_type=self.resolved_path.repo_type,
1082
- endpoint=self.fs.endpoint,
1083
- )
1084
- self.response = http_backoff(
1085
- "GET",
1086
- url,
1087
- headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()},
1088
- retry_on_status_codes=(500, 502, 503, 504),
1089
- stream=True,
1090
- timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1091
- )
1092
- hf_raise_for_status(self.response)
1065
+ retried_once = False
1066
+ while True:
1093
1067
  try:
1094
- out = self.response.raw.read(*read_args)
1068
+ if self.response is None:
1069
+ return b"" # Already read the entire file
1070
+ out = _partial_read(self.response, length)
1071
+ self.loc += len(out)
1072
+ return out
1095
1073
  except Exception:
1096
- self.response.close()
1097
- raise
1098
- self.loc += len(out)
1099
- return out
1074
+ if self.response is not None:
1075
+ self.response.close()
1076
+ if retried_once: # Already retried once, give up
1077
+ raise
1078
+ # First failure, retry with range header
1079
+ self._open_connection()
1080
+ retried_once = True
1100
1081
 
1101
1082
  def url(self) -> str:
1102
1083
  return self.fs.url(self.path)
@@ -1105,11 +1086,43 @@ class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile):
1105
1086
  if not hasattr(self, "resolved_path"):
1106
1087
  # Means that the constructor failed. Nothing to do.
1107
1088
  return
1089
+ self._exit_stack.close()
1108
1090
  return super().__del__()
1109
1091
 
1110
1092
  def __reduce__(self):
1111
1093
  return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name)
1112
1094
 
1095
+ def _open_connection(self):
1096
+ """Open a connection to the remote file."""
1097
+ url = hf_hub_url(
1098
+ repo_id=self.resolved_path.repo_id,
1099
+ revision=self.resolved_path.revision,
1100
+ filename=self.resolved_path.path_in_repo,
1101
+ repo_type=self.resolved_path.repo_type,
1102
+ endpoint=self.fs.endpoint,
1103
+ )
1104
+ headers = self.fs._api._build_hf_headers()
1105
+ if self.loc > 0:
1106
+ headers["Range"] = f"bytes={self.loc}-"
1107
+ self.response = self._exit_stack.enter_context(
1108
+ http_stream_backoff(
1109
+ "GET",
1110
+ url,
1111
+ headers=headers,
1112
+ retry_on_status_codes=(500, 502, 503, 504),
1113
+ timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT,
1114
+ )
1115
+ )
1116
+
1117
+ try:
1118
+ hf_raise_for_status(self.response)
1119
+ except HfHubHTTPError as e:
1120
+ if e.response.status_code == 416:
1121
+ # Range not satisfiable => means that we have already read the entire file
1122
+ self.response = None
1123
+ return
1124
+ raise
1125
+
1113
1126
 
1114
1127
  def safe_revision(revision: str) -> str:
1115
1128
  return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)
@@ -1132,3 +1145,26 @@ def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn:
1132
1145
 
1133
1146
  def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str):
1134
1147
  return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type)
1148
+
1149
+
1150
+ def _partial_read(response: httpx.Response, length: int = -1) -> bytes:
1151
+ """
1152
+ Read up to `length` bytes from a streamed response.
1153
+ If length == -1, read until EOF.
1154
+ """
1155
+ buf = bytearray()
1156
+ if length < -1:
1157
+ raise ValueError("length must be -1 or >= 0")
1158
+ if length == 0:
1159
+ return b""
1160
+ if length == -1:
1161
+ for chunk in response.iter_bytes():
1162
+ buf.extend(chunk)
1163
+ return bytes(buf)
1164
+
1165
+ for chunk in response.iter_bytes(chunk_size=length):
1166
+ buf.extend(chunk)
1167
+ if len(buf) >= length:
1168
+ return bytes(buf[:length])
1169
+
1170
+ return bytes(buf) # may be < length if response ended
@@ -3,7 +3,7 @@ import json
3
3
  import os
4
4
  from dataclasses import Field, asdict, dataclass, is_dataclass
5
5
  from pathlib import Path
6
- from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
6
+ from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union
7
7
 
8
8
  import packaging.version
9
9
 
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
38
38
 
39
39
  # Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
40
40
  class DataclassInstance(Protocol):
41
- __dataclass_fields__: ClassVar[Dict[str, Field]]
41
+ __dataclass_fields__: ClassVar[dict[str, Field]]
42
42
 
43
43
 
44
44
  # Generic variable that is either ModelHubMixin or a subclass thereof
@@ -47,7 +47,7 @@ T = TypeVar("T", bound="ModelHubMixin")
47
47
  ARGS_T = TypeVar("ARGS_T")
48
48
  ENCODER_T = Callable[[ARGS_T], Any]
49
49
  DECODER_T = Callable[[Any], ARGS_T]
50
- CODER_T = Tuple[ENCODER_T, DECODER_T]
50
+ CODER_T = tuple[ENCODER_T, DECODER_T]
51
51
 
52
52
 
53
53
  DEFAULT_MODEL_CARD = """
@@ -96,7 +96,7 @@ class ModelHubMixin:
96
96
  URL of the library documentation. Used to generate model card.
97
97
  model_card_template (`str`, *optional*):
98
98
  Template of the model card. Used to generate model card. Defaults to a generic template.
99
- language (`str` or `List[str]`, *optional*):
99
+ language (`str` or `list[str]`, *optional*):
100
100
  Language supported by the library. Used to generate model card.
101
101
  library_name (`str`, *optional*):
102
102
  Name of the library integrating ModelHubMixin. Used to generate model card.
@@ -113,9 +113,9 @@ class ModelHubMixin:
113
113
  E.g: "https://coqui.ai/cpml".
114
114
  pipeline_tag (`str`, *optional*):
115
115
  Tag of the pipeline. Used to generate model card. E.g. "text-classification".
116
- tags (`List[str]`, *optional*):
116
+ tags (`list[str]`, *optional*):
117
117
  Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
118
- coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
118
+ coders (`dict[Type, tuple[Callable, Callable]]`, *optional*):
119
119
  Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
120
120
  jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
121
121
 
@@ -145,12 +145,10 @@ class ModelHubMixin:
145
145
  ...
146
146
  ... @classmethod
147
147
  ... def from_pretrained(
148
- ... cls: Type[T],
148
+ ... cls: type[T],
149
149
  ... pretrained_model_name_or_path: Union[str, Path],
150
150
  ... *,
151
151
  ... force_download: bool = False,
152
- ... resume_download: Optional[bool] = None,
153
- ... proxies: Optional[Dict] = None,
154
152
  ... token: Optional[Union[str, bool]] = None,
155
153
  ... cache_dir: Optional[Union[str, Path]] = None,
156
154
  ... local_files_only: bool = False,
@@ -188,10 +186,10 @@ class ModelHubMixin:
188
186
  _hub_mixin_info: MixinInfo
189
187
  # ^ information about the library integrating ModelHubMixin (used to generate model card)
190
188
  _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
191
- _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters
192
- _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters
193
- _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded
194
- _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types
189
+ _hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters
190
+ _hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters
191
+ _hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded
192
+ _hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types
195
193
  # ^ internal values to handle config
196
194
 
197
195
  def __init_subclass__(
@@ -204,16 +202,16 @@ class ModelHubMixin:
204
202
  # Model card template
205
203
  model_card_template: str = DEFAULT_MODEL_CARD,
206
204
  # Model card metadata
207
- language: Optional[List[str]] = None,
205
+ language: Optional[list[str]] = None,
208
206
  library_name: Optional[str] = None,
209
207
  license: Optional[str] = None,
210
208
  license_name: Optional[str] = None,
211
209
  license_link: Optional[str] = None,
212
210
  pipeline_tag: Optional[str] = None,
213
- tags: Optional[List[str]] = None,
211
+ tags: Optional[list[str]] = None,
214
212
  # How to encode/decode arguments with custom type into a JSON config?
215
213
  coders: Optional[
216
- Dict[Type, CODER_T]
214
+ dict[Type, CODER_T]
217
215
  # Key is a type.
218
216
  # Value is a tuple (encoder, decoder).
219
217
  # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
@@ -266,12 +264,14 @@ class ModelHubMixin:
266
264
  if pipeline_tag is not None:
267
265
  info.model_card_data.pipeline_tag = pipeline_tag
268
266
  if tags is not None:
267
+ normalized_tags = list(tags)
269
268
  if info.model_card_data.tags is not None:
270
- info.model_card_data.tags.extend(tags)
269
+ info.model_card_data.tags.extend(normalized_tags)
271
270
  else:
272
- info.model_card_data.tags = tags
271
+ info.model_card_data.tags = normalized_tags
273
272
 
274
- info.model_card_data.tags = sorted(set(info.model_card_data.tags))
273
+ if info.model_card_data.tags is not None:
274
+ info.model_card_data.tags = sorted(set(info.model_card_data.tags))
275
275
 
276
276
  # Handle encoders/decoders for args
277
277
  cls._hub_mixin_coders = coders or {}
@@ -286,7 +286,7 @@ class ModelHubMixin:
286
286
  }
287
287
  cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
288
288
 
289
- def __new__(cls: Type[T], *args, **kwargs) -> T:
289
+ def __new__(cls: type[T], *args, **kwargs) -> T:
290
290
  """Create a new instance of the class and handle config.
291
291
 
292
292
  3 cases:
@@ -362,7 +362,7 @@ class ModelHubMixin:
362
362
  return arg
363
363
 
364
364
  @classmethod
365
- def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
365
+ def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]:
366
366
  """Decode a JSON serializable value into an argument."""
367
367
  if is_simple_optional_type(expected_type):
368
368
  if value is None:
@@ -385,7 +385,7 @@ class ModelHubMixin:
385
385
  config: Optional[Union[dict, DataclassInstance]] = None,
386
386
  repo_id: Optional[str] = None,
387
387
  push_to_hub: bool = False,
388
- model_card_kwargs: Optional[Dict[str, Any]] = None,
388
+ model_card_kwargs: Optional[dict[str, Any]] = None,
389
389
  **push_to_hub_kwargs,
390
390
  ) -> Optional[str]:
391
391
  """
@@ -401,7 +401,7 @@ class ModelHubMixin:
401
401
  repo_id (`str`, *optional*):
402
402
  ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
403
403
  not provided.
404
- model_card_kwargs (`Dict[str, Any]`, *optional*):
404
+ model_card_kwargs (`dict[str, Any]`, *optional*):
405
405
  Additional arguments passed to the model card template to customize the model card.
406
406
  push_to_hub_kwargs:
407
407
  Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
@@ -460,12 +460,10 @@ class ModelHubMixin:
460
460
  @classmethod
461
461
  @validate_hf_hub_args
462
462
  def from_pretrained(
463
- cls: Type[T],
463
+ cls: type[T],
464
464
  pretrained_model_name_or_path: Union[str, Path],
465
465
  *,
466
466
  force_download: bool = False,
467
- resume_download: Optional[bool] = None,
468
- proxies: Optional[Dict] = None,
469
467
  token: Optional[Union[str, bool]] = None,
470
468
  cache_dir: Optional[Union[str, Path]] = None,
471
469
  local_files_only: bool = False,
@@ -486,9 +484,6 @@ class ModelHubMixin:
486
484
  force_download (`bool`, *optional*, defaults to `False`):
487
485
  Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
488
486
  the existing cache.
489
- proxies (`Dict[str, str]`, *optional*):
490
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
491
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
492
487
  token (`str` or `bool`, *optional*):
493
488
  The token to use as HTTP bearer authorization for remote files. By default, it will use the token
494
489
  cached when running `hf auth login`.
@@ -496,7 +491,7 @@ class ModelHubMixin:
496
491
  Path to the folder where cached files are stored.
497
492
  local_files_only (`bool`, *optional*, defaults to `False`):
498
493
  If `True`, avoid downloading the file and return the path to the local cached file if it exists.
499
- model_kwargs (`Dict`, *optional*):
494
+ model_kwargs (`dict`, *optional*):
500
495
  Additional kwargs to pass to the model during initialization.
501
496
  """
502
497
  model_id = str(pretrained_model_name_or_path)
@@ -514,8 +509,6 @@ class ModelHubMixin:
514
509
  revision=revision,
515
510
  cache_dir=cache_dir,
516
511
  force_download=force_download,
517
- proxies=proxies,
518
- resume_download=resume_download,
519
512
  token=token,
520
513
  local_files_only=local_files_only,
521
514
  )
@@ -555,7 +548,7 @@ class ModelHubMixin:
555
548
  if key not in model_kwargs and key in config:
556
549
  model_kwargs[key] = config[key]
557
550
  elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
558
- for key, value in config.items():
551
+ for key, value in config.items(): # type: ignore[union-attr]
559
552
  if key not in model_kwargs:
560
553
  model_kwargs[key] = value
561
554
 
@@ -568,8 +561,6 @@ class ModelHubMixin:
568
561
  revision=revision,
569
562
  cache_dir=cache_dir,
570
563
  force_download=force_download,
571
- proxies=proxies,
572
- resume_download=resume_download,
573
564
  local_files_only=local_files_only,
574
565
  token=token,
575
566
  **model_kwargs,
@@ -584,14 +575,12 @@ class ModelHubMixin:
584
575
 
585
576
  @classmethod
586
577
  def _from_pretrained(
587
- cls: Type[T],
578
+ cls: type[T],
588
579
  *,
589
580
  model_id: str,
590
581
  revision: Optional[str],
591
582
  cache_dir: Optional[Union[str, Path]],
592
583
  force_download: bool,
593
- proxies: Optional[Dict],
594
- resume_download: Optional[bool],
595
584
  local_files_only: bool,
596
585
  token: Optional[Union[str, bool]],
597
586
  **model_kwargs,
@@ -614,9 +603,6 @@ class ModelHubMixin:
614
603
  force_download (`bool`, *optional*, defaults to `False`):
615
604
  Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
616
605
  the existing cache.
617
- proxies (`Dict[str, str]`, *optional*):
618
- A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
619
- 'http://hostname': 'foo.bar:4012'}`).
620
606
  token (`str` or `bool`, *optional*):
621
607
  The token to use as HTTP bearer authorization for remote files. By default, it will use the token
622
608
  cached when running `hf auth login`.
@@ -640,10 +626,10 @@ class ModelHubMixin:
640
626
  token: Optional[str] = None,
641
627
  branch: Optional[str] = None,
642
628
  create_pr: Optional[bool] = None,
643
- allow_patterns: Optional[Union[List[str], str]] = None,
644
- ignore_patterns: Optional[Union[List[str], str]] = None,
645
- delete_patterns: Optional[Union[List[str], str]] = None,
646
- model_card_kwargs: Optional[Dict[str, Any]] = None,
629
+ allow_patterns: Optional[Union[list[str], str]] = None,
630
+ ignore_patterns: Optional[Union[list[str], str]] = None,
631
+ delete_patterns: Optional[Union[list[str], str]] = None,
632
+ model_card_kwargs: Optional[dict[str, Any]] = None,
647
633
  ) -> str:
648
634
  """
649
635
  Upload model checkpoint to the Hub.
@@ -669,13 +655,13 @@ class ModelHubMixin:
669
655
  The git branch on which to push the model. This defaults to `"main"`.
670
656
  create_pr (`boolean`, *optional*):
671
657
  Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
672
- allow_patterns (`List[str]` or `str`, *optional*):
658
+ allow_patterns (`list[str]` or `str`, *optional*):
673
659
  If provided, only files matching at least one pattern are pushed.
674
- ignore_patterns (`List[str]` or `str`, *optional*):
660
+ ignore_patterns (`list[str]` or `str`, *optional*):
675
661
  If provided, files matching any of the patterns are not pushed.
676
- delete_patterns (`List[str]` or `str`, *optional*):
662
+ delete_patterns (`list[str]` or `str`, *optional*):
677
663
  If provided, remote files matching any of the patterns will be deleted from the repo.
678
- model_card_kwargs (`Dict[str, Any]`, *optional*):
664
+ model_card_kwargs (`dict[str, Any]`, *optional*):
679
665
  Additional arguments passed to the model card template to customize the model card.
680
666
 
681
667
  Returns:
@@ -758,7 +744,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
758
744
  ```
759
745
  """
760
746
 
761
- def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
747
+ def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None:
762
748
  tags = tags or []
763
749
  tags.append("pytorch_model_hub_mixin")
764
750
  kwargs["tags"] = tags
@@ -777,8 +763,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
777
763
  revision: Optional[str],
778
764
  cache_dir: Optional[Union[str, Path]],
779
765
  force_download: bool,
780
- proxies: Optional[Dict],
781
- resume_download: Optional[bool],
782
766
  local_files_only: bool,
783
767
  token: Union[str, bool, None],
784
768
  map_location: str = "cpu",
@@ -799,8 +783,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
799
783
  revision=revision,
800
784
  cache_dir=cache_dir,
801
785
  force_download=force_download,
802
- proxies=proxies,
803
- resume_download=resume_download,
804
786
  token=token,
805
787
  local_files_only=local_files_only,
806
788
  )
@@ -812,8 +794,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
812
794
  revision=revision,
813
795
  cache_dir=cache_dir,
814
796
  force_download=force_download,
815
- proxies=proxies,
816
- resume_download=resume_download,
817
797
  token=token,
818
798
  local_files_only=local_files_only,
819
799
  )
@@ -843,7 +823,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
843
823
  return model
844
824
 
845
825
 
846
- def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
826
+ def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance:
847
827
  """Load a dataclass instance from a dictionary.
848
828
 
849
829
  Fields not expected by the dataclass are ignored.