huggingface-hub 0.18.0rc0__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +31 -5
- huggingface_hub/_commit_api.py +7 -11
- huggingface_hub/_inference_endpoints.py +348 -0
- huggingface_hub/_login.py +9 -7
- huggingface_hub/_multi_commits.py +1 -1
- huggingface_hub/_snapshot_download.py +6 -7
- huggingface_hub/_space_api.py +7 -4
- huggingface_hub/_tensorboard_logger.py +1 -0
- huggingface_hub/_webhooks_payload.py +7 -7
- huggingface_hub/commands/lfs.py +3 -6
- huggingface_hub/commands/user.py +1 -4
- huggingface_hub/constants.py +27 -0
- huggingface_hub/file_download.py +142 -134
- huggingface_hub/hf_api.py +1058 -503
- huggingface_hub/hf_file_system.py +57 -12
- huggingface_hub/hub_mixin.py +3 -5
- huggingface_hub/inference/_client.py +43 -8
- huggingface_hub/inference/_common.py +8 -16
- huggingface_hub/inference/_generated/_async_client.py +41 -8
- huggingface_hub/inference/_text_generation.py +43 -0
- huggingface_hub/inference_api.py +1 -1
- huggingface_hub/lfs.py +32 -14
- huggingface_hub/repocard_data.py +7 -0
- huggingface_hub/repository.py +19 -3
- huggingface_hub/templates/datasetcard_template.md +83 -43
- huggingface_hub/templates/modelcard_template.md +4 -3
- huggingface_hub/utils/__init__.py +1 -1
- huggingface_hub/utils/_cache_assets.py +3 -3
- huggingface_hub/utils/_cache_manager.py +6 -7
- huggingface_hub/utils/_datetime.py +3 -1
- huggingface_hub/utils/_errors.py +10 -0
- huggingface_hub/utils/_hf_folder.py +4 -2
- huggingface_hub/utils/_http.py +10 -1
- huggingface_hub/utils/_runtime.py +4 -2
- huggingface_hub/utils/endpoint_helpers.py +27 -175
- huggingface_hub/utils/insecure_hashlib.py +34 -0
- huggingface_hub/utils/logging.py +4 -6
- huggingface_hub/utils/sha.py +2 -1
- {huggingface_hub-0.18.0rc0.dist-info → huggingface_hub-0.19.0.dist-info}/METADATA +16 -15
- huggingface_hub-0.19.0.dist-info/RECORD +74 -0
- {huggingface_hub-0.18.0rc0.dist-info → huggingface_hub-0.19.0.dist-info}/WHEEL +1 -1
- huggingface_hub-0.18.0rc0.dist-info/RECORD +0 -72
- {huggingface_hub-0.18.0rc0.dist-info → huggingface_hub-0.19.0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.18.0rc0.dist-info → huggingface_hub-0.19.0.dist-info}/entry_points.txt +0 -0
- {huggingface_hub-0.18.0rc0.dist-info → huggingface_hub-0.19.0.dist-info}/top_level.txt +0 -0
huggingface_hub/hf_api.py
CHANGED
|
@@ -16,9 +16,7 @@ from __future__ import annotations
|
|
|
16
16
|
|
|
17
17
|
import inspect
|
|
18
18
|
import json
|
|
19
|
-
import pprint
|
|
20
19
|
import re
|
|
21
|
-
import textwrap
|
|
22
20
|
import warnings
|
|
23
21
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
24
22
|
from dataclasses import dataclass, field
|
|
@@ -69,6 +67,7 @@ from ._commit_api import (
|
|
|
69
67
|
_upload_lfs_files,
|
|
70
68
|
_warn_on_overwriting_operations,
|
|
71
69
|
)
|
|
70
|
+
from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType
|
|
72
71
|
from ._multi_commits import (
|
|
73
72
|
MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_BAD_REQUEST_TEMPLATE,
|
|
74
73
|
MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_NO_CHANGES_TEMPLATE,
|
|
@@ -92,8 +91,10 @@ from .community import (
|
|
|
92
91
|
deserialize_event,
|
|
93
92
|
)
|
|
94
93
|
from .constants import (
|
|
94
|
+
DEFAULT_ETAG_TIMEOUT,
|
|
95
95
|
DEFAULT_REVISION,
|
|
96
96
|
ENDPOINT,
|
|
97
|
+
INFERENCE_ENDPOINTS_ENDPOINT,
|
|
97
98
|
REGEX_COMMIT_OID,
|
|
98
99
|
REPO_TYPE_MODEL,
|
|
99
100
|
REPO_TYPES,
|
|
@@ -105,6 +106,7 @@ from .file_download import (
|
|
|
105
106
|
get_hf_file_metadata,
|
|
106
107
|
hf_hub_url,
|
|
107
108
|
)
|
|
109
|
+
from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData
|
|
108
110
|
from .utils import ( # noqa: F401 # imported for backward compatibility
|
|
109
111
|
BadRequestError,
|
|
110
112
|
HfFolder,
|
|
@@ -122,12 +124,9 @@ from .utils._deprecation import (
|
|
|
122
124
|
)
|
|
123
125
|
from .utils._typing import CallableT
|
|
124
126
|
from .utils.endpoint_helpers import (
|
|
125
|
-
AttributeDictionary,
|
|
126
127
|
DatasetFilter,
|
|
127
|
-
DatasetTags,
|
|
128
128
|
ModelFilter,
|
|
129
|
-
|
|
130
|
-
_filter_emissions,
|
|
129
|
+
_is_emission_within_treshold,
|
|
131
130
|
)
|
|
132
131
|
|
|
133
132
|
|
|
@@ -145,24 +144,6 @@ _CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = (
|
|
|
145
144
|
logger = logging.get_logger(__name__)
|
|
146
145
|
|
|
147
146
|
|
|
148
|
-
class ReprMixin:
|
|
149
|
-
"""Mixin to create the __repr__ for a class"""
|
|
150
|
-
|
|
151
|
-
def __init__(self, **kwargs) -> None:
|
|
152
|
-
# Store all the other fields returned by the API
|
|
153
|
-
# Hack to ensure backward compatibility with future versions of the API.
|
|
154
|
-
# See discussion in https://github.com/huggingface/huggingface_hub/pull/951#discussion_r926460408
|
|
155
|
-
for k, v in kwargs.items():
|
|
156
|
-
setattr(self, k, v)
|
|
157
|
-
|
|
158
|
-
def __repr__(self):
|
|
159
|
-
formatted_value = pprint.pformat(self.__dict__, width=119, compact=True)
|
|
160
|
-
if "\n" in formatted_value:
|
|
161
|
-
return f"{self.__class__.__name__}: {{ \n{textwrap.indent(formatted_value, ' ')}\n}}"
|
|
162
|
-
else:
|
|
163
|
-
return f"{self.__class__.__name__}: {formatted_value}"
|
|
164
|
-
|
|
165
|
-
|
|
166
147
|
def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str]:
|
|
167
148
|
"""
|
|
168
149
|
Returns the repo type and ID from a huggingface.co URL linking to a
|
|
@@ -254,13 +235,38 @@ class BlobLfsInfo(TypedDict, total=False):
|
|
|
254
235
|
pointer_size: int
|
|
255
236
|
|
|
256
237
|
|
|
238
|
+
class BlobLastCommitInfo(TypedDict, total=False):
|
|
239
|
+
oid: str
|
|
240
|
+
title: str
|
|
241
|
+
date: datetime
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class BlobSecurityInfo(TypedDict, total=False):
|
|
245
|
+
safe: bool
|
|
246
|
+
av_scan: Optional[Dict]
|
|
247
|
+
pickle_import_scan: Optional[Dict]
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class TransformersInfo(TypedDict, total=False):
|
|
251
|
+
auto_model: str
|
|
252
|
+
custom_class: Optional[str]
|
|
253
|
+
# possible `pipeline_tag` values: https://github.com/huggingface/hub-docs/blob/f2003d2fca9d4c971629e858e314e0a5c05abf9d/js/src/lib/interfaces/Types.ts#L79
|
|
254
|
+
pipeline_tag: Optional[str]
|
|
255
|
+
processor: Optional[str]
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class SafeTensorsInfo(TypedDict, total=False):
|
|
259
|
+
parameters: List[Dict[str, int]]
|
|
260
|
+
total: int
|
|
261
|
+
|
|
262
|
+
|
|
257
263
|
@dataclass
|
|
258
264
|
class CommitInfo:
|
|
259
265
|
"""Data structure containing information about a newly created commit.
|
|
260
266
|
|
|
261
267
|
Returned by [`create_commit`].
|
|
262
268
|
|
|
263
|
-
|
|
269
|
+
Attributes:
|
|
264
270
|
commit_url (`str`):
|
|
265
271
|
Url where to find the commit.
|
|
266
272
|
|
|
@@ -370,251 +376,431 @@ class RepoUrl(str):
|
|
|
370
376
|
return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')"
|
|
371
377
|
|
|
372
378
|
|
|
373
|
-
|
|
379
|
+
@dataclass
|
|
380
|
+
class RepoSibling:
|
|
374
381
|
"""
|
|
375
|
-
|
|
382
|
+
Contains basic information about a repo file inside a repo on the Hub.
|
|
376
383
|
|
|
377
|
-
|
|
384
|
+
Attributes:
|
|
378
385
|
rfilename (str):
|
|
379
|
-
file name, relative to the repo root.
|
|
380
|
-
certain conditions there can certain other stuff.
|
|
386
|
+
file name, relative to the repo root.
|
|
381
387
|
size (`int`, *optional*):
|
|
382
|
-
The file's size, in bytes. This attribute is
|
|
388
|
+
The file's size, in bytes. This attribute is defined when `files_metadata` argument of [`repo_info`] is set
|
|
383
389
|
to `True`. It's `None` otherwise.
|
|
384
390
|
blob_id (`str`, *optional*):
|
|
385
|
-
The file's git OID. This attribute is
|
|
391
|
+
The file's git OID. This attribute is defined when `files_metadata` argument of [`repo_info`] is set to
|
|
386
392
|
`True`. It's `None` otherwise.
|
|
387
393
|
lfs (`BlobLfsInfo`, *optional*):
|
|
388
|
-
The file's LFS metadata. This attribute is
|
|
394
|
+
The file's LFS metadata. This attribute is defined when`files_metadata` argument of [`repo_info`] is set to
|
|
389
395
|
`True` and the file is stored with Git LFS. It's `None` otherwise.
|
|
390
396
|
"""
|
|
391
397
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
blobId: Optional[str] = None,
|
|
397
|
-
lfs: Optional[BlobLfsInfo] = None,
|
|
398
|
-
**kwargs,
|
|
399
|
-
):
|
|
400
|
-
self.rfilename = rfilename # filename relative to the repo root
|
|
398
|
+
rfilename: str
|
|
399
|
+
size: Optional[int] = None
|
|
400
|
+
blob_id: Optional[str] = None
|
|
401
|
+
lfs: Optional[BlobLfsInfo] = None
|
|
401
402
|
|
|
402
|
-
# Optional file metadata
|
|
403
|
-
self.size = size
|
|
404
|
-
self.blob_id = blobId
|
|
405
|
-
self.lfs = lfs
|
|
406
403
|
|
|
407
|
-
|
|
408
|
-
|
|
404
|
+
@dataclass
|
|
405
|
+
class RepoFile:
|
|
406
|
+
"""
|
|
407
|
+
Contains information about a model on the Hub.
|
|
408
|
+
|
|
409
|
+
Attributes:
|
|
410
|
+
path (str):
|
|
411
|
+
file path relative to the repo root.
|
|
412
|
+
size (`int`):
|
|
413
|
+
The file's size, in bytes.
|
|
414
|
+
blob_id (`str`):
|
|
415
|
+
The file's git OID.
|
|
416
|
+
lfs (`BlobLfsInfo`):
|
|
417
|
+
The file's LFS metadata.
|
|
418
|
+
last_commit (`BlobLastCommitInfo`, *optional*):
|
|
419
|
+
The file's last commit metadata. Only defined if [`list_files_info`] is called with `expand=True`
|
|
420
|
+
security (`BlobSecurityInfo`, *optional*):
|
|
421
|
+
The file's security scan metadata. Only defined if [`list_files_info`] is called with `expand=True`.
|
|
422
|
+
"""
|
|
423
|
+
|
|
424
|
+
path: str
|
|
425
|
+
size: int
|
|
426
|
+
blob_id: str
|
|
427
|
+
lfs: Optional[BlobLfsInfo] = None
|
|
428
|
+
last_commit: Optional[BlobLastCommitInfo] = None
|
|
429
|
+
security: Optional[BlobSecurityInfo] = None
|
|
430
|
+
|
|
431
|
+
def __post_init__(self):
|
|
432
|
+
# backwards compatibility
|
|
433
|
+
self.rfilename = self.path
|
|
434
|
+
self.lastCommit = self.last_commit
|
|
409
435
|
|
|
410
436
|
|
|
411
|
-
|
|
437
|
+
@dataclass
|
|
438
|
+
class ModelInfo:
|
|
412
439
|
"""
|
|
413
|
-
|
|
440
|
+
Contains information about a model on the Hub.
|
|
414
441
|
|
|
415
442
|
Attributes:
|
|
416
|
-
|
|
417
|
-
ID of
|
|
443
|
+
id (`str`):
|
|
444
|
+
ID of dataset.
|
|
445
|
+
author (`str`, *optional*):
|
|
446
|
+
Author of the dataset.
|
|
418
447
|
sha (`str`, *optional*):
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
448
|
+
Repo SHA at this particular revision.
|
|
449
|
+
last_modified (`datetime`, *optional*):
|
|
450
|
+
Date of last commit to the repo.
|
|
451
|
+
private (`bool`):
|
|
452
|
+
Is the repo private.
|
|
453
|
+
disabled (`bool`, *optional*):
|
|
454
|
+
Is the repo disabled.
|
|
455
|
+
gated (`bool`, *optional*):
|
|
456
|
+
Is the repo gated.
|
|
457
|
+
downloads (`int`):
|
|
458
|
+
Number of downloads of the dataset.
|
|
459
|
+
likes (`int`):
|
|
460
|
+
Number of likes of the dataset.
|
|
461
|
+
library_name (`str`, *optional*):
|
|
462
|
+
Library associated with the model.
|
|
463
|
+
tags (`List[str]`):
|
|
464
|
+
List of tags of the model. Compared to `card_data.tags`, contains extra tags computed by the Hub
|
|
465
|
+
(e.g. supported libraries, model's arXiv).
|
|
424
466
|
pipeline_tag (`str`, *optional*):
|
|
425
|
-
Pipeline tag
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
467
|
+
Pipeline tag associated with the model.
|
|
468
|
+
mask_token (`str`, *optional*):
|
|
469
|
+
Mask token used by the model.
|
|
470
|
+
widget_data (`Any`, *optional*):
|
|
471
|
+
Widget data associated with the model.
|
|
472
|
+
model_index (`Dict`, *optional*):
|
|
473
|
+
Model index for evaluation.
|
|
432
474
|
config (`Dict`, *optional*):
|
|
433
|
-
Model configuration
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
475
|
+
Model configuration.
|
|
476
|
+
transformers_info (`TransformersInfo`, *optional*):
|
|
477
|
+
Transformers-specific info (auto class, processor, etc.) associated with the model.
|
|
478
|
+
card_data (`ModelCardData`, *optional*):
|
|
479
|
+
Model Card Metadata as a [`huggingface_hub.repocard_data.ModelCardData`] object.
|
|
480
|
+
siblings (`List[RepoSibling]`):
|
|
481
|
+
List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the model.
|
|
482
|
+
spaces (`List[str]`, *optional*):
|
|
483
|
+
List of spaces using the model.
|
|
484
|
+
safetensors (`SafeTensorsInfo`, *optional*):
|
|
485
|
+
Model's safetensors information.
|
|
439
486
|
"""
|
|
440
487
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
self.
|
|
465
|
-
self.
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
488
|
+
id: str
|
|
489
|
+
author: Optional[str]
|
|
490
|
+
sha: Optional[str]
|
|
491
|
+
last_modified: Optional[datetime]
|
|
492
|
+
private: bool
|
|
493
|
+
gated: Optional[bool]
|
|
494
|
+
disabled: Optional[bool]
|
|
495
|
+
downloads: int
|
|
496
|
+
likes: int
|
|
497
|
+
library_name: Optional[str]
|
|
498
|
+
tags: List[str]
|
|
499
|
+
pipeline_tag: Optional[str]
|
|
500
|
+
mask_token: Optional[str]
|
|
501
|
+
card_data: Optional[ModelCardData]
|
|
502
|
+
widget_data: Optional[Any]
|
|
503
|
+
model_index: Optional[Dict]
|
|
504
|
+
config: Optional[Dict]
|
|
505
|
+
transformers_info: Optional[TransformersInfo]
|
|
506
|
+
siblings: Optional[List[RepoSibling]]
|
|
507
|
+
spaces: Optional[List[str]]
|
|
508
|
+
safetensors: Optional[SafeTensorsInfo]
|
|
509
|
+
|
|
510
|
+
def __init__(self, **kwargs):
|
|
511
|
+
self.id = kwargs.pop("id")
|
|
512
|
+
self.author = kwargs.pop("author", None)
|
|
513
|
+
self.sha = kwargs.pop("sha", None)
|
|
514
|
+
last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None)
|
|
515
|
+
self.last_modified = parse_datetime(last_modified) if last_modified else None
|
|
516
|
+
self.private = kwargs.pop("private")
|
|
517
|
+
self.gated = kwargs.pop("gated", None)
|
|
518
|
+
self.disabled = kwargs.pop("disabled", None)
|
|
519
|
+
self.downloads = kwargs.pop("downloads")
|
|
520
|
+
self.likes = kwargs.pop("likes")
|
|
521
|
+
self.library_name = kwargs.pop("library_name", None)
|
|
522
|
+
self.tags = kwargs.pop("tags")
|
|
523
|
+
self.pipeline_tag = kwargs.pop("pipeline_tag", None)
|
|
524
|
+
self.mask_token = kwargs.pop("mask_token", None)
|
|
525
|
+
card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None)
|
|
526
|
+
self.card_data = (
|
|
527
|
+
ModelCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
self.widget_data = kwargs.pop("widget_data", None)
|
|
531
|
+
self.model_index = kwargs.pop("model-index", None) or kwargs.pop("model_index", None)
|
|
532
|
+
self.config = kwargs.pop("config", None)
|
|
533
|
+
transformers_info = kwargs.pop("transformersInfo", None) or kwargs.pop("transformers_info", None)
|
|
534
|
+
self.transformers_info = TransformersInfo(**transformers_info) if transformers_info else None
|
|
535
|
+
siblings = kwargs.pop("siblings", None)
|
|
536
|
+
self.siblings = (
|
|
537
|
+
[
|
|
538
|
+
RepoSibling(
|
|
539
|
+
rfilename=sibling["rfilename"],
|
|
540
|
+
size=sibling.get("size"),
|
|
541
|
+
blob_id=sibling.get("blobId"),
|
|
542
|
+
lfs=(
|
|
543
|
+
BlobLfsInfo(
|
|
544
|
+
size=sibling["lfs"]["size"],
|
|
545
|
+
sha256=sibling["lfs"]["sha256"],
|
|
546
|
+
pointer_size=sibling["lfs"]["pointerSize"],
|
|
547
|
+
)
|
|
548
|
+
if sibling.get("lfs")
|
|
549
|
+
else None
|
|
550
|
+
),
|
|
551
|
+
)
|
|
552
|
+
for sibling in siblings
|
|
553
|
+
]
|
|
554
|
+
if siblings
|
|
555
|
+
else None
|
|
556
|
+
)
|
|
557
|
+
self.spaces = kwargs.pop("spaces", None)
|
|
558
|
+
safetensors = kwargs.pop("safetensors", None)
|
|
559
|
+
self.safetensors = SafeTensorsInfo(**safetensors) if safetensors else None
|
|
560
|
+
|
|
561
|
+
# backwards compatibility
|
|
562
|
+
self.lastModified = self.last_modified
|
|
563
|
+
self.cardData = self.card_data
|
|
564
|
+
self.transformersInfo = self.transformers_info
|
|
565
|
+
self.__dict__.update(**kwargs)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
@dataclass
|
|
569
|
+
class DatasetInfo:
|
|
478
570
|
"""
|
|
479
|
-
|
|
571
|
+
Contains information about a dataset on the Hub.
|
|
480
572
|
|
|
481
573
|
Attributes:
|
|
482
|
-
id (`str
|
|
483
|
-
ID of dataset
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
repo
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
574
|
+
id (`str`):
|
|
575
|
+
ID of dataset.
|
|
576
|
+
author (`str`):
|
|
577
|
+
Author of the dataset.
|
|
578
|
+
sha (`str`):
|
|
579
|
+
Repo SHA at this particular revision.
|
|
580
|
+
last_modified (`datetime`, *optional*):
|
|
581
|
+
Date of last commit to the repo.
|
|
582
|
+
private (`bool`):
|
|
583
|
+
Is the repo private.
|
|
584
|
+
disabled (`bool`, *optional*):
|
|
585
|
+
Is the repo disabled.
|
|
586
|
+
gated (`bool`, *optional*):
|
|
587
|
+
Is the repo gated.
|
|
588
|
+
downloads (`int`):
|
|
589
|
+
Number of downloads of the dataset.
|
|
590
|
+
likes (`int`):
|
|
591
|
+
Number of likes of the dataset.
|
|
592
|
+
tags (`List[str]`):
|
|
593
|
+
List of tags of the dataset.
|
|
594
|
+
card_data (`DatasetCardData`, *optional*):
|
|
595
|
+
Model Card Metadata as a [`huggingface_hub.repocard_data.DatasetCardData`] object.
|
|
596
|
+
siblings (`List[RepoSibling]`):
|
|
597
|
+
List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the dataset.
|
|
504
598
|
"""
|
|
505
599
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
):
|
|
521
|
-
self.id = id
|
|
522
|
-
self.
|
|
523
|
-
self.
|
|
524
|
-
|
|
525
|
-
self.
|
|
526
|
-
self.
|
|
527
|
-
self.
|
|
528
|
-
self.
|
|
529
|
-
self.
|
|
530
|
-
self.
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
kwargs.pop("
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
600
|
+
id: str
|
|
601
|
+
author: Optional[str]
|
|
602
|
+
sha: Optional[str]
|
|
603
|
+
last_modified: Optional[datetime]
|
|
604
|
+
private: bool
|
|
605
|
+
gated: Optional[bool]
|
|
606
|
+
disabled: Optional[bool]
|
|
607
|
+
downloads: int
|
|
608
|
+
likes: int
|
|
609
|
+
paperswithcode_id: Optional[str]
|
|
610
|
+
tags: List[str]
|
|
611
|
+
card_data: Optional[DatasetCardData]
|
|
612
|
+
siblings: Optional[List[RepoSibling]]
|
|
613
|
+
|
|
614
|
+
def __init__(self, **kwargs):
|
|
615
|
+
self.id = kwargs.pop("id")
|
|
616
|
+
self.author = kwargs.pop("author", None)
|
|
617
|
+
self.sha = kwargs.pop("sha", None)
|
|
618
|
+
last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None)
|
|
619
|
+
self.last_modified = parse_datetime(last_modified) if last_modified else None
|
|
620
|
+
self.private = kwargs.pop("private")
|
|
621
|
+
self.gated = kwargs.pop("gated", None)
|
|
622
|
+
self.disabled = kwargs.pop("disabled", None)
|
|
623
|
+
self.downloads = kwargs.pop("downloads")
|
|
624
|
+
self.likes = kwargs.pop("likes")
|
|
625
|
+
self.paperswithcode_id = kwargs.pop("paperswithcode_id", None)
|
|
626
|
+
self.tags = kwargs.pop("tags")
|
|
627
|
+
card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None)
|
|
628
|
+
self.card_data = (
|
|
629
|
+
DatasetCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data
|
|
630
|
+
)
|
|
631
|
+
siblings = kwargs.pop("siblings", None)
|
|
632
|
+
self.siblings = (
|
|
633
|
+
[
|
|
634
|
+
RepoSibling(
|
|
635
|
+
rfilename=sibling["rfilename"],
|
|
636
|
+
size=sibling.get("size"),
|
|
637
|
+
blob_id=sibling.get("blobId"),
|
|
638
|
+
lfs=(
|
|
639
|
+
BlobLfsInfo(
|
|
640
|
+
size=sibling["lfs"]["size"],
|
|
641
|
+
sha256=sibling["lfs"]["sha256"],
|
|
642
|
+
pointer_size=sibling["lfs"]["pointerSize"],
|
|
643
|
+
)
|
|
644
|
+
if sibling.get("lfs")
|
|
645
|
+
else None
|
|
646
|
+
),
|
|
647
|
+
)
|
|
648
|
+
for sibling in siblings
|
|
649
|
+
]
|
|
650
|
+
if siblings
|
|
651
|
+
else None
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# backwards compatibility
|
|
655
|
+
self.lastModified = self.last_modified
|
|
656
|
+
self.cardData = self.card_data
|
|
657
|
+
self.__dict__.update(**kwargs)
|
|
545
658
|
|
|
546
|
-
|
|
547
|
-
|
|
659
|
+
|
|
660
|
+
@dataclass
|
|
661
|
+
class SpaceInfo:
|
|
662
|
+
"""
|
|
663
|
+
Contains information about a Space on the Hub.
|
|
548
664
|
|
|
549
665
|
Attributes:
|
|
550
|
-
id (`str
|
|
551
|
-
|
|
552
|
-
sha (`str`, *optional*):
|
|
553
|
-
repo sha at this particular revision
|
|
554
|
-
lastModified (`str`, *optional*):
|
|
555
|
-
date of last commit to repo
|
|
556
|
-
siblings (`List[RepoFile]`, *optional*):
|
|
557
|
-
list of [`huggingface_hub.hf_api.RepoFIle`] objects that constitute the Space
|
|
558
|
-
private (`bool`, *optional*, defaults to `False`):
|
|
559
|
-
is the repo private
|
|
666
|
+
id (`str`):
|
|
667
|
+
ID of the Space.
|
|
560
668
|
author (`str`, *optional*):
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
669
|
+
Author of the Space.
|
|
670
|
+
sha (`str`, *optional*):
|
|
671
|
+
Repo SHA at this particular revision.
|
|
672
|
+
last_modified (`datetime`, *optional*):
|
|
673
|
+
Date of last commit to the repo.
|
|
674
|
+
private (`bool`):
|
|
675
|
+
Is the repo private.
|
|
676
|
+
gated (`bool`, *optional*):
|
|
677
|
+
Is the repo gated.
|
|
678
|
+
disabled (`bool`, *optional*):
|
|
679
|
+
Is the Space disabled.
|
|
680
|
+
host (`str`, *optional*):
|
|
681
|
+
Host URL of the Space.
|
|
682
|
+
subdomain (`str`, *optional*):
|
|
683
|
+
Subdomain of the Space.
|
|
684
|
+
likes (`int`):
|
|
685
|
+
Number of likes of the Space.
|
|
686
|
+
tags (`List[str]`):
|
|
687
|
+
List of tags of the Space.
|
|
688
|
+
siblings (`List[RepoSibling]`):
|
|
689
|
+
List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the Space.
|
|
690
|
+
card_data (`SpaceCardData`, *optional*):
|
|
691
|
+
Space Card Metadata as a [`huggingface_hub.repocard_data.SpaceCardData`] object.
|
|
692
|
+
runtime (`SpaceRuntime`, *optional*):
|
|
693
|
+
Space runtime information as a [`huggingface_hub.hf_api.SpaceRuntime`] object.
|
|
694
|
+
sdk (`str`, *optional*):
|
|
695
|
+
SDK used by the Space.
|
|
696
|
+
models (`List[str]`, *optional*):
|
|
697
|
+
List of models used by the Space.
|
|
698
|
+
datasets (`List[str]`, *optional*):
|
|
699
|
+
List of datasets used by the Space.
|
|
564
700
|
"""
|
|
565
701
|
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
702
|
+
id: str
|
|
703
|
+
author: Optional[str]
|
|
704
|
+
sha: Optional[str]
|
|
705
|
+
last_modified: Optional[datetime]
|
|
706
|
+
private: bool
|
|
707
|
+
gated: Optional[bool]
|
|
708
|
+
disabled: Optional[bool]
|
|
709
|
+
host: Optional[str]
|
|
710
|
+
subdomain: Optional[str]
|
|
711
|
+
likes: int
|
|
712
|
+
sdk: Optional[str]
|
|
713
|
+
tags: List[str]
|
|
714
|
+
siblings: Optional[List[RepoSibling]]
|
|
715
|
+
card_data: Optional[SpaceCardData]
|
|
716
|
+
runtime: Optional[SpaceRuntime]
|
|
717
|
+
models: Optional[List[str]]
|
|
718
|
+
datasets: Optional[List[str]]
|
|
719
|
+
|
|
720
|
+
def __init__(self, **kwargs):
|
|
721
|
+
self.id = kwargs.pop("id")
|
|
722
|
+
self.author = kwargs.pop("author", None)
|
|
723
|
+
self.sha = kwargs.pop("sha", None)
|
|
724
|
+
last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None)
|
|
725
|
+
self.last_modified = parse_datetime(last_modified) if last_modified else None
|
|
726
|
+
self.private = kwargs.pop("private")
|
|
727
|
+
self.gated = kwargs.pop("gated", None)
|
|
728
|
+
self.disabled = kwargs.pop("disabled", None)
|
|
729
|
+
self.host = kwargs.pop("host", None)
|
|
730
|
+
self.subdomain = kwargs.pop("subdomain", None)
|
|
731
|
+
self.likes = kwargs.pop("likes")
|
|
732
|
+
self.sdk = kwargs.pop("sdk", None)
|
|
733
|
+
self.tags = kwargs.pop("tags")
|
|
734
|
+
card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None)
|
|
735
|
+
self.card_data = (
|
|
736
|
+
SpaceCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data
|
|
737
|
+
)
|
|
738
|
+
siblings = kwargs.pop("siblings", None)
|
|
739
|
+
self.siblings = (
|
|
740
|
+
[
|
|
741
|
+
RepoSibling(
|
|
742
|
+
rfilename=sibling["rfilename"],
|
|
743
|
+
size=sibling.get("size"),
|
|
744
|
+
blob_id=sibling.get("blobId"),
|
|
745
|
+
lfs=(
|
|
746
|
+
BlobLfsInfo(
|
|
747
|
+
size=sibling["lfs"]["size"],
|
|
748
|
+
sha256=sibling["lfs"]["sha256"],
|
|
749
|
+
pointer_size=sibling["lfs"]["pointerSize"],
|
|
750
|
+
)
|
|
751
|
+
if sibling.get("lfs")
|
|
752
|
+
else None
|
|
753
|
+
),
|
|
754
|
+
)
|
|
755
|
+
for sibling in siblings
|
|
756
|
+
]
|
|
757
|
+
if siblings
|
|
758
|
+
else None
|
|
759
|
+
)
|
|
760
|
+
runtime = kwargs.pop("runtime", None)
|
|
761
|
+
self.runtime = SpaceRuntime(runtime) if runtime else None
|
|
762
|
+
self.models = kwargs.pop("models", None)
|
|
763
|
+
self.datasets = kwargs.pop("datasets", None)
|
|
585
764
|
|
|
765
|
+
# backwards compatibility
|
|
766
|
+
self.lastModified = self.last_modified
|
|
767
|
+
self.cardData = self.card_data
|
|
768
|
+
self.__dict__.update(**kwargs)
|
|
586
769
|
|
|
587
|
-
|
|
770
|
+
|
|
771
|
+
@dataclass
|
|
772
|
+
class MetricInfo:
|
|
588
773
|
"""
|
|
589
|
-
|
|
774
|
+
Contains information about a metric on the Hub.
|
|
775
|
+
|
|
776
|
+
Attributes:
|
|
777
|
+
id (`str`):
|
|
778
|
+
ID of the metric. E.g. `"accuracy"`.
|
|
779
|
+
space_id (`str`):
|
|
780
|
+
ID of the space associated with the metric. E.g. `"Accuracy"`.
|
|
781
|
+
description (`str`):
|
|
782
|
+
Description of the metric.
|
|
590
783
|
"""
|
|
591
784
|
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
id: Optional[str] = None, # id of metric
|
|
596
|
-
description: Optional[str] = None,
|
|
597
|
-
citation: Optional[str] = None,
|
|
598
|
-
**kwargs,
|
|
599
|
-
):
|
|
600
|
-
self.id = id
|
|
601
|
-
self.description = description
|
|
602
|
-
self.citation = citation
|
|
603
|
-
# Legacy stuff, "key" is always returned with an empty string
|
|
604
|
-
# because of old versions of the datasets lib that need this field
|
|
605
|
-
kwargs.pop("key", None)
|
|
606
|
-
# Store all the other fields returned by the API
|
|
607
|
-
super().__init__(**kwargs)
|
|
785
|
+
id: str
|
|
786
|
+
space_id: str
|
|
787
|
+
description: Optional[str]
|
|
608
788
|
|
|
609
|
-
def
|
|
610
|
-
|
|
611
|
-
|
|
789
|
+
def __init__(self, **kwargs):
|
|
790
|
+
self.id = kwargs.pop("id")
|
|
791
|
+
self.space_id = kwargs.pop("spaceId")
|
|
792
|
+
self.description = kwargs.pop("description", None)
|
|
793
|
+
# backwards compatibility
|
|
794
|
+
self.spaceId = self.space_id
|
|
795
|
+
self.__dict__.update(**kwargs)
|
|
612
796
|
|
|
613
797
|
|
|
614
|
-
|
|
615
|
-
|
|
798
|
+
@dataclass
|
|
799
|
+
class CollectionItem:
|
|
800
|
+
"""
|
|
801
|
+
Contains information about an item of a Collection (model, dataset, Space or paper).
|
|
616
802
|
|
|
617
|
-
|
|
803
|
+
Attributes:
|
|
618
804
|
item_object_id (`str`):
|
|
619
805
|
Unique ID of the item in the collection.
|
|
620
806
|
item_id (`str`):
|
|
@@ -626,11 +812,14 @@ class CollectionItem(ReprMixin):
|
|
|
626
812
|
Position of the item in the collection.
|
|
627
813
|
note (`str`, *optional*):
|
|
628
814
|
Note associated with the item, as plain text.
|
|
629
|
-
kwargs (`Dict`, *optional*):
|
|
630
|
-
Any other attribute returned by the server. Those attributes depend on the `item_type`: "author", "private",
|
|
631
|
-
"lastModified", "gated", "title", "likes", "upvotes", etc.
|
|
632
815
|
"""
|
|
633
816
|
|
|
817
|
+
item_object_id: str # id in database
|
|
818
|
+
item_id: str # repo_id or paper id
|
|
819
|
+
item_type: str
|
|
820
|
+
position: int
|
|
821
|
+
note: Optional[str] = None
|
|
822
|
+
|
|
634
823
|
def __init__(
|
|
635
824
|
self, _id: str, id: str, type: CollectionItemType_T, position: int, note: Optional[Dict] = None, **kwargs
|
|
636
825
|
) -> None:
|
|
@@ -640,23 +829,19 @@ class CollectionItem(ReprMixin):
|
|
|
640
829
|
self.position: int = position
|
|
641
830
|
self.note: str = note["text"] if note is not None else None
|
|
642
831
|
|
|
643
|
-
# Store all the other fields returned by the API
|
|
644
|
-
super().__init__(**kwargs)
|
|
645
|
-
|
|
646
832
|
|
|
647
|
-
|
|
833
|
+
@dataclass
|
|
834
|
+
class Collection:
|
|
648
835
|
"""
|
|
649
836
|
Contains information about a Collection on the Hub.
|
|
650
837
|
|
|
651
|
-
|
|
838
|
+
Attributes:
|
|
652
839
|
slug (`str`):
|
|
653
840
|
Slug of the collection. E.g. `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`.
|
|
654
841
|
title (`str`):
|
|
655
842
|
Title of the collection. E.g. `"Recent models"`.
|
|
656
843
|
owner (`str`):
|
|
657
844
|
Owner of the collection. E.g. `"TheBloke"`.
|
|
658
|
-
description (`str`, *optional*):
|
|
659
|
-
Description of the collection, as plain text.
|
|
660
845
|
items (`List[CollectionItem]`):
|
|
661
846
|
List of items in the collection.
|
|
662
847
|
last_updated (`datetime`):
|
|
@@ -667,139 +852,45 @@ class Collection(ReprMixin):
|
|
|
667
852
|
Whether the collection is private or not.
|
|
668
853
|
theme (`str`):
|
|
669
854
|
Theme of the collection. E.g. `"green"`.
|
|
855
|
+
upvotes (`int`):
|
|
856
|
+
Number of upvotes of the collection.
|
|
857
|
+
description (`str`, *optional*):
|
|
858
|
+
Description of the collection, as plain text.
|
|
670
859
|
url (`str`):
|
|
671
|
-
URL
|
|
860
|
+
(property) URL of the collection on the Hub.
|
|
672
861
|
"""
|
|
673
862
|
|
|
674
863
|
slug: str
|
|
675
864
|
title: str
|
|
676
865
|
owner: str
|
|
677
|
-
description: Optional[str]
|
|
678
866
|
items: List[CollectionItem]
|
|
679
|
-
|
|
680
867
|
last_updated: datetime
|
|
681
868
|
position: int
|
|
682
869
|
private: bool
|
|
683
870
|
theme: str
|
|
871
|
+
upvotes: int
|
|
872
|
+
description: Optional[str] = None
|
|
684
873
|
|
|
685
|
-
def __init__(self,
|
|
686
|
-
|
|
687
|
-
self.
|
|
688
|
-
self.
|
|
689
|
-
self.
|
|
690
|
-
self.
|
|
691
|
-
self.
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
self.
|
|
695
|
-
self.
|
|
696
|
-
|
|
697
|
-
self.theme = data["theme"]
|
|
698
|
-
|
|
699
|
-
# (internal)
|
|
874
|
+
def __init__(self, **kwargs) -> None:
|
|
875
|
+
self.slug = kwargs.pop("slug")
|
|
876
|
+
self.title = kwargs.pop("title")
|
|
877
|
+
self.owner = kwargs.pop("owner")
|
|
878
|
+
self.items = [CollectionItem(**item) for item in kwargs.pop("items")]
|
|
879
|
+
self.last_updated = parse_datetime(kwargs.pop("lastUpdated"))
|
|
880
|
+
self.position = kwargs.pop("position")
|
|
881
|
+
self.private = kwargs.pop("private")
|
|
882
|
+
self.theme = kwargs.pop("theme")
|
|
883
|
+
self.upvotes = kwargs.pop("upvotes")
|
|
884
|
+
self.description = kwargs.pop("description", None)
|
|
885
|
+
endpoint = kwargs.pop("endpoint", None)
|
|
700
886
|
if endpoint is None:
|
|
701
887
|
endpoint = ENDPOINT
|
|
702
|
-
self.
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
class ModelSearchArguments(AttributeDictionary):
|
|
706
|
-
"""
|
|
707
|
-
A nested namespace object holding all possible values for properties of
|
|
708
|
-
models currently hosted in the Hub with tab-completion. If a value starts
|
|
709
|
-
with a number, it will only exist in the dictionary
|
|
710
|
-
|
|
711
|
-
Example:
|
|
712
|
-
|
|
713
|
-
```python
|
|
714
|
-
>>> args = ModelSearchArguments()
|
|
715
|
-
|
|
716
|
-
>>> args.author.huggingface
|
|
717
|
-
'huggingface'
|
|
718
|
-
|
|
719
|
-
>>> args.language.en
|
|
720
|
-
'en'
|
|
721
|
-
```
|
|
722
|
-
|
|
723
|
-
<Tip warning={true}>
|
|
724
|
-
|
|
725
|
-
`ModelSearchArguments` is a legacy class meant for exploratory purposes only. Its
|
|
726
|
-
initialization requires listing all models on the Hub which makes it increasingly
|
|
727
|
-
slower as the number of repos on the Hub increases.
|
|
728
|
-
|
|
729
|
-
</Tip>
|
|
730
|
-
"""
|
|
731
|
-
|
|
732
|
-
def __init__(self, api: Optional["HfApi"] = None):
|
|
733
|
-
self._api = api if api is not None else HfApi()
|
|
734
|
-
tags = self._api.get_model_tags()
|
|
735
|
-
super().__init__(tags)
|
|
736
|
-
self._process_models()
|
|
737
|
-
|
|
738
|
-
def _process_models(self):
|
|
739
|
-
def clean(s: str) -> str:
|
|
740
|
-
return s.replace(" ", "").replace("-", "_").replace(".", "_")
|
|
741
|
-
|
|
742
|
-
models = self._api.list_models()
|
|
743
|
-
author_dict, model_name_dict = AttributeDictionary(), AttributeDictionary()
|
|
744
|
-
for model in models:
|
|
745
|
-
if "/" in model.modelId:
|
|
746
|
-
author, name = model.modelId.split("/")
|
|
747
|
-
author_dict[author] = clean(author)
|
|
748
|
-
else:
|
|
749
|
-
name = model.modelId
|
|
750
|
-
model_name_dict[name] = clean(name)
|
|
751
|
-
self["model_name"] = model_name_dict
|
|
752
|
-
self["author"] = author_dict
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
class DatasetSearchArguments(AttributeDictionary):
|
|
756
|
-
"""
|
|
757
|
-
A nested namespace object holding all possible values for properties of
|
|
758
|
-
datasets currently hosted in the Hub with tab-completion. If a value starts
|
|
759
|
-
with a number, it will only exist in the dictionary
|
|
760
|
-
|
|
761
|
-
Example:
|
|
762
|
-
|
|
763
|
-
```python
|
|
764
|
-
>>> args = DatasetSearchArguments()
|
|
765
|
-
|
|
766
|
-
>>> args.author.huggingface
|
|
767
|
-
'huggingface'
|
|
768
|
-
|
|
769
|
-
>>> args.language.en
|
|
770
|
-
'language:en'
|
|
771
|
-
```
|
|
772
|
-
|
|
773
|
-
<Tip warning={true}>
|
|
774
|
-
|
|
775
|
-
`DatasetSearchArguments` is a legacy class meant for exploratory purposes only. Its
|
|
776
|
-
initialization requires listing all datasets on the Hub which makes it increasingly
|
|
777
|
-
slower as the number of repos on the Hub increases.
|
|
778
|
-
|
|
779
|
-
</Tip>
|
|
780
|
-
"""
|
|
888
|
+
self._url = f"{endpoint}/collections/{self.slug}"
|
|
781
889
|
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
self._process_models()
|
|
787
|
-
|
|
788
|
-
def _process_models(self):
|
|
789
|
-
def clean(s: str):
|
|
790
|
-
return s.replace(" ", "").replace("-", "_").replace(".", "_")
|
|
791
|
-
|
|
792
|
-
datasets = self._api.list_datasets()
|
|
793
|
-
author_dict, dataset_name_dict = AttributeDictionary(), AttributeDictionary()
|
|
794
|
-
for dataset in datasets:
|
|
795
|
-
if "/" in dataset.id:
|
|
796
|
-
author, name = dataset.id.split("/")
|
|
797
|
-
author_dict[author] = clean(author)
|
|
798
|
-
else:
|
|
799
|
-
name = dataset.id
|
|
800
|
-
dataset_name_dict[name] = clean(name)
|
|
801
|
-
self["dataset_name"] = dataset_name_dict
|
|
802
|
-
self["author"] = author_dict
|
|
890
|
+
@property
|
|
891
|
+
def url(self) -> str:
|
|
892
|
+
"""Returns the URL of the collection on the Hub."""
|
|
893
|
+
return self._url
|
|
803
894
|
|
|
804
895
|
|
|
805
896
|
@dataclass
|
|
@@ -807,7 +898,7 @@ class GitRefInfo:
|
|
|
807
898
|
"""
|
|
808
899
|
Contains information about a git reference for a repo on the Hub.
|
|
809
900
|
|
|
810
|
-
|
|
901
|
+
Attributes:
|
|
811
902
|
name (`str`):
|
|
812
903
|
Name of the reference (e.g. tag name or branch name).
|
|
813
904
|
ref (`str`):
|
|
@@ -820,11 +911,6 @@ class GitRefInfo:
|
|
|
820
911
|
ref: str
|
|
821
912
|
target_commit: str
|
|
822
913
|
|
|
823
|
-
def __init__(self, data: Dict) -> None:
|
|
824
|
-
self.name = data["name"]
|
|
825
|
-
self.ref = data["ref"]
|
|
826
|
-
self.target_commit = data["targetCommit"]
|
|
827
|
-
|
|
828
914
|
|
|
829
915
|
@dataclass
|
|
830
916
|
class GitRefs:
|
|
@@ -833,7 +919,7 @@ class GitRefs:
|
|
|
833
919
|
|
|
834
920
|
Object is returned by [`list_repo_refs`].
|
|
835
921
|
|
|
836
|
-
|
|
922
|
+
Attributes:
|
|
837
923
|
branches (`List[GitRefInfo]`):
|
|
838
924
|
A list of [`GitRefInfo`] containing information about branches on the repo.
|
|
839
925
|
converts (`List[GitRefInfo]`):
|
|
@@ -853,7 +939,7 @@ class GitCommitInfo:
|
|
|
853
939
|
"""
|
|
854
940
|
Contains information about a git commit for a repo on the Hub. Check out [`list_repo_commits`] for more details.
|
|
855
941
|
|
|
856
|
-
|
|
942
|
+
Attributes:
|
|
857
943
|
commit_id (`str`):
|
|
858
944
|
OID of the commit (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`)
|
|
859
945
|
authors (`List[str]`):
|
|
@@ -880,23 +966,13 @@ class GitCommitInfo:
|
|
|
880
966
|
formatted_title: Optional[str]
|
|
881
967
|
formatted_message: Optional[str]
|
|
882
968
|
|
|
883
|
-
def __init__(self, data: Dict) -> None:
|
|
884
|
-
self.commit_id = data["id"]
|
|
885
|
-
self.authors = [author["user"] for author in data["authors"]]
|
|
886
|
-
self.created_at = parse_datetime(data["date"])
|
|
887
|
-
self.title = data["title"]
|
|
888
|
-
self.message = data["message"]
|
|
889
|
-
|
|
890
|
-
self.formatted_title = data.get("formatted", {}).get("title")
|
|
891
|
-
self.formatted_message = data.get("formatted", {}).get("message")
|
|
892
|
-
|
|
893
969
|
|
|
894
970
|
@dataclass
|
|
895
971
|
class UserLikes:
|
|
896
972
|
"""
|
|
897
973
|
Contains information about a user likes on the Hub.
|
|
898
974
|
|
|
899
|
-
|
|
975
|
+
Attributes:
|
|
900
976
|
user (`str`):
|
|
901
977
|
Name of the user for which we fetched the likes.
|
|
902
978
|
total (`int`):
|
|
@@ -924,7 +1000,7 @@ class User:
|
|
|
924
1000
|
"""
|
|
925
1001
|
Contains information about a user on the Hub.
|
|
926
1002
|
|
|
927
|
-
|
|
1003
|
+
Attributes:
|
|
928
1004
|
avatar_url (`str`):
|
|
929
1005
|
URL of the user's avatar.
|
|
930
1006
|
username (`str`):
|
|
@@ -989,9 +1065,6 @@ class HfApi:
|
|
|
989
1065
|
directly at the root of `huggingface_hub`.
|
|
990
1066
|
|
|
991
1067
|
Args:
|
|
992
|
-
endpoint (`str`, *optional*):
|
|
993
|
-
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise,
|
|
994
|
-
one can set the `HF_ENDPOINT` environment variable.
|
|
995
1068
|
token (`str`, *optional*):
|
|
996
1069
|
Hugging Face token. Will default to the locally saved token if
|
|
997
1070
|
not provided.
|
|
@@ -1101,25 +1174,23 @@ class HfApi:
|
|
|
1101
1174
|
except (LocalTokenNotFoundError, HTTPError):
|
|
1102
1175
|
return None
|
|
1103
1176
|
|
|
1104
|
-
def get_model_tags(self) ->
|
|
1177
|
+
def get_model_tags(self) -> Dict:
|
|
1105
1178
|
"""
|
|
1106
1179
|
List all valid model tags as a nested namespace object
|
|
1107
1180
|
"""
|
|
1108
1181
|
path = f"{self.endpoint}/api/models-tags-by-type"
|
|
1109
1182
|
r = get_session().get(path)
|
|
1110
1183
|
hf_raise_for_status(r)
|
|
1111
|
-
|
|
1112
|
-
return ModelTags(d)
|
|
1184
|
+
return r.json()
|
|
1113
1185
|
|
|
1114
|
-
def get_dataset_tags(self) ->
|
|
1186
|
+
def get_dataset_tags(self) -> Dict:
|
|
1115
1187
|
"""
|
|
1116
1188
|
List all valid dataset tags as a nested namespace object.
|
|
1117
1189
|
"""
|
|
1118
1190
|
path = f"{self.endpoint}/api/datasets-tags-by-type"
|
|
1119
1191
|
r = get_session().get(path)
|
|
1120
1192
|
hf_raise_for_status(r)
|
|
1121
|
-
|
|
1122
|
-
return DatasetTags(d)
|
|
1193
|
+
return r.json()
|
|
1123
1194
|
|
|
1124
1195
|
@validate_hf_hub_args
|
|
1125
1196
|
def list_models(
|
|
@@ -1129,7 +1200,7 @@ class HfApi:
|
|
|
1129
1200
|
author: Optional[str] = None,
|
|
1130
1201
|
search: Optional[str] = None,
|
|
1131
1202
|
emissions_thresholds: Optional[Tuple[float, float]] = None,
|
|
1132
|
-
sort: Union[Literal["
|
|
1203
|
+
sort: Union[Literal["last_modified"], str, None] = None,
|
|
1133
1204
|
direction: Optional[Literal[-1]] = None,
|
|
1134
1205
|
limit: Optional[int] = None,
|
|
1135
1206
|
full: Optional[bool] = None,
|
|
@@ -1152,7 +1223,7 @@ class HfApi:
|
|
|
1152
1223
|
emissions_thresholds (`Tuple`, *optional*):
|
|
1153
1224
|
A tuple of two ints or floats representing a minimum and maximum
|
|
1154
1225
|
carbon footprint to filter the resulting models with in grams.
|
|
1155
|
-
sort (`Literal["
|
|
1226
|
+
sort (`Literal["last_modified"]` or `str`, *optional*):
|
|
1156
1227
|
The key with which to sort the resulting models. Possible values
|
|
1157
1228
|
are the properties of the [`huggingface_hub.hf_api.ModelInfo`] class.
|
|
1158
1229
|
direction (`Literal[-1]` or `int`, *optional*):
|
|
@@ -1162,7 +1233,7 @@ class HfApi:
|
|
|
1162
1233
|
The limit on the number of models fetched. Leaving this option
|
|
1163
1234
|
to `None` fetches all models.
|
|
1164
1235
|
full (`bool`, *optional*):
|
|
1165
|
-
Whether to fetch all model data, including the `
|
|
1236
|
+
Whether to fetch all model data, including the `last_modified`,
|
|
1166
1237
|
the `sha`, the files and the `tags`. This is set to `True` by
|
|
1167
1238
|
default when using a filter.
|
|
1168
1239
|
cardData (`bool`, *optional*):
|
|
@@ -1191,29 +1262,15 @@ class HfApi:
|
|
|
1191
1262
|
>>> # List all models
|
|
1192
1263
|
>>> api.list_models()
|
|
1193
1264
|
|
|
1194
|
-
>>> # Get all valid search arguments
|
|
1195
|
-
>>> args = ModelSearchArguments()
|
|
1196
|
-
|
|
1197
1265
|
>>> # List only the text classification models
|
|
1198
1266
|
>>> api.list_models(filter="text-classification")
|
|
1199
1267
|
>>> # Using the `ModelFilter`
|
|
1200
1268
|
>>> filt = ModelFilter(task="text-classification")
|
|
1201
|
-
|
|
1202
|
-
>>> filt = ModelFilter(task=args.pipeline_tags.TextClassification)
|
|
1203
|
-
>>> api.list_models(filter=filt)
|
|
1204
|
-
|
|
1205
|
-
>>> # Using `ModelFilter` and `ModelSearchArguments` to find text classification in both PyTorch and TensorFlow
|
|
1206
|
-
>>> filt = ModelFilter(
|
|
1207
|
-
... task=args.pipeline_tags.TextClassification,
|
|
1208
|
-
... library=[args.library.PyTorch, args.library.TensorFlow],
|
|
1209
|
-
... )
|
|
1210
|
-
>>> api.list_models(filter=filt)
|
|
1269
|
+
|
|
1211
1270
|
|
|
1212
1271
|
>>> # List only models from the AllenNLP library
|
|
1213
1272
|
>>> api.list_models(filter="allennlp")
|
|
1214
|
-
|
|
1215
|
-
>>> filt = ModelFilter(library=args.library.allennlp)
|
|
1216
|
-
```
|
|
1273
|
+
|
|
1217
1274
|
|
|
1218
1275
|
Example usage with the `search` argument:
|
|
1219
1276
|
|
|
@@ -1246,7 +1303,7 @@ class HfApi:
|
|
|
1246
1303
|
if search is not None:
|
|
1247
1304
|
params.update({"search": search})
|
|
1248
1305
|
if sort is not None:
|
|
1249
|
-
params.update({"sort": sort})
|
|
1306
|
+
params.update({"sort": "lastModified" if sort == "last_modified" else sort})
|
|
1250
1307
|
if direction is not None:
|
|
1251
1308
|
params.update({"direction": direction})
|
|
1252
1309
|
if limit is not None:
|
|
@@ -1265,10 +1322,12 @@ class HfApi:
|
|
|
1265
1322
|
items = paginate(path, params=params, headers=headers)
|
|
1266
1323
|
if limit is not None:
|
|
1267
1324
|
items = islice(items, limit) # Do not iterate over all pages
|
|
1268
|
-
if emissions_thresholds is not None:
|
|
1269
|
-
items = _filter_emissions(items, *emissions_thresholds)
|
|
1270
1325
|
for item in items:
|
|
1271
|
-
|
|
1326
|
+
if "siblings" not in item:
|
|
1327
|
+
item["siblings"] = None
|
|
1328
|
+
model_info = ModelInfo(**item)
|
|
1329
|
+
if emissions_thresholds is None or _is_emission_within_treshold(model_info, *emissions_thresholds):
|
|
1330
|
+
yield model_info
|
|
1272
1331
|
|
|
1273
1332
|
def _unpack_model_filter(self, model_filter: ModelFilter):
|
|
1274
1333
|
"""
|
|
@@ -1326,7 +1385,7 @@ class HfApi:
|
|
|
1326
1385
|
filter: Union[DatasetFilter, str, Iterable[str], None] = None,
|
|
1327
1386
|
author: Optional[str] = None,
|
|
1328
1387
|
search: Optional[str] = None,
|
|
1329
|
-
sort: Union[Literal["
|
|
1388
|
+
sort: Union[Literal["last_modified"], str, None] = None,
|
|
1330
1389
|
direction: Optional[Literal[-1]] = None,
|
|
1331
1390
|
limit: Optional[int] = None,
|
|
1332
1391
|
full: Optional[bool] = None,
|
|
@@ -1343,7 +1402,7 @@ class HfApi:
|
|
|
1343
1402
|
A string which identify the author of the returned datasets.
|
|
1344
1403
|
search (`str`, *optional*):
|
|
1345
1404
|
A string that will be contained in the returned datasets.
|
|
1346
|
-
sort (`Literal["
|
|
1405
|
+
sort (`Literal["last_modified"]` or `str`, *optional*):
|
|
1347
1406
|
The key with which to sort the resulting datasets. Possible
|
|
1348
1407
|
values are the properties of the [`huggingface_hub.hf_api.DatasetInfo`] class.
|
|
1349
1408
|
direction (`Literal[-1]` or `int`, *optional*):
|
|
@@ -1353,8 +1412,8 @@ class HfApi:
|
|
|
1353
1412
|
The limit on the number of datasets fetched. Leaving this option
|
|
1354
1413
|
to `None` fetches all datasets.
|
|
1355
1414
|
full (`bool`, *optional*):
|
|
1356
|
-
Whether to fetch all dataset data, including the `
|
|
1357
|
-
|
|
1415
|
+
Whether to fetch all dataset data, including the `last_modified`,
|
|
1416
|
+
the `card_data` and the files. Can contain useful information such as the
|
|
1358
1417
|
PapersWithCode ID.
|
|
1359
1418
|
token (`bool` or `str`, *optional*):
|
|
1360
1419
|
A valid authentication token (see https://huggingface.co/settings/token).
|
|
@@ -1375,16 +1434,12 @@ class HfApi:
|
|
|
1375
1434
|
>>> # List all datasets
|
|
1376
1435
|
>>> api.list_datasets()
|
|
1377
1436
|
|
|
1378
|
-
>>> # Get all valid search arguments
|
|
1379
|
-
>>> args = DatasetSearchArguments()
|
|
1380
1437
|
|
|
1381
1438
|
>>> # List only the text classification datasets
|
|
1382
1439
|
>>> api.list_datasets(filter="task_categories:text-classification")
|
|
1383
1440
|
>>> # Using the `DatasetFilter`
|
|
1384
1441
|
>>> filt = DatasetFilter(task_categories="text-classification")
|
|
1385
|
-
|
|
1386
|
-
>>> filt = DatasetFilter(task=args.task_categories.text_classification)
|
|
1387
|
-
>>> api.list_models(filter=filt)
|
|
1442
|
+
|
|
1388
1443
|
|
|
1389
1444
|
>>> # List only the datasets in russian for language modeling
|
|
1390
1445
|
>>> api.list_datasets(
|
|
@@ -1392,11 +1447,7 @@ class HfApi:
|
|
|
1392
1447
|
... )
|
|
1393
1448
|
>>> # Using the `DatasetFilter`
|
|
1394
1449
|
>>> filt = DatasetFilter(language="ru", task_ids="language-modeling")
|
|
1395
|
-
|
|
1396
|
-
>>> filt = DatasetFilter(
|
|
1397
|
-
... language=args.language.ru,
|
|
1398
|
-
... task_ids=args.task_ids.language_modeling,
|
|
1399
|
-
... )
|
|
1450
|
+
|
|
1400
1451
|
>>> api.list_datasets(filter=filt)
|
|
1401
1452
|
```
|
|
1402
1453
|
|
|
@@ -1427,7 +1478,7 @@ class HfApi:
|
|
|
1427
1478
|
if search is not None:
|
|
1428
1479
|
params.update({"search": search})
|
|
1429
1480
|
if sort is not None:
|
|
1430
|
-
params.update({"sort": sort})
|
|
1481
|
+
params.update({"sort": "lastModified" if sort == "last_modified" else sort})
|
|
1431
1482
|
if direction is not None:
|
|
1432
1483
|
params.update({"direction": direction})
|
|
1433
1484
|
if limit is not None:
|
|
@@ -1439,6 +1490,8 @@ class HfApi:
|
|
|
1439
1490
|
if limit is not None:
|
|
1440
1491
|
items = islice(items, limit) # Do not iterate over all pages
|
|
1441
1492
|
for item in items:
|
|
1493
|
+
if "siblings" not in item:
|
|
1494
|
+
item["siblings"] = None
|
|
1442
1495
|
yield DatasetInfo(**item)
|
|
1443
1496
|
|
|
1444
1497
|
def _unpack_dataset_filter(self, dataset_filter: DatasetFilter):
|
|
@@ -1502,7 +1555,7 @@ class HfApi:
|
|
|
1502
1555
|
filter: Union[str, Iterable[str], None] = None,
|
|
1503
1556
|
author: Optional[str] = None,
|
|
1504
1557
|
search: Optional[str] = None,
|
|
1505
|
-
sort: Union[Literal["
|
|
1558
|
+
sort: Union[Literal["last_modified"], str, None] = None,
|
|
1506
1559
|
direction: Optional[Literal[-1]] = None,
|
|
1507
1560
|
limit: Optional[int] = None,
|
|
1508
1561
|
datasets: Union[str, Iterable[str], None] = None,
|
|
@@ -1521,7 +1574,7 @@ class HfApi:
|
|
|
1521
1574
|
A string which identify the author of the returned Spaces.
|
|
1522
1575
|
search (`str`, *optional*):
|
|
1523
1576
|
A string that will be contained in the returned Spaces.
|
|
1524
|
-
sort (`Literal["
|
|
1577
|
+
sort (`Literal["last_modified"]` or `str`, *optional*):
|
|
1525
1578
|
The key with which to sort the resulting Spaces. Possible
|
|
1526
1579
|
values are the properties of the [`huggingface_hub.hf_api.SpaceInfo`]` class.
|
|
1527
1580
|
direction (`Literal[-1]` or `int`, *optional*):
|
|
@@ -1539,8 +1592,8 @@ class HfApi:
|
|
|
1539
1592
|
linked (`bool`, *optional*):
|
|
1540
1593
|
Whether to return Spaces that make use of either a model or a dataset.
|
|
1541
1594
|
full (`bool`, *optional*):
|
|
1542
|
-
Whether to fetch all Spaces data, including the `
|
|
1543
|
-
and
|
|
1595
|
+
Whether to fetch all Spaces data, including the `last_modified`, `siblings`
|
|
1596
|
+
and `card_data` fields.
|
|
1544
1597
|
token (`bool` or `str`, *optional*):
|
|
1545
1598
|
A valid authentication token (see https://huggingface.co/settings/token).
|
|
1546
1599
|
If `None` or `True` and machine is logged in (through `huggingface-cli login`
|
|
@@ -1560,7 +1613,7 @@ class HfApi:
|
|
|
1560
1613
|
if search is not None:
|
|
1561
1614
|
params.update({"search": search})
|
|
1562
1615
|
if sort is not None:
|
|
1563
|
-
params.update({"sort": sort})
|
|
1616
|
+
params.update({"sort": "lastModified" if sort == "last_modified" else sort})
|
|
1564
1617
|
if direction is not None:
|
|
1565
1618
|
params.update({"direction": direction})
|
|
1566
1619
|
if limit is not None:
|
|
@@ -1578,6 +1631,8 @@ class HfApi:
|
|
|
1578
1631
|
if limit is not None:
|
|
1579
1632
|
items = islice(items, limit) # Do not iterate over all pages
|
|
1580
1633
|
for item in items:
|
|
1634
|
+
if "siblings" not in item:
|
|
1635
|
+
item["siblings"] = None
|
|
1581
1636
|
yield SpaceInfo(**item)
|
|
1582
1637
|
|
|
1583
1638
|
@validate_hf_hub_args
|
|
@@ -1865,8 +1920,8 @@ class HfApi:
|
|
|
1865
1920
|
params["blobs"] = True
|
|
1866
1921
|
r = get_session().get(path, headers=headers, timeout=timeout, params=params)
|
|
1867
1922
|
hf_raise_for_status(r)
|
|
1868
|
-
|
|
1869
|
-
return ModelInfo(**
|
|
1923
|
+
data = r.json()
|
|
1924
|
+
return ModelInfo(**data)
|
|
1870
1925
|
|
|
1871
1926
|
@validate_hf_hub_args
|
|
1872
1927
|
def dataset_info(
|
|
@@ -1928,8 +1983,8 @@ class HfApi:
|
|
|
1928
1983
|
|
|
1929
1984
|
r = get_session().get(path, headers=headers, timeout=timeout, params=params)
|
|
1930
1985
|
hf_raise_for_status(r)
|
|
1931
|
-
|
|
1932
|
-
return DatasetInfo(**
|
|
1986
|
+
data = r.json()
|
|
1987
|
+
return DatasetInfo(**data)
|
|
1933
1988
|
|
|
1934
1989
|
@validate_hf_hub_args
|
|
1935
1990
|
def space_info(
|
|
@@ -1991,8 +2046,8 @@ class HfApi:
|
|
|
1991
2046
|
|
|
1992
2047
|
r = get_session().get(path, headers=headers, timeout=timeout, params=params)
|
|
1993
2048
|
hf_raise_for_status(r)
|
|
1994
|
-
|
|
1995
|
-
return SpaceInfo(**
|
|
2049
|
+
data = r.json()
|
|
2050
|
+
return SpaceInfo(**data)
|
|
1996
2051
|
|
|
1997
2052
|
@validate_hf_hub_args
|
|
1998
2053
|
def repo_info(
|
|
@@ -2231,8 +2286,8 @@ class HfApi:
|
|
|
2231
2286
|
<generator object HfApi.list_files_info at 0x7f93b848e730>
|
|
2232
2287
|
>>> list(files_info)
|
|
2233
2288
|
[
|
|
2234
|
-
RepoFile
|
|
2235
|
-
RepoFile
|
|
2289
|
+
RepoFile(path='README.md', size=391, blob_id='43bd404b159de6fba7c2f4d3264347668d43af25', lfs=None, last_commit=None, security=None),
|
|
2290
|
+
RepoFile(path='config.json', size=554, blob_id='2f9618c3a19b9a61add74f70bfb121335aeef666', lfs=None, last_commit=None, security=None)
|
|
2236
2291
|
]
|
|
2237
2292
|
```
|
|
2238
2293
|
|
|
@@ -2242,44 +2297,56 @@ class HfApi:
|
|
|
2242
2297
|
>>> files_info = list_files_info("prompthero/openjourney-v4", expand=True)
|
|
2243
2298
|
>>> list(files_info)
|
|
2244
2299
|
[
|
|
2245
|
-
RepoFile
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
|
|
2251
|
-
|
|
2252
|
-
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
'
|
|
2271
|
-
'
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2300
|
+
RepoFile(
|
|
2301
|
+
path='safety_checker/pytorch_model.bin',
|
|
2302
|
+
size=1216064769,
|
|
2303
|
+
blob_id='c8835557a0d3af583cb06c7c154b7e54a069c41d',
|
|
2304
|
+
lfs={
|
|
2305
|
+
'size': 1216064769,
|
|
2306
|
+
'sha256': '16d28f2b37109f222cdc33620fdd262102ac32112be0352a7f77e9614b35a394',
|
|
2307
|
+
'pointer_size': 135
|
|
2308
|
+
},
|
|
2309
|
+
last_commit={
|
|
2310
|
+
'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190',
|
|
2311
|
+
'title': 'Upload diffusers weights (#48)',
|
|
2312
|
+
'date': datetime.datetime(2023, 3, 21, 10, 5, 27, tzinfo=datetime.timezone.utc)
|
|
2313
|
+
},
|
|
2314
|
+
security={
|
|
2315
|
+
'safe': True,
|
|
2316
|
+
'av_scan': {
|
|
2317
|
+
'virusFound': False,
|
|
2318
|
+
'virusNames': None
|
|
2319
|
+
},
|
|
2320
|
+
'pickle_import_scan': {
|
|
2321
|
+
'highestSafetyLevel': 'innocuous',
|
|
2322
|
+
'imports': [
|
|
2323
|
+
{'module': 'torch', 'name': 'FloatStorage', 'safety': 'innocuous'},
|
|
2324
|
+
{'module': 'collections', 'name': 'OrderedDict', 'safety': 'innocuous'},
|
|
2325
|
+
{'module': 'torch', 'name': 'LongStorage', 'safety': 'innocuous'},
|
|
2326
|
+
{'module': 'torch._utils', 'name': '_rebuild_tensor_v2', 'safety': 'innocuous'}
|
|
2327
|
+
]
|
|
2328
|
+
}
|
|
2329
|
+
}
|
|
2330
|
+
),
|
|
2331
|
+
RepoFile(
|
|
2332
|
+
path='scheduler/scheduler_config.json',
|
|
2333
|
+
size=465,
|
|
2334
|
+
blob_id='70d55e3e9679f41cbc66222831b38d5abce683dd',
|
|
2335
|
+
lfs=None,
|
|
2336
|
+
last_commit={
|
|
2337
|
+
'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190',
|
|
2338
|
+
'title': 'Upload diffusers weights (#48)',
|
|
2339
|
+
'date': datetime.datetime(2023, 3, 21, 10, 5, 27, tzinfo=datetime.timezone.utc)},
|
|
2340
|
+
security={
|
|
2341
|
+
'safe': True,
|
|
2342
|
+
'av_scan': {
|
|
2343
|
+
'virusFound': False,
|
|
2344
|
+
'virusNames': None
|
|
2345
|
+
},
|
|
2346
|
+
'pickle_import_scan': None
|
|
2347
|
+
}
|
|
2348
|
+
),
|
|
2349
|
+
...
|
|
2283
2350
|
]
|
|
2284
2351
|
```
|
|
2285
2352
|
|
|
@@ -2287,14 +2354,14 @@ class HfApi:
|
|
|
2287
2354
|
|
|
2288
2355
|
```py
|
|
2289
2356
|
>>> from huggingface_hub import list_files_info
|
|
2290
|
-
>>> [info.
|
|
2357
|
+
>>> [info.path for info in list_files_info("stabilityai/stable-diffusion-2", "vae") if info.lfs is not None]
|
|
2291
2358
|
['vae/diffusion_pytorch_model.bin', 'vae/diffusion_pytorch_model.safetensors']
|
|
2292
2359
|
```
|
|
2293
2360
|
|
|
2294
2361
|
List all files on a repo.
|
|
2295
2362
|
```py
|
|
2296
2363
|
>>> from huggingface_hub import list_files_info
|
|
2297
|
-
>>> [info.
|
|
2364
|
+
>>> [info.path for info in list_files_info("glue", repo_type="dataset")]
|
|
2298
2365
|
['.gitattributes', 'README.md', 'dataset_infos.json', 'glue.py']
|
|
2299
2366
|
```
|
|
2300
2367
|
"""
|
|
@@ -2305,14 +2372,24 @@ class HfApi:
|
|
|
2305
2372
|
def _format_as_repo_file(info: Dict) -> RepoFile:
|
|
2306
2373
|
# Quick alias very specific to the server return type of /paths-info and /tree endpoints. Let's keep this
|
|
2307
2374
|
# logic here.
|
|
2308
|
-
|
|
2375
|
+
path = info.pop("path")
|
|
2309
2376
|
size = info.pop("size")
|
|
2310
|
-
|
|
2377
|
+
blob_id = info.pop("oid")
|
|
2311
2378
|
lfs = info.pop("lfs", None)
|
|
2379
|
+
last_commit = info.pop("lastCommit", None)
|
|
2380
|
+
security = info.pop("security", None)
|
|
2312
2381
|
info.pop("type", None) # "file" or "folder" -> not needed in practice since we know it's a file
|
|
2382
|
+
if last_commit is not None:
|
|
2383
|
+
last_commit = BlobLastCommitInfo(
|
|
2384
|
+
oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"])
|
|
2385
|
+
)
|
|
2386
|
+
if security is not None:
|
|
2387
|
+
security = BlobSecurityInfo(
|
|
2388
|
+
safe=security["safe"], av_scan=security["avScan"], pickle_import_scan=security["pickleImportScan"]
|
|
2389
|
+
)
|
|
2313
2390
|
if lfs is not None:
|
|
2314
2391
|
lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"])
|
|
2315
|
-
return RepoFile(
|
|
2392
|
+
return RepoFile(path=path, size=size, blob_id=blob_id, lfs=lfs, last_commit=last_commit, security=security)
|
|
2316
2393
|
|
|
2317
2394
|
folder_paths = []
|
|
2318
2395
|
if paths is None:
|
|
@@ -2327,7 +2404,7 @@ class HfApi:
|
|
|
2327
2404
|
f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}",
|
|
2328
2405
|
data={
|
|
2329
2406
|
"paths": paths if isinstance(paths, list) else [paths],
|
|
2330
|
-
"expand":
|
|
2407
|
+
"expand": expand,
|
|
2331
2408
|
},
|
|
2332
2409
|
headers=headers,
|
|
2333
2410
|
)
|
|
@@ -2440,10 +2517,14 @@ class HfApi:
|
|
|
2440
2517
|
)
|
|
2441
2518
|
hf_raise_for_status(response)
|
|
2442
2519
|
data = response.json()
|
|
2520
|
+
|
|
2521
|
+
def _format_as_git_ref_info(item: Dict) -> GitRefInfo:
|
|
2522
|
+
return GitRefInfo(name=item["name"], ref=item["ref"], target_commit=item["targetCommit"])
|
|
2523
|
+
|
|
2443
2524
|
return GitRefs(
|
|
2444
|
-
branches=[
|
|
2445
|
-
converts=[
|
|
2446
|
-
tags=[
|
|
2525
|
+
branches=[_format_as_git_ref_info(item) for item in data["branches"]],
|
|
2526
|
+
converts=[_format_as_git_ref_info(item) for item in data["converts"]],
|
|
2527
|
+
tags=[_format_as_git_ref_info(item) for item in data["tags"]],
|
|
2447
2528
|
)
|
|
2448
2529
|
|
|
2449
2530
|
@validate_hf_hub_args
|
|
@@ -2516,7 +2597,15 @@ class HfApi:
|
|
|
2516
2597
|
|
|
2517
2598
|
# Paginate over results and return the list of commits.
|
|
2518
2599
|
return [
|
|
2519
|
-
GitCommitInfo(
|
|
2600
|
+
GitCommitInfo(
|
|
2601
|
+
commit_id=item["id"],
|
|
2602
|
+
authors=[author["user"] for author in item["authors"]],
|
|
2603
|
+
created_at=parse_datetime(item["date"]),
|
|
2604
|
+
title=item["title"],
|
|
2605
|
+
message=item["message"],
|
|
2606
|
+
formatted_title=item.get("formatted", {}).get("title"),
|
|
2607
|
+
formatted_message=item.get("formatted", {}).get("message"),
|
|
2608
|
+
)
|
|
2520
2609
|
for item in paginate(
|
|
2521
2610
|
f"{self.endpoint}/api/{repo_type}s/{repo_id}/commits/{revision}",
|
|
2522
2611
|
headers=self._build_hf_headers(token=token),
|
|
@@ -4235,8 +4324,9 @@ class HfApi:
|
|
|
4235
4324
|
force_download: bool = False,
|
|
4236
4325
|
force_filename: Optional[str] = None,
|
|
4237
4326
|
proxies: Optional[Dict] = None,
|
|
4238
|
-
etag_timeout: float =
|
|
4327
|
+
etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
|
|
4239
4328
|
resume_download: bool = False,
|
|
4329
|
+
token: Optional[Union[str, bool]] = None,
|
|
4240
4330
|
local_files_only: bool = False,
|
|
4241
4331
|
legacy_cache_layout: bool = False,
|
|
4242
4332
|
) -> str:
|
|
@@ -4298,9 +4388,6 @@ class HfApi:
|
|
|
4298
4388
|
revision (`str`, *optional*):
|
|
4299
4389
|
An optional Git revision id which can be a branch name, a tag, or a
|
|
4300
4390
|
commit hash.
|
|
4301
|
-
endpoint (`str`, *optional*):
|
|
4302
|
-
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
|
|
4303
|
-
environment variable.
|
|
4304
4391
|
cache_dir (`str`, `Path`, *optional*):
|
|
4305
4392
|
Path to the folder where cached files are stored.
|
|
4306
4393
|
local_dir (`str` or `Path`, *optional*):
|
|
@@ -4322,6 +4409,11 @@ class HfApi:
|
|
|
4322
4409
|
data before giving up which is passed to `requests.request`.
|
|
4323
4410
|
resume_download (`bool`, *optional*, defaults to `False`):
|
|
4324
4411
|
If `True`, resume a previously interrupted download.
|
|
4412
|
+
token (`bool` or `str`, *optional*):
|
|
4413
|
+
A valid authentication token (see https://huggingface.co/settings/token).
|
|
4414
|
+
If `None` or `True` and machine is logged in (through `huggingface-cli login`
|
|
4415
|
+
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
|
|
4416
|
+
If `False`, token is not sent in the request header.
|
|
4325
4417
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
4326
4418
|
If `True`, avoid downloading the file and return the path to the
|
|
4327
4419
|
local cached file if it exists.
|
|
@@ -4358,6 +4450,10 @@ class HfApi:
|
|
|
4358
4450
|
"""
|
|
4359
4451
|
from .file_download import hf_hub_download
|
|
4360
4452
|
|
|
4453
|
+
if token is None:
|
|
4454
|
+
# Cannot do `token = token or self.token` as token can be `False`.
|
|
4455
|
+
token = self.token
|
|
4456
|
+
|
|
4361
4457
|
return hf_hub_download(
|
|
4362
4458
|
repo_id=repo_id,
|
|
4363
4459
|
filename=filename,
|
|
@@ -4376,7 +4472,7 @@ class HfApi:
|
|
|
4376
4472
|
proxies=proxies,
|
|
4377
4473
|
etag_timeout=etag_timeout,
|
|
4378
4474
|
resume_download=resume_download,
|
|
4379
|
-
token=
|
|
4475
|
+
token=token,
|
|
4380
4476
|
local_files_only=local_files_only,
|
|
4381
4477
|
legacy_cache_layout=legacy_cache_layout,
|
|
4382
4478
|
)
|
|
@@ -4392,9 +4488,10 @@ class HfApi:
|
|
|
4392
4488
|
local_dir: Union[str, Path, None] = None,
|
|
4393
4489
|
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
|
4394
4490
|
proxies: Optional[Dict] = None,
|
|
4395
|
-
etag_timeout: float =
|
|
4491
|
+
etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
|
|
4396
4492
|
resume_download: bool = False,
|
|
4397
4493
|
force_download: bool = False,
|
|
4494
|
+
token: Optional[Union[str, bool]] = None,
|
|
4398
4495
|
local_files_only: bool = False,
|
|
4399
4496
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
4400
4497
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
@@ -4455,6 +4552,11 @@ class HfApi:
|
|
|
4455
4552
|
If `True`, resume a previously interrupted download.
|
|
4456
4553
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
4457
4554
|
Whether the file should be downloaded even if it already exists in the local cache.
|
|
4555
|
+
token (`bool` or `str`, *optional*):
|
|
4556
|
+
A valid authentication token (see https://huggingface.co/settings/token).
|
|
4557
|
+
If `None` or `True` and machine is logged in (through `huggingface-cli login`
|
|
4558
|
+
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
|
|
4559
|
+
If `False`, token is not sent in the request header.
|
|
4458
4560
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
4459
4561
|
If `True`, avoid downloading the file and return the path to the
|
|
4460
4562
|
local cached file if it exists.
|
|
@@ -4490,6 +4592,10 @@ class HfApi:
|
|
|
4490
4592
|
"""
|
|
4491
4593
|
from ._snapshot_download import snapshot_download
|
|
4492
4594
|
|
|
4595
|
+
if token is None:
|
|
4596
|
+
# Cannot do `token = token or self.token` as token can be `False`.
|
|
4597
|
+
token = self.token
|
|
4598
|
+
|
|
4493
4599
|
return snapshot_download(
|
|
4494
4600
|
repo_id=repo_id,
|
|
4495
4601
|
repo_type=repo_type,
|
|
@@ -4505,7 +4611,7 @@ class HfApi:
|
|
|
4505
4611
|
etag_timeout=etag_timeout,
|
|
4506
4612
|
resume_download=resume_download,
|
|
4507
4613
|
force_download=force_download,
|
|
4508
|
-
token=
|
|
4614
|
+
token=token,
|
|
4509
4615
|
local_files_only=local_files_only,
|
|
4510
4616
|
allow_patterns=allow_patterns,
|
|
4511
4617
|
ignore_patterns=ignore_patterns,
|
|
@@ -6002,6 +6108,446 @@ class HfApi:
|
|
|
6002
6108
|
hf_raise_for_status(r)
|
|
6003
6109
|
return SpaceRuntime(r.json())
|
|
6004
6110
|
|
|
6111
|
+
#######################
|
|
6112
|
+
# Inference Endpoints #
|
|
6113
|
+
#######################
|
|
6114
|
+
|
|
6115
|
+
def list_inference_endpoints(
|
|
6116
|
+
self, namespace: Optional[str] = None, *, token: Optional[str] = None
|
|
6117
|
+
) -> List[InferenceEndpoint]:
|
|
6118
|
+
"""Lists all inference endpoints for the given namespace.
|
|
6119
|
+
|
|
6120
|
+
Args:
|
|
6121
|
+
namespace (`str`, *optional*):
|
|
6122
|
+
The namespace to list endpoints for. Defaults to the current user. Set to `"*"` to list all endpoints
|
|
6123
|
+
from all namespaces (i.e. personal namespace and all orgs the user belongs to).
|
|
6124
|
+
token (`str`, *optional*):
|
|
6125
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6126
|
+
|
|
6127
|
+
Returns:
|
|
6128
|
+
List[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace.
|
|
6129
|
+
|
|
6130
|
+
Example:
|
|
6131
|
+
```python
|
|
6132
|
+
>>> from huggingface_hub import HfApi
|
|
6133
|
+
>>> api = HfApi()
|
|
6134
|
+
>>> api.list_inference_endpoints()
|
|
6135
|
+
[InferenceEndpoint(name='my-endpoint', ...), ...]
|
|
6136
|
+
```
|
|
6137
|
+
"""
|
|
6138
|
+
# Special case: list all endpoints for all namespaces the user has access to
|
|
6139
|
+
if namespace == "*":
|
|
6140
|
+
user = self.whoami(token=token)
|
|
6141
|
+
|
|
6142
|
+
# List personal endpoints first
|
|
6143
|
+
endpoints: List[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token))
|
|
6144
|
+
|
|
6145
|
+
# Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access)
|
|
6146
|
+
for org in user.get("orgs", []):
|
|
6147
|
+
try:
|
|
6148
|
+
endpoints += list_inference_endpoints(namespace=org["name"], token=token)
|
|
6149
|
+
except HfHubHTTPError as error:
|
|
6150
|
+
if error.response.status_code == 401: # Either no billing or user don't have access)
|
|
6151
|
+
logger.debug("Cannot list Inference Endpoints for org '%s': %s", org["name"], error)
|
|
6152
|
+
pass
|
|
6153
|
+
|
|
6154
|
+
return endpoints
|
|
6155
|
+
|
|
6156
|
+
# Normal case: list endpoints for a specific namespace
|
|
6157
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6158
|
+
|
|
6159
|
+
response = get_session().get(
|
|
6160
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}",
|
|
6161
|
+
headers=self._build_hf_headers(token=token),
|
|
6162
|
+
)
|
|
6163
|
+
hf_raise_for_status(response)
|
|
6164
|
+
|
|
6165
|
+
return [
|
|
6166
|
+
InferenceEndpoint.from_raw(endpoint, namespace=namespace, token=token)
|
|
6167
|
+
for endpoint in response.json()["items"]
|
|
6168
|
+
]
|
|
6169
|
+
|
|
6170
|
+
def create_inference_endpoint(
|
|
6171
|
+
self,
|
|
6172
|
+
name: str,
|
|
6173
|
+
*,
|
|
6174
|
+
repository: str,
|
|
6175
|
+
framework: str,
|
|
6176
|
+
accelerator: str,
|
|
6177
|
+
instance_size: str,
|
|
6178
|
+
instance_type: str,
|
|
6179
|
+
region: str,
|
|
6180
|
+
vendor: str,
|
|
6181
|
+
account_id: Optional[str] = None,
|
|
6182
|
+
min_replica: int = 0,
|
|
6183
|
+
max_replica: int = 1,
|
|
6184
|
+
revision: Optional[str] = None,
|
|
6185
|
+
task: Optional[str] = None,
|
|
6186
|
+
type: InferenceEndpointType = InferenceEndpointType.PROTECTED,
|
|
6187
|
+
namespace: Optional[str] = None,
|
|
6188
|
+
token: Optional[str] = None,
|
|
6189
|
+
) -> InferenceEndpoint:
|
|
6190
|
+
"""Create a new Inference Endpoint.
|
|
6191
|
+
|
|
6192
|
+
Args:
|
|
6193
|
+
name (`str`):
|
|
6194
|
+
The unique name for the new Inference Endpoint.
|
|
6195
|
+
repository (`str`):
|
|
6196
|
+
The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
|
|
6197
|
+
framework (`str`):
|
|
6198
|
+
The machine learning framework used for the model (e.g. `"custom"`).
|
|
6199
|
+
accelerator (`str`):
|
|
6200
|
+
The hardware accelerator to be used for inference (e.g. `"cpu"`).
|
|
6201
|
+
instance_size (`str`):
|
|
6202
|
+
The size or type of the instance to be used for hosting the model (e.g. `"large"`).
|
|
6203
|
+
instance_type (`str`):
|
|
6204
|
+
The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`).
|
|
6205
|
+
region (`str`):
|
|
6206
|
+
The cloud region in which the Inference Endpoint will be created (e.g. `"us-east-1"`).
|
|
6207
|
+
vendor (`str`):
|
|
6208
|
+
The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. `"aws"`).
|
|
6209
|
+
account_id (`str`, *optional*):
|
|
6210
|
+
The account ID used to link a VPC to a private Inference Endpoint (if applicable).
|
|
6211
|
+
min_replica (`int`, *optional*):
|
|
6212
|
+
The minimum number of replicas (instances) to keep running for the Inference Endpoint. Defaults to 0.
|
|
6213
|
+
max_replica (`int`, *optional*):
|
|
6214
|
+
The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1.
|
|
6215
|
+
revision (`str`, *optional*):
|
|
6216
|
+
The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
|
|
6217
|
+
task (`str`, *optional*):
|
|
6218
|
+
The task on which to deploy the model (e.g. `"text-classification"`).
|
|
6219
|
+
type ([`InferenceEndpointType]`, *optional*):
|
|
6220
|
+
The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`.
|
|
6221
|
+
namespace (`str`, *optional*):
|
|
6222
|
+
The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace.
|
|
6223
|
+
token (`str`, *optional*):
|
|
6224
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6225
|
+
|
|
6226
|
+
Returns:
|
|
6227
|
+
[`InferenceEndpoint`]: information about the updated Inference Endpoint.
|
|
6228
|
+
|
|
6229
|
+
Example:
|
|
6230
|
+
```python
|
|
6231
|
+
>>> from huggingface_hub import HfApi
|
|
6232
|
+
>>> api = HfApi()
|
|
6233
|
+
>>> create_inference_endpoint(
|
|
6234
|
+
... "my-endpoint-name",
|
|
6235
|
+
... repository="gpt2",
|
|
6236
|
+
... framework="pytorch",
|
|
6237
|
+
... task="text-generation",
|
|
6238
|
+
... accelerator="cpu",
|
|
6239
|
+
... vendor="aws",
|
|
6240
|
+
... region="us-east-1",
|
|
6241
|
+
... type="protected",
|
|
6242
|
+
... instance_size="medium",
|
|
6243
|
+
... instance_type="c6i"
|
|
6244
|
+
... )
|
|
6245
|
+
>>> endpoint
|
|
6246
|
+
InferenceEndpoint(name='my-endpoint-name', status="pending",...)
|
|
6247
|
+
|
|
6248
|
+
# Run inference on the endpoint
|
|
6249
|
+
>>> endpoint.client.text_generation(...)
|
|
6250
|
+
"..."
|
|
6251
|
+
```
|
|
6252
|
+
"""
|
|
6253
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6254
|
+
|
|
6255
|
+
payload: Dict = {
|
|
6256
|
+
"accountId": account_id,
|
|
6257
|
+
"compute": {
|
|
6258
|
+
"accelerator": accelerator,
|
|
6259
|
+
"instanceSize": instance_size,
|
|
6260
|
+
"instanceType": instance_type,
|
|
6261
|
+
"scaling": {
|
|
6262
|
+
"maxReplica": max_replica,
|
|
6263
|
+
"minReplica": min_replica,
|
|
6264
|
+
},
|
|
6265
|
+
},
|
|
6266
|
+
"model": {
|
|
6267
|
+
"framework": framework,
|
|
6268
|
+
"repository": repository,
|
|
6269
|
+
"revision": revision,
|
|
6270
|
+
"task": task,
|
|
6271
|
+
"image": {"huggingface": {}},
|
|
6272
|
+
},
|
|
6273
|
+
"name": name,
|
|
6274
|
+
"provider": {
|
|
6275
|
+
"region": region,
|
|
6276
|
+
"vendor": vendor,
|
|
6277
|
+
},
|
|
6278
|
+
"type": type,
|
|
6279
|
+
}
|
|
6280
|
+
|
|
6281
|
+
response = get_session().post(
|
|
6282
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}",
|
|
6283
|
+
headers=self._build_hf_headers(token=token),
|
|
6284
|
+
json=payload,
|
|
6285
|
+
)
|
|
6286
|
+
hf_raise_for_status(response)
|
|
6287
|
+
|
|
6288
|
+
return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
|
|
6289
|
+
|
|
6290
|
+
def get_inference_endpoint(
|
|
6291
|
+
self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
|
|
6292
|
+
) -> InferenceEndpoint:
|
|
6293
|
+
"""Get information about an Inference Endpoint.
|
|
6294
|
+
|
|
6295
|
+
Args:
|
|
6296
|
+
name (`str`):
|
|
6297
|
+
The name of the Inference Endpoint to retrieve information about.
|
|
6298
|
+
namespace (`str`, *optional*):
|
|
6299
|
+
The namespace in which the Inference Endpoint is located. Defaults to the current user.
|
|
6300
|
+
token (`str`, *optional*):
|
|
6301
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6302
|
+
|
|
6303
|
+
Returns:
|
|
6304
|
+
[`InferenceEndpoint`]: information about the requested Inference Endpoint.
|
|
6305
|
+
|
|
6306
|
+
Example:
|
|
6307
|
+
```python
|
|
6308
|
+
>>> from huggingface_hub import HfApi
|
|
6309
|
+
>>> api = HfApi()
|
|
6310
|
+
>>> endpoint = api.get_inference_endpoint("my-text-to-image")
|
|
6311
|
+
>>> endpoint
|
|
6312
|
+
InferenceEndpoint(name='my-text-to-image', ...)
|
|
6313
|
+
|
|
6314
|
+
# Get status
|
|
6315
|
+
>>> endpoint.status
|
|
6316
|
+
'running'
|
|
6317
|
+
>>> endpoint.url
|
|
6318
|
+
'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
|
|
6319
|
+
|
|
6320
|
+
# Run inference
|
|
6321
|
+
>>> endpoint.client.text_to_image(...)
|
|
6322
|
+
```
|
|
6323
|
+
"""
|
|
6324
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6325
|
+
|
|
6326
|
+
response = get_session().get(
|
|
6327
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
|
|
6328
|
+
headers=self._build_hf_headers(token=token),
|
|
6329
|
+
)
|
|
6330
|
+
hf_raise_for_status(response)
|
|
6331
|
+
|
|
6332
|
+
return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
|
|
6333
|
+
|
|
6334
|
+
def update_inference_endpoint(
|
|
6335
|
+
self,
|
|
6336
|
+
name: str,
|
|
6337
|
+
*,
|
|
6338
|
+
# Compute update
|
|
6339
|
+
accelerator: Optional[str] = None,
|
|
6340
|
+
instance_size: Optional[str] = None,
|
|
6341
|
+
instance_type: Optional[str] = None,
|
|
6342
|
+
min_replica: Optional[int] = None,
|
|
6343
|
+
max_replica: Optional[int] = None,
|
|
6344
|
+
# Model update
|
|
6345
|
+
repository: Optional[str] = None,
|
|
6346
|
+
framework: Optional[str] = None,
|
|
6347
|
+
revision: Optional[str] = None,
|
|
6348
|
+
task: Optional[str] = None,
|
|
6349
|
+
# Other
|
|
6350
|
+
namespace: Optional[str] = None,
|
|
6351
|
+
token: Optional[str] = None,
|
|
6352
|
+
) -> InferenceEndpoint:
|
|
6353
|
+
"""Update an Inference Endpoint.
|
|
6354
|
+
|
|
6355
|
+
This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
|
|
6356
|
+
optional but at least one must be provided.
|
|
6357
|
+
|
|
6358
|
+
For convenience, you can also update an Inference Endpoint using [`InferenceEndpoint.update`].
|
|
6359
|
+
|
|
6360
|
+
Args:
|
|
6361
|
+
name (`str`):
|
|
6362
|
+
The name of the Inference Endpoint to update.
|
|
6363
|
+
|
|
6364
|
+
accelerator (`str`, *optional*):
|
|
6365
|
+
The hardware accelerator to be used for inference (e.g. `"cpu"`).
|
|
6366
|
+
instance_size (`str`, *optional*):
|
|
6367
|
+
The size or type of the instance to be used for hosting the model (e.g. `"large"`).
|
|
6368
|
+
instance_type (`str`, *optional*):
|
|
6369
|
+
The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`).
|
|
6370
|
+
min_replica (`int`, *optional*):
|
|
6371
|
+
The minimum number of replicas (instances) to keep running for the Inference Endpoint.
|
|
6372
|
+
max_replica (`int`, *optional*):
|
|
6373
|
+
The maximum number of replicas (instances) to scale to for the Inference Endpoint.
|
|
6374
|
+
|
|
6375
|
+
repository (`str`, *optional*):
|
|
6376
|
+
The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
|
|
6377
|
+
framework (`str`, *optional*):
|
|
6378
|
+
The machine learning framework used for the model (e.g. `"custom"`).
|
|
6379
|
+
revision (`str`, *optional*):
|
|
6380
|
+
The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
|
|
6381
|
+
task (`str`, *optional*):
|
|
6382
|
+
The task on which to deploy the model (e.g. `"text-classification"`).
|
|
6383
|
+
|
|
6384
|
+
namespace (`str`, *optional*):
|
|
6385
|
+
The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace.
|
|
6386
|
+
token (`str`, *optional*):
|
|
6387
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6388
|
+
|
|
6389
|
+
Returns:
|
|
6390
|
+
[`InferenceEndpoint`]: information about the updated Inference Endpoint.
|
|
6391
|
+
"""
|
|
6392
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6393
|
+
|
|
6394
|
+
payload: Dict = {}
|
|
6395
|
+
if any(value is not None for value in (accelerator, instance_size, instance_type, min_replica, max_replica)):
|
|
6396
|
+
payload["compute"] = {
|
|
6397
|
+
"accelerator": accelerator,
|
|
6398
|
+
"instanceSize": instance_size,
|
|
6399
|
+
"instanceType": instance_type,
|
|
6400
|
+
"scaling": {
|
|
6401
|
+
"maxReplica": max_replica,
|
|
6402
|
+
"minReplica": min_replica,
|
|
6403
|
+
},
|
|
6404
|
+
}
|
|
6405
|
+
if any(value is not None for value in (repository, framework, revision, task)):
|
|
6406
|
+
payload["model"] = {
|
|
6407
|
+
"framework": framework,
|
|
6408
|
+
"repository": repository,
|
|
6409
|
+
"revision": revision,
|
|
6410
|
+
"task": task,
|
|
6411
|
+
"image": {"huggingface": {}},
|
|
6412
|
+
}
|
|
6413
|
+
|
|
6414
|
+
response = get_session().put(
|
|
6415
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
|
|
6416
|
+
headers=self._build_hf_headers(token=token),
|
|
6417
|
+
json=payload,
|
|
6418
|
+
)
|
|
6419
|
+
hf_raise_for_status(response)
|
|
6420
|
+
|
|
6421
|
+
return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
|
|
6422
|
+
|
|
6423
|
+
def delete_inference_endpoint(
|
|
6424
|
+
self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
|
|
6425
|
+
) -> None:
|
|
6426
|
+
"""Delete an Inference Endpoint.
|
|
6427
|
+
|
|
6428
|
+
This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
|
|
6429
|
+
to pause it with [`pause_inference_endpoint`] or scale it to zero with [`scale_to_zero_inference_endpoint`].
|
|
6430
|
+
|
|
6431
|
+
For convenience, you can also delete an Inference Endpoint using [`InferenceEndpoint.delete`].
|
|
6432
|
+
|
|
6433
|
+
Args:
|
|
6434
|
+
name (`str`):
|
|
6435
|
+
The name of the Inference Endpoint to delete.
|
|
6436
|
+
namespace (`str`, *optional*):
|
|
6437
|
+
The namespace in which the Inference Endpoint is located. Defaults to the current user.
|
|
6438
|
+
token (`str`, *optional*):
|
|
6439
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6440
|
+
"""
|
|
6441
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6442
|
+
response = get_session().delete(
|
|
6443
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
|
|
6444
|
+
headers=self._build_hf_headers(token=token),
|
|
6445
|
+
)
|
|
6446
|
+
hf_raise_for_status(response)
|
|
6447
|
+
|
|
6448
|
+
def pause_inference_endpoint(
|
|
6449
|
+
self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
|
|
6450
|
+
) -> InferenceEndpoint:
|
|
6451
|
+
"""Pause an Inference Endpoint.
|
|
6452
|
+
|
|
6453
|
+
A paused Inference Endpoint will not be charged. It can be resumed at any time using [`resume_inference_endpoint`].
|
|
6454
|
+
This is different than scaling the Inference Endpoint to zero with [`scale_to_zero_inference_endpoint`], which
|
|
6455
|
+
would be automatically restarted when a request is made to it.
|
|
6456
|
+
|
|
6457
|
+
For convenience, you can also pause an Inference Endpoint using [`pause_inference_endpoint`].
|
|
6458
|
+
|
|
6459
|
+
Args:
|
|
6460
|
+
name (`str`):
|
|
6461
|
+
The name of the Inference Endpoint to pause.
|
|
6462
|
+
namespace (`str`, *optional*):
|
|
6463
|
+
The namespace in which the Inference Endpoint is located. Defaults to the current user.
|
|
6464
|
+
token (`str`, *optional*):
|
|
6465
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6466
|
+
|
|
6467
|
+
Returns:
|
|
6468
|
+
[`InferenceEndpoint`]: information about the paused Inference Endpoint.
|
|
6469
|
+
"""
|
|
6470
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6471
|
+
|
|
6472
|
+
response = get_session().post(
|
|
6473
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause",
|
|
6474
|
+
headers=self._build_hf_headers(token=token),
|
|
6475
|
+
)
|
|
6476
|
+
hf_raise_for_status(response)
|
|
6477
|
+
|
|
6478
|
+
return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
|
|
6479
|
+
|
|
6480
|
+
def resume_inference_endpoint(
|
|
6481
|
+
self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
|
|
6482
|
+
) -> InferenceEndpoint:
|
|
6483
|
+
"""Resume an Inference Endpoint.
|
|
6484
|
+
|
|
6485
|
+
For convenience, you can also resume an Inference Endpoint using [`InferenceEndpoint.resume`].
|
|
6486
|
+
|
|
6487
|
+
Args:
|
|
6488
|
+
name (`str`):
|
|
6489
|
+
The name of the Inference Endpoint to resume.
|
|
6490
|
+
namespace (`str`, *optional*):
|
|
6491
|
+
The namespace in which the Inference Endpoint is located. Defaults to the current user.
|
|
6492
|
+
token (`str`, *optional*):
|
|
6493
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6494
|
+
|
|
6495
|
+
Returns:
|
|
6496
|
+
[`InferenceEndpoint`]: information about the resumed Inference Endpoint.
|
|
6497
|
+
"""
|
|
6498
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6499
|
+
|
|
6500
|
+
response = get_session().post(
|
|
6501
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume",
|
|
6502
|
+
headers=self._build_hf_headers(token=token),
|
|
6503
|
+
)
|
|
6504
|
+
hf_raise_for_status(response)
|
|
6505
|
+
|
|
6506
|
+
return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
|
|
6507
|
+
|
|
6508
|
+
def scale_to_zero_inference_endpoint(
|
|
6509
|
+
self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
|
|
6510
|
+
) -> InferenceEndpoint:
|
|
6511
|
+
"""Scale Inference Endpoint to zero.
|
|
6512
|
+
|
|
6513
|
+
An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a
|
|
6514
|
+
cold start delay. This is different than pausing the Inference Endpoint with [`pause_inference_endpoint`], which
|
|
6515
|
+
would require a manual resume with [`resume_inference_endpoint`].
|
|
6516
|
+
|
|
6517
|
+
For convenience, you can also scale an Inference Endpoint to zero using [`InferenceEndpoint.scale_to_zero`].
|
|
6518
|
+
|
|
6519
|
+
Args:
|
|
6520
|
+
name (`str`):
|
|
6521
|
+
The name of the Inference Endpoint to scale to zero.
|
|
6522
|
+
namespace (`str`, *optional*):
|
|
6523
|
+
The namespace in which the Inference Endpoint is located. Defaults to the current user.
|
|
6524
|
+
token (`str`, *optional*):
|
|
6525
|
+
An authentication token (See https://huggingface.co/settings/token).
|
|
6526
|
+
|
|
6527
|
+
Returns:
|
|
6528
|
+
[`InferenceEndpoint`]: information about the scaled-to-zero Inference Endpoint.
|
|
6529
|
+
"""
|
|
6530
|
+
namespace = namespace or self._get_namespace(token=token)
|
|
6531
|
+
|
|
6532
|
+
response = get_session().post(
|
|
6533
|
+
f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero",
|
|
6534
|
+
headers=self._build_hf_headers(token=token),
|
|
6535
|
+
)
|
|
6536
|
+
hf_raise_for_status(response)
|
|
6537
|
+
|
|
6538
|
+
return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
|
|
6539
|
+
|
|
6540
|
+
def _get_namespace(self, token: Optional[str] = None) -> str:
|
|
6541
|
+
"""Get the default namespace for the current user."""
|
|
6542
|
+
me = self.whoami(token=token)
|
|
6543
|
+
if me["type"] == "user":
|
|
6544
|
+
return me["name"]
|
|
6545
|
+
else:
|
|
6546
|
+
raise ValueError(
|
|
6547
|
+
"Cannot determine default namespace. You must provide a 'namespace' as input or be logged in as a"
|
|
6548
|
+
" user."
|
|
6549
|
+
)
|
|
6550
|
+
|
|
6005
6551
|
########################
|
|
6006
6552
|
# Collection Endpoints #
|
|
6007
6553
|
########################
|
|
@@ -6027,21 +6573,20 @@ class HfApi:
|
|
|
6027
6573
|
>>> len(collection.items)
|
|
6028
6574
|
37
|
|
6029
6575
|
>>> collection.items[0]
|
|
6030
|
-
CollectionItem
|
|
6031
|
-
|
|
6032
|
-
|
|
6033
|
-
'
|
|
6034
|
-
|
|
6035
|
-
|
|
6036
|
-
|
|
6037
|
-
}}
|
|
6576
|
+
CollectionItem(
|
|
6577
|
+
item_object_id='651446103cd773a050bf64c2',
|
|
6578
|
+
item_id='TheBloke/U-Amethyst-20B-AWQ',
|
|
6579
|
+
item_type='model',
|
|
6580
|
+
position=88,
|
|
6581
|
+
note=None
|
|
6582
|
+
)
|
|
6038
6583
|
```
|
|
6039
6584
|
"""
|
|
6040
6585
|
r = get_session().get(
|
|
6041
6586
|
f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token)
|
|
6042
6587
|
)
|
|
6043
6588
|
hf_raise_for_status(r)
|
|
6044
|
-
return Collection(r.json(), endpoint
|
|
6589
|
+
return Collection(**{**r.json(), "endpoint": self.endpoint})
|
|
6045
6590
|
|
|
6046
6591
|
def create_collection(
|
|
6047
6592
|
self,
|
|
@@ -6106,7 +6651,7 @@ class HfApi:
|
|
|
6106
6651
|
return self.get_collection(slug, token=token)
|
|
6107
6652
|
else:
|
|
6108
6653
|
raise
|
|
6109
|
-
return Collection(r.json(), endpoint
|
|
6654
|
+
return Collection(**{**r.json(), "endpoint": self.endpoint})
|
|
6110
6655
|
|
|
6111
6656
|
def update_collection_metadata(
|
|
6112
6657
|
self,
|
|
@@ -6171,7 +6716,7 @@ class HfApi:
|
|
|
6171
6716
|
json={key: value for key, value in payload.items() if value is not None},
|
|
6172
6717
|
)
|
|
6173
6718
|
hf_raise_for_status(r)
|
|
6174
|
-
return Collection(r.json()["data"], endpoint
|
|
6719
|
+
return Collection(**{**r.json()["data"], "endpoint": self.endpoint})
|
|
6175
6720
|
|
|
6176
6721
|
def delete_collection(
|
|
6177
6722
|
self, collection_slug: str, *, missing_ok: bool = False, token: Optional[str] = None
|
|
@@ -6279,7 +6824,7 @@ class HfApi:
|
|
|
6279
6824
|
return self.get_collection(collection_slug, token=token)
|
|
6280
6825
|
else:
|
|
6281
6826
|
raise
|
|
6282
|
-
return Collection(r.json(), endpoint
|
|
6827
|
+
return Collection(**{**r.json(), "endpoint": self.endpoint})
|
|
6283
6828
|
|
|
6284
6829
|
def update_collection_item(
|
|
6285
6830
|
self,
|
|
@@ -6297,7 +6842,7 @@ class HfApi:
|
|
|
6297
6842
|
Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`.
|
|
6298
6843
|
item_object_id (`str`):
|
|
6299
6844
|
ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id).
|
|
6300
|
-
It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].
|
|
6845
|
+
It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`.
|
|
6301
6846
|
note (`str`, *optional*):
|
|
6302
6847
|
A note to attach to the item in the collection. The maximum size for a note is 500 characters.
|
|
6303
6848
|
position (`int`, *optional*):
|
|
@@ -6575,6 +7120,16 @@ duplicate_space = api.duplicate_space
|
|
|
6575
7120
|
request_space_storage = api.request_space_storage
|
|
6576
7121
|
delete_space_storage = api.delete_space_storage
|
|
6577
7122
|
|
|
7123
|
+
# Inference Endpoint API
|
|
7124
|
+
list_inference_endpoints = api.list_inference_endpoints
|
|
7125
|
+
create_inference_endpoint = api.create_inference_endpoint
|
|
7126
|
+
get_inference_endpoint = api.get_inference_endpoint
|
|
7127
|
+
update_inference_endpoint = api.update_inference_endpoint
|
|
7128
|
+
delete_inference_endpoint = api.delete_inference_endpoint
|
|
7129
|
+
pause_inference_endpoint = api.pause_inference_endpoint
|
|
7130
|
+
resume_inference_endpoint = api.resume_inference_endpoint
|
|
7131
|
+
scale_to_zero_inference_endpoint = api.scale_to_zero_inference_endpoint
|
|
7132
|
+
|
|
6578
7133
|
# Collections API
|
|
6579
7134
|
get_collection = api.get_collection
|
|
6580
7135
|
create_collection = api.create_collection
|