huggingface-hub 0.33.4__py3-none-any.whl → 0.34.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of huggingface-hub might be problematic. Click here for more details.
- huggingface_hub/__init__.py +47 -1
- huggingface_hub/_commit_api.py +21 -28
- huggingface_hub/_jobs_api.py +145 -0
- huggingface_hub/_local_folder.py +7 -1
- huggingface_hub/_login.py +5 -5
- huggingface_hub/_oauth.py +1 -1
- huggingface_hub/_snapshot_download.py +11 -6
- huggingface_hub/_upload_large_folder.py +46 -23
- huggingface_hub/cli/__init__.py +27 -0
- huggingface_hub/cli/_cli_utils.py +69 -0
- huggingface_hub/cli/auth.py +210 -0
- huggingface_hub/cli/cache.py +405 -0
- huggingface_hub/cli/download.py +181 -0
- huggingface_hub/cli/hf.py +66 -0
- huggingface_hub/cli/jobs.py +522 -0
- huggingface_hub/cli/lfs.py +198 -0
- huggingface_hub/cli/repo.py +243 -0
- huggingface_hub/cli/repo_files.py +128 -0
- huggingface_hub/cli/system.py +52 -0
- huggingface_hub/cli/upload.py +316 -0
- huggingface_hub/cli/upload_large_folder.py +132 -0
- huggingface_hub/commands/_cli_utils.py +5 -0
- huggingface_hub/commands/delete_cache.py +3 -1
- huggingface_hub/commands/download.py +4 -0
- huggingface_hub/commands/env.py +3 -0
- huggingface_hub/commands/huggingface_cli.py +2 -0
- huggingface_hub/commands/repo.py +4 -0
- huggingface_hub/commands/repo_files.py +4 -0
- huggingface_hub/commands/scan_cache.py +3 -1
- huggingface_hub/commands/tag.py +3 -1
- huggingface_hub/commands/upload.py +4 -0
- huggingface_hub/commands/upload_large_folder.py +3 -1
- huggingface_hub/commands/user.py +11 -1
- huggingface_hub/commands/version.py +3 -0
- huggingface_hub/constants.py +1 -0
- huggingface_hub/file_download.py +16 -5
- huggingface_hub/hf_api.py +519 -7
- huggingface_hub/hf_file_system.py +8 -16
- huggingface_hub/hub_mixin.py +3 -3
- huggingface_hub/inference/_client.py +38 -39
- huggingface_hub/inference/_common.py +44 -14
- huggingface_hub/inference/_generated/_async_client.py +50 -51
- huggingface_hub/inference/_generated/types/__init__.py +1 -0
- huggingface_hub/inference/_generated/types/image_to_video.py +60 -0
- huggingface_hub/inference/_mcp/cli.py +36 -18
- huggingface_hub/inference/_mcp/constants.py +8 -0
- huggingface_hub/inference/_mcp/types.py +3 -0
- huggingface_hub/inference/_providers/__init__.py +4 -1
- huggingface_hub/inference/_providers/_common.py +3 -6
- huggingface_hub/inference/_providers/fal_ai.py +85 -42
- huggingface_hub/inference/_providers/hf_inference.py +17 -9
- huggingface_hub/inference/_providers/replicate.py +19 -1
- huggingface_hub/keras_mixin.py +2 -2
- huggingface_hub/repocard.py +1 -1
- huggingface_hub/repository.py +2 -2
- huggingface_hub/utils/_auth.py +1 -1
- huggingface_hub/utils/_cache_manager.py +2 -2
- huggingface_hub/utils/_dotenv.py +51 -0
- huggingface_hub/utils/_headers.py +1 -1
- huggingface_hub/utils/_runtime.py +1 -1
- huggingface_hub/utils/_xet.py +6 -2
- huggingface_hub/utils/_xet_progress_reporting.py +141 -0
- {huggingface_hub-0.33.4.dist-info → huggingface_hub-0.34.0rc0.dist-info}/METADATA +7 -8
- {huggingface_hub-0.33.4.dist-info → huggingface_hub-0.34.0rc0.dist-info}/RECORD +68 -51
- {huggingface_hub-0.33.4.dist-info → huggingface_hub-0.34.0rc0.dist-info}/entry_points.txt +1 -0
- {huggingface_hub-0.33.4.dist-info → huggingface_hub-0.34.0rc0.dist-info}/LICENSE +0 -0
- {huggingface_hub-0.33.4.dist-info → huggingface_hub-0.34.0rc0.dist-info}/WHEEL +0 -0
- {huggingface_hub-0.33.4.dist-info → huggingface_hub-0.34.0rc0.dist-info}/top_level.txt +0 -0
|
@@ -33,6 +33,7 @@ from ._local_folder import LocalUploadFileMetadata, LocalUploadFilePaths, get_lo
|
|
|
33
33
|
from .constants import DEFAULT_REVISION, REPO_TYPES
|
|
34
34
|
from .utils import DEFAULT_IGNORE_PATTERNS, filter_repo_objects, tqdm
|
|
35
35
|
from .utils._cache_manager import _format_size
|
|
36
|
+
from .utils._runtime import is_xet_available
|
|
36
37
|
from .utils.sha import sha_fileobj
|
|
37
38
|
|
|
38
39
|
|
|
@@ -45,6 +46,9 @@ WAITING_TIME_IF_NO_TASKS = 10 # seconds
|
|
|
45
46
|
MAX_NB_FILES_FETCH_UPLOAD_MODE = 100
|
|
46
47
|
COMMIT_SIZE_SCALE: List[int] = [20, 50, 75, 100, 125, 200, 250, 400, 600, 1000]
|
|
47
48
|
|
|
49
|
+
UPLOAD_BATCH_SIZE_XET = 256 # Max 256 files per upload batch for XET-enabled repos
|
|
50
|
+
UPLOAD_BATCH_SIZE_LFS = 1 # Otherwise, batches of 1 for regular LFS upload
|
|
51
|
+
|
|
48
52
|
|
|
49
53
|
def upload_large_folder_internal(
|
|
50
54
|
api: "HfApi",
|
|
@@ -93,6 +97,17 @@ def upload_large_folder_internal(
|
|
|
93
97
|
repo_url = api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)
|
|
94
98
|
logger.info(f"Repo created: {repo_url}")
|
|
95
99
|
repo_id = repo_url.repo_id
|
|
100
|
+
# 2.1 Check if xet is enabled to set batch file upload size
|
|
101
|
+
is_xet_enabled = (
|
|
102
|
+
is_xet_available()
|
|
103
|
+
and api.repo_info(
|
|
104
|
+
repo_id=repo_id,
|
|
105
|
+
repo_type=repo_type,
|
|
106
|
+
revision=revision,
|
|
107
|
+
expand="xetEnabled",
|
|
108
|
+
).xet_enabled
|
|
109
|
+
)
|
|
110
|
+
upload_batch_size = UPLOAD_BATCH_SIZE_XET if is_xet_enabled else UPLOAD_BATCH_SIZE_LFS
|
|
96
111
|
|
|
97
112
|
# 3. List files to upload
|
|
98
113
|
filtered_paths_list = filter_repo_objects(
|
|
@@ -110,7 +125,7 @@ def upload_large_folder_internal(
|
|
|
110
125
|
]
|
|
111
126
|
|
|
112
127
|
# 4. Start workers
|
|
113
|
-
status = LargeUploadStatus(items)
|
|
128
|
+
status = LargeUploadStatus(items, upload_batch_size)
|
|
114
129
|
threads = [
|
|
115
130
|
threading.Thread(
|
|
116
131
|
target=_worker_job,
|
|
@@ -168,7 +183,7 @@ JOB_ITEM_T = Tuple[LocalUploadFilePaths, LocalUploadFileMetadata]
|
|
|
168
183
|
class LargeUploadStatus:
|
|
169
184
|
"""Contains information, queues and tasks for a large upload process."""
|
|
170
185
|
|
|
171
|
-
def __init__(self, items: List[JOB_ITEM_T]):
|
|
186
|
+
def __init__(self, items: List[JOB_ITEM_T], upload_batch_size: int = 1):
|
|
172
187
|
self.items = items
|
|
173
188
|
self.queue_sha256: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
|
|
174
189
|
self.queue_get_upload_mode: "queue.Queue[JOB_ITEM_T]" = queue.Queue()
|
|
@@ -179,6 +194,7 @@ class LargeUploadStatus:
|
|
|
179
194
|
self.nb_workers_sha256: int = 0
|
|
180
195
|
self.nb_workers_get_upload_mode: int = 0
|
|
181
196
|
self.nb_workers_preupload_lfs: int = 0
|
|
197
|
+
self.upload_batch_size: int = upload_batch_size
|
|
182
198
|
self.nb_workers_commit: int = 0
|
|
183
199
|
self.nb_workers_waiting: int = 0
|
|
184
200
|
self.last_commit_attempt: Optional[float] = None
|
|
@@ -353,16 +369,17 @@ def _worker_job(
|
|
|
353
369
|
status.nb_workers_get_upload_mode -= 1
|
|
354
370
|
|
|
355
371
|
elif job == WorkerJob.PREUPLOAD_LFS:
|
|
356
|
-
item = items[0] # single item
|
|
357
372
|
try:
|
|
358
|
-
_preupload_lfs(
|
|
359
|
-
|
|
373
|
+
_preupload_lfs(items, api=api, repo_id=repo_id, repo_type=repo_type, revision=revision)
|
|
374
|
+
for item in items:
|
|
375
|
+
status.queue_commit.put(item)
|
|
360
376
|
except KeyboardInterrupt:
|
|
361
377
|
raise
|
|
362
378
|
except Exception as e:
|
|
363
379
|
logger.error(f"Failed to preupload LFS: {e}")
|
|
364
380
|
traceback.format_exc()
|
|
365
|
-
|
|
381
|
+
for item in items:
|
|
382
|
+
status.queue_preupload_lfs.put(item)
|
|
366
383
|
|
|
367
384
|
with status.lock:
|
|
368
385
|
status.nb_workers_preupload_lfs -= 1
|
|
@@ -417,11 +434,11 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob,
|
|
|
417
434
|
logger.debug(f"Job: get upload mode (>{MAX_NB_FILES_FETCH_UPLOAD_MODE} files ready)")
|
|
418
435
|
return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE))
|
|
419
436
|
|
|
420
|
-
# 4. Preupload LFS file if at least
|
|
421
|
-
elif status.queue_preupload_lfs.qsize()
|
|
437
|
+
# 4. Preupload LFS file if at least `status.upload_batch_size` files and no worker is preuploading LFS
|
|
438
|
+
elif status.queue_preupload_lfs.qsize() >= status.upload_batch_size and status.nb_workers_preupload_lfs == 0:
|
|
422
439
|
status.nb_workers_preupload_lfs += 1
|
|
423
440
|
logger.debug("Job: preupload LFS (no other worker preuploading LFS)")
|
|
424
|
-
return (WorkerJob.PREUPLOAD_LFS,
|
|
441
|
+
return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size))
|
|
425
442
|
|
|
426
443
|
# 5. Compute sha256 if at least 1 file and no worker is computing sha256
|
|
427
444
|
elif status.queue_sha256.qsize() > 0 and status.nb_workers_sha256 == 0:
|
|
@@ -435,14 +452,14 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob,
|
|
|
435
452
|
logger.debug("Job: get upload mode (no other worker getting upload mode)")
|
|
436
453
|
return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE))
|
|
437
454
|
|
|
438
|
-
# 7. Preupload LFS file if at least
|
|
455
|
+
# 7. Preupload LFS file if at least `status.upload_batch_size` files
|
|
439
456
|
# Skip if hf_transfer is enabled and there is already a worker preuploading LFS
|
|
440
|
-
elif status.queue_preupload_lfs.qsize()
|
|
457
|
+
elif status.queue_preupload_lfs.qsize() >= status.upload_batch_size and (
|
|
441
458
|
status.nb_workers_preupload_lfs == 0 or not constants.HF_HUB_ENABLE_HF_TRANSFER
|
|
442
459
|
):
|
|
443
460
|
status.nb_workers_preupload_lfs += 1
|
|
444
461
|
logger.debug("Job: preupload LFS")
|
|
445
|
-
return (WorkerJob.PREUPLOAD_LFS,
|
|
462
|
+
return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size))
|
|
446
463
|
|
|
447
464
|
# 8. Compute sha256 if at least 1 file
|
|
448
465
|
elif status.queue_sha256.qsize() > 0:
|
|
@@ -456,7 +473,13 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob,
|
|
|
456
473
|
logger.debug("Job: get upload mode")
|
|
457
474
|
return (WorkerJob.GET_UPLOAD_MODE, _get_n(status.queue_get_upload_mode, MAX_NB_FILES_FETCH_UPLOAD_MODE))
|
|
458
475
|
|
|
459
|
-
# 10.
|
|
476
|
+
# 10. Preupload LFS file if at least 1 file
|
|
477
|
+
elif status.queue_preupload_lfs.qsize() > 0:
|
|
478
|
+
status.nb_workers_preupload_lfs += 1
|
|
479
|
+
logger.debug("Job: preupload LFS")
|
|
480
|
+
return (WorkerJob.PREUPLOAD_LFS, _get_n(status.queue_preupload_lfs, status.upload_batch_size))
|
|
481
|
+
|
|
482
|
+
# 11. Commit if at least 1 file and 1 min since last commit attempt
|
|
460
483
|
elif (
|
|
461
484
|
status.nb_workers_commit == 0
|
|
462
485
|
and status.queue_commit.qsize() > 0
|
|
@@ -467,7 +490,7 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob,
|
|
|
467
490
|
logger.debug("Job: commit (1 min since last commit attempt)")
|
|
468
491
|
return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk()))
|
|
469
492
|
|
|
470
|
-
#
|
|
493
|
+
# 12. Commit if at least 1 file all other queues are empty and all workers are waiting
|
|
471
494
|
# e.g. when it's the last commit
|
|
472
495
|
elif (
|
|
473
496
|
status.nb_workers_commit == 0
|
|
@@ -483,12 +506,12 @@ def _determine_next_job(status: LargeUploadStatus) -> Optional[Tuple[WorkerJob,
|
|
|
483
506
|
logger.debug("Job: commit")
|
|
484
507
|
return (WorkerJob.COMMIT, _get_n(status.queue_commit, status.target_chunk()))
|
|
485
508
|
|
|
486
|
-
#
|
|
509
|
+
# 13. If all queues are empty, exit
|
|
487
510
|
elif all(metadata.is_committed or metadata.should_ignore for _, metadata in status.items):
|
|
488
511
|
logger.info("All files have been processed! Exiting worker.")
|
|
489
512
|
return None
|
|
490
513
|
|
|
491
|
-
#
|
|
514
|
+
# 14. If no task is available, wait
|
|
492
515
|
else:
|
|
493
516
|
status.nb_workers_waiting += 1
|
|
494
517
|
logger.debug(f"No task available, waiting... ({WAITING_TIME_IF_NO_TASKS}s)")
|
|
@@ -531,19 +554,19 @@ def _get_upload_mode(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_t
|
|
|
531
554
|
metadata.save(paths)
|
|
532
555
|
|
|
533
556
|
|
|
534
|
-
def _preupload_lfs(
|
|
535
|
-
"""Preupload LFS
|
|
536
|
-
|
|
537
|
-
addition = _build_hacky_operation(item)
|
|
557
|
+
def _preupload_lfs(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
|
|
558
|
+
"""Preupload LFS files and update metadata."""
|
|
559
|
+
additions = [_build_hacky_operation(item) for item in items]
|
|
538
560
|
api.preupload_lfs_files(
|
|
539
561
|
repo_id=repo_id,
|
|
540
562
|
repo_type=repo_type,
|
|
541
563
|
revision=revision,
|
|
542
|
-
additions=
|
|
564
|
+
additions=additions,
|
|
543
565
|
)
|
|
544
566
|
|
|
545
|
-
metadata
|
|
546
|
-
|
|
567
|
+
for paths, metadata in items:
|
|
568
|
+
metadata.is_uploaded = True
|
|
569
|
+
metadata.save(paths)
|
|
547
570
|
|
|
548
571
|
|
|
549
572
|
def _commit(items: List[JOB_ITEM_T], api: "HfApi", repo_id: str, repo_type: str, revision: str) -> None:
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from argparse import _SubParsersAction
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseHuggingfaceCLICommand(ABC):
|
|
20
|
+
@staticmethod
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def register_subcommand(parser: _SubParsersAction):
|
|
23
|
+
raise NotImplementedError()
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def run(self):
|
|
27
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains a utility for good-looking prints."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import List, Union
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ANSI:
|
|
21
|
+
"""
|
|
22
|
+
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
_bold = "\u001b[1m"
|
|
26
|
+
_gray = "\u001b[90m"
|
|
27
|
+
_red = "\u001b[31m"
|
|
28
|
+
_reset = "\u001b[0m"
|
|
29
|
+
_yellow = "\u001b[33m"
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def bold(cls, s: str) -> str:
|
|
33
|
+
return cls._format(s, cls._bold)
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def gray(cls, s: str) -> str:
|
|
37
|
+
return cls._format(s, cls._gray)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def red(cls, s: str) -> str:
|
|
41
|
+
return cls._format(s, cls._bold + cls._red)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def yellow(cls, s: str) -> str:
|
|
45
|
+
return cls._format(s, cls._yellow)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def _format(cls, s: str, code: str) -> str:
|
|
49
|
+
if os.environ.get("NO_COLOR"):
|
|
50
|
+
# See https://no-color.org/
|
|
51
|
+
return s
|
|
52
|
+
return f"{code}{s}{cls._reset}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Inspired by:
|
|
58
|
+
|
|
59
|
+
- stackoverflow.com/a/8356620/593036
|
|
60
|
+
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
|
61
|
+
"""
|
|
62
|
+
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
|
63
|
+
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
|
64
|
+
lines = []
|
|
65
|
+
lines.append(row_format.format(*headers))
|
|
66
|
+
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
|
67
|
+
for row in rows:
|
|
68
|
+
lines.append(row_format.format(*row))
|
|
69
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Contains commands to authenticate to the Hugging Face Hub and interact with your repositories.
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
# login and save token locally.
|
|
18
|
+
hf auth login --token=hf_*** --add-to-git-credential
|
|
19
|
+
|
|
20
|
+
# switch between tokens
|
|
21
|
+
hf auth switch
|
|
22
|
+
|
|
23
|
+
# list all tokens
|
|
24
|
+
hf auth list
|
|
25
|
+
|
|
26
|
+
# logout from all tokens
|
|
27
|
+
hf auth logout
|
|
28
|
+
|
|
29
|
+
# check which account you are logged in as
|
|
30
|
+
hf auth whoami
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from argparse import _SubParsersAction
|
|
34
|
+
from typing import List, Optional
|
|
35
|
+
|
|
36
|
+
from requests.exceptions import HTTPError
|
|
37
|
+
|
|
38
|
+
from huggingface_hub.commands import BaseHuggingfaceCLICommand
|
|
39
|
+
from huggingface_hub.constants import ENDPOINT
|
|
40
|
+
from huggingface_hub.hf_api import HfApi
|
|
41
|
+
|
|
42
|
+
from .._login import auth_list, auth_switch, login, logout
|
|
43
|
+
from ..utils import get_stored_tokens, get_token, logging
|
|
44
|
+
from ._cli_utils import ANSI
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
logger = logging.get_logger(__name__)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
from InquirerPy import inquirer
|
|
51
|
+
from InquirerPy.base.control import Choice
|
|
52
|
+
|
|
53
|
+
_inquirer_py_available = True
|
|
54
|
+
except ImportError:
|
|
55
|
+
_inquirer_py_available = False
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AuthCommands(BaseHuggingfaceCLICommand):
|
|
59
|
+
@staticmethod
|
|
60
|
+
def register_subcommand(parser: _SubParsersAction):
|
|
61
|
+
# Create the main 'auth' command
|
|
62
|
+
auth_parser = parser.add_parser("auth", help="Manage authentication (login, logout, etc.).")
|
|
63
|
+
auth_subparsers = auth_parser.add_subparsers(help="Authentication subcommands")
|
|
64
|
+
|
|
65
|
+
# Add 'login' as a subcommand of 'auth'
|
|
66
|
+
login_parser = auth_subparsers.add_parser(
|
|
67
|
+
"login", help="Log in using a token from huggingface.co/settings/tokens"
|
|
68
|
+
)
|
|
69
|
+
login_parser.add_argument(
|
|
70
|
+
"--token",
|
|
71
|
+
type=str,
|
|
72
|
+
help="Token generated from https://huggingface.co/settings/tokens",
|
|
73
|
+
)
|
|
74
|
+
login_parser.add_argument(
|
|
75
|
+
"--add-to-git-credential",
|
|
76
|
+
action="store_true",
|
|
77
|
+
help="Optional: Save token to git credential helper.",
|
|
78
|
+
)
|
|
79
|
+
login_parser.set_defaults(func=lambda args: AuthLogin(args))
|
|
80
|
+
|
|
81
|
+
# Add 'logout' as a subcommand of 'auth'
|
|
82
|
+
logout_parser = auth_subparsers.add_parser("logout", help="Log out")
|
|
83
|
+
logout_parser.add_argument(
|
|
84
|
+
"--token-name",
|
|
85
|
+
type=str,
|
|
86
|
+
help="Optional: Name of the access token to log out from.",
|
|
87
|
+
)
|
|
88
|
+
logout_parser.set_defaults(func=lambda args: AuthLogout(args))
|
|
89
|
+
|
|
90
|
+
# Add 'whoami' as a subcommand of 'auth'
|
|
91
|
+
whoami_parser = auth_subparsers.add_parser(
|
|
92
|
+
"whoami", help="Find out which huggingface.co account you are logged in as."
|
|
93
|
+
)
|
|
94
|
+
whoami_parser.set_defaults(func=lambda args: AuthWhoami(args))
|
|
95
|
+
|
|
96
|
+
# Existing subcommands
|
|
97
|
+
auth_switch_parser = auth_subparsers.add_parser("switch", help="Switch between access tokens")
|
|
98
|
+
auth_switch_parser.add_argument(
|
|
99
|
+
"--token-name",
|
|
100
|
+
type=str,
|
|
101
|
+
help="Optional: Name of the access token to switch to.",
|
|
102
|
+
)
|
|
103
|
+
auth_switch_parser.add_argument(
|
|
104
|
+
"--add-to-git-credential",
|
|
105
|
+
action="store_true",
|
|
106
|
+
help="Optional: Save token to git credential helper.",
|
|
107
|
+
)
|
|
108
|
+
auth_switch_parser.set_defaults(func=lambda args: AuthSwitch(args))
|
|
109
|
+
|
|
110
|
+
auth_list_parser = auth_subparsers.add_parser("list", help="List all stored access tokens")
|
|
111
|
+
auth_list_parser.set_defaults(func=lambda args: AuthList(args))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BaseAuthCommand:
|
|
115
|
+
def __init__(self, args):
|
|
116
|
+
self.args = args
|
|
117
|
+
self._api = HfApi()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class AuthLogin(BaseAuthCommand):
|
|
121
|
+
def run(self):
|
|
122
|
+
logging.set_verbosity_info()
|
|
123
|
+
login(
|
|
124
|
+
token=self.args.token,
|
|
125
|
+
add_to_git_credential=self.args.add_to_git_credential,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class AuthLogout(BaseAuthCommand):
|
|
130
|
+
def run(self):
|
|
131
|
+
logging.set_verbosity_info()
|
|
132
|
+
logout(token_name=self.args.token_name)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class AuthSwitch(BaseAuthCommand):
|
|
136
|
+
def run(self):
|
|
137
|
+
logging.set_verbosity_info()
|
|
138
|
+
token_name = self.args.token_name
|
|
139
|
+
if token_name is None:
|
|
140
|
+
token_name = self._select_token_name()
|
|
141
|
+
|
|
142
|
+
if token_name is None:
|
|
143
|
+
print("No token name provided. Aborting.")
|
|
144
|
+
exit()
|
|
145
|
+
auth_switch(token_name, add_to_git_credential=self.args.add_to_git_credential)
|
|
146
|
+
|
|
147
|
+
def _select_token_name(self) -> Optional[str]:
|
|
148
|
+
token_names = list(get_stored_tokens().keys())
|
|
149
|
+
|
|
150
|
+
if not token_names:
|
|
151
|
+
logger.error("No stored tokens found. Please login first.")
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
if _inquirer_py_available:
|
|
155
|
+
return self._select_token_name_tui(token_names)
|
|
156
|
+
# if inquirer is not available, use a simpler terminal UI
|
|
157
|
+
print("Available stored tokens:")
|
|
158
|
+
for i, token_name in enumerate(token_names, 1):
|
|
159
|
+
print(f"{i}. {token_name}")
|
|
160
|
+
while True:
|
|
161
|
+
try:
|
|
162
|
+
choice = input("Enter the number of the token to switch to (or 'q' to quit): ")
|
|
163
|
+
if choice.lower() == "q":
|
|
164
|
+
return None
|
|
165
|
+
index = int(choice) - 1
|
|
166
|
+
if 0 <= index < len(token_names):
|
|
167
|
+
return token_names[index]
|
|
168
|
+
else:
|
|
169
|
+
print("Invalid selection. Please try again.")
|
|
170
|
+
except ValueError:
|
|
171
|
+
print("Invalid input. Please enter a number or 'q' to quit.")
|
|
172
|
+
|
|
173
|
+
def _select_token_name_tui(self, token_names: List[str]) -> Optional[str]:
|
|
174
|
+
choices = [Choice(token_name, name=token_name) for token_name in token_names]
|
|
175
|
+
try:
|
|
176
|
+
return inquirer.select(
|
|
177
|
+
message="Select a token to switch to:",
|
|
178
|
+
choices=choices,
|
|
179
|
+
default=None,
|
|
180
|
+
).execute()
|
|
181
|
+
except KeyboardInterrupt:
|
|
182
|
+
logger.info("Token selection cancelled.")
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class AuthList(BaseAuthCommand):
|
|
187
|
+
def run(self):
|
|
188
|
+
logging.set_verbosity_info()
|
|
189
|
+
auth_list()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class AuthWhoami(BaseAuthCommand):
|
|
193
|
+
def run(self):
|
|
194
|
+
token = get_token()
|
|
195
|
+
if token is None:
|
|
196
|
+
print("Not logged in")
|
|
197
|
+
exit()
|
|
198
|
+
try:
|
|
199
|
+
info = self._api.whoami(token)
|
|
200
|
+
print(info["name"])
|
|
201
|
+
orgs = [org["name"] for org in info["orgs"]]
|
|
202
|
+
if orgs:
|
|
203
|
+
print(ANSI.bold("orgs: "), ",".join(orgs))
|
|
204
|
+
|
|
205
|
+
if ENDPOINT != "https://huggingface.co":
|
|
206
|
+
print(f"Authenticated through private endpoint: {ENDPOINT}")
|
|
207
|
+
except HTTPError as e:
|
|
208
|
+
print(e)
|
|
209
|
+
print(ANSI.red(e.response.text))
|
|
210
|
+
exit(1)
|