thds.adls 3.0.20250116223841__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of thds.adls might be problematic. Click here for more details.
- thds/adls/__init__.py +15 -0
- thds/adls/_progress.py +193 -0
- thds/adls/_upload.py +127 -0
- thds/adls/abfss.py +24 -0
- thds/adls/cached_up_down.py +48 -0
- thds/adls/conf.py +33 -0
- thds/adls/dbfs.py +60 -0
- thds/adls/defaults.py +26 -0
- thds/adls/download.py +394 -0
- thds/adls/download_lock.py +57 -0
- thds/adls/errors.py +44 -0
- thds/adls/etag.py +6 -0
- thds/adls/file_properties.py +13 -0
- thds/adls/fqn.py +169 -0
- thds/adls/global_client.py +78 -0
- thds/adls/impl.py +1111 -0
- thds/adls/md5.py +60 -0
- thds/adls/meta.json +8 -0
- thds/adls/named_roots.py +26 -0
- thds/adls/py.typed +0 -0
- thds/adls/resource/__init__.py +36 -0
- thds/adls/resource/core.py +79 -0
- thds/adls/resource/file_pointers.py +54 -0
- thds/adls/resource/up_down.py +245 -0
- thds/adls/ro_cache.py +126 -0
- thds/adls/shared_credential.py +107 -0
- thds/adls/source.py +66 -0
- thds/adls/tools/download.py +35 -0
- thds/adls/tools/ls.py +38 -0
- thds/adls/uri.py +38 -0
- thds.adls-3.0.20250116223841.dist-info/METADATA +16 -0
- thds.adls-3.0.20250116223841.dist-info/RECORD +35 -0
- thds.adls-3.0.20250116223841.dist-info/WHEEL +5 -0
- thds.adls-3.0.20250116223841.dist-info/entry_points.txt +3 -0
- thds.adls-3.0.20250116223841.dist-info/top_level.txt +1 -0
thds/adls/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from thds.core import meta
|
|
2
|
+
|
|
3
|
+
from . import abfss, defaults, etag, fqn, named_roots, resource, source, uri # noqa: F401
|
|
4
|
+
from .cached_up_down import download_directory, download_to_cache, upload_through_cache # noqa: F401
|
|
5
|
+
from .errors import BlobNotFoundError # noqa: F401
|
|
6
|
+
from .fqn import * # noqa: F401,F403
|
|
7
|
+
from .global_client import get_global_client, get_global_fs_client # noqa: F401
|
|
8
|
+
from .impl import * # noqa: F401,F403
|
|
9
|
+
from .ro_cache import Cache, global_cache # noqa: F401
|
|
10
|
+
from .uri import UriIsh, parse_any, parse_uri, resolve_any, resolve_uri # noqa: F401
|
|
11
|
+
|
|
12
|
+
__version__ = meta.get_version(__name__)
|
|
13
|
+
metadata = meta.read_metadata(__name__)
|
|
14
|
+
__basepackage__ = __name__
|
|
15
|
+
__commit__ = metadata.git_commit
|
thds/adls/_progress.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""An app-global progress reporter that attempts to reduce the number
|
|
2
|
+
of progress reports using either a time delay or by using fancy
|
|
3
|
+
progress bars.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import typing as ty
|
|
8
|
+
from functools import reduce
|
|
9
|
+
from timeit import default_timer
|
|
10
|
+
|
|
11
|
+
from thds.core import log
|
|
12
|
+
|
|
13
|
+
logger = log.getLogger(__name__)
|
|
14
|
+
_1MB = 2**20
|
|
15
|
+
_UPDATE_INTERVAL_S = 5
|
|
16
|
+
_SUPPORTS_CR = not bool(os.getenv("CI"))
|
|
17
|
+
# CI does not support carriage returns.
|
|
18
|
+
# if we find other cases that don't, we can add them here.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProgressState(ty.NamedTuple):
|
|
22
|
+
start: float
|
|
23
|
+
total: int
|
|
24
|
+
n: int
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _dumb_report_progress(desc: str, state: ProgressState):
|
|
28
|
+
if not state.total:
|
|
29
|
+
logger.info(f"{desc} complete!")
|
|
30
|
+
return
|
|
31
|
+
if not state.n:
|
|
32
|
+
return # don't report when nothing has happened yet.
|
|
33
|
+
|
|
34
|
+
start, total, n_bytes = state
|
|
35
|
+
pct = 100 * (n_bytes / total)
|
|
36
|
+
elapsed = default_timer() - start
|
|
37
|
+
rate_s = f" at {n_bytes/_1MB/elapsed:,.1f} MiB/s"
|
|
38
|
+
logger.info(f"{desc}: {n_bytes:,} / {total:,} bytes ({pct:.1f}%){rate_s} in {elapsed:.1f}s")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _sum_ps(ps: ty.Iterable[ProgressState]) -> ProgressState:
|
|
42
|
+
return reduce(
|
|
43
|
+
lambda x, y: ProgressState(min(x.start, y.start), x.total + y.total, y.n + x.n),
|
|
44
|
+
ps,
|
|
45
|
+
ProgressState(default_timer(), 0, 0),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _blobs(n: list) -> str:
|
|
50
|
+
if not n:
|
|
51
|
+
return ""
|
|
52
|
+
return f" {len(n)} blob" + ("" if len(n) == 1 else "s")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _Reporter(ty.Protocol):
|
|
56
|
+
def __call__(self, states: ty.List[ProgressState]):
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DumbReporter:
|
|
61
|
+
def __init__(self, desc: str):
|
|
62
|
+
self._desc = desc
|
|
63
|
+
self._started = default_timer()
|
|
64
|
+
self._last_reported = self._started
|
|
65
|
+
|
|
66
|
+
def __call__(self, states: ty.List[ProgressState]):
|
|
67
|
+
now = default_timer()
|
|
68
|
+
# two cases that require a report:
|
|
69
|
+
# 1. it's been a long enough time (update interval) since the last report.
|
|
70
|
+
if now - self._last_reported > _UPDATE_INTERVAL_S:
|
|
71
|
+
_dumb_report_progress(self._desc + f" {_blobs(states)}", _sum_ps(states))
|
|
72
|
+
self._last_reported = now
|
|
73
|
+
# 2. a download finished _and_ that specific download took longer overall than our update interval.
|
|
74
|
+
else:
|
|
75
|
+
for state in states:
|
|
76
|
+
if (
|
|
77
|
+
state.total
|
|
78
|
+
and state.n >= state.total # download finished
|
|
79
|
+
and (now - state.start) > _UPDATE_INTERVAL_S # and it took a while
|
|
80
|
+
):
|
|
81
|
+
# report individually for each download that finished.
|
|
82
|
+
_dumb_report_progress(self._desc + f" {_blobs([state])}", state)
|
|
83
|
+
# notably, we do not delay the next 'standard' report because of downloads finishing.
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class TqdmReporter:
|
|
87
|
+
"""Falls back to DumbReporter if tqdm is not installed."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, desc: str):
|
|
90
|
+
self._desc = desc
|
|
91
|
+
self._bar = None
|
|
92
|
+
self._dumb = DumbReporter(desc)
|
|
93
|
+
|
|
94
|
+
def __call__(self, states: ty.List[ProgressState]):
|
|
95
|
+
try:
|
|
96
|
+
from tqdm import tqdm # type: ignore
|
|
97
|
+
|
|
98
|
+
bar = self._bar
|
|
99
|
+
state = _sum_ps(states)
|
|
100
|
+
if not bar and state.total > 0:
|
|
101
|
+
bar = tqdm(
|
|
102
|
+
total=state.total,
|
|
103
|
+
delay=_UPDATE_INTERVAL_S,
|
|
104
|
+
mininterval=_UPDATE_INTERVAL_S,
|
|
105
|
+
initial=state.n,
|
|
106
|
+
unit="byte",
|
|
107
|
+
unit_scale=True,
|
|
108
|
+
) # type: ignore
|
|
109
|
+
if bar:
|
|
110
|
+
# if there are zero active states (which is possible),
|
|
111
|
+
# n and total will be zero after sum, and we don't
|
|
112
|
+
# want to set zeros on an existing non-zero bar.
|
|
113
|
+
bar.total == state.total or bar.total
|
|
114
|
+
new_n = state.n or bar.n
|
|
115
|
+
bar.update(new_n - bar.n)
|
|
116
|
+
bar.desc = f"{self._desc}{_blobs(states)}"
|
|
117
|
+
if _SUPPORTS_CR:
|
|
118
|
+
bar.refresh()
|
|
119
|
+
|
|
120
|
+
if bar.n >= bar.total:
|
|
121
|
+
bar.close()
|
|
122
|
+
bar = None
|
|
123
|
+
|
|
124
|
+
self._bar = bar
|
|
125
|
+
except ModuleNotFoundError:
|
|
126
|
+
self._dumb(states)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class Tracker:
|
|
130
|
+
def __init__(self, reporter: _Reporter):
|
|
131
|
+
self._progresses: ty.Dict[str, ProgressState] = dict()
|
|
132
|
+
self._reporter = reporter
|
|
133
|
+
|
|
134
|
+
def add(self, key: str, total: int) -> ty.Tuple["Tracker", str]:
|
|
135
|
+
if total < 0:
|
|
136
|
+
total = 0
|
|
137
|
+
self._progresses[key] = ProgressState(default_timer(), total, 0)
|
|
138
|
+
self._reporter(list(self._progresses.values()))
|
|
139
|
+
return self, key
|
|
140
|
+
|
|
141
|
+
def __call__(self, key: str, written: int):
|
|
142
|
+
assert written >= 0, "cannot write negative bytes: {written}"
|
|
143
|
+
try:
|
|
144
|
+
start, total, n = self._progresses[key]
|
|
145
|
+
self._progresses[key] = ProgressState(start, total, n + written)
|
|
146
|
+
self._reporter(list(self._progresses.values()))
|
|
147
|
+
if self._progresses[key].n >= total:
|
|
148
|
+
del self._progresses[key]
|
|
149
|
+
except KeyError:
|
|
150
|
+
self._reporter(list(self._progresses.values()))
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
_GLOBAL_DN_TRACKER = Tracker(TqdmReporter("thds.adls downloading"))
|
|
154
|
+
_GLOBAL_UP_TRACKER = Tracker(TqdmReporter("thds.adls uploading"))
|
|
155
|
+
T = ty.TypeVar("T", bound=ty.IO)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _proxy_io(io_type: str, stream: T, key: str, total_len: int) -> T:
|
|
159
|
+
assert io_type in ("read", "write"), io_type
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
old_io = getattr(stream, io_type)
|
|
163
|
+
total_len = total_len or len(stream) # type: ignore
|
|
164
|
+
except (AttributeError, TypeError):
|
|
165
|
+
return stream
|
|
166
|
+
|
|
167
|
+
if io_type == "read":
|
|
168
|
+
tracker, _ = _GLOBAL_UP_TRACKER.add(key, total_len)
|
|
169
|
+
else:
|
|
170
|
+
tracker, _ = _GLOBAL_DN_TRACKER.add(key, total_len)
|
|
171
|
+
|
|
172
|
+
def io(data_or_len: ty.Union[bytes, int]):
|
|
173
|
+
r = old_io(data_or_len)
|
|
174
|
+
io_len = (
|
|
175
|
+
total_len
|
|
176
|
+
if data_or_len == -1
|
|
177
|
+
else (len(data_or_len) if isinstance(data_or_len, bytes) else data_or_len)
|
|
178
|
+
)
|
|
179
|
+
tracker(key, io_len)
|
|
180
|
+
return r
|
|
181
|
+
|
|
182
|
+
setattr(stream, io_type, io)
|
|
183
|
+
return stream
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def report_download_progress(stream: T, key: str, total: int = 0) -> T:
|
|
187
|
+
if not total: # if we don't know how big a download is, we can't report progress.
|
|
188
|
+
return stream
|
|
189
|
+
return _proxy_io("write", stream, key, total)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def report_upload_progress(stream: T, key: str, total: int = 0) -> T:
|
|
193
|
+
return _proxy_io("read", stream, key, total)
|
thds/adls/_upload.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Just utilities for deciding whether or not to upload.
|
|
2
|
+
|
|
3
|
+
Not an officially-published API of the thds.adls library.
|
|
4
|
+
"""
|
|
5
|
+
import typing as ty
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import azure.core.exceptions
|
|
9
|
+
from azure.storage.blob import ContentSettings
|
|
10
|
+
|
|
11
|
+
from thds.core import hostname, log
|
|
12
|
+
|
|
13
|
+
from .md5 import AnyStrSrc, try_md5
|
|
14
|
+
|
|
15
|
+
_SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES = 2 * 2**20 # 2 MB is about right
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
logger = log.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _get_checksum_content_settings(data: AnyStrSrc) -> ty.Optional[ContentSettings]:
|
|
22
|
+
"""Ideally, we calculate an MD5 sum for all data that we upload.
|
|
23
|
+
|
|
24
|
+
The only circumstances under which we cannot do this are if the
|
|
25
|
+
stream does not exist in its entirety before the upload begins.
|
|
26
|
+
"""
|
|
27
|
+
md5 = try_md5(data)
|
|
28
|
+
if md5:
|
|
29
|
+
return ContentSettings(content_md5=md5)
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _too_small_to_skip_upload(data: AnyStrSrc, min_size_for_remote_check: int) -> bool:
|
|
34
|
+
def _len() -> int:
|
|
35
|
+
if isinstance(data, Path) and data.exists():
|
|
36
|
+
return data.stat().st_size
|
|
37
|
+
try:
|
|
38
|
+
return len(data) # type: ignore
|
|
39
|
+
except TypeError as te:
|
|
40
|
+
logger.debug(f"failed to get length? {repr(te)} for {data}")
|
|
41
|
+
return min_size_for_remote_check + 1
|
|
42
|
+
|
|
43
|
+
return _len() < min_size_for_remote_check
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class UploadDecision(ty.NamedTuple):
|
|
47
|
+
upload_required: bool
|
|
48
|
+
content_settings: ty.Optional[ContentSettings]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Properties(ty.Protocol):
|
|
52
|
+
name: str
|
|
53
|
+
content_settings: ContentSettings
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
|
|
57
|
+
data: AnyStrSrc, min_size_for_remote_check: int
|
|
58
|
+
) -> ty.Generator[bool, ty.Optional[Properties], UploadDecision]:
|
|
59
|
+
local_content_settings = _get_checksum_content_settings(data)
|
|
60
|
+
if not local_content_settings:
|
|
61
|
+
return UploadDecision(True, None)
|
|
62
|
+
if _too_small_to_skip_upload(data, min_size_for_remote_check):
|
|
63
|
+
logger.debug("Too small to bother with an early call - let's just upload...")
|
|
64
|
+
return UploadDecision(True, local_content_settings)
|
|
65
|
+
remote_properties = yield True
|
|
66
|
+
if not remote_properties:
|
|
67
|
+
logger.debug("No remote properties could be fetched so an upload is required")
|
|
68
|
+
return UploadDecision(True, local_content_settings)
|
|
69
|
+
if remote_properties.content_settings.content_md5 == local_content_settings.content_md5:
|
|
70
|
+
logger.info(f"Remote file {remote_properties.name} already exists and has matching checksum")
|
|
71
|
+
return UploadDecision(False, local_content_settings)
|
|
72
|
+
logger.debug("Remote file exists but MD5 does not match - upload required.")
|
|
73
|
+
return UploadDecision(True, local_content_settings)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
doc = """
|
|
77
|
+
Returns False for upload_required if the file is large and the remote
|
|
78
|
+
exists and has a known, matching checksum.
|
|
79
|
+
|
|
80
|
+
Returns ContentSettings if an MD5 checksum can be calculated.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def async_upload_decision_and_settings(
|
|
85
|
+
get_properties: ty.Callable[[], ty.Awaitable[Properties]],
|
|
86
|
+
data: AnyStrSrc,
|
|
87
|
+
min_size_for_remote_check: int = _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES,
|
|
88
|
+
) -> UploadDecision:
|
|
89
|
+
try:
|
|
90
|
+
co = _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
|
|
91
|
+
data, min_size_for_remote_check
|
|
92
|
+
)
|
|
93
|
+
while True:
|
|
94
|
+
co.send(None)
|
|
95
|
+
try:
|
|
96
|
+
co.send(await get_properties())
|
|
97
|
+
except azure.core.exceptions.ResourceNotFoundError:
|
|
98
|
+
co.send(None)
|
|
99
|
+
except StopIteration as stop:
|
|
100
|
+
return stop.value
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def upload_decision_and_settings(
|
|
104
|
+
get_properties: ty.Callable[[], Properties],
|
|
105
|
+
data: AnyStrSrc,
|
|
106
|
+
min_size_for_remote_check: int = _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES,
|
|
107
|
+
) -> UploadDecision:
|
|
108
|
+
try:
|
|
109
|
+
co = _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
|
|
110
|
+
data, min_size_for_remote_check
|
|
111
|
+
)
|
|
112
|
+
while True:
|
|
113
|
+
co.send(None)
|
|
114
|
+
try:
|
|
115
|
+
co.send(get_properties())
|
|
116
|
+
except azure.core.exceptions.ResourceNotFoundError:
|
|
117
|
+
co.send(None)
|
|
118
|
+
except StopIteration as stop:
|
|
119
|
+
return stop.value
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
async_upload_decision_and_settings.__doc__ = doc
|
|
123
|
+
upload_decision_and_settings.__doc__ = doc
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def metadata_for_upload() -> ty.Dict[str, str]:
|
|
127
|
+
return {"upload_wrapper_sw": "thds.adls", "upload_hostname": hostname.friendly()}
|
thds/adls/abfss.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Translate ADLS URIs to ABFSS URIs (for use with Spark/Hadoop)."""
|
|
2
|
+
from .fqn import AdlsFqn
|
|
3
|
+
|
|
4
|
+
ABFSS_SCHEME = "abfss://"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NotAbfssUri(ValueError):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def from_adls_fqn(fqn: AdlsFqn) -> str:
|
|
12
|
+
return f"{ABFSS_SCHEME}{fqn.container}@{fqn.sa}.dfs.core.windows.net/{fqn.path.lstrip('/')}"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def from_adls_uri(uri: str) -> str:
|
|
16
|
+
return from_adls_fqn(AdlsFqn.parse(uri))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def to_adls_fqn(abfss_uri: str) -> AdlsFqn:
|
|
20
|
+
if not abfss_uri.startswith(ABFSS_SCHEME):
|
|
21
|
+
raise NotAbfssUri(f"URI does not start with {ABFSS_SCHEME!r}: {abfss_uri!r}")
|
|
22
|
+
container, rest = abfss_uri[len(ABFSS_SCHEME) :].split("@", 1)
|
|
23
|
+
sa, path = rest.split(".dfs.core.windows.net/")
|
|
24
|
+
return AdlsFqn.of(sa, container, path)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from .download import download_or_use_verified
|
|
4
|
+
from .fqn import AdlsFqn
|
|
5
|
+
from .global_client import get_global_fs_client
|
|
6
|
+
from .impl import ADLSFileSystem
|
|
7
|
+
from .resource.up_down import AdlsHashedResource, upload
|
|
8
|
+
from .ro_cache import global_cache
|
|
9
|
+
from .uri import UriIsh, parse_any
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def download_to_cache(fqn_or_uri: UriIsh, md5b64: str = "") -> Path:
|
|
13
|
+
"""Downloads directly to the cache and returns a Path to the read-only file.
|
|
14
|
+
|
|
15
|
+
This will allow you to download a file 'into' the cache even if
|
|
16
|
+
you provide no MD5 and the remote file properties does not have
|
|
17
|
+
one. However, future attempts to reuse the cache will force a
|
|
18
|
+
re-download if no MD5 is available at that time.
|
|
19
|
+
"""
|
|
20
|
+
fqn = parse_any(fqn_or_uri)
|
|
21
|
+
cache_path = global_cache().path(fqn)
|
|
22
|
+
download_or_use_verified(
|
|
23
|
+
get_global_fs_client(fqn.sa, fqn.container), fqn.path, cache_path, md5b64, cache=global_cache()
|
|
24
|
+
)
|
|
25
|
+
return cache_path
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def upload_through_cache(dest: UriIsh, src_path: Path) -> AdlsHashedResource:
|
|
29
|
+
"""Return an AdlsHashedResource, since by definition an upload through the cache must have a known checksum.
|
|
30
|
+
|
|
31
|
+
Uses global client, which is pretty much always what you want.
|
|
32
|
+
"""
|
|
33
|
+
assert src_path.is_file(), "src_path must be a file."
|
|
34
|
+
resource = upload(dest, src_path, write_through_cache=global_cache())
|
|
35
|
+
assert resource, "MD5 should always be calculable for a local path."
|
|
36
|
+
return resource
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def download_directory(fqn: AdlsFqn) -> Path:
|
|
40
|
+
"""Download a directory from an AdlsFqn.
|
|
41
|
+
|
|
42
|
+
If you know you only need to download a single file, use download_to_cache.
|
|
43
|
+
"""
|
|
44
|
+
fs = ADLSFileSystem(fqn.sa, fqn.container)
|
|
45
|
+
cached_dir_root = global_cache().path(fqn)
|
|
46
|
+
fs.fetch_directory(fqn.path, cached_dir_root)
|
|
47
|
+
assert cached_dir_root.is_dir(), "Directory should have been downloaded to the cache."
|
|
48
|
+
return cached_dir_root
|
thds/adls/conf.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""This is where fine-tuning environment variables are defined."""
|
|
2
|
+
from thds.core import config
|
|
3
|
+
|
|
4
|
+
# These defaults were tested to perform well (~200 MB/sec) on a 2 core
|
|
5
|
+
# machine on Kubernetes. Larger numbers did not do any better, but
|
|
6
|
+
# these numbers did roughly 4x as well as the defaults, which are
|
|
7
|
+
# concurrency=1 and chunk_get_size=32 MB.
|
|
8
|
+
#
|
|
9
|
+
# As always, your mileage may vary.
|
|
10
|
+
#
|
|
11
|
+
# For more info, see docs at
|
|
12
|
+
# https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blob-download-python#specify-data-transfer-options-on-download
|
|
13
|
+
#
|
|
14
|
+
# Also see
|
|
15
|
+
# azure.storage.filedatalake._shared.base_client.create_configuration
|
|
16
|
+
# for actual details...
|
|
17
|
+
#
|
|
18
|
+
DOWNLOAD_FILE_MAX_CONCURRENCY = config.item("download_file_max_concurrency", 4, parse=int)
|
|
19
|
+
MAX_CHUNK_GET_SIZE = config.item("max_chunk_get_size", 2**20 * 64, parse=int) # 64MB
|
|
20
|
+
MAX_SINGLE_GET_SIZE = config.item(
|
|
21
|
+
"max_single_get_size", 2**20 * 64, parse=lambda i: max(MAX_CHUNK_GET_SIZE(), int(i))
|
|
22
|
+
) # 64MB
|
|
23
|
+
MAX_SINGLE_PUT_SIZE = config.item(
|
|
24
|
+
"max_single_put_size", 2**20 * 64, parse=lambda i: max(MAX_CHUNK_GET_SIZE(), int(i))
|
|
25
|
+
) # 64MB
|
|
26
|
+
|
|
27
|
+
# these are for upload
|
|
28
|
+
# these achieved 380 MB/sec on a 2 core machine on Kubernetes
|
|
29
|
+
MAX_BLOCK_SIZE = config.item("max_block_put_size", 2**20 * 64, parse=int) # 64 MB
|
|
30
|
+
UPLOAD_FILE_MAX_CONCURRENCY = config.item("upload_file_max_concurrency", 10, parse=int)
|
|
31
|
+
UPLOAD_CHUNK_SIZE = config.item("upload_chunk_size", 2**20 * 100, parse=int) # 100 MB
|
|
32
|
+
|
|
33
|
+
CONNECTION_TIMEOUT = config.item("connection_timeout", 2000, parse=int) # seconds
|
thds/adls/dbfs.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import typing as ty
|
|
2
|
+
|
|
3
|
+
from .fqn import AdlsFqn, AdlsRoot, join, parse_fqn
|
|
4
|
+
|
|
5
|
+
DBFS_SCHEME = "dbfs:/"
|
|
6
|
+
|
|
7
|
+
ADLS_TO_SPARK_MAPPING = {
|
|
8
|
+
"adls://uaapdatascience/data/": "/mnt/datascience/data/",
|
|
9
|
+
"adls://thdsdatasets/prod-datasets/": "/mnt/datascience/datasets/",
|
|
10
|
+
"adls://uaapdatascience/hive/": "/mnt/datascience/hive/",
|
|
11
|
+
"adls://thdsscratch/tmp/": "/mnt/datascience/scratch/",
|
|
12
|
+
}
|
|
13
|
+
ADLS_TO_DBFS_MAPPING = {k: join(DBFS_SCHEME, v) for k, v in ADLS_TO_SPARK_MAPPING.items()}
|
|
14
|
+
# Spark read/write implicitly adds a 'dbfs:/' prefix.
|
|
15
|
+
SPARK_TO_ADLS_MAPPING = {v: k for k, v in ADLS_TO_SPARK_MAPPING.items()}
|
|
16
|
+
DBFS_TO_ADLS_MAPPING = {join(DBFS_SCHEME, k): v for k, v in SPARK_TO_ADLS_MAPPING.items()}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def to_adls_root(root_uri: str) -> AdlsRoot:
|
|
20
|
+
try:
|
|
21
|
+
return AdlsRoot.parse(
|
|
22
|
+
DBFS_TO_ADLS_MAPPING[root_uri]
|
|
23
|
+
if root_uri.startswith(DBFS_SCHEME)
|
|
24
|
+
else SPARK_TO_ADLS_MAPPING[root_uri]
|
|
25
|
+
)
|
|
26
|
+
except KeyError:
|
|
27
|
+
raise ValueError(f"URI '{root_uri}' does not have a defined ADLS root!")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def to_adls_fqn(fully_qualified_name: str) -> AdlsFqn:
|
|
31
|
+
mapping = (
|
|
32
|
+
DBFS_TO_ADLS_MAPPING if fully_qualified_name.startswith(DBFS_SCHEME) else SPARK_TO_ADLS_MAPPING
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
dbfs_root, adls_root = next(
|
|
37
|
+
((k, v) for k, v in mapping.items() if fully_qualified_name.startswith(k))
|
|
38
|
+
)
|
|
39
|
+
except StopIteration:
|
|
40
|
+
raise ValueError(f"{fully_qualified_name} does not have a defined ADLS path!")
|
|
41
|
+
|
|
42
|
+
return parse_fqn(join(adls_root, fully_qualified_name.split(dbfs_root)[1]))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def to_uri(adls_path: ty.Union[AdlsRoot, AdlsFqn], spark: bool = True) -> str:
|
|
46
|
+
def get_root_uri(adls_root: AdlsRoot) -> str:
|
|
47
|
+
try:
|
|
48
|
+
return (
|
|
49
|
+
ADLS_TO_SPARK_MAPPING[str(adls_root)] if spark else ADLS_TO_DBFS_MAPPING[str(adls_root)]
|
|
50
|
+
)
|
|
51
|
+
except KeyError:
|
|
52
|
+
raise ValueError(f"{str(adls_root)} does not have a corresponding dbfs root!")
|
|
53
|
+
|
|
54
|
+
if isinstance(adls_path, AdlsRoot):
|
|
55
|
+
return get_root_uri(adls_path)
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
return join(get_root_uri(adls_path.root()), adls_path.path)
|
|
59
|
+
except ValueError:
|
|
60
|
+
raise ValueError(f"{str(adls_path)} does not have a corresponding dbfs path!")
|
thds/adls/defaults.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Prefer using named_containers for new code."""
|
|
2
|
+
|
|
3
|
+
from thds.core.env import Env
|
|
4
|
+
|
|
5
|
+
from . import fqn, named_roots
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import thds.adls._thds_defaults # noqa: F401
|
|
9
|
+
except ImportError:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def env_root(env: Env = "") -> fqn.AdlsRoot:
|
|
14
|
+
"""In many cases, you may want to call this with no arguments
|
|
15
|
+
to default to using the THDS_ENV environment variable.
|
|
16
|
+
"""
|
|
17
|
+
return named_roots.require(env)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def env_root_uri(env: Env = "") -> str:
|
|
21
|
+
return str(env_root(env))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def mops_root() -> str:
|
|
25
|
+
"""Returns a URI corresponding to the location where mops materialization should be put."""
|
|
26
|
+
return str(named_roots.require("mops"))
|