thds.adls 3.0.20250116223841__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of thds.adls might be problematic. Click here for more details.

thds/adls/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ from thds.core import meta
2
+
3
+ from . import abfss, defaults, etag, fqn, named_roots, resource, source, uri # noqa: F401
4
+ from .cached_up_down import download_directory, download_to_cache, upload_through_cache # noqa: F401
5
+ from .errors import BlobNotFoundError # noqa: F401
6
+ from .fqn import * # noqa: F401,F403
7
+ from .global_client import get_global_client, get_global_fs_client # noqa: F401
8
+ from .impl import * # noqa: F401,F403
9
+ from .ro_cache import Cache, global_cache # noqa: F401
10
+ from .uri import UriIsh, parse_any, parse_uri, resolve_any, resolve_uri # noqa: F401
11
+
12
+ __version__ = meta.get_version(__name__)
13
+ metadata = meta.read_metadata(__name__)
14
+ __basepackage__ = __name__
15
+ __commit__ = metadata.git_commit
thds/adls/_progress.py ADDED
@@ -0,0 +1,193 @@
1
+ """An app-global progress reporter that attempts to reduce the number
2
+ of progress reports using either a time delay or by using fancy
3
+ progress bars.
4
+ """
5
+
6
+ import os
7
+ import typing as ty
8
+ from functools import reduce
9
+ from timeit import default_timer
10
+
11
+ from thds.core import log
12
+
13
+ logger = log.getLogger(__name__)
14
+ _1MB = 2**20
15
+ _UPDATE_INTERVAL_S = 5
16
+ _SUPPORTS_CR = not bool(os.getenv("CI"))
17
+ # CI does not support carriage returns.
18
+ # if we find other cases that don't, we can add them here.
19
+
20
+
21
+ class ProgressState(ty.NamedTuple):
22
+ start: float
23
+ total: int
24
+ n: int
25
+
26
+
27
+ def _dumb_report_progress(desc: str, state: ProgressState):
28
+ if not state.total:
29
+ logger.info(f"{desc} complete!")
30
+ return
31
+ if not state.n:
32
+ return # don't report when nothing has happened yet.
33
+
34
+ start, total, n_bytes = state
35
+ pct = 100 * (n_bytes / total)
36
+ elapsed = default_timer() - start
37
+ rate_s = f" at {n_bytes/_1MB/elapsed:,.1f} MiB/s"
38
+ logger.info(f"{desc}: {n_bytes:,} / {total:,} bytes ({pct:.1f}%){rate_s} in {elapsed:.1f}s")
39
+
40
+
41
+ def _sum_ps(ps: ty.Iterable[ProgressState]) -> ProgressState:
42
+ return reduce(
43
+ lambda x, y: ProgressState(min(x.start, y.start), x.total + y.total, y.n + x.n),
44
+ ps,
45
+ ProgressState(default_timer(), 0, 0),
46
+ )
47
+
48
+
49
+ def _blobs(n: list) -> str:
50
+ if not n:
51
+ return ""
52
+ return f" {len(n)} blob" + ("" if len(n) == 1 else "s")
53
+
54
+
55
+ class _Reporter(ty.Protocol):
56
+ def __call__(self, states: ty.List[ProgressState]):
57
+ ...
58
+
59
+
60
+ class DumbReporter:
61
+ def __init__(self, desc: str):
62
+ self._desc = desc
63
+ self._started = default_timer()
64
+ self._last_reported = self._started
65
+
66
+ def __call__(self, states: ty.List[ProgressState]):
67
+ now = default_timer()
68
+ # two cases that require a report:
69
+ # 1. it's been a long enough time (update interval) since the last report.
70
+ if now - self._last_reported > _UPDATE_INTERVAL_S:
71
+ _dumb_report_progress(self._desc + f" {_blobs(states)}", _sum_ps(states))
72
+ self._last_reported = now
73
+ # 2. a download finished _and_ that specific download took longer overall than our update interval.
74
+ else:
75
+ for state in states:
76
+ if (
77
+ state.total
78
+ and state.n >= state.total # download finished
79
+ and (now - state.start) > _UPDATE_INTERVAL_S # and it took a while
80
+ ):
81
+ # report individually for each download that finished.
82
+ _dumb_report_progress(self._desc + f" {_blobs([state])}", state)
83
+ # notably, we do not delay the next 'standard' report because of downloads finishing.
84
+
85
+
86
+ class TqdmReporter:
87
+ """Falls back to DumbReporter if tqdm is not installed."""
88
+
89
+ def __init__(self, desc: str):
90
+ self._desc = desc
91
+ self._bar = None
92
+ self._dumb = DumbReporter(desc)
93
+
94
+ def __call__(self, states: ty.List[ProgressState]):
95
+ try:
96
+ from tqdm import tqdm # type: ignore
97
+
98
+ bar = self._bar
99
+ state = _sum_ps(states)
100
+ if not bar and state.total > 0:
101
+ bar = tqdm(
102
+ total=state.total,
103
+ delay=_UPDATE_INTERVAL_S,
104
+ mininterval=_UPDATE_INTERVAL_S,
105
+ initial=state.n,
106
+ unit="byte",
107
+ unit_scale=True,
108
+ ) # type: ignore
109
+ if bar:
110
+ # if there are zero active states (which is possible),
111
+ # n and total will be zero after sum, and we don't
112
+ # want to set zeros on an existing non-zero bar.
113
+ bar.total == state.total or bar.total
114
+ new_n = state.n or bar.n
115
+ bar.update(new_n - bar.n)
116
+ bar.desc = f"{self._desc}{_blobs(states)}"
117
+ if _SUPPORTS_CR:
118
+ bar.refresh()
119
+
120
+ if bar.n >= bar.total:
121
+ bar.close()
122
+ bar = None
123
+
124
+ self._bar = bar
125
+ except ModuleNotFoundError:
126
+ self._dumb(states)
127
+
128
+
129
+ class Tracker:
130
+ def __init__(self, reporter: _Reporter):
131
+ self._progresses: ty.Dict[str, ProgressState] = dict()
132
+ self._reporter = reporter
133
+
134
+ def add(self, key: str, total: int) -> ty.Tuple["Tracker", str]:
135
+ if total < 0:
136
+ total = 0
137
+ self._progresses[key] = ProgressState(default_timer(), total, 0)
138
+ self._reporter(list(self._progresses.values()))
139
+ return self, key
140
+
141
+ def __call__(self, key: str, written: int):
142
+ assert written >= 0, "cannot write negative bytes: {written}"
143
+ try:
144
+ start, total, n = self._progresses[key]
145
+ self._progresses[key] = ProgressState(start, total, n + written)
146
+ self._reporter(list(self._progresses.values()))
147
+ if self._progresses[key].n >= total:
148
+ del self._progresses[key]
149
+ except KeyError:
150
+ self._reporter(list(self._progresses.values()))
151
+
152
+
153
+ _GLOBAL_DN_TRACKER = Tracker(TqdmReporter("thds.adls downloading"))
154
+ _GLOBAL_UP_TRACKER = Tracker(TqdmReporter("thds.adls uploading"))
155
+ T = ty.TypeVar("T", bound=ty.IO)
156
+
157
+
158
+ def _proxy_io(io_type: str, stream: T, key: str, total_len: int) -> T:
159
+ assert io_type in ("read", "write"), io_type
160
+
161
+ try:
162
+ old_io = getattr(stream, io_type)
163
+ total_len = total_len or len(stream) # type: ignore
164
+ except (AttributeError, TypeError):
165
+ return stream
166
+
167
+ if io_type == "read":
168
+ tracker, _ = _GLOBAL_UP_TRACKER.add(key, total_len)
169
+ else:
170
+ tracker, _ = _GLOBAL_DN_TRACKER.add(key, total_len)
171
+
172
+ def io(data_or_len: ty.Union[bytes, int]):
173
+ r = old_io(data_or_len)
174
+ io_len = (
175
+ total_len
176
+ if data_or_len == -1
177
+ else (len(data_or_len) if isinstance(data_or_len, bytes) else data_or_len)
178
+ )
179
+ tracker(key, io_len)
180
+ return r
181
+
182
+ setattr(stream, io_type, io)
183
+ return stream
184
+
185
+
186
+ def report_download_progress(stream: T, key: str, total: int = 0) -> T:
187
+ if not total: # if we don't know how big a download is, we can't report progress.
188
+ return stream
189
+ return _proxy_io("write", stream, key, total)
190
+
191
+
192
+ def report_upload_progress(stream: T, key: str, total: int = 0) -> T:
193
+ return _proxy_io("read", stream, key, total)
thds/adls/_upload.py ADDED
@@ -0,0 +1,127 @@
1
+ """Just utilities for deciding whether or not to upload.
2
+
3
+ Not an officially-published API of the thds.adls library.
4
+ """
5
+ import typing as ty
6
+ from pathlib import Path
7
+
8
+ import azure.core.exceptions
9
+ from azure.storage.blob import ContentSettings
10
+
11
+ from thds.core import hostname, log
12
+
13
+ from .md5 import AnyStrSrc, try_md5
14
+
15
+ _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES = 2 * 2**20 # 2 MB is about right
16
+
17
+
18
+ logger = log.getLogger(__name__)
19
+
20
+
21
+ def _get_checksum_content_settings(data: AnyStrSrc) -> ty.Optional[ContentSettings]:
22
+ """Ideally, we calculate an MD5 sum for all data that we upload.
23
+
24
+ The only circumstances under which we cannot do this are if the
25
+ stream does not exist in its entirety before the upload begins.
26
+ """
27
+ md5 = try_md5(data)
28
+ if md5:
29
+ return ContentSettings(content_md5=md5)
30
+ return None
31
+
32
+
33
+ def _too_small_to_skip_upload(data: AnyStrSrc, min_size_for_remote_check: int) -> bool:
34
+ def _len() -> int:
35
+ if isinstance(data, Path) and data.exists():
36
+ return data.stat().st_size
37
+ try:
38
+ return len(data) # type: ignore
39
+ except TypeError as te:
40
+ logger.debug(f"failed to get length? {repr(te)} for {data}")
41
+ return min_size_for_remote_check + 1
42
+
43
+ return _len() < min_size_for_remote_check
44
+
45
+
46
+ class UploadDecision(ty.NamedTuple):
47
+ upload_required: bool
48
+ content_settings: ty.Optional[ContentSettings]
49
+
50
+
51
+ class Properties(ty.Protocol):
52
+ name: str
53
+ content_settings: ContentSettings
54
+
55
+
56
+ def _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
57
+ data: AnyStrSrc, min_size_for_remote_check: int
58
+ ) -> ty.Generator[bool, ty.Optional[Properties], UploadDecision]:
59
+ local_content_settings = _get_checksum_content_settings(data)
60
+ if not local_content_settings:
61
+ return UploadDecision(True, None)
62
+ if _too_small_to_skip_upload(data, min_size_for_remote_check):
63
+ logger.debug("Too small to bother with an early call - let's just upload...")
64
+ return UploadDecision(True, local_content_settings)
65
+ remote_properties = yield True
66
+ if not remote_properties:
67
+ logger.debug("No remote properties could be fetched so an upload is required")
68
+ return UploadDecision(True, local_content_settings)
69
+ if remote_properties.content_settings.content_md5 == local_content_settings.content_md5:
70
+ logger.info(f"Remote file {remote_properties.name} already exists and has matching checksum")
71
+ return UploadDecision(False, local_content_settings)
72
+ logger.debug("Remote file exists but MD5 does not match - upload required.")
73
+ return UploadDecision(True, local_content_settings)
74
+
75
+
76
+ doc = """
77
+ Returns False for upload_required if the file is large and the remote
78
+ exists and has a known, matching checksum.
79
+
80
+ Returns ContentSettings if an MD5 checksum can be calculated.
81
+ """
82
+
83
+
84
+ async def async_upload_decision_and_settings(
85
+ get_properties: ty.Callable[[], ty.Awaitable[Properties]],
86
+ data: AnyStrSrc,
87
+ min_size_for_remote_check: int = _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES,
88
+ ) -> UploadDecision:
89
+ try:
90
+ co = _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
91
+ data, min_size_for_remote_check
92
+ )
93
+ while True:
94
+ co.send(None)
95
+ try:
96
+ co.send(await get_properties())
97
+ except azure.core.exceptions.ResourceNotFoundError:
98
+ co.send(None)
99
+ except StopIteration as stop:
100
+ return stop.value
101
+
102
+
103
+ def upload_decision_and_settings(
104
+ get_properties: ty.Callable[[], Properties],
105
+ data: AnyStrSrc,
106
+ min_size_for_remote_check: int = _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES,
107
+ ) -> UploadDecision:
108
+ try:
109
+ co = _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
110
+ data, min_size_for_remote_check
111
+ )
112
+ while True:
113
+ co.send(None)
114
+ try:
115
+ co.send(get_properties())
116
+ except azure.core.exceptions.ResourceNotFoundError:
117
+ co.send(None)
118
+ except StopIteration as stop:
119
+ return stop.value
120
+
121
+
122
+ async_upload_decision_and_settings.__doc__ = doc
123
+ upload_decision_and_settings.__doc__ = doc
124
+
125
+
126
+ def metadata_for_upload() -> ty.Dict[str, str]:
127
+ return {"upload_wrapper_sw": "thds.adls", "upload_hostname": hostname.friendly()}
thds/adls/abfss.py ADDED
@@ -0,0 +1,24 @@
1
+ """Translate ADLS URIs to ABFSS URIs (for use with Spark/Hadoop)."""
2
+ from .fqn import AdlsFqn
3
+
4
+ ABFSS_SCHEME = "abfss://"
5
+
6
+
7
+ class NotAbfssUri(ValueError):
8
+ pass
9
+
10
+
11
+ def from_adls_fqn(fqn: AdlsFqn) -> str:
12
+ return f"{ABFSS_SCHEME}{fqn.container}@{fqn.sa}.dfs.core.windows.net/{fqn.path.lstrip('/')}"
13
+
14
+
15
+ def from_adls_uri(uri: str) -> str:
16
+ return from_adls_fqn(AdlsFqn.parse(uri))
17
+
18
+
19
+ def to_adls_fqn(abfss_uri: str) -> AdlsFqn:
20
+ if not abfss_uri.startswith(ABFSS_SCHEME):
21
+ raise NotAbfssUri(f"URI does not start with {ABFSS_SCHEME!r}: {abfss_uri!r}")
22
+ container, rest = abfss_uri[len(ABFSS_SCHEME) :].split("@", 1)
23
+ sa, path = rest.split(".dfs.core.windows.net/")
24
+ return AdlsFqn.of(sa, container, path)
@@ -0,0 +1,48 @@
1
+ from pathlib import Path
2
+
3
+ from .download import download_or_use_verified
4
+ from .fqn import AdlsFqn
5
+ from .global_client import get_global_fs_client
6
+ from .impl import ADLSFileSystem
7
+ from .resource.up_down import AdlsHashedResource, upload
8
+ from .ro_cache import global_cache
9
+ from .uri import UriIsh, parse_any
10
+
11
+
12
+ def download_to_cache(fqn_or_uri: UriIsh, md5b64: str = "") -> Path:
13
+ """Downloads directly to the cache and returns a Path to the read-only file.
14
+
15
+ This will allow you to download a file 'into' the cache even if
16
+ you provide no MD5 and the remote file properties does not have
17
+ one. However, future attempts to reuse the cache will force a
18
+ re-download if no MD5 is available at that time.
19
+ """
20
+ fqn = parse_any(fqn_or_uri)
21
+ cache_path = global_cache().path(fqn)
22
+ download_or_use_verified(
23
+ get_global_fs_client(fqn.sa, fqn.container), fqn.path, cache_path, md5b64, cache=global_cache()
24
+ )
25
+ return cache_path
26
+
27
+
28
+ def upload_through_cache(dest: UriIsh, src_path: Path) -> AdlsHashedResource:
29
+ """Return an AdlsHashedResource, since by definition an upload through the cache must have a known checksum.
30
+
31
+ Uses global client, which is pretty much always what you want.
32
+ """
33
+ assert src_path.is_file(), "src_path must be a file."
34
+ resource = upload(dest, src_path, write_through_cache=global_cache())
35
+ assert resource, "MD5 should always be calculable for a local path."
36
+ return resource
37
+
38
+
39
+ def download_directory(fqn: AdlsFqn) -> Path:
40
+ """Download a directory from an AdlsFqn.
41
+
42
+ If you know you only need to download a single file, use download_to_cache.
43
+ """
44
+ fs = ADLSFileSystem(fqn.sa, fqn.container)
45
+ cached_dir_root = global_cache().path(fqn)
46
+ fs.fetch_directory(fqn.path, cached_dir_root)
47
+ assert cached_dir_root.is_dir(), "Directory should have been downloaded to the cache."
48
+ return cached_dir_root
thds/adls/conf.py ADDED
@@ -0,0 +1,33 @@
1
+ """This is where fine-tuning environment variables are defined."""
2
+ from thds.core import config
3
+
4
+ # These defaults were tested to perform well (~200 MB/sec) on a 2 core
5
+ # machine on Kubernetes. Larger numbers did not do any better, but
6
+ # these numbers did roughly 4x as well as the defaults, which are
7
+ # concurrency=1 and chunk_get_size=32 MB.
8
+ #
9
+ # As always, your mileage may vary.
10
+ #
11
+ # For more info, see docs at
12
+ # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blob-download-python#specify-data-transfer-options-on-download
13
+ #
14
+ # Also see
15
+ # azure.storage.filedatalake._shared.base_client.create_configuration
16
+ # for actual details...
17
+ #
18
+ DOWNLOAD_FILE_MAX_CONCURRENCY = config.item("download_file_max_concurrency", 4, parse=int)
19
+ MAX_CHUNK_GET_SIZE = config.item("max_chunk_get_size", 2**20 * 64, parse=int) # 64MB
20
+ MAX_SINGLE_GET_SIZE = config.item(
21
+ "max_single_get_size", 2**20 * 64, parse=lambda i: max(MAX_CHUNK_GET_SIZE(), int(i))
22
+ ) # 64MB
23
+ MAX_SINGLE_PUT_SIZE = config.item(
24
+ "max_single_put_size", 2**20 * 64, parse=lambda i: max(MAX_CHUNK_GET_SIZE(), int(i))
25
+ ) # 64MB
26
+
27
+ # these are for upload
28
+ # these achieved 380 MB/sec on a 2 core machine on Kubernetes
29
+ MAX_BLOCK_SIZE = config.item("max_block_put_size", 2**20 * 64, parse=int) # 64 MB
30
+ UPLOAD_FILE_MAX_CONCURRENCY = config.item("upload_file_max_concurrency", 10, parse=int)
31
+ UPLOAD_CHUNK_SIZE = config.item("upload_chunk_size", 2**20 * 100, parse=int) # 100 MB
32
+
33
+ CONNECTION_TIMEOUT = config.item("connection_timeout", 2000, parse=int) # seconds
thds/adls/dbfs.py ADDED
@@ -0,0 +1,60 @@
1
+ import typing as ty
2
+
3
+ from .fqn import AdlsFqn, AdlsRoot, join, parse_fqn
4
+
5
+ DBFS_SCHEME = "dbfs:/"
6
+
7
+ ADLS_TO_SPARK_MAPPING = {
8
+ "adls://uaapdatascience/data/": "/mnt/datascience/data/",
9
+ "adls://thdsdatasets/prod-datasets/": "/mnt/datascience/datasets/",
10
+ "adls://uaapdatascience/hive/": "/mnt/datascience/hive/",
11
+ "adls://thdsscratch/tmp/": "/mnt/datascience/scratch/",
12
+ }
13
+ ADLS_TO_DBFS_MAPPING = {k: join(DBFS_SCHEME, v) for k, v in ADLS_TO_SPARK_MAPPING.items()}
14
+ # Spark read/write implicitly adds a 'dbfs:/' prefix.
15
+ SPARK_TO_ADLS_MAPPING = {v: k for k, v in ADLS_TO_SPARK_MAPPING.items()}
16
+ DBFS_TO_ADLS_MAPPING = {join(DBFS_SCHEME, k): v for k, v in SPARK_TO_ADLS_MAPPING.items()}
17
+
18
+
19
+ def to_adls_root(root_uri: str) -> AdlsRoot:
20
+ try:
21
+ return AdlsRoot.parse(
22
+ DBFS_TO_ADLS_MAPPING[root_uri]
23
+ if root_uri.startswith(DBFS_SCHEME)
24
+ else SPARK_TO_ADLS_MAPPING[root_uri]
25
+ )
26
+ except KeyError:
27
+ raise ValueError(f"URI '{root_uri}' does not have a defined ADLS root!")
28
+
29
+
30
+ def to_adls_fqn(fully_qualified_name: str) -> AdlsFqn:
31
+ mapping = (
32
+ DBFS_TO_ADLS_MAPPING if fully_qualified_name.startswith(DBFS_SCHEME) else SPARK_TO_ADLS_MAPPING
33
+ )
34
+
35
+ try:
36
+ dbfs_root, adls_root = next(
37
+ ((k, v) for k, v in mapping.items() if fully_qualified_name.startswith(k))
38
+ )
39
+ except StopIteration:
40
+ raise ValueError(f"{fully_qualified_name} does not have a defined ADLS path!")
41
+
42
+ return parse_fqn(join(adls_root, fully_qualified_name.split(dbfs_root)[1]))
43
+
44
+
45
+ def to_uri(adls_path: ty.Union[AdlsRoot, AdlsFqn], spark: bool = True) -> str:
46
+ def get_root_uri(adls_root: AdlsRoot) -> str:
47
+ try:
48
+ return (
49
+ ADLS_TO_SPARK_MAPPING[str(adls_root)] if spark else ADLS_TO_DBFS_MAPPING[str(adls_root)]
50
+ )
51
+ except KeyError:
52
+ raise ValueError(f"{str(adls_root)} does not have a corresponding dbfs root!")
53
+
54
+ if isinstance(adls_path, AdlsRoot):
55
+ return get_root_uri(adls_path)
56
+
57
+ try:
58
+ return join(get_root_uri(adls_path.root()), adls_path.path)
59
+ except ValueError:
60
+ raise ValueError(f"{str(adls_path)} does not have a corresponding dbfs path!")
thds/adls/defaults.py ADDED
@@ -0,0 +1,26 @@
1
+ """Prefer using named_containers for new code."""
2
+
3
+ from thds.core.env import Env
4
+
5
+ from . import fqn, named_roots
6
+
7
+ try:
8
+ import thds.adls._thds_defaults # noqa: F401
9
+ except ImportError:
10
+ pass
11
+
12
+
13
+ def env_root(env: Env = "") -> fqn.AdlsRoot:
14
+ """In many cases, you may want to call this with no arguments
15
+ to default to using the THDS_ENV environment variable.
16
+ """
17
+ return named_roots.require(env)
18
+
19
+
20
+ def env_root_uri(env: Env = "") -> str:
21
+ return str(env_root(env))
22
+
23
+
24
+ def mops_root() -> str:
25
+ """Returns a URI corresponding to the location where mops materialization should be put."""
26
+ return str(named_roots.require("mops"))