scmrepo 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scmrepo might be problematic. Click here for more details.
- scmrepo/base.py +5 -1
- scmrepo/fs.py +4 -3
- scmrepo/git/__init__.py +13 -0
- scmrepo/git/backend/base.py +43 -0
- scmrepo/git/backend/dulwich/__init__.py +67 -1
- scmrepo/git/backend/gitpython.py +56 -1
- scmrepo/git/backend/pygit2/__init__.py +201 -43
- scmrepo/git/backend/pygit2/callbacks.py +10 -1
- scmrepo/git/backend/pygit2/filter.py +65 -0
- scmrepo/git/config.py +35 -0
- scmrepo/git/lfs/__init__.py +8 -0
- scmrepo/git/lfs/client.py +223 -0
- scmrepo/git/lfs/exceptions.py +5 -0
- scmrepo/git/lfs/fetch.py +162 -0
- scmrepo/git/lfs/object.py +15 -0
- scmrepo/git/lfs/pointer.py +109 -0
- scmrepo/git/lfs/progress.py +61 -0
- scmrepo/git/lfs/smudge.py +51 -0
- scmrepo/git/lfs/storage.py +74 -0
- scmrepo/git/objects.py +3 -2
- {scmrepo-1.4.0.dist-info → scmrepo-1.5.0.dist-info}/METADATA +4 -2
- scmrepo-1.5.0.dist-info/RECORD +37 -0
- {scmrepo-1.4.0.dist-info → scmrepo-1.5.0.dist-info}/WHEEL +1 -1
- scmrepo-1.4.0.dist-info/RECORD +0 -26
- {scmrepo-1.4.0.dist-info → scmrepo-1.5.0.dist-info}/LICENSE +0 -0
- {scmrepo-1.4.0.dist-info → scmrepo-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from contextlib import AbstractContextManager
|
|
3
|
+
from functools import wraps
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, Optional
|
|
5
|
+
|
|
6
|
+
import aiohttp
|
|
7
|
+
from dvc_http import HTTPFileSystem
|
|
8
|
+
from dvc_http.retry import ReadOnlyRetryClient
|
|
9
|
+
from dvc_objects.executors import batch_coros
|
|
10
|
+
from dvc_objects.fs import localfs
|
|
11
|
+
from dvc_objects.fs.callbacks import DEFAULT_CALLBACK
|
|
12
|
+
from dvc_objects.fs.utils import as_atomic
|
|
13
|
+
from fsspec.asyn import sync_wrapper
|
|
14
|
+
from funcy import cached_property
|
|
15
|
+
|
|
16
|
+
from ..credentials import Credential, CredentialNotFoundError
|
|
17
|
+
from .exceptions import LFSError
|
|
18
|
+
from .pointer import Pointer
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from dvc_objects.fs.callbacks import Callback
|
|
22
|
+
|
|
23
|
+
from .storage import LFSStorage
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _LFSClient(ReadOnlyRetryClient):
|
|
29
|
+
async def _request(self, *args, **kwargs):
|
|
30
|
+
return await super()._request(*args, **kwargs) # pylint: disable=no-member
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# pylint: disable=abstract-method
|
|
34
|
+
class _LFSFileSystem(HTTPFileSystem):
|
|
35
|
+
def __init__(self, *args, **kwargs):
|
|
36
|
+
super().__init__(*args, **kwargs)
|
|
37
|
+
|
|
38
|
+
def _prepare_credentials(self, **config):
|
|
39
|
+
return {}
|
|
40
|
+
|
|
41
|
+
async def get_client(self, **kwargs):
|
|
42
|
+
from aiohttp_retry import ExponentialRetry
|
|
43
|
+
from dvc_http import make_context
|
|
44
|
+
|
|
45
|
+
kwargs["retry_options"] = ExponentialRetry(
|
|
46
|
+
attempts=self.SESSION_RETRIES,
|
|
47
|
+
factor=self.SESSION_BACKOFF_FACTOR,
|
|
48
|
+
max_timeout=self.REQUEST_TIMEOUT,
|
|
49
|
+
exceptions={aiohttp.ClientError},
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# The default total timeout for an aiohttp request is 300 seconds
|
|
53
|
+
# which is too low for DVC's interactions when dealing with large
|
|
54
|
+
# data blobs. We remove the total timeout, and only limit the time
|
|
55
|
+
# that is spent when connecting to the remote server and waiting
|
|
56
|
+
# for new data portions.
|
|
57
|
+
connect_timeout = kwargs.pop("connect_timeout")
|
|
58
|
+
kwargs["timeout"] = aiohttp.ClientTimeout(
|
|
59
|
+
total=None,
|
|
60
|
+
connect=connect_timeout,
|
|
61
|
+
sock_connect=connect_timeout,
|
|
62
|
+
sock_read=kwargs.pop("read_timeout"),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
kwargs["connector"] = aiohttp.TCPConnector(
|
|
66
|
+
# Force cleanup of closed SSL transports.
|
|
67
|
+
# See https://github.com/iterative/dvc/issues/7414
|
|
68
|
+
enable_cleanup_closed=True,
|
|
69
|
+
ssl=make_context(kwargs.pop("ssl_verify", None)),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return ReadOnlyRetryClient(**kwargs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _authed(f: Callable[..., Awaitable]):
|
|
76
|
+
"""Set credentials and retry the given coroutine if needed."""
|
|
77
|
+
|
|
78
|
+
# pylint: disable=protected-access
|
|
79
|
+
@wraps(f) # type: ignore[arg-type]
|
|
80
|
+
async def wrapper(self, *args, **kwargs):
|
|
81
|
+
try:
|
|
82
|
+
return await f(self, *args, **kwargs)
|
|
83
|
+
except aiohttp.ClientResponseError as exc:
|
|
84
|
+
if exc.status != 401:
|
|
85
|
+
raise
|
|
86
|
+
session = await self._set_session()
|
|
87
|
+
if session.auth:
|
|
88
|
+
raise
|
|
89
|
+
auth = self._get_auth()
|
|
90
|
+
if auth is None:
|
|
91
|
+
raise
|
|
92
|
+
self._session._auth = auth
|
|
93
|
+
return await f(self, *args, **kwargs)
|
|
94
|
+
|
|
95
|
+
return wrapper
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class LFSClient(AbstractContextManager):
|
|
99
|
+
"""Naive read-only LFS HTTP client."""
|
|
100
|
+
|
|
101
|
+
JSON_CONTENT_TYPE = "application/vnd.git-lfs+json"
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
url: str,
|
|
106
|
+
git_url: Optional[str] = None,
|
|
107
|
+
headers: Optional[Dict[str, str]] = None,
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
Args:
|
|
111
|
+
url: LFS server URL.
|
|
112
|
+
"""
|
|
113
|
+
self.url = url
|
|
114
|
+
self.git_url = git_url
|
|
115
|
+
self.headers: Dict[str, str] = headers or {}
|
|
116
|
+
|
|
117
|
+
def __exit__(self, *args, **kwargs):
|
|
118
|
+
self.close()
|
|
119
|
+
|
|
120
|
+
@cached_property
|
|
121
|
+
def fs(self) -> "_LFSFileSystem":
|
|
122
|
+
return _LFSFileSystem()
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def httpfs(self) -> "HTTPFileSystem":
|
|
126
|
+
return self.fs.fs
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def loop(self):
|
|
130
|
+
return self.httpfs.loop
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_git_url(cls, git_url: str) -> "LFSClient":
|
|
134
|
+
if git_url.endswith(".git"):
|
|
135
|
+
url = f"{git_url}/info/lfs"
|
|
136
|
+
else:
|
|
137
|
+
url = f"{git_url}.git/info/lfs"
|
|
138
|
+
return cls(url, git_url=git_url)
|
|
139
|
+
|
|
140
|
+
def close(self):
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
def _get_auth(self) -> Optional[aiohttp.BasicAuth]:
|
|
144
|
+
try:
|
|
145
|
+
creds = Credential(url=self.git_url).fill()
|
|
146
|
+
if creds.username and creds.password:
|
|
147
|
+
return aiohttp.BasicAuth(creds.username, creds.password)
|
|
148
|
+
except CredentialNotFoundError:
|
|
149
|
+
pass
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
async def _set_session(self) -> aiohttp.ClientSession:
|
|
153
|
+
return await self.fs.fs.set_session()
|
|
154
|
+
|
|
155
|
+
@_authed
|
|
156
|
+
async def _batch_request(
|
|
157
|
+
self,
|
|
158
|
+
objects: Iterable[Pointer],
|
|
159
|
+
upload: bool = False,
|
|
160
|
+
ref: Optional[str] = None,
|
|
161
|
+
hash_algo: str = "sha256",
|
|
162
|
+
) -> Dict[str, Any]:
|
|
163
|
+
"""Send LFS API /objects/batch request."""
|
|
164
|
+
url = f"{self.url}/objects/batch"
|
|
165
|
+
body: Dict[str, Any] = {
|
|
166
|
+
"operation": "upload" if upload else "download",
|
|
167
|
+
"transfers": ["basic"],
|
|
168
|
+
"objects": [{"oid": obj.oid, "size": obj.size} for obj in objects],
|
|
169
|
+
"hash_algo": hash_algo,
|
|
170
|
+
}
|
|
171
|
+
if ref:
|
|
172
|
+
body["ref"] = [{"name": ref}]
|
|
173
|
+
session = await self._set_session()
|
|
174
|
+
headers = dict(self.headers)
|
|
175
|
+
headers["Content-Type"] = self.JSON_CONTENT_TYPE
|
|
176
|
+
async with session.post(
|
|
177
|
+
url,
|
|
178
|
+
headers=headers,
|
|
179
|
+
json=body,
|
|
180
|
+
) as resp:
|
|
181
|
+
data = await resp.json()
|
|
182
|
+
return data
|
|
183
|
+
|
|
184
|
+
@_authed
|
|
185
|
+
async def _download(
|
|
186
|
+
self,
|
|
187
|
+
storage: "LFSStorage",
|
|
188
|
+
objects: Iterable[Pointer],
|
|
189
|
+
callback: "Callback" = DEFAULT_CALLBACK,
|
|
190
|
+
**kwargs,
|
|
191
|
+
):
|
|
192
|
+
async def _get_one(from_path: str, to_path: str, **kwargs):
|
|
193
|
+
get_coro = callback.wrap_and_branch_coro(
|
|
194
|
+
self.httpfs._get_file # pylint: disable=protected-access
|
|
195
|
+
)
|
|
196
|
+
with as_atomic(localfs, to_path, create_parents=True) as tmp_file:
|
|
197
|
+
return await get_coro(
|
|
198
|
+
from_path,
|
|
199
|
+
tmp_file,
|
|
200
|
+
**kwargs,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
resp_data = await self._batch_request(objects, **kwargs)
|
|
204
|
+
if resp_data.get("transfer") != "basic":
|
|
205
|
+
raise LFSError("Unsupported LFS transfer type")
|
|
206
|
+
coros = []
|
|
207
|
+
for data in resp_data.get("objects", []):
|
|
208
|
+
obj = Pointer(data["oid"], data["size"])
|
|
209
|
+
download = data.get("actions", {}).get("download", {})
|
|
210
|
+
url = download.get("href")
|
|
211
|
+
if not url:
|
|
212
|
+
logger.debug("No download URL for LFS object '%s'", obj)
|
|
213
|
+
continue
|
|
214
|
+
headers = download.get("header", {})
|
|
215
|
+
to_path = storage.oid_to_path(obj.oid)
|
|
216
|
+
coros.append(_get_one(url, to_path, headers=headers))
|
|
217
|
+
for result in await batch_coros(
|
|
218
|
+
coros, batch_size=self.fs.jobs, return_exceptions=True
|
|
219
|
+
):
|
|
220
|
+
if isinstance(result, BaseException):
|
|
221
|
+
raise result
|
|
222
|
+
|
|
223
|
+
download = sync_wrapper(_download)
|
scmrepo/git/lfs/fetch.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import fnmatch
|
|
2
|
+
import io
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Set
|
|
5
|
+
|
|
6
|
+
from scmrepo.exceptions import InvalidRemote, SCMError
|
|
7
|
+
|
|
8
|
+
from .pointer import HEADERS, Pointer
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from scmrepo.git import Git
|
|
12
|
+
from scmrepo.git.config import Config
|
|
13
|
+
from scmrepo.progress import GitProgressEvent
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def fetch(
|
|
17
|
+
scm: "Git",
|
|
18
|
+
revs: Optional[List[str]] = None,
|
|
19
|
+
remote: Optional[str] = None,
|
|
20
|
+
include: Optional[List[str]] = None,
|
|
21
|
+
exclude: Optional[List[str]] = None,
|
|
22
|
+
progress: Optional[Callable[["GitProgressEvent"], None]] = None,
|
|
23
|
+
):
|
|
24
|
+
# NOTE: This currently does not support fetching objects from the worktree
|
|
25
|
+
if not revs:
|
|
26
|
+
revs = ["HEAD"]
|
|
27
|
+
objects: Set[Pointer] = set()
|
|
28
|
+
for rev in revs:
|
|
29
|
+
objects.update(
|
|
30
|
+
pointer
|
|
31
|
+
for pointer in _collect_objects(scm, rev, include, exclude)
|
|
32
|
+
if not scm.lfs_storage.exists(pointer)
|
|
33
|
+
)
|
|
34
|
+
if not objects:
|
|
35
|
+
return
|
|
36
|
+
try:
|
|
37
|
+
url = get_fetch_url(scm, remote=remote)
|
|
38
|
+
except InvalidRemote:
|
|
39
|
+
if remote:
|
|
40
|
+
# treat remote as a raw Git remote
|
|
41
|
+
url = remote
|
|
42
|
+
else:
|
|
43
|
+
raise
|
|
44
|
+
scm.lfs_storage.fetch(url, objects, progress=progress)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_fetch_url(scm: "Git", remote: Optional[str] = None): # noqa: C901
|
|
48
|
+
"""Return LFS fetch URL for the specified repository."""
|
|
49
|
+
git_config = scm.get_config()
|
|
50
|
+
|
|
51
|
+
# check lfs.url (can be set in git config and .lfsconfig)
|
|
52
|
+
try:
|
|
53
|
+
return git_config.get(("lfs",), "url")
|
|
54
|
+
except KeyError:
|
|
55
|
+
pass
|
|
56
|
+
try:
|
|
57
|
+
lfs_config: Optional["Config"] = scm.get_config(
|
|
58
|
+
os.path.join(scm.root_dir, ".lfsconfig")
|
|
59
|
+
)
|
|
60
|
+
except FileNotFoundError:
|
|
61
|
+
lfs_config = None
|
|
62
|
+
if lfs_config:
|
|
63
|
+
try:
|
|
64
|
+
return lfs_config.get(("lfs",), "url")
|
|
65
|
+
except KeyError:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
# use:
|
|
69
|
+
# current tracking-branch remote
|
|
70
|
+
# or remote.lfsdefault (can only be set in git config)
|
|
71
|
+
# or "origin"
|
|
72
|
+
# in that order
|
|
73
|
+
if not remote:
|
|
74
|
+
try:
|
|
75
|
+
remote = scm.active_branch_remote()
|
|
76
|
+
except SCMError:
|
|
77
|
+
pass
|
|
78
|
+
if not remote:
|
|
79
|
+
try:
|
|
80
|
+
remote = git_config.get(("remote",), "lfsdefault")
|
|
81
|
+
except KeyError:
|
|
82
|
+
remote = "origin"
|
|
83
|
+
|
|
84
|
+
# check remote.*.lfsurl (can be set in git config and .lfsconfig)
|
|
85
|
+
assert remote is not None
|
|
86
|
+
try:
|
|
87
|
+
return git_config.get(("remote", remote), "lfsurl")
|
|
88
|
+
except KeyError:
|
|
89
|
+
pass
|
|
90
|
+
if lfs_config:
|
|
91
|
+
try:
|
|
92
|
+
return lfs_config.get(("remote", remote), "lfsurl")
|
|
93
|
+
except KeyError:
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
# return default Git fetch URL for this remote
|
|
97
|
+
return scm.get_remote_url(remote)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _collect_objects(
|
|
101
|
+
scm: "Git",
|
|
102
|
+
rev: str,
|
|
103
|
+
include: Optional[List[str]],
|
|
104
|
+
exclude: Optional[List[str]],
|
|
105
|
+
) -> Iterator[Pointer]:
|
|
106
|
+
fs = scm.get_fs(rev)
|
|
107
|
+
for path in _filter_paths(fs.find("/"), include, exclude):
|
|
108
|
+
check_path = path.lstrip("/")
|
|
109
|
+
if scm.check_attr(check_path, "filter", source=rev) == "lfs":
|
|
110
|
+
try:
|
|
111
|
+
with fs.open(path, "rb", raw=True) as fobj:
|
|
112
|
+
with io.BufferedReader(fobj) as reader:
|
|
113
|
+
data = reader.peek(100)
|
|
114
|
+
if any(data.startswith(header) for header in HEADERS):
|
|
115
|
+
yield Pointer.load(reader)
|
|
116
|
+
except (ValueError, OSError):
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _filter_paths(
|
|
121
|
+
paths: Iterable[str], include: Optional[List[str]], exclude: Optional[List[str]]
|
|
122
|
+
) -> Iterator[str]:
|
|
123
|
+
filtered = set()
|
|
124
|
+
if include:
|
|
125
|
+
for pattern in include:
|
|
126
|
+
filtered.update(fnmatch.filter(paths, pattern))
|
|
127
|
+
else:
|
|
128
|
+
filtered.update(paths)
|
|
129
|
+
if exclude:
|
|
130
|
+
for pattern in exclude:
|
|
131
|
+
filtered.difference_update(fnmatch.filter(paths, pattern))
|
|
132
|
+
yield from filtered
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
if __name__ == "__main__":
|
|
136
|
+
# Minimal `git lfs fetch` CLI implementation
|
|
137
|
+
import argparse
|
|
138
|
+
import sys
|
|
139
|
+
|
|
140
|
+
from scmrepo.git import Git # noqa: F811
|
|
141
|
+
|
|
142
|
+
parser = argparse.ArgumentParser(
|
|
143
|
+
description=(
|
|
144
|
+
"Download Git LFS objects at the given refs from the specified remote."
|
|
145
|
+
),
|
|
146
|
+
)
|
|
147
|
+
parser.add_argument(
|
|
148
|
+
"remote",
|
|
149
|
+
nargs="?",
|
|
150
|
+
default="origin",
|
|
151
|
+
help="Remote to fetch from. Defaults to 'origin'.",
|
|
152
|
+
)
|
|
153
|
+
parser.add_argument(
|
|
154
|
+
"refs",
|
|
155
|
+
nargs="*",
|
|
156
|
+
default=["HEAD"],
|
|
157
|
+
help="Refs or commits to fetch. Defaults to 'HEAD'.",
|
|
158
|
+
)
|
|
159
|
+
args = parser.parse_args()
|
|
160
|
+
with Git(".") as scm_: # pylint: disable=E0601
|
|
161
|
+
print("fetch: fetching reference", ", ".join(args.refs), file=sys.stderr)
|
|
162
|
+
fetch(scm_, revs=args.refs, remote=args.remote)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from dataclasses import dataclass, fields
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass(frozen=True)
|
|
6
|
+
class LFSObject:
|
|
7
|
+
oid: str
|
|
8
|
+
size: int
|
|
9
|
+
|
|
10
|
+
def __str__(self) -> str:
|
|
11
|
+
return self.oid
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def from_dict(cls, d: Dict[str, Any]) -> "LFSObject":
|
|
15
|
+
return cls(**{k: v for k, v in d.items() if k in fields(cls)})
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import io
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import IO, BinaryIO, TextIO, Tuple
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
LFS_VERSION = "https://git-lfs.github.com/spec/v1"
|
|
11
|
+
LEGACY_LFS_VERSION = "https://hawser.github.com/spec/v1"
|
|
12
|
+
ALLOWED_VERSIONS = (LFS_VERSION, LEGACY_LFS_VERSION)
|
|
13
|
+
HEADERS = [(b"version " + version.encode("utf-8")) for version in ALLOWED_VERSIONS]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_kv(line: str) -> Tuple[str, str]:
|
|
17
|
+
key, value = line.strip().split(maxsplit=1)
|
|
18
|
+
return key, value
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class Pointer:
|
|
23
|
+
oid: str
|
|
24
|
+
size: int
|
|
25
|
+
|
|
26
|
+
def __init__(self, oid: str, size: int, **kwargs):
|
|
27
|
+
self.oid = oid
|
|
28
|
+
self.size = size
|
|
29
|
+
self._dict = kwargs
|
|
30
|
+
|
|
31
|
+
def __hash__(self):
|
|
32
|
+
return hash(self.dump())
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def build(cls, fobj: BinaryIO) -> "Pointer":
|
|
36
|
+
m = hashlib.sha256()
|
|
37
|
+
data = fobj.read()
|
|
38
|
+
m.update(data)
|
|
39
|
+
return cls(m.hexdigest(), len(data))
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def load(cls, fobj: IO) -> "Pointer":
|
|
43
|
+
"""Load the specified pointer file."""
|
|
44
|
+
|
|
45
|
+
if isinstance(fobj, io.TextIOBase): # type: ignore[unreachable]
|
|
46
|
+
text_obj: TextIO = fobj # type: ignore[unreachable]
|
|
47
|
+
|
|
48
|
+
else:
|
|
49
|
+
text_obj = io.TextIOWrapper(fobj, encoding="utf-8")
|
|
50
|
+
|
|
51
|
+
cls.check_version(text_obj.readline())
|
|
52
|
+
d = dict(_get_kv(line) for line in text_obj.readlines())
|
|
53
|
+
try:
|
|
54
|
+
value = d.pop("oid")
|
|
55
|
+
hash_method, oid = value.split(":", maxsplit=1)
|
|
56
|
+
if hash_method != "sha256":
|
|
57
|
+
raise ValueError("Invalid LFS hash method '{hash_method}'")
|
|
58
|
+
except ValueError as e:
|
|
59
|
+
raise ValueError("Invalid LFS pointer oid") from e
|
|
60
|
+
try:
|
|
61
|
+
value = d.pop("size")
|
|
62
|
+
size = int(value)
|
|
63
|
+
except ValueError as e:
|
|
64
|
+
raise ValueError("Invalid LFS pointer size") from e
|
|
65
|
+
|
|
66
|
+
return cls(oid, size, **d)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def check_version(line: str):
|
|
70
|
+
try:
|
|
71
|
+
key, value = _get_kv(line)
|
|
72
|
+
if key != "version":
|
|
73
|
+
raise ValueError("LFS pointer file must start with 'version'")
|
|
74
|
+
if value not in ALLOWED_VERSIONS:
|
|
75
|
+
raise ValueError(f"Unsupported LFS pointer version '{value}'")
|
|
76
|
+
except (ValueError, OSError) as e:
|
|
77
|
+
raise ValueError("Invalid LFS pointer file") from e
|
|
78
|
+
|
|
79
|
+
def dump(self) -> str:
|
|
80
|
+
d = {
|
|
81
|
+
"oid": f"sha256:{self.oid}",
|
|
82
|
+
"size": self.size,
|
|
83
|
+
}
|
|
84
|
+
d.update(self._dict)
|
|
85
|
+
return "\n".join(
|
|
86
|
+
[f"version {LFS_VERSION}"] + [f"{key} {d[key]}" for key in sorted(d)] + [""]
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def to_bytes(self) -> bytes:
|
|
90
|
+
return self.dump().encode("utf-8")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
if __name__ == "__main__":
|
|
94
|
+
# Minimal `git lfs pointer` CLI implementation
|
|
95
|
+
import argparse
|
|
96
|
+
import sys
|
|
97
|
+
|
|
98
|
+
parser = argparse.ArgumentParser(
|
|
99
|
+
description="Build generated pointer files.",
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument("--file", help="A local file to build the pointer from.")
|
|
102
|
+
args = parser.parse_args()
|
|
103
|
+
if not args.file:
|
|
104
|
+
sys.exit("Nothing to do")
|
|
105
|
+
|
|
106
|
+
print(f"Git LFS pointer for {args.file}\n", file=sys.stderr)
|
|
107
|
+
with open(args.file, mode="rb") as fobj_:
|
|
108
|
+
p = Pointer.build(fobj_)
|
|
109
|
+
print(p.dump(), end="")
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Any, BinaryIO, Callable, Dict, Optional, Union
|
|
2
|
+
|
|
3
|
+
from dvc_objects.fs.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
|
|
4
|
+
|
|
5
|
+
from scmrepo.progress import GitProgressEvent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LFSCallback(Callback):
|
|
9
|
+
"""Callback subclass to generate Git/LFS style progress."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
*args,
|
|
14
|
+
git_progress: Optional[Callable[[GitProgressEvent], None]] = None,
|
|
15
|
+
direction: str = "Downloading",
|
|
16
|
+
**kwargs,
|
|
17
|
+
):
|
|
18
|
+
super().__init__(*args, **kwargs)
|
|
19
|
+
self.direction = direction
|
|
20
|
+
self.git_progress = git_progress
|
|
21
|
+
|
|
22
|
+
def call(self, *args, **kwargs):
|
|
23
|
+
super().call(*args, **kwargs)
|
|
24
|
+
self._update_git()
|
|
25
|
+
|
|
26
|
+
def _update_git(self):
|
|
27
|
+
if not self.git_progress:
|
|
28
|
+
return
|
|
29
|
+
event = GitProgressEvent(
|
|
30
|
+
phase=f"{self.direction} LFS objects",
|
|
31
|
+
completed=self.value,
|
|
32
|
+
total=self.size,
|
|
33
|
+
)
|
|
34
|
+
self.git_progress(event)
|
|
35
|
+
|
|
36
|
+
def branch(
|
|
37
|
+
self,
|
|
38
|
+
path_1: Union[str, BinaryIO],
|
|
39
|
+
path_2: str,
|
|
40
|
+
kwargs: Dict[str, Any],
|
|
41
|
+
child: Optional[Callback] = None,
|
|
42
|
+
):
|
|
43
|
+
if child:
|
|
44
|
+
pass
|
|
45
|
+
elif self.git_progress:
|
|
46
|
+
child = TqdmCallback(
|
|
47
|
+
bytes=True, desc=path_1 if isinstance(path_1, str) else path_2
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
child = DEFAULT_CALLBACK
|
|
51
|
+
return super().branch(path_1, path_2, kwargs, child=child)
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def as_lfs_callback(
|
|
55
|
+
cls,
|
|
56
|
+
git_progress: Optional[Callable[[GitProgressEvent], None]] = None,
|
|
57
|
+
**kwargs,
|
|
58
|
+
):
|
|
59
|
+
if git_progress is None:
|
|
60
|
+
return DEFAULT_CALLBACK
|
|
61
|
+
return cls(git_progress=git_progress, **kwargs)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import logging
|
|
3
|
+
from typing import TYPE_CHECKING, BinaryIO, Optional
|
|
4
|
+
|
|
5
|
+
from .pointer import HEADERS, Pointer
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from .storage import LFSStorage
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def smudge(
|
|
14
|
+
storage: "LFSStorage", fobj: BinaryIO, url: Optional[str] = None
|
|
15
|
+
) -> BinaryIO:
|
|
16
|
+
"""Wrap the specified binary IO stream and run LFS smudge if necessary."""
|
|
17
|
+
reader = io.BufferedReader(fobj) # type: ignore[arg-type]
|
|
18
|
+
data = reader.peek(100)
|
|
19
|
+
if any(data.startswith(header) for header in HEADERS):
|
|
20
|
+
# read the pointer data into memory since the raw stream is unseekable
|
|
21
|
+
# and we may need to return it in fallback case
|
|
22
|
+
data = reader.read()
|
|
23
|
+
lfs_obj: Optional[BinaryIO] = None
|
|
24
|
+
try:
|
|
25
|
+
pointer = Pointer.load(io.BytesIO(data))
|
|
26
|
+
lfs_obj = storage.open(pointer, mode="rb", fetch_url=url)
|
|
27
|
+
except (ValueError, OSError):
|
|
28
|
+
logger.warning("Could not open LFS object, falling back to raw pointer")
|
|
29
|
+
if lfs_obj:
|
|
30
|
+
return lfs_obj
|
|
31
|
+
return io.BytesIO(data)
|
|
32
|
+
return reader
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
if __name__ == "__main__":
|
|
36
|
+
# Minimal `git lfs smudge` CLI implementation
|
|
37
|
+
import sys
|
|
38
|
+
|
|
39
|
+
from scmrepo.git import Git
|
|
40
|
+
|
|
41
|
+
if sys.stdin.isatty():
|
|
42
|
+
sys.exit(
|
|
43
|
+
"Cannot read from STDIN: "
|
|
44
|
+
"This command should be run by the Git 'smudge' filter"
|
|
45
|
+
)
|
|
46
|
+
scm = Git()
|
|
47
|
+
try:
|
|
48
|
+
with smudge(scm.lfs_storage, sys.stdin.buffer) as fobj_:
|
|
49
|
+
sys.stdout.buffer.write(fobj_.read())
|
|
50
|
+
finally:
|
|
51
|
+
scm.close()
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import errno
|
|
2
|
+
import os
|
|
3
|
+
from typing import TYPE_CHECKING, BinaryIO, Callable, Collection, Optional, Union
|
|
4
|
+
|
|
5
|
+
from .pointer import Pointer
|
|
6
|
+
from .progress import LFSCallback
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from scmrepo.git import Git
|
|
10
|
+
from scmrepo.progress import GitProgressEvent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LFSStorage:
|
|
14
|
+
def __init__(self, path: str):
|
|
15
|
+
self.path = path
|
|
16
|
+
|
|
17
|
+
def fetch(
|
|
18
|
+
self,
|
|
19
|
+
url: str,
|
|
20
|
+
objects: Collection[Pointer],
|
|
21
|
+
progress: Optional[Callable[["GitProgressEvent"], None]] = None,
|
|
22
|
+
):
|
|
23
|
+
from .client import LFSClient
|
|
24
|
+
|
|
25
|
+
with LFSCallback.as_lfs_callback(progress) as cb:
|
|
26
|
+
cb.set_size(len(objects))
|
|
27
|
+
with LFSClient.from_git_url(url) as client:
|
|
28
|
+
client.download(self, objects, callback=cb)
|
|
29
|
+
|
|
30
|
+
def oid_to_path(self, oid: str):
|
|
31
|
+
return os.path.join(self.path, "objects", oid[0:2], oid[2:4], oid)
|
|
32
|
+
|
|
33
|
+
def exists(self, obj: Union[Pointer, str]):
|
|
34
|
+
oid = obj if isinstance(obj, str) else obj.oid
|
|
35
|
+
path = self.oid_to_path(oid)
|
|
36
|
+
return os.path.exists(path)
|
|
37
|
+
|
|
38
|
+
def open(
|
|
39
|
+
self,
|
|
40
|
+
obj: Union[Pointer, str],
|
|
41
|
+
fetch_url: Optional[str] = None,
|
|
42
|
+
**kwargs,
|
|
43
|
+
) -> BinaryIO:
|
|
44
|
+
oid = obj if isinstance(obj, str) else obj.oid
|
|
45
|
+
path = self.oid_to_path(oid)
|
|
46
|
+
try:
|
|
47
|
+
return open(path, **kwargs) # pylint: disable=unspecified-encoding
|
|
48
|
+
except FileNotFoundError:
|
|
49
|
+
if not fetch_url or not isinstance(obj, Pointer):
|
|
50
|
+
raise
|
|
51
|
+
try:
|
|
52
|
+
self.fetch(fetch_url, [obj])
|
|
53
|
+
except BaseException as exc:
|
|
54
|
+
raise FileNotFoundError(
|
|
55
|
+
errno.ENOENT, os.strerror(errno.ENOENT), path
|
|
56
|
+
) from exc
|
|
57
|
+
return open(path, **kwargs) # pylint: disable=unspecified-encoding
|
|
58
|
+
|
|
59
|
+
def close(self):
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_storage_path(scm: "Git") -> str:
|
|
64
|
+
"""Return the LFS storage directory for the specified repository."""
|
|
65
|
+
|
|
66
|
+
config = scm.get_config()
|
|
67
|
+
git_dir = scm._get_git_dir(scm.root_dir) # pylint: disable=protected-access
|
|
68
|
+
try:
|
|
69
|
+
path = config.get(("lfs",), "storage")
|
|
70
|
+
if os.path.isabs(path):
|
|
71
|
+
return path
|
|
72
|
+
except KeyError:
|
|
73
|
+
path = "lfs"
|
|
74
|
+
return os.path.join(git_dir, path)
|