anemoi-utils 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of anemoi-utils might be problematic. Click here for more details.
- anemoi/utils/__main__.py +2 -3
- anemoi/utils/_version.py +2 -2
- anemoi/utils/checkpoints.py +2 -2
- anemoi/utils/commands/__init__.py +2 -3
- anemoi/utils/commands/config.py +0 -1
- anemoi/utils/compatibility.py +76 -0
- anemoi/utils/hindcasts.py +9 -9
- anemoi/utils/mars/__init__.py +3 -1
- anemoi/utils/registry.py +52 -4
- anemoi/utils/remote/__init__.py +328 -0
- anemoi/utils/remote/s3.py +386 -0
- anemoi/utils/remote/ssh.py +133 -0
- anemoi/utils/s3.py +47 -544
- {anemoi_utils-0.4.4.dist-info → anemoi_utils-0.4.6.dist-info}/METADATA +2 -1
- anemoi_utils-0.4.6.dist-info/RECORD +32 -0
- {anemoi_utils-0.4.4.dist-info → anemoi_utils-0.4.6.dist-info}/WHEEL +1 -1
- anemoi_utils-0.4.4.dist-info/RECORD +0 -28
- {anemoi_utils-0.4.4.dist-info → anemoi_utils-0.4.6.dist-info}/LICENSE +0 -0
- {anemoi_utils-0.4.4.dist-info → anemoi_utils-0.4.6.dist-info}/entry_points.txt +0 -0
- {anemoi_utils-0.4.4.dist-info → anemoi_utils-0.4.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
"""This module provides functions to upload, download, list and delete files and folders on S3.
|
|
9
|
+
The functions of this package expect that the AWS credentials are set up in the environment
|
|
10
|
+
typicaly by setting the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables or
|
|
11
|
+
by creating a `~/.aws/credentials` file. It is also possible to set the `endpoint_url` in the same file
|
|
12
|
+
to use a different S3 compatible service::
|
|
13
|
+
|
|
14
|
+
[default]
|
|
15
|
+
endpoint_url = https://some-storage.somewhere.world
|
|
16
|
+
aws_access_key_id = xxxxxxxxxxxxxxxxxxxxxxxx
|
|
17
|
+
aws_secret_access_key = xxxxxxxxxxxxxxxxxxxxxxxx
|
|
18
|
+
|
|
19
|
+
Alternatively, the `endpoint_url`, and keys can be set in one of
|
|
20
|
+
the `~/.config/anemoi/settings.toml`
|
|
21
|
+
or `~/.config/anemoi/settings-secrets.toml` files.
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
import threading
|
|
28
|
+
from copy import deepcopy
|
|
29
|
+
from typing import Iterable
|
|
30
|
+
|
|
31
|
+
import tqdm
|
|
32
|
+
|
|
33
|
+
from ..config import load_config
|
|
34
|
+
from ..humanize import bytes_to_human
|
|
35
|
+
from . import BaseDownload
|
|
36
|
+
from . import BaseUpload
|
|
37
|
+
|
|
38
|
+
LOGGER = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# s3_clients are not thread-safe, so we need to create a new client for each thread
|
|
42
|
+
|
|
43
|
+
thread_local = threading.local()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def s3_client(bucket, region=None):
|
|
47
|
+
import boto3
|
|
48
|
+
from botocore import UNSIGNED
|
|
49
|
+
from botocore.client import Config
|
|
50
|
+
|
|
51
|
+
if not hasattr(thread_local, "s3_clients"):
|
|
52
|
+
thread_local.s3_clients = {}
|
|
53
|
+
|
|
54
|
+
key = f"{bucket}-{region}"
|
|
55
|
+
|
|
56
|
+
boto3_config = dict(max_pool_connections=25)
|
|
57
|
+
|
|
58
|
+
if key in thread_local.s3_clients:
|
|
59
|
+
return thread_local.s3_clients[key]
|
|
60
|
+
|
|
61
|
+
boto3_config = dict(max_pool_connections=25)
|
|
62
|
+
|
|
63
|
+
if region:
|
|
64
|
+
# This is using AWS
|
|
65
|
+
|
|
66
|
+
options = {"region_name": region}
|
|
67
|
+
|
|
68
|
+
# Anonymous access
|
|
69
|
+
if not (
|
|
70
|
+
os.path.exists(os.path.expanduser("~/.aws/credentials"))
|
|
71
|
+
or ("AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ)
|
|
72
|
+
):
|
|
73
|
+
boto3_config["signature_version"] = UNSIGNED
|
|
74
|
+
|
|
75
|
+
else:
|
|
76
|
+
|
|
77
|
+
# We may be accessing a different S3 compatible service
|
|
78
|
+
# Use anemoi.config to get the configuration
|
|
79
|
+
|
|
80
|
+
options = {}
|
|
81
|
+
config = load_config(secrets=["aws_access_key_id", "aws_secret_access_key"])
|
|
82
|
+
|
|
83
|
+
cfg = config.get("object-storage", {})
|
|
84
|
+
for k, v in cfg.items():
|
|
85
|
+
if isinstance(v, (str, int, float, bool)):
|
|
86
|
+
options[k] = v
|
|
87
|
+
|
|
88
|
+
for k, v in cfg.get(bucket, {}).items():
|
|
89
|
+
if isinstance(v, (str, int, float, bool)):
|
|
90
|
+
options[k] = v
|
|
91
|
+
|
|
92
|
+
type = options.pop("type", "s3")
|
|
93
|
+
if type != "s3":
|
|
94
|
+
raise ValueError(f"Unsupported object storage type {type}")
|
|
95
|
+
|
|
96
|
+
if "config" in options:
|
|
97
|
+
boto3_config.update(options["config"])
|
|
98
|
+
del options["config"]
|
|
99
|
+
from botocore.client import Config
|
|
100
|
+
|
|
101
|
+
options["config"] = Config(**boto3_config)
|
|
102
|
+
|
|
103
|
+
thread_local.s3_clients[key] = boto3.client("s3", **options)
|
|
104
|
+
|
|
105
|
+
return thread_local.s3_clients[key]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class S3Upload(BaseUpload):
|
|
109
|
+
|
|
110
|
+
def get_temporary_target(self, target, pattern):
|
|
111
|
+
return target
|
|
112
|
+
|
|
113
|
+
def rename_target(self, target, temporary_target):
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
def delete_target(self, target):
|
|
117
|
+
pass
|
|
118
|
+
# delete(target)
|
|
119
|
+
|
|
120
|
+
def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None):
|
|
121
|
+
|
|
122
|
+
from botocore.exceptions import ClientError
|
|
123
|
+
|
|
124
|
+
assert target.startswith("s3://")
|
|
125
|
+
|
|
126
|
+
_, _, bucket, key = target.split("/", 3)
|
|
127
|
+
s3 = s3_client(bucket)
|
|
128
|
+
|
|
129
|
+
size = os.path.getsize(source)
|
|
130
|
+
|
|
131
|
+
if verbosity > 0:
|
|
132
|
+
LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})")
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
results = s3.head_object(Bucket=bucket, Key=key)
|
|
136
|
+
remote_size = int(results["ContentLength"])
|
|
137
|
+
except ClientError as e:
|
|
138
|
+
if e.response["Error"]["Code"] != "404":
|
|
139
|
+
raise
|
|
140
|
+
remote_size = None
|
|
141
|
+
|
|
142
|
+
if remote_size is not None:
|
|
143
|
+
if remote_size != size:
|
|
144
|
+
LOGGER.warning(
|
|
145
|
+
f"{target} already exists, but with different size, re-uploading (remote={remote_size}, local={size})"
|
|
146
|
+
)
|
|
147
|
+
elif resume:
|
|
148
|
+
# LOGGER.info(f"{target} already exists, skipping")
|
|
149
|
+
return size
|
|
150
|
+
|
|
151
|
+
if remote_size is not None and not overwrite and not resume:
|
|
152
|
+
raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip")
|
|
153
|
+
|
|
154
|
+
if verbosity > 0:
|
|
155
|
+
with tqdm.tqdm(total=size, unit="B", unit_scale=True, unit_divisor=1024, leave=False) as pbar:
|
|
156
|
+
s3.upload_file(source, bucket, key, Callback=lambda x: pbar.update(x), Config=config)
|
|
157
|
+
else:
|
|
158
|
+
s3.upload_file(source, bucket, key, Config=config)
|
|
159
|
+
|
|
160
|
+
return size
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class S3Download(BaseDownload):
|
|
164
|
+
|
|
165
|
+
def copy(self, source, target, **kwargs):
|
|
166
|
+
assert source.startswith("s3://")
|
|
167
|
+
|
|
168
|
+
if source.endswith("/"):
|
|
169
|
+
self.transfer_folder(source=source, target=target, **kwargs)
|
|
170
|
+
else:
|
|
171
|
+
self.transfer_file(source=source, target=target, **kwargs)
|
|
172
|
+
|
|
173
|
+
def list_source(self, source):
|
|
174
|
+
yield from _list_objects(source)
|
|
175
|
+
|
|
176
|
+
def source_path(self, s3_object, source):
|
|
177
|
+
_, _, bucket, _ = source.split("/", 3)
|
|
178
|
+
return f"s3://{bucket}/{s3_object['Key']}"
|
|
179
|
+
|
|
180
|
+
def target_path(self, s3_object, source, target):
|
|
181
|
+
_, _, _, folder = source.split("/", 3)
|
|
182
|
+
local_path = os.path.join(target, os.path.relpath(s3_object["Key"], folder))
|
|
183
|
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
184
|
+
return local_path
|
|
185
|
+
|
|
186
|
+
def source_size(self, s3_object):
|
|
187
|
+
return s3_object["Size"]
|
|
188
|
+
|
|
189
|
+
def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None):
|
|
190
|
+
# from boto3.s3.transfer import TransferConfig
|
|
191
|
+
|
|
192
|
+
_, _, bucket, key = source.split("/", 3)
|
|
193
|
+
s3 = s3_client(bucket)
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
response = s3.head_object(Bucket=bucket, Key=key)
|
|
197
|
+
except s3.exceptions.ClientError as e:
|
|
198
|
+
if e.response["Error"]["Code"] == "404":
|
|
199
|
+
raise ValueError(f"{source} does not exist ({bucket}, {key})")
|
|
200
|
+
raise
|
|
201
|
+
|
|
202
|
+
size = int(response["ContentLength"])
|
|
203
|
+
|
|
204
|
+
if verbosity > 0:
|
|
205
|
+
LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})")
|
|
206
|
+
|
|
207
|
+
if overwrite:
|
|
208
|
+
resume = False
|
|
209
|
+
|
|
210
|
+
if resume:
|
|
211
|
+
if os.path.exists(target):
|
|
212
|
+
local_size = os.path.getsize(target)
|
|
213
|
+
if local_size != size:
|
|
214
|
+
LOGGER.warning(
|
|
215
|
+
f"{target} already with different size, re-downloading (remote={size}, local={local_size})"
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
# if verbosity > 0:
|
|
219
|
+
# LOGGER.info(f"{target} already exists, skipping")
|
|
220
|
+
return size
|
|
221
|
+
|
|
222
|
+
if os.path.exists(target) and not overwrite:
|
|
223
|
+
raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip")
|
|
224
|
+
|
|
225
|
+
if verbosity > 0:
|
|
226
|
+
with tqdm.tqdm(total=size, unit="B", unit_scale=True, unit_divisor=1024, leave=False) as pbar:
|
|
227
|
+
s3.download_file(bucket, key, target, Callback=lambda x: pbar.update(x), Config=config)
|
|
228
|
+
else:
|
|
229
|
+
s3.download_file(bucket, key, target, Config=config)
|
|
230
|
+
|
|
231
|
+
return size
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _list_objects(target, batch=False):
|
|
235
|
+
_, _, bucket, prefix = target.split("/", 3)
|
|
236
|
+
s3 = s3_client(bucket)
|
|
237
|
+
|
|
238
|
+
paginator = s3.get_paginator("list_objects_v2")
|
|
239
|
+
|
|
240
|
+
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
|
241
|
+
if "Contents" in page:
|
|
242
|
+
objects = deepcopy(page["Contents"])
|
|
243
|
+
if batch:
|
|
244
|
+
yield objects
|
|
245
|
+
else:
|
|
246
|
+
yield from objects
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _delete_folder(target) -> None:
|
|
250
|
+
_, _, bucket, _ = target.split("/", 3)
|
|
251
|
+
s3 = s3_client(bucket)
|
|
252
|
+
|
|
253
|
+
total = 0
|
|
254
|
+
for batch in _list_objects(target, batch=True):
|
|
255
|
+
LOGGER.info(f"Deleting {len(batch):,} objects from {target}")
|
|
256
|
+
s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": o["Key"]} for o in batch]})
|
|
257
|
+
total += len(batch)
|
|
258
|
+
LOGGER.info(f"Deleted {len(batch):,} objects (total={total:,})")
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _delete_file(target) -> None:
|
|
262
|
+
from botocore.exceptions import ClientError
|
|
263
|
+
|
|
264
|
+
_, _, bucket, key = target.split("/", 3)
|
|
265
|
+
s3 = s3_client(bucket)
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
s3.head_object(Bucket=bucket, Key=key)
|
|
269
|
+
exits = True
|
|
270
|
+
except ClientError as e:
|
|
271
|
+
if e.response["Error"]["Code"] != "404":
|
|
272
|
+
raise
|
|
273
|
+
exits = False
|
|
274
|
+
|
|
275
|
+
if not exits:
|
|
276
|
+
LOGGER.warning(f"{target} does not exist. Did you mean to delete a folder? Then add a trailing '/'")
|
|
277
|
+
return
|
|
278
|
+
|
|
279
|
+
LOGGER.info(f"Deleting {target}")
|
|
280
|
+
s3.delete_object(Bucket=bucket, Key=key)
|
|
281
|
+
LOGGER.info(f"{target} is deleted")
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def delete(target) -> None:
|
|
285
|
+
"""Delete a file or a folder from S3.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
target : str
|
|
290
|
+
The URL of a file or a folder on S3. The url should start with 's3://'. If the URL ends with a '/' it is
|
|
291
|
+
assumed to be a folder, otherwise it is assumed to be a file.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
assert target.startswith("s3://")
|
|
295
|
+
|
|
296
|
+
if target.endswith("/"):
|
|
297
|
+
_delete_folder(target)
|
|
298
|
+
else:
|
|
299
|
+
_delete_file(target)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def list_folder(folder) -> Iterable:
|
|
303
|
+
"""List the sub folders in a folder on S3.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
folder : str
|
|
308
|
+
The URL of a folder on S3. The url should start with 's3://'.
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
list
|
|
313
|
+
A list of the subfolders names in the folder.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
assert folder.startswith("s3://")
|
|
317
|
+
if not folder.endswith("/"):
|
|
318
|
+
folder += "/"
|
|
319
|
+
|
|
320
|
+
_, _, bucket, prefix = folder.split("/", 3)
|
|
321
|
+
|
|
322
|
+
s3 = s3_client(bucket)
|
|
323
|
+
paginator = s3.get_paginator("list_objects_v2")
|
|
324
|
+
|
|
325
|
+
for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"):
|
|
326
|
+
if "CommonPrefixes" in page:
|
|
327
|
+
yield from [folder + _["Prefix"] for _ in page.get("CommonPrefixes")]
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def object_info(target) -> dict:
|
|
331
|
+
"""Get information about an object on S3.
|
|
332
|
+
|
|
333
|
+
Parameters
|
|
334
|
+
----------
|
|
335
|
+
target : str
|
|
336
|
+
The URL of a file or a folder on S3. The url should start with 's3://'.
|
|
337
|
+
|
|
338
|
+
Returns
|
|
339
|
+
-------
|
|
340
|
+
dict
|
|
341
|
+
A dictionary with information about the object.
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
_, _, bucket, key = target.split("/", 3)
|
|
345
|
+
s3 = s3_client(bucket)
|
|
346
|
+
|
|
347
|
+
try:
|
|
348
|
+
return s3.head_object(Bucket=bucket, Key=key)
|
|
349
|
+
except s3.exceptions.ClientError as e:
|
|
350
|
+
if e.response["Error"]["Code"] == "404":
|
|
351
|
+
raise ValueError(f"{target} does not exist")
|
|
352
|
+
raise
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def object_acl(target) -> dict:
|
|
356
|
+
"""Get information about an object's ACL on S3.
|
|
357
|
+
|
|
358
|
+
Parameters
|
|
359
|
+
----------
|
|
360
|
+
target : str
|
|
361
|
+
The URL of a file or a folder on S3. The url should start with 's3://'.
|
|
362
|
+
|
|
363
|
+
Returns
|
|
364
|
+
-------
|
|
365
|
+
dict
|
|
366
|
+
A dictionary with information about the object's ACL.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
_, _, bucket, key = target.split("/", 3)
|
|
370
|
+
s3 = s3_client()
|
|
371
|
+
|
|
372
|
+
return s3.get_object_acl(Bucket=bucket, Key=key)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def download(source, target, *args, **kwargs):
|
|
376
|
+
from . import transfer
|
|
377
|
+
|
|
378
|
+
assert source.startswith("s3://"), f"source {source} should start with 's3://'"
|
|
379
|
+
return transfer(source, target, *args, **kwargs)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def upload(source, target, *args, **kwargs):
|
|
383
|
+
from . import transfer
|
|
384
|
+
|
|
385
|
+
assert target.startswith("s3://"), f"target {target} should start with 's3://'"
|
|
386
|
+
return transfer(source, target, *args, **kwargs)
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import random
|
|
11
|
+
import shlex
|
|
12
|
+
import subprocess
|
|
13
|
+
|
|
14
|
+
from ..humanize import bytes_to_human
|
|
15
|
+
from . import BaseUpload
|
|
16
|
+
|
|
17
|
+
LOGGER = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def call_process(*args):
|
|
21
|
+
proc = subprocess.Popen(
|
|
22
|
+
args,
|
|
23
|
+
stdout=subprocess.PIPE,
|
|
24
|
+
stderr=subprocess.PIPE,
|
|
25
|
+
)
|
|
26
|
+
stdout, stderr = proc.communicate()
|
|
27
|
+
if proc.returncode != 0:
|
|
28
|
+
print(stdout)
|
|
29
|
+
msg = f"{' '.join(args)} failed: {stderr}"
|
|
30
|
+
raise RuntimeError(msg)
|
|
31
|
+
|
|
32
|
+
return stdout.decode("utf-8").strip()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SshBaseUpload(BaseUpload):
|
|
36
|
+
|
|
37
|
+
def _parse_target(self, target):
|
|
38
|
+
assert target.startswith("ssh://"), target
|
|
39
|
+
|
|
40
|
+
target = target[6:]
|
|
41
|
+
hostname, path = target.split(":")
|
|
42
|
+
|
|
43
|
+
if "+" in hostname:
|
|
44
|
+
hostnames = hostname.split("+")
|
|
45
|
+
hostname = hostnames[random.randint(0, len(hostnames) - 1)]
|
|
46
|
+
|
|
47
|
+
return hostname, path
|
|
48
|
+
|
|
49
|
+
def get_temporary_target(self, target, pattern):
|
|
50
|
+
hostname, path = self._parse_target(target)
|
|
51
|
+
dirname, basename = os.path.split(path)
|
|
52
|
+
path = pattern.format(dirname=dirname, basename=basename)
|
|
53
|
+
return f"ssh://{hostname}:{path}"
|
|
54
|
+
|
|
55
|
+
def rename_target(self, target, new_target):
|
|
56
|
+
hostname, path = self._parse_target(target)
|
|
57
|
+
hostname, new_path = self._parse_target(new_target)
|
|
58
|
+
call_process("ssh", hostname, "mkdir", "-p", shlex.quote(os.path.dirname(new_path)))
|
|
59
|
+
call_process("ssh", hostname, "mv", shlex.quote(path), shlex.quote(new_path))
|
|
60
|
+
|
|
61
|
+
def delete_target(self, target):
|
|
62
|
+
pass
|
|
63
|
+
# hostname, path = self._parse_target(target)
|
|
64
|
+
# LOGGER.info(f"Deleting {target}")
|
|
65
|
+
# call_process("ssh", hostname, "rm", "-rf", shlex.quote(path))
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class RsyncUpload(SshBaseUpload):
|
|
69
|
+
|
|
70
|
+
def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None):
|
|
71
|
+
hostname, path = self._parse_target(target)
|
|
72
|
+
|
|
73
|
+
size = os.path.getsize(source)
|
|
74
|
+
|
|
75
|
+
if verbosity > 0:
|
|
76
|
+
LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})")
|
|
77
|
+
|
|
78
|
+
call_process("ssh", hostname, "mkdir", "-p", shlex.quote(os.path.dirname(path)))
|
|
79
|
+
call_process(
|
|
80
|
+
"rsync",
|
|
81
|
+
"-av",
|
|
82
|
+
"--partial",
|
|
83
|
+
# it would be nice to avoid two ssh calls, but the following is not possible,
|
|
84
|
+
# this is because it requires a shell command and would not be safe.
|
|
85
|
+
# # f"--rsync-path='mkdir -p {os.path.dirname(path)} && rsync'",
|
|
86
|
+
source,
|
|
87
|
+
f"{hostname}:{path}",
|
|
88
|
+
)
|
|
89
|
+
return size
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ScpUpload(SshBaseUpload):
|
|
93
|
+
|
|
94
|
+
def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None):
|
|
95
|
+
hostname, path = self._parse_target(target)
|
|
96
|
+
|
|
97
|
+
size = os.path.getsize(source)
|
|
98
|
+
|
|
99
|
+
if verbosity > 0:
|
|
100
|
+
LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})")
|
|
101
|
+
|
|
102
|
+
remote_size = None
|
|
103
|
+
try:
|
|
104
|
+
out = call_process("ssh", hostname, "stat", "-c", "%s", shlex.quote(path))
|
|
105
|
+
remote_size = int(out)
|
|
106
|
+
except RuntimeError:
|
|
107
|
+
remote_size = None
|
|
108
|
+
|
|
109
|
+
if remote_size is not None:
|
|
110
|
+
if remote_size != size:
|
|
111
|
+
LOGGER.warning(
|
|
112
|
+
f"{target} already exists, but with different size, re-uploading (remote={remote_size}, local={size})"
|
|
113
|
+
)
|
|
114
|
+
elif resume:
|
|
115
|
+
# LOGGER.info(f"{target} already exists, skipping")
|
|
116
|
+
return size
|
|
117
|
+
|
|
118
|
+
if remote_size is not None and not overwrite and not resume:
|
|
119
|
+
raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip")
|
|
120
|
+
|
|
121
|
+
call_process("ssh", hostname, "mkdir", "-p", shlex.quote(os.path.dirname(path)))
|
|
122
|
+
call_process("scp", source, shlex.quote(f"{hostname}:{path}"))
|
|
123
|
+
|
|
124
|
+
return size
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def upload(source, target, **kwargs) -> None:
|
|
128
|
+
uploader = RsyncUpload()
|
|
129
|
+
|
|
130
|
+
if os.path.isdir(source):
|
|
131
|
+
uploader.transfer_folder(source=source, target=target, **kwargs)
|
|
132
|
+
else:
|
|
133
|
+
uploader.transfer_file(source=source, target=target, **kwargs)
|