huggingface-hub 0.29.0rc2__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- huggingface_hub/__init__.py +160 -46
- huggingface_hub/_commit_api.py +277 -71
- huggingface_hub/_commit_scheduler.py +15 -15
- huggingface_hub/_inference_endpoints.py +33 -22
- huggingface_hub/_jobs_api.py +301 -0
- huggingface_hub/_local_folder.py +18 -3
- huggingface_hub/_login.py +31 -63
- huggingface_hub/_oauth.py +460 -0
- huggingface_hub/_snapshot_download.py +241 -81
- huggingface_hub/_space_api.py +18 -10
- huggingface_hub/_tensorboard_logger.py +15 -19
- huggingface_hub/_upload_large_folder.py +196 -76
- huggingface_hub/_webhooks_payload.py +3 -3
- huggingface_hub/_webhooks_server.py +15 -25
- huggingface_hub/{commands → cli}/__init__.py +1 -15
- huggingface_hub/cli/_cli_utils.py +173 -0
- huggingface_hub/cli/auth.py +147 -0
- huggingface_hub/cli/cache.py +841 -0
- huggingface_hub/cli/download.py +189 -0
- huggingface_hub/cli/hf.py +60 -0
- huggingface_hub/cli/inference_endpoints.py +377 -0
- huggingface_hub/cli/jobs.py +772 -0
- huggingface_hub/cli/lfs.py +175 -0
- huggingface_hub/cli/repo.py +315 -0
- huggingface_hub/cli/repo_files.py +94 -0
- huggingface_hub/{commands/env.py → cli/system.py} +10 -13
- huggingface_hub/cli/upload.py +294 -0
- huggingface_hub/cli/upload_large_folder.py +117 -0
- huggingface_hub/community.py +20 -12
- huggingface_hub/constants.py +83 -59
- huggingface_hub/dataclasses.py +609 -0
- huggingface_hub/errors.py +99 -30
- huggingface_hub/fastai_utils.py +30 -41
- huggingface_hub/file_download.py +606 -346
- huggingface_hub/hf_api.py +2445 -1132
- huggingface_hub/hf_file_system.py +269 -152
- huggingface_hub/hub_mixin.py +61 -66
- huggingface_hub/inference/_client.py +501 -630
- huggingface_hub/inference/_common.py +133 -121
- huggingface_hub/inference/_generated/_async_client.py +536 -722
- huggingface_hub/inference/_generated/types/__init__.py +6 -1
- huggingface_hub/inference/_generated/types/automatic_speech_recognition.py +5 -6
- huggingface_hub/inference/_generated/types/base.py +10 -7
- huggingface_hub/inference/_generated/types/chat_completion.py +77 -31
- huggingface_hub/inference/_generated/types/depth_estimation.py +2 -2
- huggingface_hub/inference/_generated/types/document_question_answering.py +2 -2
- huggingface_hub/inference/_generated/types/feature_extraction.py +2 -2
- huggingface_hub/inference/_generated/types/fill_mask.py +2 -2
- huggingface_hub/inference/_generated/types/image_to_image.py +8 -2
- huggingface_hub/inference/_generated/types/image_to_text.py +2 -3
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_generated/types/sentence_similarity.py +3 -3
- huggingface_hub/inference/_generated/types/summarization.py +2 -2
- huggingface_hub/inference/_generated/types/table_question_answering.py +5 -5
- huggingface_hub/inference/_generated/types/text2text_generation.py +2 -2
- huggingface_hub/inference/_generated/types/text_generation.py +11 -11
- huggingface_hub/inference/_generated/types/text_to_audio.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_speech.py +1 -2
- huggingface_hub/inference/_generated/types/text_to_video.py +2 -2
- huggingface_hub/inference/_generated/types/token_classification.py +2 -2
- huggingface_hub/inference/_generated/types/translation.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_image_classification.py +2 -2
- huggingface_hub/inference/_generated/types/zero_shot_object_detection.py +1 -3
- huggingface_hub/inference/_mcp/__init__.py +0 -0
- huggingface_hub/inference/_mcp/_cli_hacks.py +88 -0
- huggingface_hub/inference/_mcp/agent.py +100 -0
- huggingface_hub/inference/_mcp/cli.py +247 -0
- huggingface_hub/inference/_mcp/constants.py +81 -0
- huggingface_hub/inference/_mcp/mcp_client.py +395 -0
- huggingface_hub/inference/_mcp/types.py +45 -0
- huggingface_hub/inference/_mcp/utils.py +128 -0
- huggingface_hub/inference/_providers/__init__.py +149 -20
- huggingface_hub/inference/_providers/_common.py +160 -37
- huggingface_hub/inference/_providers/black_forest_labs.py +12 -9
- huggingface_hub/inference/_providers/cerebras.py +6 -0
- huggingface_hub/inference/_providers/clarifai.py +13 -0
- huggingface_hub/inference/_providers/cohere.py +32 -0
- huggingface_hub/inference/_providers/fal_ai.py +231 -22
- huggingface_hub/inference/_providers/featherless_ai.py +38 -0
- huggingface_hub/inference/_providers/fireworks_ai.py +22 -1
- huggingface_hub/inference/_providers/groq.py +9 -0
- huggingface_hub/inference/_providers/hf_inference.py +143 -33
- huggingface_hub/inference/_providers/hyperbolic.py +9 -5
- huggingface_hub/inference/_providers/nebius.py +47 -5
- huggingface_hub/inference/_providers/novita.py +48 -5
- huggingface_hub/inference/_providers/nscale.py +44 -0
- huggingface_hub/inference/_providers/openai.py +25 -0
- huggingface_hub/inference/_providers/publicai.py +6 -0
- huggingface_hub/inference/_providers/replicate.py +46 -9
- huggingface_hub/inference/_providers/sambanova.py +37 -1
- huggingface_hub/inference/_providers/scaleway.py +28 -0
- huggingface_hub/inference/_providers/together.py +34 -5
- huggingface_hub/inference/_providers/wavespeed.py +138 -0
- huggingface_hub/inference/_providers/zai_org.py +17 -0
- huggingface_hub/lfs.py +33 -100
- huggingface_hub/repocard.py +34 -38
- huggingface_hub/repocard_data.py +79 -59
- huggingface_hub/serialization/__init__.py +0 -1
- huggingface_hub/serialization/_base.py +12 -15
- huggingface_hub/serialization/_dduf.py +8 -8
- huggingface_hub/serialization/_torch.py +69 -69
- huggingface_hub/utils/__init__.py +27 -8
- huggingface_hub/utils/_auth.py +7 -7
- huggingface_hub/utils/_cache_manager.py +92 -147
- huggingface_hub/utils/_chunk_utils.py +2 -3
- huggingface_hub/utils/_deprecation.py +1 -1
- huggingface_hub/utils/_dotenv.py +55 -0
- huggingface_hub/utils/_experimental.py +7 -5
- huggingface_hub/utils/_fixes.py +0 -10
- huggingface_hub/utils/_git_credential.py +5 -5
- huggingface_hub/utils/_headers.py +8 -30
- huggingface_hub/utils/_http.py +399 -237
- huggingface_hub/utils/_pagination.py +6 -6
- huggingface_hub/utils/_parsing.py +98 -0
- huggingface_hub/utils/_paths.py +5 -5
- huggingface_hub/utils/_runtime.py +74 -22
- huggingface_hub/utils/_safetensors.py +21 -21
- huggingface_hub/utils/_subprocess.py +13 -11
- huggingface_hub/utils/_telemetry.py +4 -4
- huggingface_hub/{commands/_cli_utils.py → utils/_terminal.py} +4 -4
- huggingface_hub/utils/_typing.py +25 -5
- huggingface_hub/utils/_validators.py +55 -74
- huggingface_hub/utils/_verification.py +167 -0
- huggingface_hub/utils/_xet.py +235 -0
- huggingface_hub/utils/_xet_progress_reporting.py +162 -0
- huggingface_hub/utils/insecure_hashlib.py +3 -5
- huggingface_hub/utils/logging.py +8 -11
- huggingface_hub/utils/tqdm.py +33 -4
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/METADATA +94 -82
- huggingface_hub-1.1.3.dist-info/RECORD +155 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/WHEEL +1 -1
- huggingface_hub-1.1.3.dist-info/entry_points.txt +6 -0
- huggingface_hub/commands/delete_cache.py +0 -428
- huggingface_hub/commands/download.py +0 -200
- huggingface_hub/commands/huggingface_cli.py +0 -61
- huggingface_hub/commands/lfs.py +0 -200
- huggingface_hub/commands/repo_files.py +0 -128
- huggingface_hub/commands/scan_cache.py +0 -181
- huggingface_hub/commands/tag.py +0 -159
- huggingface_hub/commands/upload.py +0 -299
- huggingface_hub/commands/upload_large_folder.py +0 -129
- huggingface_hub/commands/user.py +0 -304
- huggingface_hub/commands/version.py +0 -37
- huggingface_hub/inference_api.py +0 -217
- huggingface_hub/keras_mixin.py +0 -500
- huggingface_hub/repository.py +0 -1477
- huggingface_hub/serialization/_tensorflow.py +0 -95
- huggingface_hub/utils/_hf_folder.py +0 -68
- huggingface_hub-0.29.0rc2.dist-info/RECORD +0 -131
- huggingface_hub-0.29.0rc2.dist-info/entry_points.txt +0 -6
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info/licenses}/LICENSE +0 -0
- {huggingface_hub-0.29.0rc2.dist-info → huggingface_hub-1.1.3.dist-info}/top_level.txt +0 -0
huggingface_hub/hub_mixin.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import Field, asdict, dataclass, is_dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, ClassVar,
|
|
6
|
+
from typing import Any, Callable, ClassVar, Optional, Protocol, Type, TypeVar, Union
|
|
7
7
|
|
|
8
8
|
import packaging.version
|
|
9
9
|
|
|
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
|
|
|
38
38
|
|
|
39
39
|
# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
|
|
40
40
|
class DataclassInstance(Protocol):
|
|
41
|
-
__dataclass_fields__: ClassVar[
|
|
41
|
+
__dataclass_fields__: ClassVar[dict[str, Field]]
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
# Generic variable that is either ModelHubMixin or a subclass thereof
|
|
@@ -47,7 +47,7 @@ T = TypeVar("T", bound="ModelHubMixin")
|
|
|
47
47
|
ARGS_T = TypeVar("ARGS_T")
|
|
48
48
|
ENCODER_T = Callable[[ARGS_T], Any]
|
|
49
49
|
DECODER_T = Callable[[Any], ARGS_T]
|
|
50
|
-
CODER_T =
|
|
50
|
+
CODER_T = tuple[ENCODER_T, DECODER_T]
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
DEFAULT_MODEL_CARD = """
|
|
@@ -58,7 +58,8 @@ DEFAULT_MODEL_CARD = """
|
|
|
58
58
|
---
|
|
59
59
|
|
|
60
60
|
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
|
|
61
|
-
-
|
|
61
|
+
- Code: {{ repo_url | default("[More Information Needed]", true) }}
|
|
62
|
+
- Paper: {{ paper_url | default("[More Information Needed]", true) }}
|
|
62
63
|
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
|
|
63
64
|
"""
|
|
64
65
|
|
|
@@ -67,8 +68,9 @@ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://h
|
|
|
67
68
|
class MixinInfo:
|
|
68
69
|
model_card_template: str
|
|
69
70
|
model_card_data: ModelCardData
|
|
70
|
-
repo_url: Optional[str] = None
|
|
71
71
|
docs_url: Optional[str] = None
|
|
72
|
+
paper_url: Optional[str] = None
|
|
73
|
+
repo_url: Optional[str] = None
|
|
72
74
|
|
|
73
75
|
|
|
74
76
|
class ModelHubMixin:
|
|
@@ -88,11 +90,13 @@ class ModelHubMixin:
|
|
|
88
90
|
Args:
|
|
89
91
|
repo_url (`str`, *optional*):
|
|
90
92
|
URL of the library repository. Used to generate model card.
|
|
93
|
+
paper_url (`str`, *optional*):
|
|
94
|
+
URL of the library paper. Used to generate model card.
|
|
91
95
|
docs_url (`str`, *optional*):
|
|
92
96
|
URL of the library documentation. Used to generate model card.
|
|
93
97
|
model_card_template (`str`, *optional*):
|
|
94
98
|
Template of the model card. Used to generate model card. Defaults to a generic template.
|
|
95
|
-
language (`str` or `
|
|
99
|
+
language (`str` or `list[str]`, *optional*):
|
|
96
100
|
Language supported by the library. Used to generate model card.
|
|
97
101
|
library_name (`str`, *optional*):
|
|
98
102
|
Name of the library integrating ModelHubMixin. Used to generate model card.
|
|
@@ -109,11 +113,11 @@ class ModelHubMixin:
|
|
|
109
113
|
E.g: "https://coqui.ai/cpml".
|
|
110
114
|
pipeline_tag (`str`, *optional*):
|
|
111
115
|
Tag of the pipeline. Used to generate model card. E.g. "text-classification".
|
|
112
|
-
tags (`
|
|
113
|
-
Tags to be added to the model card. Used to generate model card. E.g. ["
|
|
114
|
-
coders (`
|
|
116
|
+
tags (`list[str]`, *optional*):
|
|
117
|
+
Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
|
|
118
|
+
coders (`dict[Type, tuple[Callable, Callable]]`, *optional*):
|
|
115
119
|
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
|
|
116
|
-
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
120
|
+
jsonable by default. E.g. dataclasses, argparse.Namespace, OmegaConf, etc.
|
|
117
121
|
|
|
118
122
|
Example:
|
|
119
123
|
|
|
@@ -124,8 +128,9 @@ class ModelHubMixin:
|
|
|
124
128
|
>>> class MyCustomModel(
|
|
125
129
|
... ModelHubMixin,
|
|
126
130
|
... library_name="my-library",
|
|
127
|
-
... tags=["
|
|
131
|
+
... tags=["computer-vision"],
|
|
128
132
|
... repo_url="https://github.com/huggingface/my-cool-library",
|
|
133
|
+
... paper_url="https://arxiv.org/abs/2304.12244",
|
|
129
134
|
... docs_url="https://huggingface.co/docs/my-cool-library",
|
|
130
135
|
... # ^ optional metadata to generate model card
|
|
131
136
|
... ):
|
|
@@ -140,12 +145,10 @@ class ModelHubMixin:
|
|
|
140
145
|
...
|
|
141
146
|
... @classmethod
|
|
142
147
|
... def from_pretrained(
|
|
143
|
-
... cls:
|
|
148
|
+
... cls: type[T],
|
|
144
149
|
... pretrained_model_name_or_path: Union[str, Path],
|
|
145
150
|
... *,
|
|
146
151
|
... force_download: bool = False,
|
|
147
|
-
... resume_download: Optional[bool] = None,
|
|
148
|
-
... proxies: Optional[Dict] = None,
|
|
149
152
|
... token: Optional[Union[str, bool]] = None,
|
|
150
153
|
... cache_dir: Optional[Union[str, Path]] = None,
|
|
151
154
|
... local_files_only: bool = False,
|
|
@@ -183,10 +186,10 @@ class ModelHubMixin:
|
|
|
183
186
|
_hub_mixin_info: MixinInfo
|
|
184
187
|
# ^ information about the library integrating ModelHubMixin (used to generate model card)
|
|
185
188
|
_hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
|
|
186
|
-
_hub_mixin_init_parameters:
|
|
187
|
-
_hub_mixin_jsonable_default_values:
|
|
188
|
-
_hub_mixin_jsonable_custom_types:
|
|
189
|
-
_hub_mixin_coders:
|
|
189
|
+
_hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters
|
|
190
|
+
_hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters
|
|
191
|
+
_hub_mixin_jsonable_custom_types: tuple[Type, ...] # custom types that can be encoded/decoded
|
|
192
|
+
_hub_mixin_coders: dict[Type, CODER_T] # encoders/decoders for custom types
|
|
190
193
|
# ^ internal values to handle config
|
|
191
194
|
|
|
192
195
|
def __init_subclass__(
|
|
@@ -194,20 +197,21 @@ class ModelHubMixin:
|
|
|
194
197
|
*,
|
|
195
198
|
# Generic info for model card
|
|
196
199
|
repo_url: Optional[str] = None,
|
|
200
|
+
paper_url: Optional[str] = None,
|
|
197
201
|
docs_url: Optional[str] = None,
|
|
198
202
|
# Model card template
|
|
199
203
|
model_card_template: str = DEFAULT_MODEL_CARD,
|
|
200
204
|
# Model card metadata
|
|
201
|
-
language: Optional[
|
|
205
|
+
language: Optional[list[str]] = None,
|
|
202
206
|
library_name: Optional[str] = None,
|
|
203
207
|
license: Optional[str] = None,
|
|
204
208
|
license_name: Optional[str] = None,
|
|
205
209
|
license_link: Optional[str] = None,
|
|
206
210
|
pipeline_tag: Optional[str] = None,
|
|
207
|
-
tags: Optional[
|
|
211
|
+
tags: Optional[list[str]] = None,
|
|
208
212
|
# How to encode/decode arguments with custom type into a JSON config?
|
|
209
213
|
coders: Optional[
|
|
210
|
-
|
|
214
|
+
dict[Type, CODER_T]
|
|
211
215
|
# Key is a type.
|
|
212
216
|
# Value is a tuple (encoder, decoder).
|
|
213
217
|
# Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
|
|
@@ -234,6 +238,7 @@ class ModelHubMixin:
|
|
|
234
238
|
|
|
235
239
|
# Inherit other info
|
|
236
240
|
info.docs_url = cls._hub_mixin_info.docs_url
|
|
241
|
+
info.paper_url = cls._hub_mixin_info.paper_url
|
|
237
242
|
info.repo_url = cls._hub_mixin_info.repo_url
|
|
238
243
|
cls._hub_mixin_info = info
|
|
239
244
|
|
|
@@ -242,6 +247,8 @@ class ModelHubMixin:
|
|
|
242
247
|
info.model_card_template = model_card_template
|
|
243
248
|
if repo_url is not None:
|
|
244
249
|
info.repo_url = repo_url
|
|
250
|
+
if paper_url is not None:
|
|
251
|
+
info.paper_url = paper_url
|
|
245
252
|
if docs_url is not None:
|
|
246
253
|
info.docs_url = docs_url
|
|
247
254
|
if language is not None:
|
|
@@ -257,12 +264,14 @@ class ModelHubMixin:
|
|
|
257
264
|
if pipeline_tag is not None:
|
|
258
265
|
info.model_card_data.pipeline_tag = pipeline_tag
|
|
259
266
|
if tags is not None:
|
|
267
|
+
normalized_tags = list(tags)
|
|
260
268
|
if info.model_card_data.tags is not None:
|
|
261
|
-
info.model_card_data.tags.extend(
|
|
269
|
+
info.model_card_data.tags.extend(normalized_tags)
|
|
262
270
|
else:
|
|
263
|
-
info.model_card_data.tags =
|
|
271
|
+
info.model_card_data.tags = normalized_tags
|
|
264
272
|
|
|
265
|
-
info.model_card_data.tags
|
|
273
|
+
if info.model_card_data.tags is not None:
|
|
274
|
+
info.model_card_data.tags = sorted(set(info.model_card_data.tags))
|
|
266
275
|
|
|
267
276
|
# Handle encoders/decoders for args
|
|
268
277
|
cls._hub_mixin_coders = coders or {}
|
|
@@ -277,7 +286,7 @@ class ModelHubMixin:
|
|
|
277
286
|
}
|
|
278
287
|
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
|
|
279
288
|
|
|
280
|
-
def __new__(cls:
|
|
289
|
+
def __new__(cls: type[T], *args, **kwargs) -> T:
|
|
281
290
|
"""Create a new instance of the class and handle config.
|
|
282
291
|
|
|
283
292
|
3 cases:
|
|
@@ -334,6 +343,8 @@ class ModelHubMixin:
|
|
|
334
343
|
@classmethod
|
|
335
344
|
def _is_jsonable(cls, value: Any) -> bool:
|
|
336
345
|
"""Check if a value is JSON serializable."""
|
|
346
|
+
if is_dataclass(value):
|
|
347
|
+
return True
|
|
337
348
|
if isinstance(value, cls._hub_mixin_jsonable_custom_types):
|
|
338
349
|
return True
|
|
339
350
|
return is_jsonable(value)
|
|
@@ -341,6 +352,8 @@ class ModelHubMixin:
|
|
|
341
352
|
@classmethod
|
|
342
353
|
def _encode_arg(cls, arg: Any) -> Any:
|
|
343
354
|
"""Encode an argument into a JSON serializable format."""
|
|
355
|
+
if is_dataclass(arg):
|
|
356
|
+
return asdict(arg) # type: ignore[arg-type]
|
|
344
357
|
for type_, (encoder, _) in cls._hub_mixin_coders.items():
|
|
345
358
|
if isinstance(arg, type_):
|
|
346
359
|
if arg is None:
|
|
@@ -349,7 +362,7 @@ class ModelHubMixin:
|
|
|
349
362
|
return arg
|
|
350
363
|
|
|
351
364
|
@classmethod
|
|
352
|
-
def _decode_arg(cls, expected_type:
|
|
365
|
+
def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> Optional[ARGS_T]:
|
|
353
366
|
"""Decode a JSON serializable value into an argument."""
|
|
354
367
|
if is_simple_optional_type(expected_type):
|
|
355
368
|
if value is None:
|
|
@@ -372,7 +385,7 @@ class ModelHubMixin:
|
|
|
372
385
|
config: Optional[Union[dict, DataclassInstance]] = None,
|
|
373
386
|
repo_id: Optional[str] = None,
|
|
374
387
|
push_to_hub: bool = False,
|
|
375
|
-
model_card_kwargs: Optional[
|
|
388
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
376
389
|
**push_to_hub_kwargs,
|
|
377
390
|
) -> Optional[str]:
|
|
378
391
|
"""
|
|
@@ -388,7 +401,7 @@ class ModelHubMixin:
|
|
|
388
401
|
repo_id (`str`, *optional*):
|
|
389
402
|
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
|
390
403
|
not provided.
|
|
391
|
-
model_card_kwargs (`
|
|
404
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
392
405
|
Additional arguments passed to the model card template to customize the model card.
|
|
393
406
|
push_to_hub_kwargs:
|
|
394
407
|
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
|
@@ -447,12 +460,10 @@ class ModelHubMixin:
|
|
|
447
460
|
@classmethod
|
|
448
461
|
@validate_hf_hub_args
|
|
449
462
|
def from_pretrained(
|
|
450
|
-
cls:
|
|
463
|
+
cls: type[T],
|
|
451
464
|
pretrained_model_name_or_path: Union[str, Path],
|
|
452
465
|
*,
|
|
453
466
|
force_download: bool = False,
|
|
454
|
-
resume_download: Optional[bool] = None,
|
|
455
|
-
proxies: Optional[Dict] = None,
|
|
456
467
|
token: Optional[Union[str, bool]] = None,
|
|
457
468
|
cache_dir: Optional[Union[str, Path]] = None,
|
|
458
469
|
local_files_only: bool = False,
|
|
@@ -473,17 +484,14 @@ class ModelHubMixin:
|
|
|
473
484
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
474
485
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
475
486
|
the existing cache.
|
|
476
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
477
|
-
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
478
|
-
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
|
|
479
487
|
token (`str` or `bool`, *optional*):
|
|
480
488
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
481
|
-
cached when running `
|
|
489
|
+
cached when running `hf auth login`.
|
|
482
490
|
cache_dir (`str`, `Path`, *optional*):
|
|
483
491
|
Path to the folder where cached files are stored.
|
|
484
492
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
485
493
|
If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
|
486
|
-
model_kwargs (`
|
|
494
|
+
model_kwargs (`dict`, *optional*):
|
|
487
495
|
Additional kwargs to pass to the model during initialization.
|
|
488
496
|
"""
|
|
489
497
|
model_id = str(pretrained_model_name_or_path)
|
|
@@ -501,8 +509,6 @@ class ModelHubMixin:
|
|
|
501
509
|
revision=revision,
|
|
502
510
|
cache_dir=cache_dir,
|
|
503
511
|
force_download=force_download,
|
|
504
|
-
proxies=proxies,
|
|
505
|
-
resume_download=resume_download,
|
|
506
512
|
token=token,
|
|
507
513
|
local_files_only=local_files_only,
|
|
508
514
|
)
|
|
@@ -542,7 +548,7 @@ class ModelHubMixin:
|
|
|
542
548
|
if key not in model_kwargs and key in config:
|
|
543
549
|
model_kwargs[key] = config[key]
|
|
544
550
|
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
|
|
545
|
-
for key, value in config.items():
|
|
551
|
+
for key, value in config.items(): # type: ignore[union-attr]
|
|
546
552
|
if key not in model_kwargs:
|
|
547
553
|
model_kwargs[key] = value
|
|
548
554
|
|
|
@@ -555,8 +561,6 @@ class ModelHubMixin:
|
|
|
555
561
|
revision=revision,
|
|
556
562
|
cache_dir=cache_dir,
|
|
557
563
|
force_download=force_download,
|
|
558
|
-
proxies=proxies,
|
|
559
|
-
resume_download=resume_download,
|
|
560
564
|
local_files_only=local_files_only,
|
|
561
565
|
token=token,
|
|
562
566
|
**model_kwargs,
|
|
@@ -571,14 +575,12 @@ class ModelHubMixin:
|
|
|
571
575
|
|
|
572
576
|
@classmethod
|
|
573
577
|
def _from_pretrained(
|
|
574
|
-
cls:
|
|
578
|
+
cls: type[T],
|
|
575
579
|
*,
|
|
576
580
|
model_id: str,
|
|
577
581
|
revision: Optional[str],
|
|
578
582
|
cache_dir: Optional[Union[str, Path]],
|
|
579
583
|
force_download: bool,
|
|
580
|
-
proxies: Optional[Dict],
|
|
581
|
-
resume_download: Optional[bool],
|
|
582
584
|
local_files_only: bool,
|
|
583
585
|
token: Optional[Union[str, bool]],
|
|
584
586
|
**model_kwargs,
|
|
@@ -601,12 +603,9 @@ class ModelHubMixin:
|
|
|
601
603
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
602
604
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
|
|
603
605
|
the existing cache.
|
|
604
|
-
proxies (`Dict[str, str]`, *optional*):
|
|
605
|
-
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
|
|
606
|
-
'http://hostname': 'foo.bar:4012'}`).
|
|
607
606
|
token (`str` or `bool`, *optional*):
|
|
608
607
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
609
|
-
cached when running `
|
|
608
|
+
cached when running `hf auth login`.
|
|
610
609
|
cache_dir (`str`, `Path`, *optional*):
|
|
611
610
|
Path to the folder where cached files are stored.
|
|
612
611
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
@@ -627,10 +626,10 @@ class ModelHubMixin:
|
|
|
627
626
|
token: Optional[str] = None,
|
|
628
627
|
branch: Optional[str] = None,
|
|
629
628
|
create_pr: Optional[bool] = None,
|
|
630
|
-
allow_patterns: Optional[Union[
|
|
631
|
-
ignore_patterns: Optional[Union[
|
|
632
|
-
delete_patterns: Optional[Union[
|
|
633
|
-
model_card_kwargs: Optional[
|
|
629
|
+
allow_patterns: Optional[Union[list[str], str]] = None,
|
|
630
|
+
ignore_patterns: Optional[Union[list[str], str]] = None,
|
|
631
|
+
delete_patterns: Optional[Union[list[str], str]] = None,
|
|
632
|
+
model_card_kwargs: Optional[dict[str, Any]] = None,
|
|
634
633
|
) -> str:
|
|
635
634
|
"""
|
|
636
635
|
Upload model checkpoint to the Hub.
|
|
@@ -651,18 +650,18 @@ class ModelHubMixin:
|
|
|
651
650
|
If `None` (default), the repo will be public unless the organization's default is private.
|
|
652
651
|
token (`str`, *optional*):
|
|
653
652
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
|
|
654
|
-
cached when running `
|
|
653
|
+
cached when running `hf auth login`.
|
|
655
654
|
branch (`str`, *optional*):
|
|
656
655
|
The git branch on which to push the model. This defaults to `"main"`.
|
|
657
656
|
create_pr (`boolean`, *optional*):
|
|
658
657
|
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
|
|
659
|
-
allow_patterns (`
|
|
658
|
+
allow_patterns (`list[str]` or `str`, *optional*):
|
|
660
659
|
If provided, only files matching at least one pattern are pushed.
|
|
661
|
-
ignore_patterns (`
|
|
660
|
+
ignore_patterns (`list[str]` or `str`, *optional*):
|
|
662
661
|
If provided, files matching any of the patterns are not pushed.
|
|
663
|
-
delete_patterns (`
|
|
662
|
+
delete_patterns (`list[str]` or `str`, *optional*):
|
|
664
663
|
If provided, remote files matching any of the patterns will be deleted from the repo.
|
|
665
|
-
model_card_kwargs (`
|
|
664
|
+
model_card_kwargs (`dict[str, Any]`, *optional*):
|
|
666
665
|
Additional arguments passed to the model card template to customize the model card.
|
|
667
666
|
|
|
668
667
|
Returns:
|
|
@@ -692,6 +691,7 @@ class ModelHubMixin:
|
|
|
692
691
|
card_data=self._hub_mixin_info.model_card_data,
|
|
693
692
|
template_str=self._hub_mixin_info.model_card_template,
|
|
694
693
|
repo_url=self._hub_mixin_info.repo_url,
|
|
694
|
+
paper_url=self._hub_mixin_info.paper_url,
|
|
695
695
|
docs_url=self._hub_mixin_info.docs_url,
|
|
696
696
|
**kwargs,
|
|
697
697
|
)
|
|
@@ -718,6 +718,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
718
718
|
... PyTorchModelHubMixin,
|
|
719
719
|
... library_name="keras-nlp",
|
|
720
720
|
... repo_url="https://github.com/keras-team/keras-nlp",
|
|
721
|
+
... paper_url="https://arxiv.org/abs/2304.12244",
|
|
721
722
|
... docs_url="https://keras.io/keras_nlp/",
|
|
722
723
|
... # ^ optional metadata to generate model card
|
|
723
724
|
... ):
|
|
@@ -743,7 +744,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
743
744
|
```
|
|
744
745
|
"""
|
|
745
746
|
|
|
746
|
-
def __init_subclass__(cls, *args, tags: Optional[
|
|
747
|
+
def __init_subclass__(cls, *args, tags: Optional[list[str]] = None, **kwargs) -> None:
|
|
747
748
|
tags = tags or []
|
|
748
749
|
tags.append("pytorch_model_hub_mixin")
|
|
749
750
|
kwargs["tags"] = tags
|
|
@@ -752,7 +753,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
752
753
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
753
754
|
"""Save weights from a Pytorch model to a local directory."""
|
|
754
755
|
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
|
|
755
|
-
save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE))
|
|
756
|
+
save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) # type: ignore [arg-type]
|
|
756
757
|
|
|
757
758
|
@classmethod
|
|
758
759
|
def _from_pretrained(
|
|
@@ -762,8 +763,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
762
763
|
revision: Optional[str],
|
|
763
764
|
cache_dir: Optional[Union[str, Path]],
|
|
764
765
|
force_download: bool,
|
|
765
|
-
proxies: Optional[Dict],
|
|
766
|
-
resume_download: Optional[bool],
|
|
767
766
|
local_files_only: bool,
|
|
768
767
|
token: Union[str, bool, None],
|
|
769
768
|
map_location: str = "cpu",
|
|
@@ -784,8 +783,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
784
783
|
revision=revision,
|
|
785
784
|
cache_dir=cache_dir,
|
|
786
785
|
force_download=force_download,
|
|
787
|
-
proxies=proxies,
|
|
788
|
-
resume_download=resume_download,
|
|
789
786
|
token=token,
|
|
790
787
|
local_files_only=local_files_only,
|
|
791
788
|
)
|
|
@@ -797,8 +794,6 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
797
794
|
revision=revision,
|
|
798
795
|
cache_dir=cache_dir,
|
|
799
796
|
force_download=force_download,
|
|
800
|
-
proxies=proxies,
|
|
801
|
-
resume_download=resume_download,
|
|
802
797
|
token=token,
|
|
803
798
|
local_files_only=local_files_only,
|
|
804
799
|
)
|
|
@@ -828,7 +823,7 @@ class PyTorchModelHubMixin(ModelHubMixin):
|
|
|
828
823
|
return model
|
|
829
824
|
|
|
830
825
|
|
|
831
|
-
def _load_dataclass(datacls:
|
|
826
|
+
def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance:
|
|
832
827
|
"""Load a dataclass instance from a dictionary.
|
|
833
828
|
|
|
834
829
|
Fields not expected by the dataclass are ignored.
|