huggingface-hub 0.31.0rc0__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +145 -46
- huggingface_hub/_commit_api.py +168 -119
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +15 -12
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +239 -80
- huggingface_hub/_space_api.py +5 -5
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +172 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +13 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +38 -53
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +80 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +435 -351
- huggingface_hub/hf_api.py +2050 -1124
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +43 -63
- huggingface_hub/inference/_client.py +347 -434
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +397 -541
- huggingface_hub/inference/_generated/types/__init__.py +5 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +3 -3
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +59 -23
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +6 -2
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +10 -10
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +82 -7
- huggingface_hub/inference/_providers/_common.py +129 -27
- huggingface_hub/inference/_providers/black_forest_labs.py +6 -6
- huggingface_hub/inference/_providers/cerebras.py +1 -1
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +20 -3
- huggingface_hub/inference/_providers/fal_ai.py +183 -56
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +18 -0
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +69 -30
- huggingface_hub/inference/_providers/hyperbolic.py +4 -4
- huggingface_hub/inference/_providers/nebius.py +33 -5
- huggingface_hub/inference/_providers/novita.py +5 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +3 -1
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +31 -13
- huggingface_hub/inference/_providers/sambanova.py +18 -4
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +20 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +57 -57
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +19 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +398 -239
- huggingface_hub/utils/_pagination.py +4 -4
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +61 -24
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +9 -9
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +64 -17
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +5 -4
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -85
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -474
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -314
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.31.0rc0.dist-info/RECORD +0 -135
- huggingface_hub-0.31.0rc0.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.31.0rc0.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import Field, asdict, dataclass, is_dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, ClassVar,
|
|
6
|
+
from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union
|
|
7
7
|
|
|
8
8
|
import packaging.version
|
|
9
9
|
|
|
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
|
|
|
38
38
|
|
|
39
39
|
# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
|
|
40
40
|
class DataclassInstance(Protocol):
|
|
41
|
-
__dataclass_fields__: ClassVar[
|
|
41
|
+
__dataclass_fields__: ClassVar[dict[str, Field]]
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
@@ -47,7 +47,7 @@ T = TypeVar("T", bound="ModelHubMixin")
|
|
|
47
47
|
ARGS_T = TypeVar("ARGS_T")
|
|
48
48
|
ENCODER_T = Callable[[ARGS_T], Any]
|
|
49
49
|
DECODER_T = Callable[[Any], ARGS_T]
|
|
50
|
-
CODER_T =
|
|
50
|
+
CODER_T = tuple[ENCODER_T, DECODER_T]
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
DEFAULT_MODEL_CARD = """
|
|
@@ -96,7 +96,7 @@ class ModelHubMixin:
|
|
|
96
96
|
URL of the library documentation. Used to generate model card.
|
|
97
97
|
model_card_template (`str`, *optional*):
|
|
98
98
|
Template of the model card. Used to generate model card. Defaults to a generic template.
|
|
99
|
-
language (`str` or `
|
|
99
|
+
language (`str` or `list[str]`, *optional*):
|
|
100
100
|
Language supported by the library. Used to generate model card.
|
|
101
101
|
library_name (`str`, *optional*):
|
|
102
102
|
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
@@ -113,11 +113,11 @@ class ModelHubMixin:
|
|
|
113
113
|
E.g: "https://coqui.ai/cpml".
|
|
114
114
|
pipeline_tag (`str`, *optional*):
|
|
115
115
|
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
|
116
|
-
tags (`
|
|
116
|
+
tags (`list[str]`, *optional*):
|
|
117
117
|
Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
|
|
118
|
-
coders (`
|
|
118
|
+
coders (`dict[Type, tuple[Callable, Callable]]`, *optional*):
|
|
119
119
|
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
|
120
|
-
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
120
|
+
jsonable by default. E.g. dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
121
121
|
|
|
122
122
|
Example:
|
|
123
123
|
|
|
@@ -145,12 +145,10 @@ class ModelHubMixin:
|
|
|
145
145
|
...
|
|
146
146
|
... @classmethod
|
|
147
147
|
... def from_pretrained(
|
|
148
|
-
... cls:
|
|
148
|
+
... cls: type[T],
|
|
149
149
|
... pretrained_model_name_or_path: Union[str, Path],
|
|
150
150
|
... *,
|
|
151
151
|
... force_download: bool = False,
|
|
152
|
-
... resume_download: Optional[bool] = None,
|
|
153
|
-
... proxies: Optional[Dict] = None,
|
|
154
152
|
... token: Optional[Union[str, bool]] = None,
|
|
155
153
|
... cache_dir: Optional[Union[str, Path]] = None,
|
|
156
154
|
... local_files_only: bool = False,
|
|
@@ -188,10 +186,10 @@ class ModelHubMixin:
|
|
|
188
186
|
_hub_mixin_info: MixinInfo
|
|
189
187
|
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
|
190
188
|
_hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
|
|
191
|
-
_hub_mixin_init_parameters:
|
|
192
|
-
_hub_mixin_jsonable_default_values:
|
|
193
|
-
_hub_mixin_jsonable_custom_types:
|
|
194
|
-
_hub_mixin_coders:
|
|
189
|
+
_hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters
|
|
190
|
+
_hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters
|
|
191
|
+
_hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded
|
|
192
|
+
_hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types
|
|
195
193
|
# ^ internal values to handle config
|
|
196
194
|
|
|
197
195
|
def __init_subclass__(
|
|
@@ -204,16 +202,16 @@ class ModelHubMixin:
|
|
|
204
202
|
# Model card template
|
|
205
203
|
model_card_template: str = DEFAULT_MODEL_CARD,
|
|
206
204
|
# Model card metadata
|
|
207
|
-
language: Optional[
|
|
205
|
+
language: Optional[list[str]] = None,
|
|
208
206
|
library_name: Optional[str] = None,
|
|
209
207
|
license: Optional[str] = None,
|
|
210
208
|
license_name: Optional[str] = None,
|
|
211
209
|
license_link: Optional[str] = None,
|
|
212
210
|
pipeline_tag: Optional[str] = None,
|
|
213
|
-
tags: Optional[
|
|
211
|
+
tags: Optional[list[str]] = None,
|
|
214
212
|
# How to encode/decode arguments with custom type into a JSON config?
|
|
215
213
|
coders: Optional[
|
|
216
|
-
|
|
214
|
+
dict[Type, CODER_T]
|
|
217
215
|
# Key is a type.
|
|
218
216
|
# Value is a tuple (encoder, decoder).
|
|
219
217
|
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
|
|
@@ -266,12 +264,14 @@ class ModelHubMixin:
|
|
|
266
264
|
if pipeline_tag is not None:
|
|
267
265
|
info.model_card_data.pipeline_tag = pipeline_tag
|
|
268
266
|
if tags is not None:
|
|
267
|
+
normalized_tags = list(tags)
|
|
269
268
|
if info.model_card_data.tags is not None:
|
|
270
|
-
info.model_card_data.tags.extend(
|
|
269
|
+
info.model_card_data.tags.extend(normalized_tags)
|
|
271
270
|
else:
|
|
272
|
-
info.model_card_data.tags =
|
|
271
|
+
info.model_card_data.tags = normalized_tags
|
|
273
272
|
|
|
274
|
-
info.model_card_data.tags
|
|
273
|
+
if info.model_card_data.tags is not None:
|
|
274
|
+
info.model_card_data.tags = sorted(set(info.model_card_data.tags))
|
|
275
275
|
|
|
276
276
|
# Handle encoders/decoders for args
|
|
277
277
|
cls._hub_mixin_coders = coders or {}
|
|
@@ -286,7 +286,7 @@ class ModelHubMixin:
|
|
|
286
286
|
}
|
|
287
287
|
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
|
288
288
|
|
|
289
|
-
def __new__(cls:
|
|
289
|
+
def __new__(cls: type[T], *args, **kwargs) -> T:
|
|
290
290
|
"""Create a new instance of the class and handle config.
|
|
291
291
|
|
|
292
292
|
3 cases:
|
|
@@ -353,7 +353,7 @@ class ModelHubMixin:
|
|
|
353
353
|
def _encode_arg(cls, arg: Any) -> Any:
|
|
354
354
|
"""Encode an argument into a JSON serializable format."""
|
|
355
355
|
if is_dataclass(arg):
|
|
356
|
-
return asdict(arg)
|
|
356
|
+
return asdict(arg) # type: ignore[arg-type]
|
|
357
357
|
for type_, (encoder, _) in cls._hub_mixin_coders.items():
|
|
358
358
|
if isinstance(arg, type_):
|
|
359
359
|
if arg is None:
|
|
@@ -362,7 +362,7 @@ class ModelHubMixin:
|
|
|
362
362
|
return arg
|
|
363
363
|
|
|
364
364
|
@classmethod
|
|
365
|
-
def _decode_arg(cls, expected_type:
|
|
365
|
+
def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]:
|
|
366
366
|
"""Decode a JSON serializable value into an argument."""
|
|
367
367
|
if is_simple_optional_type(expected_type):
|
|
368
368
|
if value is None:
|
|
@@ -385,7 +385,7 @@ class ModelHubMixin:
|
|
|
385
385
|
config: Optional[Union[dict, DataclassInstance]] = None,
|
|
386
386
|
repo_id: Optional[str] = None,
|
|
387
387
|
push_to_hub: bool = False,
|
|
388
|
-
model_card_kwargs: Optional[
|
|
388
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
389
389
|
**push_to_hub_kwargs,
|
|
390
390
|
) -> Optional[str]:
|
|
391
391
|
"""
|
|
@@ -401,7 +401,7 @@ class ModelHubMixin:
|
|
|
401
401
|
repo_id (`str`, *optional*):
|
|
402
402
|
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
|
403
403
|
not provided.
|
|
404
|
-
model_card_kwargs (`
|
|
404
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
405
405
|
Additional arguments passed to the model card template to customize the model card.
|
|
406
406
|
push_to_hub_kwargs:
|
|
407
407
|
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
|
@@ -460,12 +460,10 @@ class ModelHubMixin:
|
|
|
460
460
|
@classmethod
|
|
461
461
|
@validate_hf_hub_args
|
|
462
462
|
def from_pretrained(
|
|
463
|
-
cls:
|
|
463
|
+
cls: type[T],
|
|
464
464
|
pretrained_model_name_or_path: Union[str, Path],
|
|
465
465
|
*,
|
|
466
466
|
force_download: bool = False,
|
|
467
|
-
resume_download: Optional[bool] = None,
|
|
468
|
-
proxies: Optional[Dict] = None,
|
|
469
467
|
token: Optional[Union[str, bool]] = None,
|
|
470
468
|
cache_dir: Optional[Union[str, Path]] = None,
|
|
471
469
|
local_files_only: bool = False,
|
|
@@ -486,17 +484,14 @@ class ModelHubMixin:
|
|
|
486
484
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
487
485
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
488
486
|
the existing cache.
|
|
489
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
490
|
-
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
491
|
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
|
492
487
|
token (`str` or `bool`, *optional*):
|
|
493
488
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
494
|
-
cached when running `
|
|
489
|
+
cached when running `hf auth login`.
|
|
495
490
|
cache_dir (`str`, `Path`, *optional*):
|
|
496
491
|
Path to the folder where cached files are stored.
|
|
497
492
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
498
493
|
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
|
499
|
-
model_kwargs (`
|
|
494
|
+
model_kwargs (`dict`, *optional*):
|
|
500
495
|
Additional kwargs to pass to the model during initialization.
|
|
501
496
|
"""
|
|
502
497
|
model_id = str(pretrained_model_name_or_path)
|
|
@@ -514,8 +509,6 @@ class ModelHubMixin:
|
|
|
514
509
|
revision=revision,
|
|
515
510
|
cache_dir=cache_dir,
|
|
516
511
|
force_download=force_download,
|
|
517
|
-
proxies=proxies,
|
|
518
|
-
resume_download=resume_download,
|
|
519
512
|
token=token,
|
|
520
513
|
local_files_only=local_files_only,
|
|
521
514
|
)
|
|
@@ -555,7 +548,7 @@ class ModelHubMixin:
|
|
|
555
548
|
if key not in model_kwargs and key in config:
|
|
556
549
|
model_kwargs[key] = config[key]
|
|
557
550
|
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
|
|
558
|
-
for key, value in config.items():
|
|
551
|
+
for key, value in config.items(): # type: ignore[union-attr]
|
|
559
552
|
if key not in model_kwargs:
|
|
560
553
|
model_kwargs[key] = value
|
|
561
554
|
|
|
@@ -568,8 +561,6 @@ class ModelHubMixin:
|
|
|
568
561
|
revision=revision,
|
|
569
562
|
cache_dir=cache_dir,
|
|
570
563
|
force_download=force_download,
|
|
571
|
-
proxies=proxies,
|
|
572
|
-
resume_download=resume_download,
|
|
573
564
|
local_files_only=local_files_only,
|
|
574
565
|
token=token,
|
|
575
566
|
**model_kwargs,
|
|
@@ -584,14 +575,12 @@ class ModelHubMixin:
|
|
|
584
575
|
|
|
585
576
|
@classmethod
|
|
586
577
|
def _from_pretrained(
|
|
587
|
-
cls:
|
|
578
|
+
cls: type[T],
|
|
588
579
|
*,
|
|
589
580
|
model_id: str,
|
|
590
581
|
revision: Optional[str],
|
|
591
582
|
cache_dir: Optional[Union[str, Path]],
|
|
592
583
|
force_download: bool,
|
|
593
|
-
proxies: Optional[Dict],
|
|
594
|
-
resume_download: Optional[bool],
|
|
595
584
|
local_files_only: bool,
|
|
596
585
|
token: Optional[Union[str, bool]],
|
|
597
586
|
**model_kwargs,
|
|
@@ -614,12 +603,9 @@ class ModelHubMixin:
|
|
|
614
603
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
615
604
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
616
605
|
the existing cache.
|
|
617
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
618
|
-
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
|
|
619
|
-
'http://hostname': 'foo.bar:4012'}`).
|
|
620
606
|
token (`str` or `bool`, *optional*):
|
|
621
607
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
622
|
-
cached when running `
|
|
608
|
+
cached when running `hf auth login`.
|
|
623
609
|
cache_dir (`str`, `Path`, *optional*):
|
|
624
610
|
Path to the folder where cached files are stored.
|
|
625
611
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
@@ -640,10 +626,10 @@ class ModelHubMixin:
|
|
|
640
626
|
token: Optional[str] = None,
|
|
641
627
|
branch: Optional[str] = None,
|
|
642
628
|
create_pr: Optional[bool] = None,
|
|
643
|
-
allow_patterns: Optional[Union[
|
|
644
|
-
ignore_patterns: Optional[Union[
|
|
645
|
-
delete_patterns: Optional[Union[
|
|
646
|
-
model_card_kwargs: Optional[
|
|
629
|
+
allow_patterns: Optional[Union[list[str], str]] = None,
|
|
630
|
+
ignore_patterns: Optional[Union[list[str], str]] = None,
|
|
631
|
+
delete_patterns: Optional[Union[list[str], str]] = None,
|
|
632
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
647
633
|
) -> str:
|
|
648
634
|
"""
|
|
649
635
|
Upload model checkpoint to the Hub.
|
|
@@ -664,18 +650,18 @@ class ModelHubMixin:
|
|
|
664
650
|
If `None` (default), the repo will be public unless the organization's default is private.
|
|
665
651
|
token (`str`, *optional*):
|
|
666
652
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
667
|
-
cached when running `
|
|
653
|
+
cached when running `hf auth login`.
|
|
668
654
|
branch (`str`, *optional*):
|
|
669
655
|
The git branch on which to push the model. This defaults to `"main"`.
|
|
670
656
|
create_pr (`boolean`, *optional*):
|
|
671
657
|
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
|
672
|
-
allow_patterns (`
|
|
658
|
+
allow_patterns (`list[str]` or `str`, *optional*):
|
|
673
659
|
If provided, only files matching at least one pattern are pushed.
|
|
674
|
-
ignore_patterns (`
|
|
660
|
+
ignore_patterns (`list[str]` or `str`, *optional*):
|
|
675
661
|
If provided, files matching any of the patterns are not pushed.
|
|
676
|
-
delete_patterns (`
|
|
662
|
+
delete_patterns (`list[str]` or `str`, *optional*):
|
|
677
663
|
If provided, remote files matching any of the patterns will be deleted from the repo.
|
|
678
|
-
model_card_kwargs (`
|
|
664
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
679
665
|
Additional arguments passed to the model card template to customize the model card.
|
|
680
666
|
|
|
681
667
|
Returns:
|
|
@@ -758,7 +744,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
758
744
|
```
|
|
759
745
|
"""
|
|
760
746
|
|
|
761
|
-
def __init_subclass__(cls, *args, tags: Optional[
|
|
747
|
+
def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None:
|
|
762
748
|
tags = tags or []
|
|
763
749
|
tags.append("pytorch_model_hub_mixin")
|
|
764
750
|
kwargs["tags"] = tags
|
|
@@ -767,7 +753,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
767
753
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
768
754
|
"""Save weights from a Pytorch model to a local directory."""
|
|
769
755
|
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
|
770
|
-
save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE))
|
|
756
|
+
save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) # type: ignore [arg-type]
|
|
771
757
|
|
|
772
758
|
@classmethod
|
|
773
759
|
def _from_pretrained(
|
|
@@ -777,8 +763,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
777
763
|
revision: Optional[str],
|
|
778
764
|
cache_dir: Optional[Union[str, Path]],
|
|
779
765
|
force_download: bool,
|
|
780
|
-
proxies: Optional[Dict],
|
|
781
|
-
resume_download: Optional[bool],
|
|
782
766
|
local_files_only: bool,
|
|
783
767
|
token: Union[str, bool, None],
|
|
784
768
|
map_location: str = "cpu",
|
|
@@ -799,8 +783,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
799
783
|
revision=revision,
|
|
800
784
|
cache_dir=cache_dir,
|
|
801
785
|
force_download=force_download,
|
|
802
|
-
proxies=proxies,
|
|
803
|
-
resume_download=resume_download,
|
|
804
786
|
token=token,
|
|
805
787
|
local_files_only=local_files_only,
|
|
806
788
|
)
|
|
@@ -812,8 +794,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
812
794
|
revision=revision,
|
|
813
795
|
cache_dir=cache_dir,
|
|
814
796
|
force_download=force_download,
|
|
815
|
-
proxies=proxies,
|
|
816
|
-
resume_download=resume_download,
|
|
817
797
|
token=token,
|
|
818
798
|
local_files_only=local_files_only,
|
|
819
799
|
)
|
|
@@ -843,7 +823,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
843
823
|
return model
|
|
844
824
|
|
|
845
825
|
|
|
846
|
-
def _load_dataclass(datacls:
|
|
826
|
+
def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance:
|
|
847
827
|
"""Load a dataclass instance from a dictionary.
|
|
848
828
|
|
|
849
829
|
Fields not expected by the dataclass are ignored.
|