plexus-python-common 1.0.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plexus/common/__init__.py +6 -0
- plexus/common/carto/OSMFile.py +259 -0
- plexus/common/carto/OSMNode.py +25 -0
- plexus/common/carto/OSMTags.py +101 -0
- plexus/common/carto/OSMWay.py +24 -0
- plexus/common/carto/__init__.py +11 -0
- plexus/common/pose.py +107 -0
- plexus/common/proj.py +305 -0
- plexus/common/utils/__init__.py +0 -0
- plexus/common/utils/apiutils.py +31 -0
- plexus/common/utils/bagutils.py +215 -0
- plexus/common/utils/config.py +61 -0
- plexus/common/utils/datautils.py +200 -0
- plexus/common/utils/jsonutils.py +92 -0
- plexus/common/utils/ormutils.py +1428 -0
- plexus/common/utils/s3utils.py +799 -0
- plexus/common/utils/shutils.py +234 -0
- plexus/common/utils/sqlutils.py +9 -0
- plexus/common/utils/strutils.py +382 -0
- plexus/common/utils/testutils.py +49 -0
- plexus_python_common-1.0.31.dist-info/METADATA +38 -0
- plexus_python_common-1.0.31.dist-info/RECORD +24 -0
- plexus_python_common-1.0.31.dist-info/WHEEL +5 -0
- plexus_python_common-1.0.31.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,799 @@
|
|
|
1
|
+
import concurrent.futures
|
|
2
|
+
import contextlib
|
|
3
|
+
import dataclasses
|
|
4
|
+
import datetime
|
|
5
|
+
import functools
|
|
6
|
+
import io
|
|
7
|
+
import mimetypes
|
|
8
|
+
import os
|
|
9
|
+
import os.path
|
|
10
|
+
import shutil
|
|
11
|
+
import tempfile
|
|
12
|
+
import typing
|
|
13
|
+
import zipfile
|
|
14
|
+
import zlib
|
|
15
|
+
from collections.abc import Callable, Generator
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Literal
|
|
18
|
+
|
|
19
|
+
import boto3
|
|
20
|
+
import fsspec
|
|
21
|
+
import fsspec.utils
|
|
22
|
+
from cloudpathlib import CloudPath, S3Client, S3Path
|
|
23
|
+
from iker.common.utils.sequtils import chunk_between, head, last
|
|
24
|
+
from iker.common.utils.shutils import glob_match, listfile, path_depth
|
|
25
|
+
from iker.common.utils.strutils import is_empty, trim_to_none
|
|
26
|
+
from rich.progress import BarColumn, DownloadColumn, Progress, TaskID, TextColumn, TransferSpeedColumn
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"S3ObjectMeta",
|
|
30
|
+
"s3_make_client",
|
|
31
|
+
"s3_list_objects",
|
|
32
|
+
"s3_listfile",
|
|
33
|
+
"s3_cp_download",
|
|
34
|
+
"s3_cp_upload",
|
|
35
|
+
"s3_sync_download",
|
|
36
|
+
"s3_sync_upload",
|
|
37
|
+
"s3_pull_text",
|
|
38
|
+
"s3_push_text",
|
|
39
|
+
"S3TransferCallbackClient",
|
|
40
|
+
"s3_make_progress_callback",
|
|
41
|
+
"s3_make_progressed_client",
|
|
42
|
+
"s3_archive_listfile",
|
|
43
|
+
"s3_archive_open_member",
|
|
44
|
+
"s3_archive_open_members",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclasses.dataclass(frozen=True, eq=True)
|
|
49
|
+
class S3ObjectMeta(object):
|
|
50
|
+
key: str
|
|
51
|
+
last_modified: datetime.datetime
|
|
52
|
+
size: int
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if typing.TYPE_CHECKING:
|
|
56
|
+
def s3_make_client(
|
|
57
|
+
access_key_id: str = None,
|
|
58
|
+
secret_access_key: str = None,
|
|
59
|
+
region_name: str = None,
|
|
60
|
+
endpoint_url: str = None,
|
|
61
|
+
) -> contextlib.AbstractContextManager[S3Client]: ...
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@contextlib.contextmanager
|
|
65
|
+
def s3_make_client(
|
|
66
|
+
access_key_id: str = None,
|
|
67
|
+
secret_access_key: str = None,
|
|
68
|
+
region_name: str = None,
|
|
69
|
+
endpoint_url: str = None,
|
|
70
|
+
) -> Generator[S3Client, None, None]:
|
|
71
|
+
"""
|
|
72
|
+
Creates an S3 client as a context manager for safe resource handling.
|
|
73
|
+
|
|
74
|
+
:param access_key_id: AWS access key ID.
|
|
75
|
+
:param secret_access_key: AWS secret access key.
|
|
76
|
+
:param region_name: AWS service region name.
|
|
77
|
+
:param endpoint_url: AWS service endpoint URL.
|
|
78
|
+
:return: An instance of ``S3Client``.
|
|
79
|
+
"""
|
|
80
|
+
session = boto3.Session(aws_access_key_id=trim_to_none(access_key_id),
|
|
81
|
+
aws_secret_access_key=trim_to_none(secret_access_key),
|
|
82
|
+
region_name=trim_to_none(region_name))
|
|
83
|
+
client = S3Client(boto3_session=session, endpoint_url=trim_to_none(endpoint_url))
|
|
84
|
+
try:
|
|
85
|
+
yield client
|
|
86
|
+
finally:
|
|
87
|
+
if hasattr(client, "close"):
|
|
88
|
+
client.close()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def s3_list_objects(client: S3Client, bucket: str, prefix: str, limit: int = None) -> Generator[S3ObjectMeta]:
|
|
92
|
+
"""
|
|
93
|
+
Lists all objects from the given S3 ``bucket`` and ``prefix``.
|
|
94
|
+
|
|
95
|
+
:param client: An instance of ``S3Client``.
|
|
96
|
+
:param bucket: Bucket name.
|
|
97
|
+
:param prefix: Object keys prefix.
|
|
98
|
+
:param limit: Maximum number of objects to return (``None`` for all).
|
|
99
|
+
:return: An iterable of ``S3ObjectMeta`` objects representing the S3 objects.
|
|
100
|
+
"""
|
|
101
|
+
continuation_token = None
|
|
102
|
+
count = 0
|
|
103
|
+
while True:
|
|
104
|
+
if is_empty(continuation_token):
|
|
105
|
+
response = client.client.list_objects_v2(MaxKeys=1000, Bucket=bucket, Prefix=prefix)
|
|
106
|
+
else:
|
|
107
|
+
response = client.client.list_objects_v2(MaxKeys=1000,
|
|
108
|
+
Bucket=bucket,
|
|
109
|
+
Prefix=prefix,
|
|
110
|
+
ContinuationToken=continuation_token)
|
|
111
|
+
|
|
112
|
+
contents = response.get("Contents", [])
|
|
113
|
+
count += len(contents)
|
|
114
|
+
if limit is not None and count > limit:
|
|
115
|
+
contents = contents[:limit - count]
|
|
116
|
+
|
|
117
|
+
yield from (S3ObjectMeta(key=e["Key"], last_modified=e["LastModified"], size=e["Size"]) for e in contents)
|
|
118
|
+
|
|
119
|
+
if not response.get("IsTruncated") or (limit is not None and count >= limit):
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
continuation_token = response.get("NextContinuationToken")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def s3_listfile(
|
|
126
|
+
client: S3Client,
|
|
127
|
+
bucket: str,
|
|
128
|
+
prefix: str,
|
|
129
|
+
*,
|
|
130
|
+
include_patterns: list[str] | None = None,
|
|
131
|
+
exclude_patterns: list[str] | None = None,
|
|
132
|
+
depth: int = 0,
|
|
133
|
+
) -> Generator[S3ObjectMeta]:
|
|
134
|
+
"""
|
|
135
|
+
Lists all objects from the given S3 ``bucket`` and ``prefix``, filtered by patterns and directory depth.
|
|
136
|
+
|
|
137
|
+
:param client: An instance of ``S3Client``.
|
|
138
|
+
:param bucket: Bucket name.
|
|
139
|
+
:param prefix: Object keys prefix.
|
|
140
|
+
:param include_patterns: Inclusive glob patterns applied to filenames.
|
|
141
|
+
:param exclude_patterns: Exclusive glob patterns applied to filenames.
|
|
142
|
+
:param depth: Maximum depth of subdirectories to include in the scan (``0`` for unlimited depth).
|
|
143
|
+
:return: An iterable of ``S3ObjectMeta`` objects representing the filtered S3 objects.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
# We add trailing slash "/" to the prefix if it is absent
|
|
147
|
+
if not prefix.endswith("/"):
|
|
148
|
+
prefix = prefix + "/"
|
|
149
|
+
|
|
150
|
+
def filter_object_meta(object_meta: S3ObjectMeta) -> bool:
|
|
151
|
+
if 0 < depth <= path_depth(prefix, os.path.dirname(object_meta.key)):
|
|
152
|
+
return False
|
|
153
|
+
if len(glob_match([os.path.basename(object_meta.key)], include_patterns, exclude_patterns)) == 0:
|
|
154
|
+
return False
|
|
155
|
+
return True
|
|
156
|
+
|
|
157
|
+
yield from filter(filter_object_meta, s3_list_objects(client, bucket, prefix))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def s3_cp_download(client: S3Client, bucket: str, key: str, file_path: str | os.PathLike[str]):
|
|
161
|
+
"""
|
|
162
|
+
Downloads an object from the given S3 ``bucket`` and ``key`` to a local file path.
|
|
163
|
+
|
|
164
|
+
:param client: An instance of ``S3Client``.
|
|
165
|
+
:param bucket: Bucket name.
|
|
166
|
+
:param key: Object key.
|
|
167
|
+
:param file_path: Local file path to save the object.
|
|
168
|
+
"""
|
|
169
|
+
client.client.download_file(bucket, key, file_path)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def s3_cp_upload(client: S3Client, file_path: str | os.PathLike[str], bucket: str, key: str):
|
|
173
|
+
"""
|
|
174
|
+
Uploads a local file to the given S3 ``bucket`` and ``key``.
|
|
175
|
+
|
|
176
|
+
:param client: An instance of ``S3Client``.
|
|
177
|
+
:param file_path: Local file path to upload.
|
|
178
|
+
:param bucket: Bucket name.
|
|
179
|
+
:param key: Object key for the uploaded file.
|
|
180
|
+
"""
|
|
181
|
+
t, _ = mimetypes.MimeTypes().guess_type(file_path)
|
|
182
|
+
client.client.upload_file(file_path,
|
|
183
|
+
bucket,
|
|
184
|
+
key,
|
|
185
|
+
ExtraArgs={"ContentType": "binary/octet-stream" if t is None else t})
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def s3_sync_download(
|
|
189
|
+
client: S3Client,
|
|
190
|
+
bucket: str,
|
|
191
|
+
prefix: str,
|
|
192
|
+
dir_path: str | os.PathLike[str],
|
|
193
|
+
*,
|
|
194
|
+
max_workers: int = None,
|
|
195
|
+
include_patterns: list[str] = None,
|
|
196
|
+
exclude_patterns: list[str] = None,
|
|
197
|
+
depth: int = 0,
|
|
198
|
+
):
|
|
199
|
+
"""
|
|
200
|
+
Recursively downloads all objects from the given S3 ``bucket`` and ``prefix`` to a local directory path, using a thread pool.
|
|
201
|
+
|
|
202
|
+
:param client: An instance of ``S3Client``.
|
|
203
|
+
:param bucket: Bucket name.
|
|
204
|
+
:param prefix: Object keys prefix.
|
|
205
|
+
:param dir_path: Local directory path to save objects.
|
|
206
|
+
:param max_workers: Maximum number of worker threads.
|
|
207
|
+
:param include_patterns: Inclusive glob patterns applied to filenames.
|
|
208
|
+
:param exclude_patterns: Exclusive glob patterns applied to filenames.
|
|
209
|
+
:param depth: Maximum depth of subdirectories to include in the scan (``0`` for unlimited depth).
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
# We add trailing slash "/" to the prefix if it is absent
|
|
213
|
+
if not prefix.endswith("/"):
|
|
214
|
+
prefix = prefix + "/"
|
|
215
|
+
|
|
216
|
+
objects = s3_listfile(client,
|
|
217
|
+
bucket,
|
|
218
|
+
prefix,
|
|
219
|
+
include_patterns=include_patterns,
|
|
220
|
+
exclude_patterns=exclude_patterns,
|
|
221
|
+
depth=depth)
|
|
222
|
+
|
|
223
|
+
def download_file(key: str):
|
|
224
|
+
file_path = os.path.join(dir_path, key[len(prefix):])
|
|
225
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
226
|
+
s3_cp_download(client, bucket, key, file_path)
|
|
227
|
+
|
|
228
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
229
|
+
futures = [executor.submit(download_file, obj.key) for obj in objects]
|
|
230
|
+
done_futures, not_done_futures = concurrent.futures.wait(futures,
|
|
231
|
+
return_when=concurrent.futures.FIRST_EXCEPTION)
|
|
232
|
+
if len(not_done_futures) > 0:
|
|
233
|
+
for future in not_done_futures:
|
|
234
|
+
future.cancel()
|
|
235
|
+
for future in done_futures:
|
|
236
|
+
exc = future.exception()
|
|
237
|
+
if exc is not None:
|
|
238
|
+
raise exc
|
|
239
|
+
if len(not_done_futures) > 0:
|
|
240
|
+
raise RuntimeError("download did not complete due to errors in some threads")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def s3_sync_upload(
|
|
244
|
+
client: S3Client,
|
|
245
|
+
dir_path: str | os.PathLike[str],
|
|
246
|
+
bucket: str,
|
|
247
|
+
prefix: str,
|
|
248
|
+
*,
|
|
249
|
+
max_workers: int = None,
|
|
250
|
+
include_patterns: list[str] = None,
|
|
251
|
+
exclude_patterns: list[str] = None,
|
|
252
|
+
depth: int = 0,
|
|
253
|
+
):
|
|
254
|
+
"""
|
|
255
|
+
Recursively uploads all files from a local directory to the given S3 ``bucket`` and ``prefix``, using a thread pool.
|
|
256
|
+
|
|
257
|
+
:param client: An instance of ``S3Client``.
|
|
258
|
+
:param dir_path: Local directory path to upload from.
|
|
259
|
+
:param bucket: Bucket name.
|
|
260
|
+
:param prefix: Object keys prefix for uploaded files.
|
|
261
|
+
:param max_workers: Maximum number of worker threads.
|
|
262
|
+
:param include_patterns: Inclusive glob patterns applied to filenames.
|
|
263
|
+
:param exclude_patterns: Exclusive glob patterns applied to filenames.
|
|
264
|
+
:param depth: Maximum depth of subdirectories to include in the scan (``0`` for unlimited depth).
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
# We add trailing slash "/" to the prefix if it is absent
|
|
268
|
+
if not prefix.endswith("/"):
|
|
269
|
+
prefix = prefix + "/"
|
|
270
|
+
|
|
271
|
+
file_paths = listfile(dir_path,
|
|
272
|
+
include_patterns=include_patterns,
|
|
273
|
+
exclude_patterns=exclude_patterns,
|
|
274
|
+
depth=depth)
|
|
275
|
+
|
|
276
|
+
def upload_file(file_path: str):
|
|
277
|
+
s3_cp_upload(client, file_path, bucket, prefix + os.path.relpath(file_path, dir_path))
|
|
278
|
+
|
|
279
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
280
|
+
futures = [executor.submit(upload_file, file_path) for file_path in file_paths]
|
|
281
|
+
done_futures, not_done_futures = concurrent.futures.wait(futures,
|
|
282
|
+
return_when=concurrent.futures.FIRST_EXCEPTION)
|
|
283
|
+
if len(not_done_futures) > 0:
|
|
284
|
+
for future in not_done_futures:
|
|
285
|
+
future.cancel()
|
|
286
|
+
for future in done_futures:
|
|
287
|
+
exc = future.exception()
|
|
288
|
+
if exc is not None:
|
|
289
|
+
raise exc
|
|
290
|
+
if len(not_done_futures) > 0:
|
|
291
|
+
raise RuntimeError("upload did not complete due to errors in some threads")
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def s3_pull_text(client: S3Client, bucket: str, key: str, encoding: str = None) -> str:
|
|
295
|
+
"""
|
|
296
|
+
Downloads and decodes text content stored as an object in the given S3 ``bucket`` and ``key``.
|
|
297
|
+
|
|
298
|
+
:param client: An instance of ``S3Client``.
|
|
299
|
+
:param bucket: Bucket name.
|
|
300
|
+
:param key: Object key storing the text.
|
|
301
|
+
:param encoding: String encoding to use (defaults to UTF-8).
|
|
302
|
+
:return: The decoded text content.
|
|
303
|
+
"""
|
|
304
|
+
with tempfile.TemporaryFile() as fp:
|
|
305
|
+
client.client.download_fileobj(bucket, key, fp)
|
|
306
|
+
fp.seek(0)
|
|
307
|
+
return fp.read().decode(encoding or "utf-8")
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def s3_push_text(client: S3Client, text: str, bucket: str, key: str, encoding: str = None):
|
|
311
|
+
"""
|
|
312
|
+
Uploads the given text as an object to the specified S3 ``bucket`` and ``key``.
|
|
313
|
+
|
|
314
|
+
:param client: An instance of ``S3Client``.
|
|
315
|
+
:param text: Text content to upload.
|
|
316
|
+
:param bucket: Bucket name.
|
|
317
|
+
:param key: Object key to store the text.
|
|
318
|
+
:param encoding: String encoding to use (defaults to UTF-8).
|
|
319
|
+
"""
|
|
320
|
+
with tempfile.TemporaryFile() as fp:
|
|
321
|
+
fp.write(text.encode(encoding or "utf-8"))
|
|
322
|
+
fp.seek(0)
|
|
323
|
+
client.client.upload_fileobj(fp, bucket, key)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
TransferDirection = Literal["download", "upload"]
|
|
327
|
+
TransferState = Literal["start", "update", "stop"]
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
@contextlib.contextmanager
|
|
331
|
+
def make_transfer_callback(
|
|
332
|
+
callback: Callable[[CloudPath, TransferDirection, TransferState, int], None],
|
|
333
|
+
path: Path | CloudPath,
|
|
334
|
+
direction: TransferDirection,
|
|
335
|
+
):
|
|
336
|
+
if callback is None:
|
|
337
|
+
yield None
|
|
338
|
+
return
|
|
339
|
+
|
|
340
|
+
callback(path, direction, "start", 0)
|
|
341
|
+
try:
|
|
342
|
+
yield functools.partial(callback, path, direction, "update")
|
|
343
|
+
finally:
|
|
344
|
+
callback(path, direction, "stop", 0)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class S3TransferCallbackClient(S3Client):
|
|
348
|
+
def __init__(
|
|
349
|
+
self,
|
|
350
|
+
*args,
|
|
351
|
+
transfer_callback: Callable[[Path | CloudPath, TransferDirection, TransferState, int], None],
|
|
352
|
+
**kwargs,
|
|
353
|
+
):
|
|
354
|
+
super().__init__(*args, **kwargs)
|
|
355
|
+
self.transfer_callback = transfer_callback
|
|
356
|
+
|
|
357
|
+
def _download_file(self, cloud_path: S3Path, local_path: str | os.PathLike[str]) -> Path:
|
|
358
|
+
local_path = Path(local_path)
|
|
359
|
+
|
|
360
|
+
obj = self.s3.Object(cloud_path.bucket, cloud_path.key)
|
|
361
|
+
|
|
362
|
+
with make_transfer_callback(self.transfer_callback, cloud_path, "download") as callback:
|
|
363
|
+
obj.download_file(
|
|
364
|
+
str(local_path),
|
|
365
|
+
Config=self.boto3_transfer_config,
|
|
366
|
+
ExtraArgs=self.boto3_dl_extra_args,
|
|
367
|
+
Callback=callback,
|
|
368
|
+
)
|
|
369
|
+
return local_path
|
|
370
|
+
|
|
371
|
+
def _upload_file(self, local_path: str | os.PathLike[str], cloud_path: S3Path) -> S3Path:
|
|
372
|
+
local_path = Path(local_path)
|
|
373
|
+
|
|
374
|
+
obj = self.s3.Object(cloud_path.bucket, cloud_path.key)
|
|
375
|
+
|
|
376
|
+
extra_args = self.boto3_ul_extra_args.copy()
|
|
377
|
+
|
|
378
|
+
if self.content_type_method is not None:
|
|
379
|
+
content_type, content_encoding = self.content_type_method(str(local_path))
|
|
380
|
+
if content_type is not None:
|
|
381
|
+
extra_args["ContentType"] = content_type
|
|
382
|
+
if content_encoding is not None:
|
|
383
|
+
extra_args["ContentEncoding"] = content_encoding
|
|
384
|
+
|
|
385
|
+
with make_transfer_callback(self.transfer_callback, local_path, "upload") as callback:
|
|
386
|
+
obj.upload_file(
|
|
387
|
+
str(local_path),
|
|
388
|
+
Config=self.boto3_transfer_config,
|
|
389
|
+
ExtraArgs=extra_args,
|
|
390
|
+
Callback=callback,
|
|
391
|
+
)
|
|
392
|
+
return cloud_path
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def s3_make_progress_callback(
|
|
396
|
+
progress: Progress,
|
|
397
|
+
) -> Callable[[Path | CloudPath, TransferDirection, TransferState, int], None]:
|
|
398
|
+
task_ids: dict[Path | CloudPath, TaskID] = {}
|
|
399
|
+
|
|
400
|
+
def progress_callback(path: Path | CloudPath, direction: TransferDirection, state: TransferState, bytes_sent: int):
|
|
401
|
+
if state == "start":
|
|
402
|
+
size = path.stat().st_size
|
|
403
|
+
task_ids[path] = progress.add_task(direction, total=size, filename=path.name)
|
|
404
|
+
elif state == "stop":
|
|
405
|
+
if path in task_ids:
|
|
406
|
+
progress.remove_task(task_ids[path])
|
|
407
|
+
del task_ids[path]
|
|
408
|
+
else:
|
|
409
|
+
progress.update(task_ids[path], advance=bytes_sent)
|
|
410
|
+
|
|
411
|
+
return progress_callback
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
if typing.TYPE_CHECKING:
|
|
415
|
+
def s3_make_progressed_client(
|
|
416
|
+
access_key_id: str = None,
|
|
417
|
+
secret_access_key: str = None,
|
|
418
|
+
region_name: str = None,
|
|
419
|
+
endpoint_url: str = None,
|
|
420
|
+
) -> contextlib.AbstractContextManager[S3Client]: ...
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@contextlib.contextmanager
|
|
424
|
+
def s3_make_progressed_client(
|
|
425
|
+
access_key_id: str = None,
|
|
426
|
+
secret_access_key: str = None,
|
|
427
|
+
region_name: str = None,
|
|
428
|
+
endpoint_url: str = None,
|
|
429
|
+
) -> Generator[S3Client]:
|
|
430
|
+
"""
|
|
431
|
+
Creates an S3 client with progress callback as a context manager for safe resource handling.
|
|
432
|
+
|
|
433
|
+
:param access_key_id: AWS access key ID.
|
|
434
|
+
:param secret_access_key: AWS secret access key.
|
|
435
|
+
:param region_name: AWS service region name.
|
|
436
|
+
:param endpoint_url: AWS service endpoint URL.
|
|
437
|
+
:return: An instance of ``S3TransferCallbackClient``.
|
|
438
|
+
"""
|
|
439
|
+
with Progress(
|
|
440
|
+
TextColumn("[blue]{task.fields[filename]}"),
|
|
441
|
+
BarColumn(),
|
|
442
|
+
DownloadColumn(),
|
|
443
|
+
TransferSpeedColumn(),
|
|
444
|
+
) as progress:
|
|
445
|
+
session = boto3.Session(aws_access_key_id=trim_to_none(access_key_id),
|
|
446
|
+
aws_secret_access_key=trim_to_none(secret_access_key),
|
|
447
|
+
region_name=trim_to_none(region_name))
|
|
448
|
+
yield S3TransferCallbackClient(boto3_session=session,
|
|
449
|
+
endpoint_url=trim_to_none(endpoint_url),
|
|
450
|
+
transfer_callback=s3_make_progress_callback(progress))
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def s3_options_from_s3_client(client: S3Client) -> dict[str, Any]:
|
|
454
|
+
"""
|
|
455
|
+
Extracts S3 connection options from an existing S3Client instance for use with ``fsspec``.
|
|
456
|
+
|
|
457
|
+
:param client: An instance of ``S3Client``.
|
|
458
|
+
:return: A dictionary of S3 connection options.
|
|
459
|
+
"""
|
|
460
|
+
if client.sess is None:
|
|
461
|
+
return {}
|
|
462
|
+
|
|
463
|
+
s3_options: dict[str, Any] = {}
|
|
464
|
+
|
|
465
|
+
credentials = client.sess.get_credentials()
|
|
466
|
+
if credentials is not None:
|
|
467
|
+
if credentials.access_key:
|
|
468
|
+
s3_options["key"] = credentials.access_key
|
|
469
|
+
if credentials.secret_key:
|
|
470
|
+
s3_options["secret"] = credentials.secret_key
|
|
471
|
+
if credentials.token:
|
|
472
|
+
s3_options["token"] = credentials.token
|
|
473
|
+
|
|
474
|
+
client_kwargs = {}
|
|
475
|
+
if client.sess.region_name or client.client.meta.region_name:
|
|
476
|
+
client_kwargs["region_name"] = client.sess.region_name or client.client.meta.region_name
|
|
477
|
+
if client.client.meta.endpoint_url:
|
|
478
|
+
client_kwargs["endpoint_url"] = client.client.meta.endpoint_url
|
|
479
|
+
|
|
480
|
+
if client_kwargs:
|
|
481
|
+
s3_options["client_kwargs"] = client_kwargs
|
|
482
|
+
|
|
483
|
+
return s3_options
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def s3_archive_listfile(
|
|
487
|
+
client: S3Client,
|
|
488
|
+
archive_url: str,
|
|
489
|
+
members: list[str] | None = None,
|
|
490
|
+
) -> tuple[int, list[zipfile.ZipInfo], list[str]]:
|
|
491
|
+
"""
|
|
492
|
+
Lists members of a ZIP archive stored in S3, optionally filtering by specific member names. When filtering,
|
|
493
|
+
if a member is a directory, it must end with a trailing slash ("/") to be recognized as such, and all files
|
|
494
|
+
under that directory will be included in the results.
|
|
495
|
+
|
|
496
|
+
Example usage:
|
|
497
|
+
>>> archive_size, member_infos, missed_members = s3_archive_listfile(client, archive_url, members=["file1.txt", "dir1/"])
|
|
498
|
+
>>> for info in member_infos:
|
|
499
|
+
... print(info.filename, info.file_size)
|
|
500
|
+
>>> if missed_members:
|
|
501
|
+
... print("Members not found:", missed_members)
|
|
502
|
+
|
|
503
|
+
:param client: An instance of ``S3Client``.
|
|
504
|
+
:param archive_url: The URL of the ZIP archive in S3.
|
|
505
|
+
:param members: Optional list of member names to filter; if ``None``, all members are returned.
|
|
506
|
+
:return: A tuple containing:
|
|
507
|
+
- The size of the archive in bytes.
|
|
508
|
+
- A list of ``zipfile.ZipInfo`` objects for the included members.
|
|
509
|
+
- A list of member names that were not found in the archive.
|
|
510
|
+
"""
|
|
511
|
+
s3_options = s3_options_from_s3_client(client)
|
|
512
|
+
|
|
513
|
+
storage_opts = fsspec.utils.infer_storage_options(archive_url)
|
|
514
|
+
protocol = storage_opts.get("protocol")
|
|
515
|
+
if protocol != "s3":
|
|
516
|
+
raise ValueError(f"unsupported protocol '{protocol}', only 's3' is supported")
|
|
517
|
+
|
|
518
|
+
fs = fsspec.filesystem("s3", **s3_options)
|
|
519
|
+
archive_size = fs.size(archive_url)
|
|
520
|
+
|
|
521
|
+
with fsspec.open(archive_url, "rb", s3=s3_options) as s3_fh, zipfile.ZipFile(s3_fh) as archive:
|
|
522
|
+
member_zip_infos = archive.infolist()
|
|
523
|
+
|
|
524
|
+
if members is None:
|
|
525
|
+
return archive_size, [info for info in member_zip_infos if not info.is_dir()], []
|
|
526
|
+
|
|
527
|
+
# Build a tree structure of members for efficient lookup
|
|
528
|
+
# Directories have ZipInfo and a nested dict; files have ZipInfo and None
|
|
529
|
+
# Directory members are recognized by names ending with a trailing slash ("/")
|
|
530
|
+
# Example:
|
|
531
|
+
# {
|
|
532
|
+
# "dir1/": (ZipInfo, {
|
|
533
|
+
# "file1.txt": (ZipInfo, None),
|
|
534
|
+
# "subdir/": (ZipInfo, {
|
|
535
|
+
# "file2.txt": (ZipInfo, None)
|
|
536
|
+
# })
|
|
537
|
+
# }),
|
|
538
|
+
# "file3.txt": (ZipInfo, None)
|
|
539
|
+
# }
|
|
540
|
+
members_tree: dict[str, tuple[zipfile.ZipInfo, dict | None]] = {}
|
|
541
|
+
|
|
542
|
+
def build_members_tree(info: zipfile.ZipInfo):
|
|
543
|
+
*parts, last_part = info.filename.rstrip("/").split("/")
|
|
544
|
+
current = members_tree
|
|
545
|
+
for part in parts:
|
|
546
|
+
_, current = current.setdefault(part + "/", (None, {}))
|
|
547
|
+
if info.is_dir():
|
|
548
|
+
current[last_part + "/"] = info, {}
|
|
549
|
+
else:
|
|
550
|
+
current[last_part] = info, None
|
|
551
|
+
|
|
552
|
+
# Sort by filename to ensure directories are created before their contents
|
|
553
|
+
for info in sorted(member_zip_infos, key=lambda x: x.filename):
|
|
554
|
+
build_members_tree(info)
|
|
555
|
+
|
|
556
|
+
def search_members_tree(member: str) -> tuple[zipfile.ZipInfo | None, dict | None]:
|
|
557
|
+
*parts, last_part = member.rstrip("/").split("/")
|
|
558
|
+
current = members_tree
|
|
559
|
+
for part in parts:
|
|
560
|
+
_, current = current.get(part + "/", (None, None))
|
|
561
|
+
if current is None:
|
|
562
|
+
return None, None
|
|
563
|
+
if member.endswith("/"): # Directory member recognized by trailing slash
|
|
564
|
+
return current.get(last_part + "/", (None, None))
|
|
565
|
+
else:
|
|
566
|
+
return current.get(last_part, (None, None))
|
|
567
|
+
|
|
568
|
+
def collect_member_infos(tree: dict[str, tuple[zipfile.ZipInfo, dict | None]]) -> Generator[zipfile.ZipInfo]:
|
|
569
|
+
for _, (member_info, member_tree) in tree.items():
|
|
570
|
+
if member_info is None:
|
|
571
|
+
continue
|
|
572
|
+
if member_info.is_dir():
|
|
573
|
+
yield from collect_member_infos(member_tree)
|
|
574
|
+
else:
|
|
575
|
+
yield member_info
|
|
576
|
+
|
|
577
|
+
included_member_zip_infos = []
|
|
578
|
+
missed_members = []
|
|
579
|
+
|
|
580
|
+
for member in members:
|
|
581
|
+
member_info, member_tree = search_members_tree(member)
|
|
582
|
+
if member_info is None:
|
|
583
|
+
missed_members.append(member)
|
|
584
|
+
continue
|
|
585
|
+
if not member_info.is_dir():
|
|
586
|
+
included_member_zip_infos.append(member_info)
|
|
587
|
+
else:
|
|
588
|
+
included_member_zip_infos.extend(collect_member_infos(member_tree or {}))
|
|
589
|
+
|
|
590
|
+
return archive_size, included_member_zip_infos, missed_members
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
if typing.TYPE_CHECKING:
|
|
594
|
+
def s3_archive_open_member(
|
|
595
|
+
client: S3Client,
|
|
596
|
+
archive_url: str,
|
|
597
|
+
member: str,
|
|
598
|
+
mode: Literal["r", "rb"] = "r",
|
|
599
|
+
) -> contextlib.AbstractContextManager[typing.IO]: ...
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
@contextlib.contextmanager
|
|
603
|
+
def s3_archive_open_member(
|
|
604
|
+
client: S3Client,
|
|
605
|
+
archive_url: str,
|
|
606
|
+
member: str,
|
|
607
|
+
mode: Literal["r", "rb"] = "r",
|
|
608
|
+
) -> Generator[typing.IO, None, None]:
|
|
609
|
+
"""
|
|
610
|
+
Opens a specific member file from a ZIP archive stored in S3.
|
|
611
|
+
|
|
612
|
+
:param client: An instance of ``S3Client``.
|
|
613
|
+
:param archive_url: The URL of the ZIP archive in S3.
|
|
614
|
+
:param member: The member file name to open from the archive.
|
|
615
|
+
:param mode: File mode for opening the member ("r" for text, "rb" for binary).
|
|
616
|
+
|
|
617
|
+
:return: A file-like object for the specified member within the ZIP archive.
|
|
618
|
+
"""
|
|
619
|
+
if mode not in ("r", "rb"):
|
|
620
|
+
raise ValueError("mode must be either 'r' or 'rb'")
|
|
621
|
+
|
|
622
|
+
s3_options = s3_options_from_s3_client(client)
|
|
623
|
+
|
|
624
|
+
storage_opts = fsspec.utils.infer_storage_options(archive_url)
|
|
625
|
+
protocol = storage_opts.get("protocol")
|
|
626
|
+
if protocol != "s3":
|
|
627
|
+
raise ValueError(f"unsupported protocol '{protocol}', only 's3' is supported")
|
|
628
|
+
|
|
629
|
+
with fsspec.open(f"zip://{member}::{archive_url}", mode, s3=s3_options) as s3_fh:
|
|
630
|
+
yield s3_fh
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
ZIPFILE_HEADER_MIN_SIZE = 30
|
|
634
|
+
ZIPFILE_HEADER_FN_LEN_OFFSET = 26
|
|
635
|
+
ZIPFILE_HEADER_EX_LEN_OFFSET = 28
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def zipfile_hdr_fn_len_slice(offset: int) -> slice:
|
|
639
|
+
return slice(offset + ZIPFILE_HEADER_FN_LEN_OFFSET, offset + ZIPFILE_HEADER_FN_LEN_OFFSET + 2)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def zipfile_hdr_ex_len_slice(offset: int) -> slice:
|
|
643
|
+
return slice(offset + ZIPFILE_HEADER_EX_LEN_OFFSET, offset + ZIPFILE_HEADER_EX_LEN_OFFSET + 2)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def s3_archive_open_members(
|
|
647
|
+
client: S3Client,
|
|
648
|
+
archive_url: str,
|
|
649
|
+
members: list[str] | None = None,
|
|
650
|
+
mode: Literal["r", "rb"] = "r",
|
|
651
|
+
*,
|
|
652
|
+
threshold: float = 0.5,
|
|
653
|
+
central_directory_overhead: int = 64 * 1024,
|
|
654
|
+
member_header_overhead: int = 128,
|
|
655
|
+
use_chunked_reads: bool = False,
|
|
656
|
+
) -> Generator[tuple[str, Callable[[], typing.IO]], None, None]:
|
|
657
|
+
"""
|
|
658
|
+
Choose the best transfer strategy (ranged requests per-member vs full archive transfer)
|
|
659
|
+
based on estimated transfer size ratio and yield callables to open each requested member.
|
|
660
|
+
|
|
661
|
+
The callables return file-like objects for each member when invoked. Due to lazy evaluation,
|
|
662
|
+
the actual data transfer occurs when the member is opened by the corresponding callable.
|
|
663
|
+
Thus, the callables must be used immediately after being yielded, to avoid issues with temporary
|
|
664
|
+
file lifetimes.
|
|
665
|
+
|
|
666
|
+
Example usage:
|
|
667
|
+
|
|
668
|
+
>>> for member, opener in s3_archive_open_members(client, archive_url, members):
|
|
669
|
+
... with opener() as fh:
|
|
670
|
+
... data = fh.read()
|
|
671
|
+
|
|
672
|
+
Incorrect usage that may lead to errors due to temporary file cleanup:
|
|
673
|
+
|
|
674
|
+
>>> openers = []
|
|
675
|
+
>>> for member, opener in s3_archive_open_members(client, archive_url, members):
|
|
676
|
+
... openers.append((member, opener))
|
|
677
|
+
>>> for member, opener in openers:
|
|
678
|
+
... with opener() as fh: # May fail if temporary files have been cleaned up
|
|
679
|
+
... data = fh.read()
|
|
680
|
+
|
|
681
|
+
:param client: An instance of ``S3Client``.
|
|
682
|
+
:param archive_url: S3 URL to the ZIP archive.
|
|
683
|
+
:param members: List of member names to stream.
|
|
684
|
+
:param mode: File mode for opening members ("r" for text, "rb" for binary).
|
|
685
|
+
:param threshold: If (estimated ranged transfer bytes / archive bytes) <= threshold,
|
|
686
|
+
use ranged per-member access; otherwise download the whole archive.
|
|
687
|
+
:param central_directory_overhead: Passed to s3_estimate_archive_ranged_requests.
|
|
688
|
+
:param member_header_overhead: Passed to s3_estimate_archive_ranged_requests.
|
|
689
|
+
:param use_chunked_reads: If ``True`` and ranged access is chosen, group adjacent members into single ranged reads.
|
|
690
|
+
|
|
691
|
+
:return: An iterable of callables that return file-like objects for each requested member.
|
|
692
|
+
"""
|
|
693
|
+
if mode not in ("r", "rb"):
|
|
694
|
+
raise ValueError("mode must be either 'r' or 'rb'")
|
|
695
|
+
if member_header_overhead < ZIPFILE_HEADER_MIN_SIZE:
|
|
696
|
+
raise ValueError(f"member_header_overhead must be at least {ZIPFILE_HEADER_MIN_SIZE} bytes")
|
|
697
|
+
|
|
698
|
+
s3_options = s3_options_from_s3_client(client)
|
|
699
|
+
|
|
700
|
+
archive_size, member_zip_infos, missed_members = s3_archive_listfile(client, archive_url, members)
|
|
701
|
+
|
|
702
|
+
if len(missed_members) > 0:
|
|
703
|
+
raise FileNotFoundError(f"members not found in archive '{missed_members}'")
|
|
704
|
+
|
|
705
|
+
estimated_ranged_total_size = central_directory_overhead + sum(info.compress_size + member_header_overhead
|
|
706
|
+
for info in member_zip_infos)
|
|
707
|
+
|
|
708
|
+
# Avoid division by zero; prefer ranged if archive size is zero (degenerate case)
|
|
709
|
+
use_ranged = (estimated_ranged_total_size / archive_size) <= threshold if archive_size > 0 else True
|
|
710
|
+
|
|
711
|
+
@dataclasses.dataclass(frozen=True)
|
|
712
|
+
class MemberChunk(object):
|
|
713
|
+
name: str
|
|
714
|
+
header_offset: int
|
|
715
|
+
compress_size: int
|
|
716
|
+
compress_type: int
|
|
717
|
+
|
|
718
|
+
@property
|
|
719
|
+
def end(self) -> int:
|
|
720
|
+
return self.header_offset + self.compress_size + member_header_overhead
|
|
721
|
+
|
|
722
|
+
if use_ranged and use_chunked_reads:
|
|
723
|
+
# Open archive once to read central directory and gather ZipInfo for requested members.
|
|
724
|
+
# We will group adjacent members (by local header offsets) and issue one ranged read per group,
|
|
725
|
+
# then extract each member from the group's bytes to avoid many small ranged requests.
|
|
726
|
+
with fsspec.open(archive_url, "rb", s3=s3_options) as s3_fh, zipfile.ZipFile(s3_fh) as archive:
|
|
727
|
+
|
|
728
|
+
chunks = [MemberChunk(info.filename, info.header_offset, info.compress_size, info.compress_type)
|
|
729
|
+
for info in member_zip_infos]
|
|
730
|
+
|
|
731
|
+
chunks_groups = chunk_between(sorted(chunks, key=lambda x: x.header_offset),
|
|
732
|
+
chunk_func=lambda x, y: y.header_offset > x.end + member_header_overhead)
|
|
733
|
+
|
|
734
|
+
# For each group, create openers for members inside the group.
|
|
735
|
+
for group in chunks_groups:
|
|
736
|
+
group_offset = head(group).header_offset
|
|
737
|
+
group_size = last(group).end - group_offset
|
|
738
|
+
|
|
739
|
+
# Read group's bytes from remote (single ranged read)
|
|
740
|
+
s3_fh.seek(group_offset)
|
|
741
|
+
group_bytes = s3_fh.read(group_size)
|
|
742
|
+
|
|
743
|
+
def make_opener(chunk: MemberChunk) -> Callable[[], typing.IO]:
|
|
744
|
+
|
|
745
|
+
def opener() -> typing.IO:
|
|
746
|
+
index = chunk.header_offset - group_offset
|
|
747
|
+
|
|
748
|
+
if index + ZIPFILE_HEADER_MIN_SIZE > len(group_bytes):
|
|
749
|
+
raise IOError("unexpected short read of member header")
|
|
750
|
+
|
|
751
|
+
fn_len = int.from_bytes(group_bytes[zipfile_hdr_fn_len_slice(index)], "little")
|
|
752
|
+
ex_len = int.from_bytes(group_bytes[zipfile_hdr_ex_len_slice(index)], "little")
|
|
753
|
+
|
|
754
|
+
raw_data_begin = index + ZIPFILE_HEADER_MIN_SIZE + fn_len + ex_len
|
|
755
|
+
raw_data_end = raw_data_begin + chunk.compress_size
|
|
756
|
+
|
|
757
|
+
if raw_data_end > len(group_bytes):
|
|
758
|
+
raise IOError("unexpected short read of compressed data")
|
|
759
|
+
|
|
760
|
+
raw_data = group_bytes[raw_data_begin:raw_data_end]
|
|
761
|
+
|
|
762
|
+
if chunk.compress_type == zipfile.ZIP_STORED:
|
|
763
|
+
pass
|
|
764
|
+
elif chunk.compress_type == zipfile.ZIP_DEFLATED:
|
|
765
|
+
raw_data = zlib.decompress(raw_data, -zlib.MAX_WBITS)
|
|
766
|
+
else:
|
|
767
|
+
raise NotImplementedError(f"unsupported compression '{chunk.compress_type}'")
|
|
768
|
+
|
|
769
|
+
if mode == "r":
|
|
770
|
+
return io.TextIOWrapper(io.BytesIO(raw_data), encoding="utf-8")
|
|
771
|
+
else:
|
|
772
|
+
return io.BytesIO(raw_data)
|
|
773
|
+
|
|
774
|
+
return opener
|
|
775
|
+
|
|
776
|
+
yield from ((chunk.name, make_opener(chunk)) for chunk in group)
|
|
777
|
+
return
|
|
778
|
+
|
|
779
|
+
if use_ranged:
|
|
780
|
+
for info in member_zip_infos:
|
|
781
|
+
opener = functools.partial(s3_archive_open_member, client, archive_url, info.filename, mode)
|
|
782
|
+
yield info.filename, opener
|
|
783
|
+
return
|
|
784
|
+
|
|
785
|
+
# Download full archive once and serve members from it (read member bytes into memory)
|
|
786
|
+
with fsspec.open(archive_url, "rb", s3=s3_options) as s3_fh, tempfile.TemporaryFile() as temp_fh:
|
|
787
|
+
shutil.copyfileobj(s3_fh, temp_fh)
|
|
788
|
+
temp_fh.seek(0)
|
|
789
|
+
with zipfile.ZipFile(temp_fh) as archive:
|
|
790
|
+
for info in member_zip_infos:
|
|
791
|
+
try:
|
|
792
|
+
if mode == "r":
|
|
793
|
+
opener = lambda fn=info.filename: io.TextIOWrapper(archive.open(fn), encoding="utf-8")
|
|
794
|
+
else:
|
|
795
|
+
opener = lambda fn=info.filename: archive.open(fn)
|
|
796
|
+
yield info.filename, opener
|
|
797
|
+
except KeyError as e:
|
|
798
|
+
# Shouldn't happen due to earlier check, but guard anyway
|
|
799
|
+
raise FileNotFoundError(info.filename) from e
|