huggingface-hub 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (43) hide show
  1. huggingface_hub/__init__.py +31 -5
  2. huggingface_hub/_inference_endpoints.py +348 -0
  3. huggingface_hub/_login.py +9 -7
  4. huggingface_hub/_multi_commits.py +1 -1
  5. huggingface_hub/_snapshot_download.py +6 -7
  6. huggingface_hub/_space_api.py +7 -4
  7. huggingface_hub/_tensorboard_logger.py +1 -0
  8. huggingface_hub/_webhooks_payload.py +7 -7
  9. huggingface_hub/commands/lfs.py +3 -6
  10. huggingface_hub/commands/user.py +1 -4
  11. huggingface_hub/constants.py +27 -0
  12. huggingface_hub/file_download.py +142 -134
  13. huggingface_hub/hf_api.py +1036 -501
  14. huggingface_hub/hf_file_system.py +57 -12
  15. huggingface_hub/hub_mixin.py +3 -5
  16. huggingface_hub/inference/_client.py +43 -8
  17. huggingface_hub/inference/_common.py +8 -16
  18. huggingface_hub/inference/_generated/_async_client.py +41 -8
  19. huggingface_hub/inference/_text_generation.py +43 -0
  20. huggingface_hub/inference_api.py +1 -1
  21. huggingface_hub/lfs.py +32 -14
  22. huggingface_hub/repocard_data.py +7 -0
  23. huggingface_hub/repository.py +19 -3
  24. huggingface_hub/templates/modelcard_template.md +1 -1
  25. huggingface_hub/utils/__init__.py +1 -1
  26. huggingface_hub/utils/_cache_assets.py +3 -3
  27. huggingface_hub/utils/_cache_manager.py +6 -7
  28. huggingface_hub/utils/_datetime.py +3 -1
  29. huggingface_hub/utils/_errors.py +10 -0
  30. huggingface_hub/utils/_hf_folder.py +4 -2
  31. huggingface_hub/utils/_http.py +10 -1
  32. huggingface_hub/utils/_runtime.py +4 -2
  33. huggingface_hub/utils/endpoint_helpers.py +27 -175
  34. huggingface_hub/utils/insecure_hashlib.py +34 -0
  35. huggingface_hub/utils/logging.py +4 -6
  36. huggingface_hub/utils/sha.py +2 -1
  37. {huggingface_hub-0.18.0.dist-info → huggingface_hub-0.19.0.dist-info}/METADATA +16 -15
  38. huggingface_hub-0.19.0.dist-info/RECORD +74 -0
  39. {huggingface_hub-0.18.0.dist-info → huggingface_hub-0.19.0.dist-info}/WHEEL +1 -1
  40. huggingface_hub-0.18.0.dist-info/RECORD +0 -72
  41. {huggingface_hub-0.18.0.dist-info → huggingface_hub-0.19.0.dist-info}/LICENSE +0 -0
  42. {huggingface_hub-0.18.0.dist-info → huggingface_hub-0.19.0.dist-info}/entry_points.txt +0 -0
  43. {huggingface_hub-0.18.0.dist-info → huggingface_hub-0.19.0.dist-info}/top_level.txt +0 -0
huggingface_hub/hf_api.py CHANGED
@@ -16,9 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  import inspect
18
18
  import json
19
- import pprint
20
19
  import re
21
- import textwrap
22
20
  import warnings
23
21
  from concurrent.futures import Future, ThreadPoolExecutor
24
22
  from dataclasses import dataclass, field
@@ -69,6 +67,7 @@ from ._commit_api import (
69
67
  _upload_lfs_files,
70
68
  _warn_on_overwriting_operations,
71
69
  )
70
+ from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType
72
71
  from ._multi_commits import (
73
72
  MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_BAD_REQUEST_TEMPLATE,
74
73
  MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_NO_CHANGES_TEMPLATE,
@@ -92,8 +91,10 @@ from .community import (
92
91
  deserialize_event,
93
92
  )
94
93
  from .constants import (
94
+ DEFAULT_ETAG_TIMEOUT,
95
95
  DEFAULT_REVISION,
96
96
  ENDPOINT,
97
+ INFERENCE_ENDPOINTS_ENDPOINT,
97
98
  REGEX_COMMIT_OID,
98
99
  REPO_TYPE_MODEL,
99
100
  REPO_TYPES,
@@ -105,6 +106,7 @@ from .file_download import (
105
106
  get_hf_file_metadata,
106
107
  hf_hub_url,
107
108
  )
109
+ from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData
108
110
  from .utils import ( # noqa: F401 # imported for backward compatibility
109
111
  BadRequestError,
110
112
  HfFolder,
@@ -122,12 +124,9 @@ from .utils._deprecation import (
122
124
  )
123
125
  from .utils._typing import CallableT
124
126
  from .utils.endpoint_helpers import (
125
- AttributeDictionary,
126
127
  DatasetFilter,
127
- DatasetTags,
128
128
  ModelFilter,
129
- ModelTags,
130
- _filter_emissions,
129
+ _is_emission_within_treshold,
131
130
  )
132
131
 
133
132
 
@@ -145,24 +144,6 @@ _CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = (
145
144
  logger = logging.get_logger(__name__)
146
145
 
147
146
 
148
- class ReprMixin:
149
- """Mixin to create the __repr__ for a class"""
150
-
151
- def __init__(self, **kwargs) -> None:
152
- # Store all the other fields returned by the API
153
- # Hack to ensure backward compatibility with future versions of the API.
154
- # See discussion in https://github.com/huggingface/huggingface_hub/pull/951#discussion_r926460408
155
- for k, v in kwargs.items():
156
- setattr(self, k, v)
157
-
158
- def __repr__(self):
159
- formatted_value = pprint.pformat(self.__dict__, width=119, compact=True)
160
- if "\n" in formatted_value:
161
- return f"{self.__class__.__name__}: {{ \n{textwrap.indent(formatted_value, ' ')}\n}}"
162
- else:
163
- return f"{self.__class__.__name__}: {formatted_value}"
164
-
165
-
166
147
  def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str]:
167
148
  """
168
149
  Returns the repo type and ID from a huggingface.co URL linking to a
@@ -254,13 +235,38 @@ class BlobLfsInfo(TypedDict, total=False):
254
235
  pointer_size: int
255
236
 
256
237
 
238
+ class BlobLastCommitInfo(TypedDict, total=False):
239
+ oid: str
240
+ title: str
241
+ date: datetime
242
+
243
+
244
+ class BlobSecurityInfo(TypedDict, total=False):
245
+ safe: bool
246
+ av_scan: Optional[Dict]
247
+ pickle_import_scan: Optional[Dict]
248
+
249
+
250
+ class TransformersInfo(TypedDict, total=False):
251
+ auto_model: str
252
+ custom_class: Optional[str]
253
+ # possible `pipeline_tag` values: https://github.com/huggingface/hub-docs/blob/f2003d2fca9d4c971629e858e314e0a5c05abf9d/js/src/lib/interfaces/Types.ts#L79
254
+ pipeline_tag: Optional[str]
255
+ processor: Optional[str]
256
+
257
+
258
+ class SafeTensorsInfo(TypedDict, total=False):
259
+ parameters: List[Dict[str, int]]
260
+ total: int
261
+
262
+
257
263
  @dataclass
258
264
  class CommitInfo:
259
265
  """Data structure containing information about a newly created commit.
260
266
 
261
267
  Returned by [`create_commit`].
262
268
 
263
- Args:
269
+ Attributes:
264
270
  commit_url (`str`):
265
271
  Url where to find the commit.
266
272
 
@@ -370,251 +376,431 @@ class RepoUrl(str):
370
376
  return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')"
371
377
 
372
378
 
373
- class RepoFile(ReprMixin):
379
+ @dataclass
380
+ class RepoSibling:
374
381
  """
375
- Data structure that represents a public file inside a repo, accessible from huggingface.co
382
+ Contains basic information about a repo file inside a repo on the Hub.
376
383
 
377
- Args:
384
+ Attributes:
378
385
  rfilename (str):
379
- file name, relative to the repo root. This is the only attribute that's guaranteed to be here, but under
380
- certain conditions there can certain other stuff.
386
+ file name, relative to the repo root.
381
387
  size (`int`, *optional*):
382
- The file's size, in bytes. This attribute is present when `files_metadata` argument of [`repo_info`] is set
388
+ The file's size, in bytes. This attribute is defined when `files_metadata` argument of [`repo_info`] is set
383
389
  to `True`. It's `None` otherwise.
384
390
  blob_id (`str`, *optional*):
385
- The file's git OID. This attribute is present when `files_metadata` argument of [`repo_info`] is set to
391
+ The file's git OID. This attribute is defined when `files_metadata` argument of [`repo_info`] is set to
386
392
  `True`. It's `None` otherwise.
387
393
  lfs (`BlobLfsInfo`, *optional*):
388
- The file's LFS metadata. This attribute is present when`files_metadata` argument of [`repo_info`] is set to
394
+ The file's LFS metadata. This attribute is defined when`files_metadata` argument of [`repo_info`] is set to
389
395
  `True` and the file is stored with Git LFS. It's `None` otherwise.
390
396
  """
391
397
 
392
- def __init__(
393
- self,
394
- rfilename: str,
395
- size: Optional[int] = None,
396
- blobId: Optional[str] = None,
397
- lfs: Optional[BlobLfsInfo] = None,
398
- **kwargs,
399
- ):
400
- self.rfilename = rfilename # filename relative to the repo root
398
+ rfilename: str
399
+ size: Optional[int] = None
400
+ blob_id: Optional[str] = None
401
+ lfs: Optional[BlobLfsInfo] = None
402
+
403
+
404
+ @dataclass
405
+ class RepoFile:
406
+ """
407
+ Contains information about a model on the Hub.
408
+
409
+ Attributes:
410
+ path (str):
411
+ file path relative to the repo root.
412
+ size (`int`):
413
+ The file's size, in bytes.
414
+ blob_id (`str`):
415
+ The file's git OID.
416
+ lfs (`BlobLfsInfo`):
417
+ The file's LFS metadata.
418
+ last_commit (`BlobLastCommitInfo`, *optional*):
419
+ The file's last commit metadata. Only defined if [`list_files_info`] is called with `expand=True`
420
+ security (`BlobSecurityInfo`, *optional*):
421
+ The file's security scan metadata. Only defined if [`list_files_info`] is called with `expand=True`.
422
+ """
401
423
 
402
- # Optional file metadata
403
- self.size = size
404
- self.blob_id = blobId
405
- self.lfs = lfs
424
+ path: str
425
+ size: int
426
+ blob_id: str
427
+ lfs: Optional[BlobLfsInfo] = None
428
+ last_commit: Optional[BlobLastCommitInfo] = None
429
+ security: Optional[BlobSecurityInfo] = None
406
430
 
407
- # Store all the other fields returned by the API
408
- super().__init__(**kwargs)
431
+ def __post_init__(self):
432
+ # backwards compatibility
433
+ self.rfilename = self.path
434
+ self.lastCommit = self.last_commit
409
435
 
410
436
 
411
- class ModelInfo(ReprMixin):
437
+ @dataclass
438
+ class ModelInfo:
412
439
  """
413
- Info about a model accessible from huggingface.co
440
+ Contains information about a model on the Hub.
414
441
 
415
442
  Attributes:
416
- modelId (`str`, *optional*):
417
- ID of model repository.
443
+ id (`str`):
444
+ ID of dataset.
445
+ author (`str`, *optional*):
446
+ Author of the dataset.
418
447
  sha (`str`, *optional*):
419
- repo sha at this particular revision
420
- lastModified (`str`, *optional*):
421
- date of last commit to repo
422
- tags (`List[str]`, *optional*):
423
- List of tags.
448
+ Repo SHA at this particular revision.
449
+ last_modified (`datetime`, *optional*):
450
+ Date of last commit to the repo.
451
+ private (`bool`):
452
+ Is the repo private.
453
+ disabled (`bool`, *optional*):
454
+ Is the repo disabled.
455
+ gated (`bool`, *optional*):
456
+ Is the repo gated.
457
+ downloads (`int`):
458
+ Number of downloads of the dataset.
459
+ likes (`int`):
460
+ Number of likes of the dataset.
461
+ library_name (`str`, *optional*):
462
+ Library associated with the model.
463
+ tags (`List[str]`):
464
+ List of tags of the model. Compared to `card_data.tags`, contains extra tags computed by the Hub
465
+ (e.g. supported libraries, model's arXiv).
424
466
  pipeline_tag (`str`, *optional*):
425
- Pipeline tag to identify the correct widget.
426
- siblings (`List[RepoFile]`, *optional*):
427
- list of ([`huggingface_hub.hf_api.RepoFile`]) objects that constitute the model.
428
- private (`bool`, *optional*, defaults to `False`):
429
- is the repo private
430
- author (`str`, *optional*):
431
- repo author
467
+ Pipeline tag associated with the model.
468
+ mask_token (`str`, *optional*):
469
+ Mask token used by the model.
470
+ widget_data (`Any`, *optional*):
471
+ Widget data associated with the model.
472
+ model_index (`Dict`, *optional*):
473
+ Model index for evaluation.
432
474
  config (`Dict`, *optional*):
433
- Model configuration information
434
- securityStatus (`Dict`, *optional*):
435
- Security status of the model.
436
- Example: `{"containsInfected": False}`
437
- kwargs (`Dict`, *optional*):
438
- Kwargs that will be become attributes of the class.
475
+ Model configuration.
476
+ transformers_info (`TransformersInfo`, *optional*):
477
+ Transformers-specific info (auto class, processor, etc.) associated with the model.
478
+ card_data (`ModelCardData`, *optional*):
479
+ Model Card Metadata as a [`huggingface_hub.repocard_data.ModelCardData`] object.
480
+ siblings (`List[RepoSibling]`):
481
+ List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the model.
482
+ spaces (`List[str]`, *optional*):
483
+ List of spaces using the model.
484
+ safetensors (`SafeTensorsInfo`, *optional*):
485
+ Model's safetensors information.
439
486
  """
440
487
 
441
- def __init__(
442
- self,
443
- *,
444
- modelId: Optional[str] = None,
445
- sha: Optional[str] = None,
446
- lastModified: Optional[str] = None,
447
- tags: Optional[List[str]] = None,
448
- pipeline_tag: Optional[str] = None,
449
- siblings: Optional[List[Dict]] = None,
450
- private: bool = False,
451
- author: Optional[str] = None,
452
- config: Optional[Dict] = None,
453
- securityStatus: Optional[Dict] = None,
454
- **kwargs,
455
- ):
456
- self.modelId = modelId
457
- self.sha = sha
458
- self.lastModified = lastModified
459
- self.tags = tags
460
- self.pipeline_tag = pipeline_tag
461
- self.siblings = [RepoFile(**x) for x in siblings] if siblings is not None else []
462
- self.private = private
463
- self.author = author
464
- self.config = config
465
- self.securityStatus = securityStatus
466
-
467
- # Store all the other fields returned by the API
468
- super().__init__(**kwargs)
469
-
470
- def __str__(self):
471
- r = f"Model Name: {self.modelId}, Tags: {self.tags}"
472
- if self.pipeline_tag:
473
- r += f", Task: {self.pipeline_tag}"
474
- return r
475
-
476
-
477
- class DatasetInfo(ReprMixin):
488
+ id: str
489
+ author: Optional[str]
490
+ sha: Optional[str]
491
+ last_modified: Optional[datetime]
492
+ private: bool
493
+ gated: Optional[bool]
494
+ disabled: Optional[bool]
495
+ downloads: int
496
+ likes: int
497
+ library_name: Optional[str]
498
+ tags: List[str]
499
+ pipeline_tag: Optional[str]
500
+ mask_token: Optional[str]
501
+ card_data: Optional[ModelCardData]
502
+ widget_data: Optional[Any]
503
+ model_index: Optional[Dict]
504
+ config: Optional[Dict]
505
+ transformers_info: Optional[TransformersInfo]
506
+ siblings: Optional[List[RepoSibling]]
507
+ spaces: Optional[List[str]]
508
+ safetensors: Optional[SafeTensorsInfo]
509
+
510
+ def __init__(self, **kwargs):
511
+ self.id = kwargs.pop("id")
512
+ self.author = kwargs.pop("author", None)
513
+ self.sha = kwargs.pop("sha", None)
514
+ last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None)
515
+ self.last_modified = parse_datetime(last_modified) if last_modified else None
516
+ self.private = kwargs.pop("private")
517
+ self.gated = kwargs.pop("gated", None)
518
+ self.disabled = kwargs.pop("disabled", None)
519
+ self.downloads = kwargs.pop("downloads")
520
+ self.likes = kwargs.pop("likes")
521
+ self.library_name = kwargs.pop("library_name", None)
522
+ self.tags = kwargs.pop("tags")
523
+ self.pipeline_tag = kwargs.pop("pipeline_tag", None)
524
+ self.mask_token = kwargs.pop("mask_token", None)
525
+ card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None)
526
+ self.card_data = (
527
+ ModelCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data
528
+ )
529
+
530
+ self.widget_data = kwargs.pop("widget_data", None)
531
+ self.model_index = kwargs.pop("model-index", None) or kwargs.pop("model_index", None)
532
+ self.config = kwargs.pop("config", None)
533
+ transformers_info = kwargs.pop("transformersInfo", None) or kwargs.pop("transformers_info", None)
534
+ self.transformers_info = TransformersInfo(**transformers_info) if transformers_info else None
535
+ siblings = kwargs.pop("siblings", None)
536
+ self.siblings = (
537
+ [
538
+ RepoSibling(
539
+ rfilename=sibling["rfilename"],
540
+ size=sibling.get("size"),
541
+ blob_id=sibling.get("blobId"),
542
+ lfs=(
543
+ BlobLfsInfo(
544
+ size=sibling["lfs"]["size"],
545
+ sha256=sibling["lfs"]["sha256"],
546
+ pointer_size=sibling["lfs"]["pointerSize"],
547
+ )
548
+ if sibling.get("lfs")
549
+ else None
550
+ ),
551
+ )
552
+ for sibling in siblings
553
+ ]
554
+ if siblings
555
+ else None
556
+ )
557
+ self.spaces = kwargs.pop("spaces", None)
558
+ safetensors = kwargs.pop("safetensors", None)
559
+ self.safetensors = SafeTensorsInfo(**safetensors) if safetensors else None
560
+
561
+ # backwards compatibility
562
+ self.lastModified = self.last_modified
563
+ self.cardData = self.card_data
564
+ self.transformersInfo = self.transformers_info
565
+ self.__dict__.update(**kwargs)
566
+
567
+
568
+ @dataclass
569
+ class DatasetInfo:
478
570
  """
479
- Info about a dataset accessible from huggingface.co
571
+ Contains information about a dataset on the Hub.
480
572
 
481
573
  Attributes:
482
- id (`str`, *optional*):
483
- ID of dataset repository.
484
- sha (`str`, *optional*):
485
- repo sha at this particular revision
486
- lastModified (`str`, *optional*):
487
- date of last commit to repo
488
- tags (`List[str]`, *optional*):
489
- List of tags.
490
- siblings (`List[RepoFile]`, *optional*):
491
- list of [`huggingface_hub.hf_api.RepoFile`] objects that constitute the dataset.
492
- private (`bool`, *optional*, defaults to `False`):
493
- is the repo private
494
- author (`str`, *optional*):
495
- repo author
496
- description (`str`, *optional*):
497
- Description of the dataset
498
- citation (`str`, *optional*):
499
- Dataset citation
500
- cardData (`Dict`, *optional*):
501
- Metadata of the model card as a dictionary.
502
- kwargs (`Dict`, *optional*):
503
- Kwargs that will be become attributes of the class.
574
+ id (`str`):
575
+ ID of dataset.
576
+ author (`str`):
577
+ Author of the dataset.
578
+ sha (`str`):
579
+ Repo SHA at this particular revision.
580
+ last_modified (`datetime`, *optional*):
581
+ Date of last commit to the repo.
582
+ private (`bool`):
583
+ Is the repo private.
584
+ disabled (`bool`, *optional*):
585
+ Is the repo disabled.
586
+ gated (`bool`, *optional*):
587
+ Is the repo gated.
588
+ downloads (`int`):
589
+ Number of downloads of the dataset.
590
+ likes (`int`):
591
+ Number of likes of the dataset.
592
+ tags (`List[str]`):
593
+ List of tags of the dataset.
594
+ card_data (`DatasetCardData`, *optional*):
595
+ Model Card Metadata as a [`huggingface_hub.repocard_data.DatasetCardData`] object.
596
+ siblings (`List[RepoSibling]`):
597
+ List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the dataset.
504
598
  """
505
599
 
506
- def __init__(
507
- self,
508
- *,
509
- id: Optional[str] = None,
510
- sha: Optional[str] = None,
511
- lastModified: Optional[str] = None,
512
- tags: Optional[List[str]] = None,
513
- siblings: Optional[List[Dict]] = None,
514
- private: bool = False,
515
- author: Optional[str] = None,
516
- description: Optional[str] = None,
517
- citation: Optional[str] = None,
518
- cardData: Optional[dict] = None,
519
- **kwargs,
520
- ):
521
- self.id = id
522
- self.sha = sha
523
- self.lastModified = lastModified
524
- self.tags = tags
525
- self.private = private
526
- self.author = author
527
- self.description = description
528
- self.citation = citation
529
- self.cardData = cardData
530
- self.siblings = [RepoFile(**x) for x in siblings] if siblings is not None else []
531
- # Legacy stuff, "key" is always returned with an empty string
532
- # because of old versions of the datasets lib that need this field
533
- kwargs.pop("key", None)
534
- # Store all the other fields returned by the API
535
- super().__init__(**kwargs)
536
-
537
- def __str__(self):
538
- r = f"Dataset Name: {self.id}, Tags: {self.tags}"
539
- return r
540
-
541
-
542
- class SpaceInfo(ReprMixin):
543
- """
544
- Info about a Space accessible from huggingface.co
600
+ id: str
601
+ author: Optional[str]
602
+ sha: Optional[str]
603
+ last_modified: Optional[datetime]
604
+ private: bool
605
+ gated: Optional[bool]
606
+ disabled: Optional[bool]
607
+ downloads: int
608
+ likes: int
609
+ paperswithcode_id: Optional[str]
610
+ tags: List[str]
611
+ card_data: Optional[DatasetCardData]
612
+ siblings: Optional[List[RepoSibling]]
613
+
614
+ def __init__(self, **kwargs):
615
+ self.id = kwargs.pop("id")
616
+ self.author = kwargs.pop("author", None)
617
+ self.sha = kwargs.pop("sha", None)
618
+ last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None)
619
+ self.last_modified = parse_datetime(last_modified) if last_modified else None
620
+ self.private = kwargs.pop("private")
621
+ self.gated = kwargs.pop("gated", None)
622
+ self.disabled = kwargs.pop("disabled", None)
623
+ self.downloads = kwargs.pop("downloads")
624
+ self.likes = kwargs.pop("likes")
625
+ self.paperswithcode_id = kwargs.pop("paperswithcode_id", None)
626
+ self.tags = kwargs.pop("tags")
627
+ card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None)
628
+ self.card_data = (
629
+ DatasetCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data
630
+ )
631
+ siblings = kwargs.pop("siblings", None)
632
+ self.siblings = (
633
+ [
634
+ RepoSibling(
635
+ rfilename=sibling["rfilename"],
636
+ size=sibling.get("size"),
637
+ blob_id=sibling.get("blobId"),
638
+ lfs=(
639
+ BlobLfsInfo(
640
+ size=sibling["lfs"]["size"],
641
+ sha256=sibling["lfs"]["sha256"],
642
+ pointer_size=sibling["lfs"]["pointerSize"],
643
+ )
644
+ if sibling.get("lfs")
645
+ else None
646
+ ),
647
+ )
648
+ for sibling in siblings
649
+ ]
650
+ if siblings
651
+ else None
652
+ )
545
653
 
546
- This is a "dataclass" like container that just sets on itself any attribute
547
- passed by the server.
654
+ # backwards compatibility
655
+ self.lastModified = self.last_modified
656
+ self.cardData = self.card_data
657
+ self.__dict__.update(**kwargs)
658
+
659
+
660
+ @dataclass
661
+ class SpaceInfo:
662
+ """
663
+ Contains information about a Space on the Hub.
548
664
 
549
665
  Attributes:
550
- id (`str`, *optional*):
551
- id of space
552
- sha (`str`, *optional*):
553
- repo sha at this particular revision
554
- lastModified (`str`, *optional*):
555
- date of last commit to repo
556
- siblings (`List[RepoFile]`, *optional*):
557
- list of [`huggingface_hub.hf_api.RepoFIle`] objects that constitute the Space
558
- private (`bool`, *optional*, defaults to `False`):
559
- is the repo private
666
+ id (`str`):
667
+ ID of the Space.
560
668
  author (`str`, *optional*):
561
- repo author
562
- kwargs (`Dict`, *optional*):
563
- Kwargs that will be become attributes of the class.
669
+ Author of the Space.
670
+ sha (`str`, *optional*):
671
+ Repo SHA at this particular revision.
672
+ last_modified (`datetime`, *optional*):
673
+ Date of last commit to the repo.
674
+ private (`bool`):
675
+ Is the repo private.
676
+ gated (`bool`, *optional*):
677
+ Is the repo gated.
678
+ disabled (`bool`, *optional*):
679
+ Is the Space disabled.
680
+ host (`str`, *optional*):
681
+ Host URL of the Space.
682
+ subdomain (`str`, *optional*):
683
+ Subdomain of the Space.
684
+ likes (`int`):
685
+ Number of likes of the Space.
686
+ tags (`List[str]`):
687
+ List of tags of the Space.
688
+ siblings (`List[RepoSibling]`):
689
+ List of [`huggingface_hub.hf_api.RepoSibling`] objects that constitute the Space.
690
+ card_data (`SpaceCardData`, *optional*):
691
+ Space Card Metadata as a [`huggingface_hub.repocard_data.SpaceCardData`] object.
692
+ runtime (`SpaceRuntime`, *optional*):
693
+ Space runtime information as a [`huggingface_hub.hf_api.SpaceRuntime`] object.
694
+ sdk (`str`, *optional*):
695
+ SDK used by the Space.
696
+ models (`List[str]`, *optional*):
697
+ List of models used by the Space.
698
+ datasets (`List[str]`, *optional*):
699
+ List of datasets used by the Space.
564
700
  """
565
701
 
566
- def __init__(
567
- self,
568
- *,
569
- id: Optional[str] = None,
570
- sha: Optional[str] = None,
571
- lastModified: Optional[str] = None,
572
- siblings: Optional[List[Dict]] = None,
573
- private: bool = False,
574
- author: Optional[str] = None,
575
- **kwargs,
576
- ):
577
- self.id = id
578
- self.sha = sha
579
- self.lastModified = lastModified
580
- self.siblings = [RepoFile(**x) for x in siblings] if siblings is not None else []
581
- self.private = private
582
- self.author = author
583
- # Store all the other fields returned by the API
584
- super().__init__(**kwargs)
702
+ id: str
703
+ author: Optional[str]
704
+ sha: Optional[str]
705
+ last_modified: Optional[datetime]
706
+ private: bool
707
+ gated: Optional[bool]
708
+ disabled: Optional[bool]
709
+ host: Optional[str]
710
+ subdomain: Optional[str]
711
+ likes: int
712
+ sdk: Optional[str]
713
+ tags: List[str]
714
+ siblings: Optional[List[RepoSibling]]
715
+ card_data: Optional[SpaceCardData]
716
+ runtime: Optional[SpaceRuntime]
717
+ models: Optional[List[str]]
718
+ datasets: Optional[List[str]]
719
+
720
+ def __init__(self, **kwargs):
721
+ self.id = kwargs.pop("id")
722
+ self.author = kwargs.pop("author", None)
723
+ self.sha = kwargs.pop("sha", None)
724
+ last_modified = kwargs.pop("lastModified", None) or kwargs.pop("last_modified", None)
725
+ self.last_modified = parse_datetime(last_modified) if last_modified else None
726
+ self.private = kwargs.pop("private")
727
+ self.gated = kwargs.pop("gated", None)
728
+ self.disabled = kwargs.pop("disabled", None)
729
+ self.host = kwargs.pop("host", None)
730
+ self.subdomain = kwargs.pop("subdomain", None)
731
+ self.likes = kwargs.pop("likes")
732
+ self.sdk = kwargs.pop("sdk", None)
733
+ self.tags = kwargs.pop("tags")
734
+ card_data = kwargs.pop("cardData", None) or kwargs.pop("card_data", None)
735
+ self.card_data = (
736
+ SpaceCardData(**card_data, ignore_metadata_errors=True) if isinstance(card_data, dict) else card_data
737
+ )
738
+ siblings = kwargs.pop("siblings", None)
739
+ self.siblings = (
740
+ [
741
+ RepoSibling(
742
+ rfilename=sibling["rfilename"],
743
+ size=sibling.get("size"),
744
+ blob_id=sibling.get("blobId"),
745
+ lfs=(
746
+ BlobLfsInfo(
747
+ size=sibling["lfs"]["size"],
748
+ sha256=sibling["lfs"]["sha256"],
749
+ pointer_size=sibling["lfs"]["pointerSize"],
750
+ )
751
+ if sibling.get("lfs")
752
+ else None
753
+ ),
754
+ )
755
+ for sibling in siblings
756
+ ]
757
+ if siblings
758
+ else None
759
+ )
760
+ runtime = kwargs.pop("runtime", None)
761
+ self.runtime = SpaceRuntime(runtime) if runtime else None
762
+ self.models = kwargs.pop("models", None)
763
+ self.datasets = kwargs.pop("datasets", None)
764
+
765
+ # backwards compatibility
766
+ self.lastModified = self.last_modified
767
+ self.cardData = self.card_data
768
+ self.__dict__.update(**kwargs)
585
769
 
586
770
 
587
- class MetricInfo(ReprMixin):
771
+ @dataclass
772
+ class MetricInfo:
588
773
  """
589
- Info about a public metric accessible from huggingface.co
774
+ Contains information about a metric on the Hub.
775
+
776
+ Attributes:
777
+ id (`str`):
778
+ ID of the metric. E.g. `"accuracy"`.
779
+ space_id (`str`):
780
+ ID of the space associated with the metric. E.g. `"Accuracy"`.
781
+ description (`str`):
782
+ Description of the metric.
590
783
  """
591
784
 
592
- def __init__(
593
- self,
594
- *,
595
- id: Optional[str] = None, # id of metric
596
- description: Optional[str] = None,
597
- citation: Optional[str] = None,
598
- **kwargs,
599
- ):
600
- self.id = id
601
- self.description = description
602
- self.citation = citation
603
- # Legacy stuff, "key" is always returned with an empty string
604
- # because of old versions of the datasets lib that need this field
605
- kwargs.pop("key", None)
606
- # Store all the other fields returned by the API
607
- super().__init__(**kwargs)
785
+ id: str
786
+ space_id: str
787
+ description: Optional[str]
608
788
 
609
- def __str__(self):
610
- r = f"Metric Name: {self.id}"
611
- return r
789
+ def __init__(self, **kwargs):
790
+ self.id = kwargs.pop("id")
791
+ self.space_id = kwargs.pop("spaceId")
792
+ self.description = kwargs.pop("description", None)
793
+ # backwards compatibility
794
+ self.spaceId = self.space_id
795
+ self.__dict__.update(**kwargs)
612
796
 
613
797
 
614
- class CollectionItem(ReprMixin):
615
- """Contains information about an item of a Collection (model, dataset, Space or paper).
798
+ @dataclass
799
+ class CollectionItem:
800
+ """
801
+ Contains information about an item of a Collection (model, dataset, Space or paper).
616
802
 
617
- Args:
803
+ Attributes:
618
804
  item_object_id (`str`):
619
805
  Unique ID of the item in the collection.
620
806
  item_id (`str`):
@@ -626,11 +812,14 @@ class CollectionItem(ReprMixin):
626
812
  Position of the item in the collection.
627
813
  note (`str`, *optional*):
628
814
  Note associated with the item, as plain text.
629
- kwargs (`Dict`, *optional*):
630
- Any other attribute returned by the server. Those attributes depend on the `item_type`: "author", "private",
631
- "lastModified", "gated", "title", "likes", "upvotes", etc.
632
815
  """
633
816
 
817
+ item_object_id: str # id in database
818
+ item_id: str # repo_id or paper id
819
+ item_type: str
820
+ position: int
821
+ note: Optional[str] = None
822
+
634
823
  def __init__(
635
824
  self, _id: str, id: str, type: CollectionItemType_T, position: int, note: Optional[Dict] = None, **kwargs
636
825
  ) -> None:
@@ -640,23 +829,19 @@ class CollectionItem(ReprMixin):
640
829
  self.position: int = position
641
830
  self.note: str = note["text"] if note is not None else None
642
831
 
643
- # Store all the other fields returned by the API
644
- super().__init__(**kwargs)
645
832
 
646
-
647
- class Collection(ReprMixin):
833
+ @dataclass
834
+ class Collection:
648
835
  """
649
836
  Contains information about a Collection on the Hub.
650
837
 
651
- Args:
838
+ Attributes:
652
839
  slug (`str`):
653
840
  Slug of the collection. E.g. `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`.
654
841
  title (`str`):
655
842
  Title of the collection. E.g. `"Recent models"`.
656
843
  owner (`str`):
657
844
  Owner of the collection. E.g. `"TheBloke"`.
658
- description (`str`, *optional*):
659
- Description of the collection, as plain text.
660
845
  items (`List[CollectionItem]`):
661
846
  List of items in the collection.
662
847
  last_updated (`datetime`):
@@ -667,139 +852,45 @@ class Collection(ReprMixin):
667
852
  Whether the collection is private or not.
668
853
  theme (`str`):
669
854
  Theme of the collection. E.g. `"green"`.
855
+ upvotes (`int`):
856
+ Number of upvotes of the collection.
857
+ description (`str`, *optional*):
858
+ Description of the collection, as plain text.
670
859
  url (`str`):
671
- URL for the collection on the Hub.
860
+ (property) URL of the collection on the Hub.
672
861
  """
673
862
 
674
863
  slug: str
675
864
  title: str
676
865
  owner: str
677
- description: Optional[str]
678
866
  items: List[CollectionItem]
679
-
680
867
  last_updated: datetime
681
868
  position: int
682
869
  private: bool
683
870
  theme: str
871
+ upvotes: int
872
+ description: Optional[str] = None
684
873
 
685
- def __init__(self, data: Dict, endpoint: Optional[str] = None) -> None:
686
- # Collection info
687
- self.slug = data["slug"]
688
- self.title = data["title"]
689
- self.owner = data["owner"]["name"]
690
- self.description = data.get("description")
691
- self.items = [CollectionItem(**item) for item in data["items"]]
692
-
693
- # Metadata
694
- self.last_updated = parse_datetime(data["lastUpdated"])
695
- self.private = data["private"]
696
- self.position = data["position"]
697
- self.theme = data["theme"]
698
-
699
- # (internal)
874
+ def __init__(self, **kwargs) -> None:
875
+ self.slug = kwargs.pop("slug")
876
+ self.title = kwargs.pop("title")
877
+ self.owner = kwargs.pop("owner")
878
+ self.items = [CollectionItem(**item) for item in kwargs.pop("items")]
879
+ self.last_updated = parse_datetime(kwargs.pop("lastUpdated"))
880
+ self.position = kwargs.pop("position")
881
+ self.private = kwargs.pop("private")
882
+ self.theme = kwargs.pop("theme")
883
+ self.upvotes = kwargs.pop("upvotes")
884
+ self.description = kwargs.pop("description", None)
885
+ endpoint = kwargs.pop("endpoint", None)
700
886
  if endpoint is None:
701
887
  endpoint = ENDPOINT
702
- self.url = f"{ENDPOINT}/collections/{self.slug}"
703
-
704
-
705
- class ModelSearchArguments(AttributeDictionary):
706
- """
707
- A nested namespace object holding all possible values for properties of
708
- models currently hosted in the Hub with tab-completion. If a value starts
709
- with a number, it will only exist in the dictionary
710
-
711
- Example:
712
-
713
- ```python
714
- >>> args = ModelSearchArguments()
715
-
716
- >>> args.author.huggingface
717
- 'huggingface'
718
-
719
- >>> args.language.en
720
- 'en'
721
- ```
722
-
723
- <Tip warning={true}>
724
-
725
- `ModelSearchArguments` is a legacy class meant for exploratory purposes only. Its
726
- initialization requires listing all models on the Hub which makes it increasingly
727
- slower as the number of repos on the Hub increases.
728
-
729
- </Tip>
730
- """
731
-
732
- def __init__(self, api: Optional["HfApi"] = None):
733
- self._api = api if api is not None else HfApi()
734
- tags = self._api.get_model_tags()
735
- super().__init__(tags)
736
- self._process_models()
737
-
738
- def _process_models(self):
739
- def clean(s: str) -> str:
740
- return s.replace(" ", "").replace("-", "_").replace(".", "_")
741
-
742
- models = self._api.list_models()
743
- author_dict, model_name_dict = AttributeDictionary(), AttributeDictionary()
744
- for model in models:
745
- if "/" in model.modelId:
746
- author, name = model.modelId.split("/")
747
- author_dict[author] = clean(author)
748
- else:
749
- name = model.modelId
750
- model_name_dict[name] = clean(name)
751
- self["model_name"] = model_name_dict
752
- self["author"] = author_dict
753
-
754
-
755
- class DatasetSearchArguments(AttributeDictionary):
756
- """
757
- A nested namespace object holding all possible values for properties of
758
- datasets currently hosted in the Hub with tab-completion. If a value starts
759
- with a number, it will only exist in the dictionary
760
-
761
- Example:
888
+ self._url = f"{endpoint}/collections/{self.slug}"
762
889
 
763
- ```python
764
- >>> args = DatasetSearchArguments()
765
-
766
- >>> args.author.huggingface
767
- 'huggingface'
768
-
769
- >>> args.language.en
770
- 'language:en'
771
- ```
772
-
773
- <Tip warning={true}>
774
-
775
- `DatasetSearchArguments` is a legacy class meant for exploratory purposes only. Its
776
- initialization requires listing all datasets on the Hub which makes it increasingly
777
- slower as the number of repos on the Hub increases.
778
-
779
- </Tip>
780
- """
781
-
782
- def __init__(self, api: Optional["HfApi"] = None):
783
- self._api = api if api is not None else HfApi()
784
- tags = self._api.get_dataset_tags()
785
- super().__init__(tags)
786
- self._process_models()
787
-
788
- def _process_models(self):
789
- def clean(s: str):
790
- return s.replace(" ", "").replace("-", "_").replace(".", "_")
791
-
792
- datasets = self._api.list_datasets()
793
- author_dict, dataset_name_dict = AttributeDictionary(), AttributeDictionary()
794
- for dataset in datasets:
795
- if "/" in dataset.id:
796
- author, name = dataset.id.split("/")
797
- author_dict[author] = clean(author)
798
- else:
799
- name = dataset.id
800
- dataset_name_dict[name] = clean(name)
801
- self["dataset_name"] = dataset_name_dict
802
- self["author"] = author_dict
890
+ @property
891
+ def url(self) -> str:
892
+ """Returns the URL of the collection on the Hub."""
893
+ return self._url
803
894
 
804
895
 
805
896
  @dataclass
@@ -807,7 +898,7 @@ class GitRefInfo:
807
898
  """
808
899
  Contains information about a git reference for a repo on the Hub.
809
900
 
810
- Args:
901
+ Attributes:
811
902
  name (`str`):
812
903
  Name of the reference (e.g. tag name or branch name).
813
904
  ref (`str`):
@@ -820,11 +911,6 @@ class GitRefInfo:
820
911
  ref: str
821
912
  target_commit: str
822
913
 
823
- def __init__(self, data: Dict) -> None:
824
- self.name = data["name"]
825
- self.ref = data["ref"]
826
- self.target_commit = data["targetCommit"]
827
-
828
914
 
829
915
  @dataclass
830
916
  class GitRefs:
@@ -833,7 +919,7 @@ class GitRefs:
833
919
 
834
920
  Object is returned by [`list_repo_refs`].
835
921
 
836
- Args:
922
+ Attributes:
837
923
  branches (`List[GitRefInfo]`):
838
924
  A list of [`GitRefInfo`] containing information about branches on the repo.
839
925
  converts (`List[GitRefInfo]`):
@@ -853,7 +939,7 @@ class GitCommitInfo:
853
939
  """
854
940
  Contains information about a git commit for a repo on the Hub. Check out [`list_repo_commits`] for more details.
855
941
 
856
- Args:
942
+ Attributes:
857
943
  commit_id (`str`):
858
944
  OID of the commit (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`)
859
945
  authors (`List[str]`):
@@ -880,23 +966,13 @@ class GitCommitInfo:
880
966
  formatted_title: Optional[str]
881
967
  formatted_message: Optional[str]
882
968
 
883
- def __init__(self, data: Dict) -> None:
884
- self.commit_id = data["id"]
885
- self.authors = [author["user"] for author in data["authors"]]
886
- self.created_at = parse_datetime(data["date"])
887
- self.title = data["title"]
888
- self.message = data["message"]
889
-
890
- self.formatted_title = data.get("formatted", {}).get("title")
891
- self.formatted_message = data.get("formatted", {}).get("message")
892
-
893
969
 
894
970
  @dataclass
895
971
  class UserLikes:
896
972
  """
897
973
  Contains information about a user likes on the Hub.
898
974
 
899
- Args:
975
+ Attributes:
900
976
  user (`str`):
901
977
  Name of the user for which we fetched the likes.
902
978
  total (`int`):
@@ -924,7 +1000,7 @@ class User:
924
1000
  """
925
1001
  Contains information about a user on the Hub.
926
1002
 
927
- Args:
1003
+ Attributes:
928
1004
  avatar_url (`str`):
929
1005
  URL of the user's avatar.
930
1006
  username (`str`):
@@ -989,9 +1065,6 @@ class HfApi:
989
1065
  directly at the root of `huggingface_hub`.
990
1066
 
991
1067
  Args:
992
- endpoint (`str`, *optional*):
993
- Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise,
994
- one can set the `HF_ENDPOINT` environment variable.
995
1068
  token (`str`, *optional*):
996
1069
  Hugging Face token. Will default to the locally saved token if
997
1070
  not provided.
@@ -1101,25 +1174,23 @@ class HfApi:
1101
1174
  except (LocalTokenNotFoundError, HTTPError):
1102
1175
  return None
1103
1176
 
1104
- def get_model_tags(self) -> ModelTags:
1177
+ def get_model_tags(self) -> Dict:
1105
1178
  """
1106
1179
  List all valid model tags as a nested namespace object
1107
1180
  """
1108
1181
  path = f"{self.endpoint}/api/models-tags-by-type"
1109
1182
  r = get_session().get(path)
1110
1183
  hf_raise_for_status(r)
1111
- d = r.json()
1112
- return ModelTags(d)
1184
+ return r.json()
1113
1185
 
1114
- def get_dataset_tags(self) -> DatasetTags:
1186
+ def get_dataset_tags(self) -> Dict:
1115
1187
  """
1116
1188
  List all valid dataset tags as a nested namespace object.
1117
1189
  """
1118
1190
  path = f"{self.endpoint}/api/datasets-tags-by-type"
1119
1191
  r = get_session().get(path)
1120
1192
  hf_raise_for_status(r)
1121
- d = r.json()
1122
- return DatasetTags(d)
1193
+ return r.json()
1123
1194
 
1124
1195
  @validate_hf_hub_args
1125
1196
  def list_models(
@@ -1129,7 +1200,7 @@ class HfApi:
1129
1200
  author: Optional[str] = None,
1130
1201
  search: Optional[str] = None,
1131
1202
  emissions_thresholds: Optional[Tuple[float, float]] = None,
1132
- sort: Union[Literal["lastModified"], str, None] = None,
1203
+ sort: Union[Literal["last_modified"], str, None] = None,
1133
1204
  direction: Optional[Literal[-1]] = None,
1134
1205
  limit: Optional[int] = None,
1135
1206
  full: Optional[bool] = None,
@@ -1152,7 +1223,7 @@ class HfApi:
1152
1223
  emissions_thresholds (`Tuple`, *optional*):
1153
1224
  A tuple of two ints or floats representing a minimum and maximum
1154
1225
  carbon footprint to filter the resulting models with in grams.
1155
- sort (`Literal["lastModified"]` or `str`, *optional*):
1226
+ sort (`Literal["last_modified"]` or `str`, *optional*):
1156
1227
  The key with which to sort the resulting models. Possible values
1157
1228
  are the properties of the [`huggingface_hub.hf_api.ModelInfo`] class.
1158
1229
  direction (`Literal[-1]` or `int`, *optional*):
@@ -1162,7 +1233,7 @@ class HfApi:
1162
1233
  The limit on the number of models fetched. Leaving this option
1163
1234
  to `None` fetches all models.
1164
1235
  full (`bool`, *optional*):
1165
- Whether to fetch all model data, including the `lastModified`,
1236
+ Whether to fetch all model data, including the `last_modified`,
1166
1237
  the `sha`, the files and the `tags`. This is set to `True` by
1167
1238
  default when using a filter.
1168
1239
  cardData (`bool`, *optional*):
@@ -1191,29 +1262,15 @@ class HfApi:
1191
1262
  >>> # List all models
1192
1263
  >>> api.list_models()
1193
1264
 
1194
- >>> # Get all valid search arguments
1195
- >>> args = ModelSearchArguments()
1196
-
1197
1265
  >>> # List only the text classification models
1198
1266
  >>> api.list_models(filter="text-classification")
1199
1267
  >>> # Using the `ModelFilter`
1200
1268
  >>> filt = ModelFilter(task="text-classification")
1201
- >>> # With `ModelSearchArguments`
1202
- >>> filt = ModelFilter(task=args.pipeline_tags.TextClassification)
1203
- >>> api.list_models(filter=filt)
1204
-
1205
- >>> # Using `ModelFilter` and `ModelSearchArguments` to find text classification in both PyTorch and TensorFlow
1206
- >>> filt = ModelFilter(
1207
- ... task=args.pipeline_tags.TextClassification,
1208
- ... library=[args.library.PyTorch, args.library.TensorFlow],
1209
- ... )
1210
- >>> api.list_models(filter=filt)
1269
+
1211
1270
 
1212
1271
  >>> # List only models from the AllenNLP library
1213
1272
  >>> api.list_models(filter="allennlp")
1214
- >>> # Using `ModelFilter` and `ModelSearchArguments`
1215
- >>> filt = ModelFilter(library=args.library.allennlp)
1216
- ```
1273
+
1217
1274
 
1218
1275
  Example usage with the `search` argument:
1219
1276
 
@@ -1246,7 +1303,7 @@ class HfApi:
1246
1303
  if search is not None:
1247
1304
  params.update({"search": search})
1248
1305
  if sort is not None:
1249
- params.update({"sort": sort})
1306
+ params.update({"sort": "lastModified" if sort == "last_modified" else sort})
1250
1307
  if direction is not None:
1251
1308
  params.update({"direction": direction})
1252
1309
  if limit is not None:
@@ -1265,10 +1322,12 @@ class HfApi:
1265
1322
  items = paginate(path, params=params, headers=headers)
1266
1323
  if limit is not None:
1267
1324
  items = islice(items, limit) # Do not iterate over all pages
1268
- if emissions_thresholds is not None:
1269
- items = _filter_emissions(items, *emissions_thresholds)
1270
1325
  for item in items:
1271
- yield ModelInfo(**item)
1326
+ if "siblings" not in item:
1327
+ item["siblings"] = None
1328
+ model_info = ModelInfo(**item)
1329
+ if emissions_thresholds is None or _is_emission_within_treshold(model_info, *emissions_thresholds):
1330
+ yield model_info
1272
1331
 
1273
1332
  def _unpack_model_filter(self, model_filter: ModelFilter):
1274
1333
  """
@@ -1326,7 +1385,7 @@ class HfApi:
1326
1385
  filter: Union[DatasetFilter, str, Iterable[str], None] = None,
1327
1386
  author: Optional[str] = None,
1328
1387
  search: Optional[str] = None,
1329
- sort: Union[Literal["lastModified"], str, None] = None,
1388
+ sort: Union[Literal["last_modified"], str, None] = None,
1330
1389
  direction: Optional[Literal[-1]] = None,
1331
1390
  limit: Optional[int] = None,
1332
1391
  full: Optional[bool] = None,
@@ -1343,7 +1402,7 @@ class HfApi:
1343
1402
  A string which identify the author of the returned datasets.
1344
1403
  search (`str`, *optional*):
1345
1404
  A string that will be contained in the returned datasets.
1346
- sort (`Literal["lastModified"]` or `str`, *optional*):
1405
+ sort (`Literal["last_modified"]` or `str`, *optional*):
1347
1406
  The key with which to sort the resulting datasets. Possible
1348
1407
  values are the properties of the [`huggingface_hub.hf_api.DatasetInfo`] class.
1349
1408
  direction (`Literal[-1]` or `int`, *optional*):
@@ -1353,8 +1412,8 @@ class HfApi:
1353
1412
  The limit on the number of datasets fetched. Leaving this option
1354
1413
  to `None` fetches all datasets.
1355
1414
  full (`bool`, *optional*):
1356
- Whether to fetch all dataset data, including the `lastModified`
1357
- and the `cardData`. Can contain useful information such as the
1415
+ Whether to fetch all dataset data, including the `last_modified`,
1416
+ the `card_data` and the files. Can contain useful information such as the
1358
1417
  PapersWithCode ID.
1359
1418
  token (`bool` or `str`, *optional*):
1360
1419
  A valid authentication token (see https://huggingface.co/settings/token).
@@ -1375,16 +1434,12 @@ class HfApi:
1375
1434
  >>> # List all datasets
1376
1435
  >>> api.list_datasets()
1377
1436
 
1378
- >>> # Get all valid search arguments
1379
- >>> args = DatasetSearchArguments()
1380
1437
 
1381
1438
  >>> # List only the text classification datasets
1382
1439
  >>> api.list_datasets(filter="task_categories:text-classification")
1383
1440
  >>> # Using the `DatasetFilter`
1384
1441
  >>> filt = DatasetFilter(task_categories="text-classification")
1385
- >>> # With `DatasetSearchArguments`
1386
- >>> filt = DatasetFilter(task=args.task_categories.text_classification)
1387
- >>> api.list_models(filter=filt)
1442
+
1388
1443
 
1389
1444
  >>> # List only the datasets in russian for language modeling
1390
1445
  >>> api.list_datasets(
@@ -1392,11 +1447,7 @@ class HfApi:
1392
1447
  ... )
1393
1448
  >>> # Using the `DatasetFilter`
1394
1449
  >>> filt = DatasetFilter(language="ru", task_ids="language-modeling")
1395
- >>> # With `DatasetSearchArguments`
1396
- >>> filt = DatasetFilter(
1397
- ... language=args.language.ru,
1398
- ... task_ids=args.task_ids.language_modeling,
1399
- ... )
1450
+
1400
1451
  >>> api.list_datasets(filter=filt)
1401
1452
  ```
1402
1453
 
@@ -1427,7 +1478,7 @@ class HfApi:
1427
1478
  if search is not None:
1428
1479
  params.update({"search": search})
1429
1480
  if sort is not None:
1430
- params.update({"sort": sort})
1481
+ params.update({"sort": "lastModified" if sort == "last_modified" else sort})
1431
1482
  if direction is not None:
1432
1483
  params.update({"direction": direction})
1433
1484
  if limit is not None:
@@ -1439,6 +1490,8 @@ class HfApi:
1439
1490
  if limit is not None:
1440
1491
  items = islice(items, limit) # Do not iterate over all pages
1441
1492
  for item in items:
1493
+ if "siblings" not in item:
1494
+ item["siblings"] = None
1442
1495
  yield DatasetInfo(**item)
1443
1496
 
1444
1497
  def _unpack_dataset_filter(self, dataset_filter: DatasetFilter):
@@ -1502,7 +1555,7 @@ class HfApi:
1502
1555
  filter: Union[str, Iterable[str], None] = None,
1503
1556
  author: Optional[str] = None,
1504
1557
  search: Optional[str] = None,
1505
- sort: Union[Literal["lastModified"], str, None] = None,
1558
+ sort: Union[Literal["last_modified"], str, None] = None,
1506
1559
  direction: Optional[Literal[-1]] = None,
1507
1560
  limit: Optional[int] = None,
1508
1561
  datasets: Union[str, Iterable[str], None] = None,
@@ -1521,7 +1574,7 @@ class HfApi:
1521
1574
  A string which identify the author of the returned Spaces.
1522
1575
  search (`str`, *optional*):
1523
1576
  A string that will be contained in the returned Spaces.
1524
- sort (`Literal["lastModified"]` or `str`, *optional*):
1577
+ sort (`Literal["last_modified"]` or `str`, *optional*):
1525
1578
  The key with which to sort the resulting Spaces. Possible
1526
1579
  values are the properties of the [`huggingface_hub.hf_api.SpaceInfo`]` class.
1527
1580
  direction (`Literal[-1]` or `int`, *optional*):
@@ -1539,8 +1592,8 @@ class HfApi:
1539
1592
  linked (`bool`, *optional*):
1540
1593
  Whether to return Spaces that make use of either a model or a dataset.
1541
1594
  full (`bool`, *optional*):
1542
- Whether to fetch all Spaces data, including the `lastModified`
1543
- and the `cardData`.
1595
+ Whether to fetch all Spaces data, including the `last_modified`, `siblings`
1596
+ and `card_data` fields.
1544
1597
  token (`bool` or `str`, *optional*):
1545
1598
  A valid authentication token (see https://huggingface.co/settings/token).
1546
1599
  If `None` or `True` and machine is logged in (through `huggingface-cli login`
@@ -1560,7 +1613,7 @@ class HfApi:
1560
1613
  if search is not None:
1561
1614
  params.update({"search": search})
1562
1615
  if sort is not None:
1563
- params.update({"sort": sort})
1616
+ params.update({"sort": "lastModified" if sort == "last_modified" else sort})
1564
1617
  if direction is not None:
1565
1618
  params.update({"direction": direction})
1566
1619
  if limit is not None:
@@ -1578,6 +1631,8 @@ class HfApi:
1578
1631
  if limit is not None:
1579
1632
  items = islice(items, limit) # Do not iterate over all pages
1580
1633
  for item in items:
1634
+ if "siblings" not in item:
1635
+ item["siblings"] = None
1581
1636
  yield SpaceInfo(**item)
1582
1637
 
1583
1638
  @validate_hf_hub_args
@@ -1865,8 +1920,8 @@ class HfApi:
1865
1920
  params["blobs"] = True
1866
1921
  r = get_session().get(path, headers=headers, timeout=timeout, params=params)
1867
1922
  hf_raise_for_status(r)
1868
- d = r.json()
1869
- return ModelInfo(**d)
1923
+ data = r.json()
1924
+ return ModelInfo(**data)
1870
1925
 
1871
1926
  @validate_hf_hub_args
1872
1927
  def dataset_info(
@@ -1928,8 +1983,8 @@ class HfApi:
1928
1983
 
1929
1984
  r = get_session().get(path, headers=headers, timeout=timeout, params=params)
1930
1985
  hf_raise_for_status(r)
1931
- d = r.json()
1932
- return DatasetInfo(**d)
1986
+ data = r.json()
1987
+ return DatasetInfo(**data)
1933
1988
 
1934
1989
  @validate_hf_hub_args
1935
1990
  def space_info(
@@ -1991,8 +2046,8 @@ class HfApi:
1991
2046
 
1992
2047
  r = get_session().get(path, headers=headers, timeout=timeout, params=params)
1993
2048
  hf_raise_for_status(r)
1994
- d = r.json()
1995
- return SpaceInfo(**d)
2049
+ data = r.json()
2050
+ return SpaceInfo(**data)
1996
2051
 
1997
2052
  @validate_hf_hub_args
1998
2053
  def repo_info(
@@ -2231,8 +2286,8 @@ class HfApi:
2231
2286
  <generator object HfApi.list_files_info at 0x7f93b848e730>
2232
2287
  >>> list(files_info)
2233
2288
  [
2234
- RepoFile: {"blob_id": "43bd404b159de6fba7c2f4d3264347668d43af25", "lfs": None, "rfilename": "README.md", "size": 391},
2235
- RepoFile: {"blob_id": "2f9618c3a19b9a61add74f70bfb121335aeef666", "lfs": None, "rfilename": "config.json", "size": 554},
2289
+ RepoFile(path='README.md', size=391, blob_id='43bd404b159de6fba7c2f4d3264347668d43af25', lfs=None, last_commit=None, security=None),
2290
+ RepoFile(path='config.json', size=554, blob_id='2f9618c3a19b9a61add74f70bfb121335aeef666', lfs=None, last_commit=None, security=None)
2236
2291
  ]
2237
2292
  ```
2238
2293
 
@@ -2242,44 +2297,56 @@ class HfApi:
2242
2297
  >>> files_info = list_files_info("prompthero/openjourney-v4", expand=True)
2243
2298
  >>> list(files_info)
2244
2299
  [
2245
- RepoFile: {
2246
- {'blob_id': '815004af1a321eaed1d93f850b2e94b0c0678e42',
2247
- 'lastCommit': {'date': '2023-03-21T09:05:27.000Z',
2248
- 'id': '47b62b20b20e06b9de610e840282b7e6c3d51190',
2249
- 'title': 'Upload diffusers weights (#48)'},
2250
- 'lfs': None,
2251
- 'rfilename': 'model_index.json',
2252
- 'security': {'avScan': {'virusFound': False, 'virusNames': None},
2253
- 'blobId': '815004af1a321eaed1d93f850b2e94b0c0678e42',
2254
- 'name': 'model_index.json',
2255
- 'pickleImportScan': None,
2256
- 'repositoryId': 'models/prompthero/openjourney-v4',
2257
- 'safe': True},
2258
- 'size': 584}
2259
- },
2260
- RepoFile: {
2261
- {'blob_id': 'd2343d78b33ac03dade1d525538b02b130d0a3a0',
2262
- 'lastCommit': {'date': '2023-03-21T09:05:27.000Z',
2263
- 'id': '47b62b20b20e06b9de610e840282b7e6c3d51190',
2264
- 'title': 'Upload diffusers weights (#48)'},
2265
- 'lfs': {'pointer_size': 134,
2266
- 'sha256': 'dcf4507d99b88db73f3916e2a20169fe74ada6b5582e9af56cfa80f5f3141765',
2267
- 'size': 334711857},
2268
- 'rfilename': 'vae/diffusion_pytorch_model.bin',
2269
- 'security': {'avScan': {'virusFound': False, 'virusNames': None},
2270
- 'blobId': 'd2343d78b33ac03dade1d525538b02b130d0a3a0',
2271
- 'name': 'vae/diffusion_pytorch_model.bin',
2272
- 'pickleImportScan': {'highestSafetyLevel': 'innocuous',
2273
- 'imports': [{'module': 'torch._utils',
2274
- 'name': '_rebuild_tensor_v2',
2275
- 'safety': 'innocuous'},
2276
- {'module': 'collections', 'name': 'OrderedDict', 'safety': 'innocuous'},
2277
- {'module': 'torch', 'name': 'FloatStorage', 'safety': 'innocuous'}]},
2278
- 'repositoryId': 'models/prompthero/openjourney-v4',
2279
- 'safe': True},
2280
- 'size': 334711857}
2281
- },
2282
- (...)
2300
+ RepoFile(
2301
+ path='safety_checker/pytorch_model.bin',
2302
+ size=1216064769,
2303
+ blob_id='c8835557a0d3af583cb06c7c154b7e54a069c41d',
2304
+ lfs={
2305
+ 'size': 1216064769,
2306
+ 'sha256': '16d28f2b37109f222cdc33620fdd262102ac32112be0352a7f77e9614b35a394',
2307
+ 'pointer_size': 135
2308
+ },
2309
+ last_commit={
2310
+ 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190',
2311
+ 'title': 'Upload diffusers weights (#48)',
2312
+ 'date': datetime.datetime(2023, 3, 21, 10, 5, 27, tzinfo=datetime.timezone.utc)
2313
+ },
2314
+ security={
2315
+ 'safe': True,
2316
+ 'av_scan': {
2317
+ 'virusFound': False,
2318
+ 'virusNames': None
2319
+ },
2320
+ 'pickle_import_scan': {
2321
+ 'highestSafetyLevel': 'innocuous',
2322
+ 'imports': [
2323
+ {'module': 'torch', 'name': 'FloatStorage', 'safety': 'innocuous'},
2324
+ {'module': 'collections', 'name': 'OrderedDict', 'safety': 'innocuous'},
2325
+ {'module': 'torch', 'name': 'LongStorage', 'safety': 'innocuous'},
2326
+ {'module': 'torch._utils', 'name': '_rebuild_tensor_v2', 'safety': 'innocuous'}
2327
+ ]
2328
+ }
2329
+ }
2330
+ ),
2331
+ RepoFile(
2332
+ path='scheduler/scheduler_config.json',
2333
+ size=465,
2334
+ blob_id='70d55e3e9679f41cbc66222831b38d5abce683dd',
2335
+ lfs=None,
2336
+ last_commit={
2337
+ 'oid': '47b62b20b20e06b9de610e840282b7e6c3d51190',
2338
+ 'title': 'Upload diffusers weights (#48)',
2339
+ 'date': datetime.datetime(2023, 3, 21, 10, 5, 27, tzinfo=datetime.timezone.utc)},
2340
+ security={
2341
+ 'safe': True,
2342
+ 'av_scan': {
2343
+ 'virusFound': False,
2344
+ 'virusNames': None
2345
+ },
2346
+ 'pickle_import_scan': None
2347
+ }
2348
+ ),
2349
+ ...
2283
2350
  ]
2284
2351
  ```
2285
2352
 
@@ -2287,14 +2354,14 @@ class HfApi:
2287
2354
 
2288
2355
  ```py
2289
2356
  >>> from huggingface_hub import list_files_info
2290
- >>> [info.rfilename for info in list_files_info("stabilityai/stable-diffusion-2", "vae") if info.lfs is not None]
2357
+ >>> [info.path for info in list_files_info("stabilityai/stable-diffusion-2", "vae") if info.lfs is not None]
2291
2358
  ['vae/diffusion_pytorch_model.bin', 'vae/diffusion_pytorch_model.safetensors']
2292
2359
  ```
2293
2360
 
2294
2361
  List all files on a repo.
2295
2362
  ```py
2296
2363
  >>> from huggingface_hub import list_files_info
2297
- >>> [info.rfilename for info in list_files_info("glue", repo_type="dataset")]
2364
+ >>> [info.path for info in list_files_info("glue", repo_type="dataset")]
2298
2365
  ['.gitattributes', 'README.md', 'dataset_infos.json', 'glue.py']
2299
2366
  ```
2300
2367
  """
@@ -2305,14 +2372,24 @@ class HfApi:
2305
2372
  def _format_as_repo_file(info: Dict) -> RepoFile:
2306
2373
  # Quick alias very specific to the server return type of /paths-info and /tree endpoints. Let's keep this
2307
2374
  # logic here.
2308
- rfilename = info.pop("path")
2375
+ path = info.pop("path")
2309
2376
  size = info.pop("size")
2310
- blobId = info.pop("oid")
2377
+ blob_id = info.pop("oid")
2311
2378
  lfs = info.pop("lfs", None)
2379
+ last_commit = info.pop("lastCommit", None)
2380
+ security = info.pop("security", None)
2312
2381
  info.pop("type", None) # "file" or "folder" -> not needed in practice since we know it's a file
2382
+ if last_commit is not None:
2383
+ last_commit = BlobLastCommitInfo(
2384
+ oid=last_commit["id"], title=last_commit["title"], date=parse_datetime(last_commit["date"])
2385
+ )
2386
+ if security is not None:
2387
+ security = BlobSecurityInfo(
2388
+ safe=security["safe"], av_scan=security["avScan"], pickle_import_scan=security["pickleImportScan"]
2389
+ )
2313
2390
  if lfs is not None:
2314
2391
  lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"])
2315
- return RepoFile(rfilename=rfilename, size=size, blobId=blobId, lfs=lfs, **info)
2392
+ return RepoFile(path=path, size=size, blob_id=blob_id, lfs=lfs, last_commit=last_commit, security=security)
2316
2393
 
2317
2394
  folder_paths = []
2318
2395
  if paths is None:
@@ -2327,7 +2404,7 @@ class HfApi:
2327
2404
  f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}",
2328
2405
  data={
2329
2406
  "paths": paths if isinstance(paths, list) else [paths],
2330
- "expand": True,
2407
+ "expand": expand,
2331
2408
  },
2332
2409
  headers=headers,
2333
2410
  )
@@ -2440,10 +2517,14 @@ class HfApi:
2440
2517
  )
2441
2518
  hf_raise_for_status(response)
2442
2519
  data = response.json()
2520
+
2521
+ def _format_as_git_ref_info(item: Dict) -> GitRefInfo:
2522
+ return GitRefInfo(name=item["name"], ref=item["ref"], target_commit=item["targetCommit"])
2523
+
2443
2524
  return GitRefs(
2444
- branches=[GitRefInfo(item) for item in data["branches"]],
2445
- converts=[GitRefInfo(item) for item in data["converts"]],
2446
- tags=[GitRefInfo(item) for item in data["tags"]],
2525
+ branches=[_format_as_git_ref_info(item) for item in data["branches"]],
2526
+ converts=[_format_as_git_ref_info(item) for item in data["converts"]],
2527
+ tags=[_format_as_git_ref_info(item) for item in data["tags"]],
2447
2528
  )
2448
2529
 
2449
2530
  @validate_hf_hub_args
@@ -2516,7 +2597,15 @@ class HfApi:
2516
2597
 
2517
2598
  # Paginate over results and return the list of commits.
2518
2599
  return [
2519
- GitCommitInfo(item)
2600
+ GitCommitInfo(
2601
+ commit_id=item["id"],
2602
+ authors=[author["user"] for author in item["authors"]],
2603
+ created_at=parse_datetime(item["date"]),
2604
+ title=item["title"],
2605
+ message=item["message"],
2606
+ formatted_title=item.get("formatted", {}).get("title"),
2607
+ formatted_message=item.get("formatted", {}).get("message"),
2608
+ )
2520
2609
  for item in paginate(
2521
2610
  f"{self.endpoint}/api/{repo_type}s/{repo_id}/commits/{revision}",
2522
2611
  headers=self._build_hf_headers(token=token),
@@ -4235,7 +4324,7 @@ class HfApi:
4235
4324
  force_download: bool = False,
4236
4325
  force_filename: Optional[str] = None,
4237
4326
  proxies: Optional[Dict] = None,
4238
- etag_timeout: float = 10,
4327
+ etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
4239
4328
  resume_download: bool = False,
4240
4329
  token: Optional[Union[str, bool]] = None,
4241
4330
  local_files_only: bool = False,
@@ -4299,9 +4388,6 @@ class HfApi:
4299
4388
  revision (`str`, *optional*):
4300
4389
  An optional Git revision id which can be a branch name, a tag, or a
4301
4390
  commit hash.
4302
- endpoint (`str`, *optional*):
4303
- Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
4304
- environment variable.
4305
4391
  cache_dir (`str`, `Path`, *optional*):
4306
4392
  Path to the folder where cached files are stored.
4307
4393
  local_dir (`str` or `Path`, *optional*):
@@ -4402,7 +4488,7 @@ class HfApi:
4402
4488
  local_dir: Union[str, Path, None] = None,
4403
4489
  local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
4404
4490
  proxies: Optional[Dict] = None,
4405
- etag_timeout: float = 10,
4491
+ etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
4406
4492
  resume_download: bool = False,
4407
4493
  force_download: bool = False,
4408
4494
  token: Optional[Union[str, bool]] = None,
@@ -6022,6 +6108,446 @@ class HfApi:
6022
6108
  hf_raise_for_status(r)
6023
6109
  return SpaceRuntime(r.json())
6024
6110
 
6111
+ #######################
6112
+ # Inference Endpoints #
6113
+ #######################
6114
+
6115
+ def list_inference_endpoints(
6116
+ self, namespace: Optional[str] = None, *, token: Optional[str] = None
6117
+ ) -> List[InferenceEndpoint]:
6118
+ """Lists all inference endpoints for the given namespace.
6119
+
6120
+ Args:
6121
+ namespace (`str`, *optional*):
6122
+ The namespace to list endpoints for. Defaults to the current user. Set to `"*"` to list all endpoints
6123
+ from all namespaces (i.e. personal namespace and all orgs the user belongs to).
6124
+ token (`str`, *optional*):
6125
+ An authentication token (See https://huggingface.co/settings/token).
6126
+
6127
+ Returns:
6128
+ List[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace.
6129
+
6130
+ Example:
6131
+ ```python
6132
+ >>> from huggingface_hub import HfApi
6133
+ >>> api = HfApi()
6134
+ >>> api.list_inference_endpoints()
6135
+ [InferenceEndpoint(name='my-endpoint', ...), ...]
6136
+ ```
6137
+ """
6138
+ # Special case: list all endpoints for all namespaces the user has access to
6139
+ if namespace == "*":
6140
+ user = self.whoami(token=token)
6141
+
6142
+ # List personal endpoints first
6143
+ endpoints: List[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token))
6144
+
6145
+ # Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access)
6146
+ for org in user.get("orgs", []):
6147
+ try:
6148
+ endpoints += list_inference_endpoints(namespace=org["name"], token=token)
6149
+ except HfHubHTTPError as error:
6150
+ if error.response.status_code == 401: # Either no billing or user don't have access)
6151
+ logger.debug("Cannot list Inference Endpoints for org '%s': %s", org["name"], error)
6152
+ pass
6153
+
6154
+ return endpoints
6155
+
6156
+ # Normal case: list endpoints for a specific namespace
6157
+ namespace = namespace or self._get_namespace(token=token)
6158
+
6159
+ response = get_session().get(
6160
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}",
6161
+ headers=self._build_hf_headers(token=token),
6162
+ )
6163
+ hf_raise_for_status(response)
6164
+
6165
+ return [
6166
+ InferenceEndpoint.from_raw(endpoint, namespace=namespace, token=token)
6167
+ for endpoint in response.json()["items"]
6168
+ ]
6169
+
6170
+ def create_inference_endpoint(
6171
+ self,
6172
+ name: str,
6173
+ *,
6174
+ repository: str,
6175
+ framework: str,
6176
+ accelerator: str,
6177
+ instance_size: str,
6178
+ instance_type: str,
6179
+ region: str,
6180
+ vendor: str,
6181
+ account_id: Optional[str] = None,
6182
+ min_replica: int = 0,
6183
+ max_replica: int = 1,
6184
+ revision: Optional[str] = None,
6185
+ task: Optional[str] = None,
6186
+ type: InferenceEndpointType = InferenceEndpointType.PROTECTED,
6187
+ namespace: Optional[str] = None,
6188
+ token: Optional[str] = None,
6189
+ ) -> InferenceEndpoint:
6190
+ """Create a new Inference Endpoint.
6191
+
6192
+ Args:
6193
+ name (`str`):
6194
+ The unique name for the new Inference Endpoint.
6195
+ repository (`str`):
6196
+ The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
6197
+ framework (`str`):
6198
+ The machine learning framework used for the model (e.g. `"custom"`).
6199
+ accelerator (`str`):
6200
+ The hardware accelerator to be used for inference (e.g. `"cpu"`).
6201
+ instance_size (`str`):
6202
+ The size or type of the instance to be used for hosting the model (e.g. `"large"`).
6203
+ instance_type (`str`):
6204
+ The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`).
6205
+ region (`str`):
6206
+ The cloud region in which the Inference Endpoint will be created (e.g. `"us-east-1"`).
6207
+ vendor (`str`):
6208
+ The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. `"aws"`).
6209
+ account_id (`str`, *optional*):
6210
+ The account ID used to link a VPC to a private Inference Endpoint (if applicable).
6211
+ min_replica (`int`, *optional*):
6212
+ The minimum number of replicas (instances) to keep running for the Inference Endpoint. Defaults to 0.
6213
+ max_replica (`int`, *optional*):
6214
+ The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1.
6215
+ revision (`str`, *optional*):
6216
+ The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
6217
+ task (`str`, *optional*):
6218
+ The task on which to deploy the model (e.g. `"text-classification"`).
6219
+ type ([`InferenceEndpointType]`, *optional*):
6220
+ The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`.
6221
+ namespace (`str`, *optional*):
6222
+ The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace.
6223
+ token (`str`, *optional*):
6224
+ An authentication token (See https://huggingface.co/settings/token).
6225
+
6226
+ Returns:
6227
+ [`InferenceEndpoint`]: information about the updated Inference Endpoint.
6228
+
6229
+ Example:
6230
+ ```python
6231
+ >>> from huggingface_hub import HfApi
6232
+ >>> api = HfApi()
6233
+ >>> create_inference_endpoint(
6234
+ ... "my-endpoint-name",
6235
+ ... repository="gpt2",
6236
+ ... framework="pytorch",
6237
+ ... task="text-generation",
6238
+ ... accelerator="cpu",
6239
+ ... vendor="aws",
6240
+ ... region="us-east-1",
6241
+ ... type="protected",
6242
+ ... instance_size="medium",
6243
+ ... instance_type="c6i"
6244
+ ... )
6245
+ >>> endpoint
6246
+ InferenceEndpoint(name='my-endpoint-name', status="pending",...)
6247
+
6248
+ # Run inference on the endpoint
6249
+ >>> endpoint.client.text_generation(...)
6250
+ "..."
6251
+ ```
6252
+ """
6253
+ namespace = namespace or self._get_namespace(token=token)
6254
+
6255
+ payload: Dict = {
6256
+ "accountId": account_id,
6257
+ "compute": {
6258
+ "accelerator": accelerator,
6259
+ "instanceSize": instance_size,
6260
+ "instanceType": instance_type,
6261
+ "scaling": {
6262
+ "maxReplica": max_replica,
6263
+ "minReplica": min_replica,
6264
+ },
6265
+ },
6266
+ "model": {
6267
+ "framework": framework,
6268
+ "repository": repository,
6269
+ "revision": revision,
6270
+ "task": task,
6271
+ "image": {"huggingface": {}},
6272
+ },
6273
+ "name": name,
6274
+ "provider": {
6275
+ "region": region,
6276
+ "vendor": vendor,
6277
+ },
6278
+ "type": type,
6279
+ }
6280
+
6281
+ response = get_session().post(
6282
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}",
6283
+ headers=self._build_hf_headers(token=token),
6284
+ json=payload,
6285
+ )
6286
+ hf_raise_for_status(response)
6287
+
6288
+ return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
6289
+
6290
+ def get_inference_endpoint(
6291
+ self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
6292
+ ) -> InferenceEndpoint:
6293
+ """Get information about an Inference Endpoint.
6294
+
6295
+ Args:
6296
+ name (`str`):
6297
+ The name of the Inference Endpoint to retrieve information about.
6298
+ namespace (`str`, *optional*):
6299
+ The namespace in which the Inference Endpoint is located. Defaults to the current user.
6300
+ token (`str`, *optional*):
6301
+ An authentication token (See https://huggingface.co/settings/token).
6302
+
6303
+ Returns:
6304
+ [`InferenceEndpoint`]: information about the requested Inference Endpoint.
6305
+
6306
+ Example:
6307
+ ```python
6308
+ >>> from huggingface_hub import HfApi
6309
+ >>> api = HfApi()
6310
+ >>> endpoint = api.get_inference_endpoint("my-text-to-image")
6311
+ >>> endpoint
6312
+ InferenceEndpoint(name='my-text-to-image', ...)
6313
+
6314
+ # Get status
6315
+ >>> endpoint.status
6316
+ 'running'
6317
+ >>> endpoint.url
6318
+ 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
6319
+
6320
+ # Run inference
6321
+ >>> endpoint.client.text_to_image(...)
6322
+ ```
6323
+ """
6324
+ namespace = namespace or self._get_namespace(token=token)
6325
+
6326
+ response = get_session().get(
6327
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
6328
+ headers=self._build_hf_headers(token=token),
6329
+ )
6330
+ hf_raise_for_status(response)
6331
+
6332
+ return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
6333
+
6334
+ def update_inference_endpoint(
6335
+ self,
6336
+ name: str,
6337
+ *,
6338
+ # Compute update
6339
+ accelerator: Optional[str] = None,
6340
+ instance_size: Optional[str] = None,
6341
+ instance_type: Optional[str] = None,
6342
+ min_replica: Optional[int] = None,
6343
+ max_replica: Optional[int] = None,
6344
+ # Model update
6345
+ repository: Optional[str] = None,
6346
+ framework: Optional[str] = None,
6347
+ revision: Optional[str] = None,
6348
+ task: Optional[str] = None,
6349
+ # Other
6350
+ namespace: Optional[str] = None,
6351
+ token: Optional[str] = None,
6352
+ ) -> InferenceEndpoint:
6353
+ """Update an Inference Endpoint.
6354
+
6355
+ This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
6356
+ optional but at least one must be provided.
6357
+
6358
+ For convenience, you can also update an Inference Endpoint using [`InferenceEndpoint.update`].
6359
+
6360
+ Args:
6361
+ name (`str`):
6362
+ The name of the Inference Endpoint to update.
6363
+
6364
+ accelerator (`str`, *optional*):
6365
+ The hardware accelerator to be used for inference (e.g. `"cpu"`).
6366
+ instance_size (`str`, *optional*):
6367
+ The size or type of the instance to be used for hosting the model (e.g. `"large"`).
6368
+ instance_type (`str`, *optional*):
6369
+ The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`).
6370
+ min_replica (`int`, *optional*):
6371
+ The minimum number of replicas (instances) to keep running for the Inference Endpoint.
6372
+ max_replica (`int`, *optional*):
6373
+ The maximum number of replicas (instances) to scale to for the Inference Endpoint.
6374
+
6375
+ repository (`str`, *optional*):
6376
+ The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
6377
+ framework (`str`, *optional*):
6378
+ The machine learning framework used for the model (e.g. `"custom"`).
6379
+ revision (`str`, *optional*):
6380
+ The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
6381
+ task (`str`, *optional*):
6382
+ The task on which to deploy the model (e.g. `"text-classification"`).
6383
+
6384
+ namespace (`str`, *optional*):
6385
+ The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace.
6386
+ token (`str`, *optional*):
6387
+ An authentication token (See https://huggingface.co/settings/token).
6388
+
6389
+ Returns:
6390
+ [`InferenceEndpoint`]: information about the updated Inference Endpoint.
6391
+ """
6392
+ namespace = namespace or self._get_namespace(token=token)
6393
+
6394
+ payload: Dict = {}
6395
+ if any(value is not None for value in (accelerator, instance_size, instance_type, min_replica, max_replica)):
6396
+ payload["compute"] = {
6397
+ "accelerator": accelerator,
6398
+ "instanceSize": instance_size,
6399
+ "instanceType": instance_type,
6400
+ "scaling": {
6401
+ "maxReplica": max_replica,
6402
+ "minReplica": min_replica,
6403
+ },
6404
+ }
6405
+ if any(value is not None for value in (repository, framework, revision, task)):
6406
+ payload["model"] = {
6407
+ "framework": framework,
6408
+ "repository": repository,
6409
+ "revision": revision,
6410
+ "task": task,
6411
+ "image": {"huggingface": {}},
6412
+ }
6413
+
6414
+ response = get_session().put(
6415
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
6416
+ headers=self._build_hf_headers(token=token),
6417
+ json=payload,
6418
+ )
6419
+ hf_raise_for_status(response)
6420
+
6421
+ return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
6422
+
6423
+ def delete_inference_endpoint(
6424
+ self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
6425
+ ) -> None:
6426
+ """Delete an Inference Endpoint.
6427
+
6428
+ This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
6429
+ to pause it with [`pause_inference_endpoint`] or scale it to zero with [`scale_to_zero_inference_endpoint`].
6430
+
6431
+ For convenience, you can also delete an Inference Endpoint using [`InferenceEndpoint.delete`].
6432
+
6433
+ Args:
6434
+ name (`str`):
6435
+ The name of the Inference Endpoint to delete.
6436
+ namespace (`str`, *optional*):
6437
+ The namespace in which the Inference Endpoint is located. Defaults to the current user.
6438
+ token (`str`, *optional*):
6439
+ An authentication token (See https://huggingface.co/settings/token).
6440
+ """
6441
+ namespace = namespace or self._get_namespace(token=token)
6442
+ response = get_session().delete(
6443
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
6444
+ headers=self._build_hf_headers(token=token),
6445
+ )
6446
+ hf_raise_for_status(response)
6447
+
6448
+ def pause_inference_endpoint(
6449
+ self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
6450
+ ) -> InferenceEndpoint:
6451
+ """Pause an Inference Endpoint.
6452
+
6453
+ A paused Inference Endpoint will not be charged. It can be resumed at any time using [`resume_inference_endpoint`].
6454
+ This is different than scaling the Inference Endpoint to zero with [`scale_to_zero_inference_endpoint`], which
6455
+ would be automatically restarted when a request is made to it.
6456
+
6457
+ For convenience, you can also pause an Inference Endpoint using [`pause_inference_endpoint`].
6458
+
6459
+ Args:
6460
+ name (`str`):
6461
+ The name of the Inference Endpoint to pause.
6462
+ namespace (`str`, *optional*):
6463
+ The namespace in which the Inference Endpoint is located. Defaults to the current user.
6464
+ token (`str`, *optional*):
6465
+ An authentication token (See https://huggingface.co/settings/token).
6466
+
6467
+ Returns:
6468
+ [`InferenceEndpoint`]: information about the paused Inference Endpoint.
6469
+ """
6470
+ namespace = namespace or self._get_namespace(token=token)
6471
+
6472
+ response = get_session().post(
6473
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause",
6474
+ headers=self._build_hf_headers(token=token),
6475
+ )
6476
+ hf_raise_for_status(response)
6477
+
6478
+ return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
6479
+
6480
+ def resume_inference_endpoint(
6481
+ self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
6482
+ ) -> InferenceEndpoint:
6483
+ """Resume an Inference Endpoint.
6484
+
6485
+ For convenience, you can also resume an Inference Endpoint using [`InferenceEndpoint.resume`].
6486
+
6487
+ Args:
6488
+ name (`str`):
6489
+ The name of the Inference Endpoint to resume.
6490
+ namespace (`str`, *optional*):
6491
+ The namespace in which the Inference Endpoint is located. Defaults to the current user.
6492
+ token (`str`, *optional*):
6493
+ An authentication token (See https://huggingface.co/settings/token).
6494
+
6495
+ Returns:
6496
+ [`InferenceEndpoint`]: information about the resumed Inference Endpoint.
6497
+ """
6498
+ namespace = namespace or self._get_namespace(token=token)
6499
+
6500
+ response = get_session().post(
6501
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume",
6502
+ headers=self._build_hf_headers(token=token),
6503
+ )
6504
+ hf_raise_for_status(response)
6505
+
6506
+ return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
6507
+
6508
+ def scale_to_zero_inference_endpoint(
6509
+ self, name: str, *, namespace: Optional[str] = None, token: Optional[str] = None
6510
+ ) -> InferenceEndpoint:
6511
+ """Scale Inference Endpoint to zero.
6512
+
6513
+ An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a
6514
+ cold start delay. This is different than pausing the Inference Endpoint with [`pause_inference_endpoint`], which
6515
+ would require a manual resume with [`resume_inference_endpoint`].
6516
+
6517
+ For convenience, you can also scale an Inference Endpoint to zero using [`InferenceEndpoint.scale_to_zero`].
6518
+
6519
+ Args:
6520
+ name (`str`):
6521
+ The name of the Inference Endpoint to scale to zero.
6522
+ namespace (`str`, *optional*):
6523
+ The namespace in which the Inference Endpoint is located. Defaults to the current user.
6524
+ token (`str`, *optional*):
6525
+ An authentication token (See https://huggingface.co/settings/token).
6526
+
6527
+ Returns:
6528
+ [`InferenceEndpoint`]: information about the scaled-to-zero Inference Endpoint.
6529
+ """
6530
+ namespace = namespace or self._get_namespace(token=token)
6531
+
6532
+ response = get_session().post(
6533
+ f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero",
6534
+ headers=self._build_hf_headers(token=token),
6535
+ )
6536
+ hf_raise_for_status(response)
6537
+
6538
+ return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token)
6539
+
6540
+ def _get_namespace(self, token: Optional[str] = None) -> str:
6541
+ """Get the default namespace for the current user."""
6542
+ me = self.whoami(token=token)
6543
+ if me["type"] == "user":
6544
+ return me["name"]
6545
+ else:
6546
+ raise ValueError(
6547
+ "Cannot determine default namespace. You must provide a 'namespace' as input or be logged in as a"
6548
+ " user."
6549
+ )
6550
+
6025
6551
  ########################
6026
6552
  # Collection Endpoints #
6027
6553
  ########################
@@ -6047,21 +6573,20 @@ class HfApi:
6047
6573
  >>> len(collection.items)
6048
6574
  37
6049
6575
  >>> collection.items[0]
6050
- CollectionItem: {
6051
- {'item_object_id': '6507f6d5423b46492ee1413e',
6052
- 'item_id': 'TheBloke/TigerBot-70B-Chat-GPTQ',
6053
- 'author': 'TheBloke',
6054
- 'item_type': 'model',
6055
- 'lastModified': '2023-09-19T12:55:21.000Z',
6056
- (...)
6057
- }}
6576
+ CollectionItem(
6577
+ item_object_id='651446103cd773a050bf64c2',
6578
+ item_id='TheBloke/U-Amethyst-20B-AWQ',
6579
+ item_type='model',
6580
+ position=88,
6581
+ note=None
6582
+ )
6058
6583
  ```
6059
6584
  """
6060
6585
  r = get_session().get(
6061
6586
  f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token)
6062
6587
  )
6063
6588
  hf_raise_for_status(r)
6064
- return Collection(r.json(), endpoint=self.endpoint)
6589
+ return Collection(**{**r.json(), "endpoint": self.endpoint})
6065
6590
 
6066
6591
  def create_collection(
6067
6592
  self,
@@ -6126,7 +6651,7 @@ class HfApi:
6126
6651
  return self.get_collection(slug, token=token)
6127
6652
  else:
6128
6653
  raise
6129
- return Collection(r.json(), endpoint=self.endpoint)
6654
+ return Collection(**{**r.json(), "endpoint": self.endpoint})
6130
6655
 
6131
6656
  def update_collection_metadata(
6132
6657
  self,
@@ -6191,7 +6716,7 @@ class HfApi:
6191
6716
  json={key: value for key, value in payload.items() if value is not None},
6192
6717
  )
6193
6718
  hf_raise_for_status(r)
6194
- return Collection(r.json()["data"], endpoint=self.endpoint)
6719
+ return Collection(**{**r.json()["data"], "endpoint": self.endpoint})
6195
6720
 
6196
6721
  def delete_collection(
6197
6722
  self, collection_slug: str, *, missing_ok: bool = False, token: Optional[str] = None
@@ -6299,7 +6824,7 @@ class HfApi:
6299
6824
  return self.get_collection(collection_slug, token=token)
6300
6825
  else:
6301
6826
  raise
6302
- return Collection(r.json(), endpoint=self.endpoint)
6827
+ return Collection(**{**r.json(), "endpoint": self.endpoint})
6303
6828
 
6304
6829
  def update_collection_item(
6305
6830
  self,
@@ -6317,7 +6842,7 @@ class HfApi:
6317
6842
  Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`.
6318
6843
  item_object_id (`str`):
6319
6844
  ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id).
6320
- It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0]._id`.
6845
+ It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`.
6321
6846
  note (`str`, *optional*):
6322
6847
  A note to attach to the item in the collection. The maximum size for a note is 500 characters.
6323
6848
  position (`int`, *optional*):
@@ -6595,6 +7120,16 @@ duplicate_space = api.duplicate_space
6595
7120
  request_space_storage = api.request_space_storage
6596
7121
  delete_space_storage = api.delete_space_storage
6597
7122
 
7123
+ # Inference Endpoint API
7124
+ list_inference_endpoints = api.list_inference_endpoints
7125
+ create_inference_endpoint = api.create_inference_endpoint
7126
+ get_inference_endpoint = api.get_inference_endpoint
7127
+ update_inference_endpoint = api.update_inference_endpoint
7128
+ delete_inference_endpoint = api.delete_inference_endpoint
7129
+ pause_inference_endpoint = api.pause_inference_endpoint
7130
+ resume_inference_endpoint = api.resume_inference_endpoint
7131
+ scale_to_zero_inference_endpoint = api.scale_to_zero_inference_endpoint
7132
+
6598
7133
  # Collections API
6599
7134
  get_collection = api.get_collection
6600
7135
  create_collection = api.create_collection