huggingface-hub 0.13.3__py3-none-any.whl → 0.14.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of huggingface-hub might be problematic. Click here for more details.

Files changed (40) hide show
  1. huggingface_hub/__init__.py +59 -5
  2. huggingface_hub/_commit_api.py +26 -71
  3. huggingface_hub/_login.py +17 -16
  4. huggingface_hub/_multi_commits.py +305 -0
  5. huggingface_hub/_snapshot_download.py +4 -0
  6. huggingface_hub/_space_api.py +6 -0
  7. huggingface_hub/_webhooks_payload.py +124 -0
  8. huggingface_hub/_webhooks_server.py +362 -0
  9. huggingface_hub/commands/lfs.py +3 -5
  10. huggingface_hub/commands/user.py +0 -3
  11. huggingface_hub/community.py +21 -0
  12. huggingface_hub/constants.py +3 -0
  13. huggingface_hub/file_download.py +54 -13
  14. huggingface_hub/hf_api.py +666 -139
  15. huggingface_hub/hf_file_system.py +441 -0
  16. huggingface_hub/hub_mixin.py +1 -1
  17. huggingface_hub/inference_api.py +2 -4
  18. huggingface_hub/keras_mixin.py +1 -1
  19. huggingface_hub/lfs.py +196 -176
  20. huggingface_hub/repocard.py +2 -2
  21. huggingface_hub/repository.py +1 -1
  22. huggingface_hub/templates/modelcard_template.md +1 -1
  23. huggingface_hub/utils/__init__.py +8 -11
  24. huggingface_hub/utils/_errors.py +4 -4
  25. huggingface_hub/utils/_experimental.py +65 -0
  26. huggingface_hub/utils/_git_credential.py +1 -80
  27. huggingface_hub/utils/_http.py +85 -2
  28. huggingface_hub/utils/_pagination.py +4 -3
  29. huggingface_hub/utils/_paths.py +2 -0
  30. huggingface_hub/utils/_runtime.py +12 -0
  31. huggingface_hub/utils/_subprocess.py +22 -0
  32. huggingface_hub/utils/_telemetry.py +2 -4
  33. huggingface_hub/utils/tqdm.py +23 -18
  34. {huggingface_hub-0.13.3.dist-info → huggingface_hub-0.14.0.dev0.dist-info}/METADATA +5 -1
  35. huggingface_hub-0.14.0.dev0.dist-info/RECORD +61 -0
  36. {huggingface_hub-0.13.3.dist-info → huggingface_hub-0.14.0.dev0.dist-info}/entry_points.txt +3 -0
  37. huggingface_hub-0.13.3.dist-info/RECORD +0 -56
  38. {huggingface_hub-0.13.3.dist-info → huggingface_hub-0.14.0.dev0.dist-info}/LICENSE +0 -0
  39. {huggingface_hub-0.13.3.dist-info → huggingface_hub-0.14.0.dev0.dist-info}/WHEEL +0 -0
  40. {huggingface_hub-0.13.3.dist-info → huggingface_hub-0.14.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,362 @@
1
+ # coding=utf-8
2
+ # Copyright 2023-present, the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily."""
16
+ import atexit
17
+ import inspect
18
+ import os
19
+ from functools import wraps
20
+ from typing import Callable, Dict, Optional
21
+
22
+ from .utils import experimental, is_gradio_available
23
+
24
+
25
+ if not is_gradio_available():
26
+ raise ImportError(
27
+ "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio` first."
28
+ )
29
+
30
+
31
+ import gradio as gr
32
+ from fastapi import FastAPI, Request
33
+ from fastapi.responses import JSONResponse
34
+
35
+
36
+ _global_app: Optional["WebhooksServer"] = None
37
+ _is_local = os.getenv("SYSTEM") != "spaces"
38
+
39
+
40
+ @experimental
41
+ class WebhooksServer:
42
+ """
43
+ The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks.
44
+ These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to
45
+ the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `run` method has to be
46
+ called to start the app.
47
+
48
+ It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic
49
+ model that contains all the information about the webhook event. The data will be parsed automatically for you.
50
+
51
+ Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
52
+ WebhooksServer and deploy it on a Space.
53
+
54
+ <Tip warning={true}>
55
+
56
+ `WebhooksServer` is experimental. Its API is subject to change in the future.
57
+
58
+ </Tip>
59
+
60
+ <Tip warning={true}>
61
+
62
+ You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`).
63
+
64
+ </Tip>
65
+
66
+ Args:
67
+ ui (`gradio.Blocks`, optional):
68
+ A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions
69
+ about the configured webhooks is created.
70
+ webhook_secret (`str`, optional):
71
+ A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as
72
+ you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You
73
+ can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the
74
+ webhook endpoints are opened without any security.
75
+
76
+ Example:
77
+
78
+ ```python
79
+ import gradio as gr
80
+ from huggingface_hub import WebhooksServer, WebhookPayload
81
+
82
+ with gr.Blocks() as ui:
83
+ ...
84
+
85
+ app = WebhooksServer(ui=ui, webhook_secret="my_secret_key")
86
+
87
+ @app.add_webhook("/say_hello")
88
+ async def hello(payload: WebhookPayload):
89
+ return {"message": "hello"}
90
+
91
+ app.run()
92
+ ```
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ ui: Optional[gr.Blocks] = None,
98
+ webhook_secret: Optional[str] = None,
99
+ ) -> None:
100
+ self._ui = ui
101
+
102
+ self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
103
+ self.registered_webhooks: Dict[str, Callable] = {}
104
+ _warn_on_empty_secret(self.webhook_secret)
105
+
106
+ def add_webhook(self, path: Optional[str] = None) -> Callable:
107
+ """
108
+ Decorator to add a webhook to the [`WebhooksServer`] server.
109
+
110
+ Args:
111
+ path (`str`, optional):
112
+ The URL path to register the webhook function. If not provided, the function name will be used as the
113
+ path. In any case, all webhooks are registered under `/webhooks`.
114
+
115
+ Raises:
116
+ ValueError: If the provided path is already registered as a webhook.
117
+
118
+ Example:
119
+ ```python
120
+ from huggingface_hub import WebhooksServer, WebhookPayload
121
+
122
+ app = WebhooksServer()
123
+
124
+ @app.add_webhook
125
+ async def trigger_training(payload: WebhookPayload):
126
+ if payload.repo.type == "dataset" and payload.event.action == "update":
127
+ # Trigger a training job if a dataset is updated
128
+ ...
129
+
130
+ app.run()
131
+ ```
132
+ """
133
+ # Usage: directly as decorator. Example: `@app.add_webhook`
134
+ if callable(path):
135
+ # If path is a function, it means it was used as a decorator without arguments
136
+ return self.add_webhook()(path)
137
+
138
+ # Usage: provide a path. Example: `@app.add_webhook(...)`
139
+ @wraps(FastAPI.post)
140
+ def _inner_post(*args, **kwargs):
141
+ func = args[0]
142
+ abs_path = f"/webhooks/{(path or func.__name__).strip('/')}"
143
+ if abs_path in self.registered_webhooks:
144
+ raise ValueError(f"Webhook {abs_path} already exists.")
145
+ self.registered_webhooks[abs_path] = func
146
+
147
+ return _inner_post
148
+
149
+ def run(self) -> None:
150
+ """Starts the Gradio app with the FastAPI server and registers the webhooks."""
151
+ ui = self._ui or self._get_default_ui()
152
+
153
+ # Start Gradio App
154
+ # - as non-blocking so that webhooks can be added afterwards
155
+ # - as shared if launch locally (to debug webhooks)
156
+ self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, share=_is_local)
157
+
158
+ # Register webhooks to FastAPI app
159
+ for path, func in self.registered_webhooks.items():
160
+ # Add secret check if required
161
+ if self.webhook_secret is not None:
162
+ func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret)
163
+
164
+ # Add route to FastAPI app
165
+ self.fastapi_app.post(path)(func)
166
+
167
+ # Print instructions and block main thread
168
+ url = (ui.share_url or ui.local_url).strip("/")
169
+ message = "\nWebhooks are correctly setup and ready to use:"
170
+ message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks)
171
+ message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks."
172
+ print(message)
173
+
174
+ ui.block_thread()
175
+
176
+ def _get_default_ui(self) -> gr.Blocks:
177
+ """Default UI if not provided (lists webhooks and provides basic instructions)."""
178
+ with gr.Blocks() as ui:
179
+ gr.Markdown("# This is an app to process 🤗 Webhooks")
180
+ gr.Markdown(
181
+ "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on"
182
+ " specific repos or to all repos belonging to particular set of users/organizations (not just your"
183
+ " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to"
184
+ " know more about webhooks on the Huggingface Hub."
185
+ )
186
+ gr.Markdown(
187
+ f"{len(self.registered_webhooks)} webhook(s) are registered:"
188
+ + "\n\n"
189
+ + "\n ".join(
190
+ f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})"
191
+ for webhook_path, webhook in self.registered_webhooks.items()
192
+ )
193
+ )
194
+ gr.Markdown(
195
+ "Go to https://huggingface.co/settings/webhooks to setup your webhooks."
196
+ + "\nYou app is running locally. Please look at the logs to check the full URL you need to set."
197
+ if _is_local
198
+ else (
199
+ "\nThis app is running on a Space. You can find the corresponding URL in the options menu"
200
+ " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'."
201
+ )
202
+ )
203
+ return ui
204
+
205
+
206
+ @experimental
207
+ def webhook_endpoint(path: Optional[str] = None) -> Callable:
208
+ """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint.
209
+
210
+ This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret),
211
+ you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using
212
+ this decorator multiple times.
213
+
214
+ Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your
215
+ server and deploy it on a Space.
216
+
217
+ <Tip warning={true}>
218
+
219
+ `webhook_endpoint` is experimental. Its API is subject to change in the future.
220
+
221
+ </Tip>
222
+
223
+ <Tip warning={true}>
224
+
225
+ You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`).
226
+
227
+ </Tip>
228
+
229
+ Args:
230
+ path (`str`, optional):
231
+ The URL path to register the webhook function. If not provided, the function name will be used as the path.
232
+ In any case, all webhooks are registered under `/webhooks`.
233
+
234
+ Examples:
235
+ The default usage is to register a function as a webhook endpoint. The function name will be used as the path.
236
+ The server will be started automatically at exit (i.e. at the end of the script).
237
+
238
+ ```python
239
+ from huggingface_hub import webhook_endpoint, WebhookPayload
240
+
241
+ @webhook_endpoint
242
+ async def trigger_training(payload: WebhookPayload):
243
+ if payload.repo.type == "dataset" and payload.event.action == "update":
244
+ # Trigger a training job if a dataset is updated
245
+ ...
246
+
247
+ # Server is automatically started at the end of the script.
248
+ ```
249
+
250
+ Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you
251
+ are running it in a notebook.
252
+
253
+ ```python
254
+ from huggingface_hub import webhook_endpoint, WebhookPayload
255
+
256
+ @webhook_endpoint
257
+ async def trigger_training(payload: WebhookPayload):
258
+ if payload.repo.type == "dataset" and payload.event.action == "update":
259
+ # Trigger a training job if a dataset is updated
260
+ ...
261
+
262
+ # Start the server manually
263
+ trigger_training.run()
264
+ ```
265
+ """
266
+ if callable(path):
267
+ # If path is a function, it means it was used as a decorator without arguments
268
+ return webhook_endpoint()(path)
269
+
270
+ @wraps(WebhooksServer.add_webhook)
271
+ def _inner(func: Callable) -> Callable:
272
+ app = _get_global_app()
273
+ app.add_webhook(path)(func)
274
+ if len(app.registered_webhooks) == 1:
275
+ # Register `app.run` to run at exit (only once)
276
+ atexit.register(app.run)
277
+
278
+ @wraps(app.run)
279
+ def _run_now():
280
+ # Run the app directly (without waiting atexit)
281
+ atexit.unregister(app.run)
282
+ app.run()
283
+
284
+ func.run = _run_now # type: ignore
285
+ return func
286
+
287
+ return _inner
288
+
289
+
290
+ def _get_global_app() -> WebhooksServer:
291
+ global _global_app
292
+ if _global_app is None:
293
+ _global_app = WebhooksServer()
294
+ return _global_app
295
+
296
+
297
+ def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None:
298
+ if webhook_secret is None:
299
+ print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.")
300
+ print(
301
+ "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
302
+ "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`"
303
+ )
304
+ print(
305
+ "For more details about webhook secrets, please refer to"
306
+ " https://huggingface.co/docs/hub/webhooks#webhook-secret."
307
+ )
308
+ else:
309
+ print("Webhook secret is correctly defined.")
310
+
311
+
312
+ def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str:
313
+ """Returns the anchor to a given webhook in the docs (experimental)"""
314
+ return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post"
315
+
316
+
317
+ def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable:
318
+ """Wraps a webhook function to check the webhook secret before calling the function.
319
+
320
+ This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route
321
+ parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request`
322
+ object (and hence the headers). A far cleaner solution would be to use a middleware. However, since
323
+ `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by
324
+ Gradio internals (and not by us), we cannot add a middleware.
325
+
326
+ This method is called only when a secret has been defined by the user. If a request is sent without the
327
+ "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect,
328
+ the function will return a 403 error (forbidden).
329
+
330
+ Inspired by https://stackoverflow.com/a/33112180.
331
+ """
332
+ initial_sig = inspect.signature(func)
333
+
334
+ @wraps(func)
335
+ async def _protected_func(request: Request, **kwargs):
336
+ request_secret = request.headers.get("x-webhook-secret")
337
+ if request_secret is None:
338
+ return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401)
339
+ if request_secret != webhook_secret:
340
+ return JSONResponse({"error": "Invalid webhook secret."}, status_code=403)
341
+
342
+ # Inject `request` in kwargs if required
343
+ if "request" in initial_sig.parameters:
344
+ kwargs["request"] = request
345
+
346
+ # Handle both sync and async routes
347
+ if inspect.iscoroutinefunction(func):
348
+ return await func(**kwargs)
349
+ else:
350
+ return func(**kwargs)
351
+
352
+ # Update signature to include request
353
+ if "request" not in initial_sig.parameters:
354
+ _protected_func.__signature__ = initial_sig.replace( # type: ignore
355
+ parameters=(
356
+ inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
357
+ )
358
+ + tuple(initial_sig.parameters.values())
359
+ )
360
+
361
+ # Return protected route
362
+ return _protected_func
@@ -23,12 +23,10 @@ import sys
23
23
  from argparse import _SubParsersAction
24
24
  from typing import Dict, List, Optional
25
25
 
26
- import requests
27
-
28
26
  from huggingface_hub.commands import BaseHuggingfaceCLICommand
29
27
  from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND, SliceFileObj
30
28
 
31
- from ..utils import hf_raise_for_status, logging
29
+ from ..utils import get_session, hf_raise_for_status, logging
32
30
 
33
31
 
34
32
  logger = logging.get_logger(__name__)
@@ -172,7 +170,7 @@ class LfsUploadCommand:
172
170
  seek_from=i * chunk_size,
173
171
  read_limit=chunk_size,
174
172
  ) as data:
175
- r = requests.put(presigned_url, data=data)
173
+ r = get_session().put(presigned_url, data=data)
176
174
  hf_raise_for_status(r)
177
175
  parts.append(
178
176
  {
@@ -192,7 +190,7 @@ class LfsUploadCommand:
192
190
  )
193
191
  # Not precise but that's ok.
194
192
 
195
- r = requests.post(
193
+ r = get_session().post(
196
194
  completion_url,
197
195
  json={
198
196
  "oid": oid,
@@ -33,9 +33,6 @@ from .._login import ( # noqa: F401 # for backward compatibility # noqa: F401
33
33
  logout,
34
34
  notebook_login,
35
35
  )
36
- from .._login import (
37
- _currently_setup_credential_helpers as currently_setup_credential_helpers, # noqa: F401 # for backward compatibility
38
- )
39
36
  from ..utils import HfFolder
40
37
  from ._cli_utils import ANSI
41
38
 
@@ -8,6 +8,7 @@ from dataclasses import dataclass
8
8
  from datetime import datetime
9
9
  from typing import List, Optional
10
10
 
11
+ from .constants import REPO_TYPE_MODEL
11
12
  from .utils import parse_datetime
12
13
  from .utils._typing import Literal
13
14
 
@@ -47,6 +48,12 @@ class Discussion:
47
48
  Whether or not this is a Pull Request.
48
49
  created_at (`datetime`):
49
50
  The `datetime` of creation of the Discussion / Pull Request.
51
+ endpoint (`str`):
52
+ Endpoint of the Hub. Default is https://huggingface.co.
53
+ git_reference (`str`, *optional*):
54
+ (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
55
+ url (`str`):
56
+ (property) URL of the discussion on the Hub.
50
57
  """
51
58
 
52
59
  title: str
@@ -57,6 +64,7 @@ class Discussion:
57
64
  author: str
58
65
  is_pull_request: bool
59
66
  created_at: datetime
67
+ endpoint: str
60
68
 
61
69
  @property
62
70
  def git_reference(self) -> Optional[str]:
@@ -68,6 +76,13 @@ class Discussion:
68
76
  return f"refs/pr/{self.num}"
69
77
  return None
70
78
 
79
+ @property
80
+ def url(self) -> str:
81
+ """Returns the URL of the discussion on the Hub."""
82
+ if self.repo_type is None or self.repo_type == REPO_TYPE_MODEL:
83
+ return f"{self.endpoint}/{self.repo_id}/discussions/{self.num}"
84
+ return f"{self.endpoint}/{self.repo_type}/{self.repo_id}/discussions/{self.num}"
85
+
71
86
 
72
87
  @dataclass
73
88
  class DiscussionWithDetails(Discussion):
@@ -112,6 +127,12 @@ class DiscussionWithDetails(Discussion):
112
127
  the merge commit, `None` otherwise.
113
128
  diff (`str`, *optional*):
114
129
  The git diff if this is a Pull Request , `None` otherwise.
130
+ endpoint (`str`):
131
+ Endpoint of the Hub. Default is https://huggingface.co.
132
+ git_reference (`str`, *optional*):
133
+ (property) Git reference to which changes can be pushed if this is a Pull Request, `None` otherwise.
134
+ url (`str`):
135
+ (property) URL of the discussion on the Hub.
115
136
  """
116
137
 
117
138
  events: List["DiscussionEvent"]
@@ -110,6 +110,9 @@ HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = (
110
110
  # Disable warning on machines that do not support symlinks (e.g. Windows non-developer)
111
111
  HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING"))
112
112
 
113
+ # Disable warning when using experimental features
114
+ HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING"))
115
+
113
116
  # Disable sending the cached token by default is all HTTP requests to the Hub
114
117
  HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"))
115
118
 
@@ -477,8 +477,7 @@ def http_get(
477
477
  if HF_HUB_ENABLE_HF_TRANSFER:
478
478
  try:
479
479
  # Download file using an external Rust-based package. Download is faster
480
- # (~2x speed-up) but support less features (no error handling, no retries,
481
- # no progress bars).
480
+ # (~2x speed-up) but support less features (no progress bars).
482
481
  from hf_transfer import download
483
482
 
484
483
  logger.debug(f"Download {url} using HF_TRANSFER.")
@@ -539,6 +538,15 @@ def http_get(
539
538
  if chunk: # filter out keep-alive new chunks
540
539
  progress.update(len(chunk))
541
540
  temp_file.write(chunk)
541
+
542
+ if total is not None and total != temp_file.tell():
543
+ raise EnvironmentError(
544
+ f"Consistency check failed: file should be of size {total} but has size"
545
+ f" {temp_file.tell()} ({displayed_name}).\nWe are sorry for the inconvenience. Please retry download and"
546
+ " pass `force_download=True, resume_download=False` as argument.\nIf the issue persists, please let us"
547
+ " know by opening an issue on https://github.com/huggingface/huggingface_hub."
548
+ )
549
+
542
550
  progress.close()
543
551
 
544
552
 
@@ -670,7 +678,7 @@ def cached_download(
670
678
  timeout=etag_timeout,
671
679
  )
672
680
  hf_raise_for_status(r)
673
- etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
681
+ etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
674
682
  # We favor a custom header indicating the etag of the linked resource, and
675
683
  # we fallback to the regular etag header.
676
684
  # If we don't have any of those, raise an error.
@@ -805,7 +813,8 @@ def _normalize_etag(etag: Optional[str]) -> Optional[str]:
805
813
  ETag: W/"<etag_value>"
806
814
  ETag: "<etag_value>"
807
815
 
808
- The hf.co hub guarantees to only send the second form.
816
+ For now, we only expect the second form from the server, but we want to be future-proof so we support both. For
817
+ more context, see `TestNormalizeEtag` tests and https://github.com/huggingface/huggingface_hub/pull/1428.
809
818
 
810
819
  Args:
811
820
  etag (`str`, *optional*): HTTP header
@@ -816,7 +825,7 @@ def _normalize_etag(etag: Optional[str]) -> Optional[str]:
816
825
  """
817
826
  if etag is None:
818
827
  return None
819
- return etag.strip('"')
828
+ return etag.lstrip("W/").strip('"')
820
829
 
821
830
 
822
831
  def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None:
@@ -870,7 +879,13 @@ def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
870
879
  relative_src = None
871
880
 
872
881
  try:
873
- _support_symlinks = are_symlinks_supported(os.path.dirname(os.path.commonpath([abs_src, abs_dst])))
882
+ try:
883
+ commonpath = os.path.commonpath([abs_src, abs_dst])
884
+ _support_symlinks = are_symlinks_supported(os.path.dirname(commonpath))
885
+ except ValueError:
886
+ # Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos.
887
+ # See https://docs.python.org/3/library/os.path.html#os.path.commonpath
888
+ _support_symlinks = os.name != "nt"
874
889
  except PermissionError:
875
890
  # Permission error means src and dst are not in the same volume (e.g. destination path has been provided
876
891
  # by the user via `local_dir`. Let's test symlink support there)
@@ -892,7 +907,7 @@ def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
892
907
  raise
893
908
  elif new_blob:
894
909
  logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
895
- os.replace(src, dst)
910
+ shutil.move(src, dst)
896
911
  else:
897
912
  logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
898
913
  shutil.copyfile(src, dst)
@@ -1132,11 +1147,17 @@ def hf_hub_download(
1132
1147
 
1133
1148
  # cross platform transcription of filename, to be used as a local file path.
1134
1149
  relative_filename = os.path.join(*filename.split("/"))
1150
+ if os.name == "nt":
1151
+ if relative_filename.startswith("..\\") or "\\..\\" in relative_filename:
1152
+ raise ValueError(
1153
+ f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository"
1154
+ " owner to rename this file."
1155
+ )
1135
1156
 
1136
1157
  # if user provides a commit_hash and they already have the file on disk,
1137
1158
  # shortcut everything.
1138
1159
  if REGEX_COMMIT_HASH.match(revision):
1139
- pointer_path = os.path.join(storage_folder, "snapshots", revision, relative_filename)
1160
+ pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
1140
1161
  if os.path.exists(pointer_path):
1141
1162
  if local_dir is not None:
1142
1163
  return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
@@ -1231,7 +1252,7 @@ def hf_hub_download(
1231
1252
 
1232
1253
  # Return pointer file if exists
1233
1254
  if commit_hash is not None:
1234
- pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename)
1255
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
1235
1256
  if os.path.exists(pointer_path):
1236
1257
  if local_dir is not None:
1237
1258
  return _to_local_dir(
@@ -1260,7 +1281,7 @@ def hf_hub_download(
1260
1281
  assert etag is not None, "etag must have been retrieved from server"
1261
1282
  assert commit_hash is not None, "commit_hash must have been retrieved from server"
1262
1283
  blob_path = os.path.join(storage_folder, "blobs", etag)
1263
- pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename)
1284
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
1264
1285
 
1265
1286
  os.makedirs(os.path.dirname(blob_path), exist_ok=True)
1266
1287
  os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
@@ -1506,8 +1527,8 @@ def get_hf_file_metadata(
1506
1527
  etag=_normalize_etag(
1507
1528
  # We favor a custom header indicating the etag of the linked resource, and
1508
1529
  # we fallback to the regular etag header.
1509
- r.headers.get("ETag")
1510
- or r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG)
1530
+ r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG)
1531
+ or r.headers.get("ETag")
1511
1532
  ),
1512
1533
  # Either from response headers (if redirected) or defaults to request url
1513
1534
  # Do not use directly `url`, as `_request_wrapper` might have followed relative
@@ -1546,7 +1567,20 @@ def _chmod_and_replace(src: str, dst: str) -> None:
1546
1567
  finally:
1547
1568
  tmp_file.unlink()
1548
1569
 
1549
- os.replace(src, dst)
1570
+ shutil.move(src, dst)
1571
+
1572
+
1573
+ def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str:
1574
+ # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
1575
+ snapshot_path = os.path.join(storage_folder, "snapshots")
1576
+ pointer_path = os.path.join(snapshot_path, revision, relative_filename)
1577
+ if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents:
1578
+ raise ValueError(
1579
+ "Invalid pointer path: cannot create pointer path in snapshot folder if"
1580
+ f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and"
1581
+ f" `relative_filename='{relative_filename}'`."
1582
+ )
1583
+ return pointer_path
1550
1584
 
1551
1585
 
1552
1586
  def _to_local_dir(
@@ -1556,7 +1590,14 @@ def _to_local_dir(
1556
1590
 
1557
1591
  Either symlink to blob file in cache or duplicate file depending on `use_symlinks` and file size.
1558
1592
  """
1593
+ # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
1559
1594
  local_dir_filepath = os.path.join(local_dir, relative_filename)
1595
+ if Path(os.path.abspath(local_dir)) not in Path(os.path.abspath(local_dir_filepath)).parents:
1596
+ raise ValueError(
1597
+ f"Cannot copy file '{relative_filename}' to local dir '{local_dir}': file would not be in the local"
1598
+ " directory."
1599
+ )
1600
+
1560
1601
  os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
1561
1602
  real_blob_path = os.path.realpath(path)
1562
1603