thds.adls 3.2.20250630174944__py3-none-any.whl → 4.1.20250701190349__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of thds.adls might be problematic. Click here for more details.

thds/adls/__init__.py CHANGED
@@ -1,16 +1,21 @@
1
- from thds.core import meta
1
+ from thds import core
2
2
 
3
- from . import abfss, defaults, etag, fqn, named_roots, resource, source, source_tree, uri # noqa: F401
4
- from .cached_up_down import download_directory, download_to_cache, upload_through_cache # noqa: F401
3
+ from . import abfss, defaults, etag, fqn, hashes, named_roots, source, source_tree, uri # noqa: F401
4
+ from .cached import download_directory, download_to_cache, upload_through_cache # noqa: F401
5
5
  from .copy import copy_file, copy_files, wait_for_copy # noqa: F401
6
6
  from .errors import BlobNotFoundError # noqa: F401
7
7
  from .fqn import * # noqa: F401,F403
8
8
  from .global_client import get_global_client, get_global_fs_client # noqa: F401
9
9
  from .impl import * # noqa: F401,F403
10
10
  from .ro_cache import Cache, global_cache # noqa: F401
11
+ from .upload import upload # noqa: F401
11
12
  from .uri import UriIsh, parse_any, parse_uri, resolve_any, resolve_uri # noqa: F401
12
13
 
13
- __version__ = meta.get_version(__name__)
14
- metadata = meta.read_metadata(__name__)
14
+ __version__ = core.meta.get_version(__name__)
15
+ metadata = core.meta.read_metadata(__name__)
15
16
  __basepackage__ = __name__
16
17
  __commit__ = metadata.git_commit
18
+
19
+ hashes.register_hashes()
20
+ # SPOOKY: without the above line, the hashing algorithms will not be registered with thds.core.hash_cache,
21
+ # which will be bad for core.Source as well as uploads and downloads.
thds/adls/_upload.py CHANGED
@@ -2,15 +2,16 @@
2
2
 
3
3
  Not an officially-published API of the thds.adls library.
4
4
  """
5
+
5
6
  import typing as ty
6
7
  from pathlib import Path
7
8
 
8
9
  import azure.core.exceptions
9
- from azure.storage.blob import ContentSettings
10
10
 
11
- from thds.core import hostname, log
11
+ from thds.core import hash_cache, hashing, hostname, log
12
12
 
13
- from .md5 import AnyStrSrc, try_md5
13
+ from . import hashes
14
+ from .file_properties import PropertiesP
14
15
 
15
16
  _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES = 2 * 2**20 # 2 MB is about right
16
17
 
@@ -18,19 +19,26 @@ _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES = 2 * 2**20 # 2 MB is about rig
18
19
  logger = log.getLogger(__name__)
19
20
 
20
21
 
21
- def _get_checksum_content_settings(data: AnyStrSrc) -> ty.Optional[ContentSettings]:
22
- """Ideally, we calculate an MD5 sum for all data that we upload.
22
+ def _try_default_hash(data: hashes.AnyStrSrc) -> ty.Optional[hashing.Hash]:
23
+ """Ideally, we calculate a hash/checksum for all data that we upload.
23
24
 
24
25
  The only circumstances under which we cannot do this are if the
25
26
  stream does not exist in its entirety before the upload begins.
26
27
  """
27
- md5 = try_md5(data)
28
- if md5:
29
- return ContentSettings(content_md5=bytearray(md5))
28
+ hasher = hashes.default_hasher()
29
+ hbytes = None
30
+ if isinstance(data, Path):
31
+ hbytes = hash_cache.hash_file(data, hasher)
32
+ elif hashing.hash_anything(data, hasher):
33
+ hbytes = hasher.digest()
34
+
35
+ if hbytes:
36
+ return hashing.Hash(hasher.name.lower(), hbytes)
37
+
30
38
  return None
31
39
 
32
40
 
33
- def _too_small_to_skip_upload(data: AnyStrSrc, min_size_for_remote_check: int) -> bool:
41
+ def _too_small_to_skip_upload(data: hashes.AnyStrSrc, min_size_for_remote_check: int) -> bool:
34
42
  def _len() -> int:
35
43
  if isinstance(data, Path) and data.exists():
36
44
  return data.stat().st_size
@@ -45,49 +53,58 @@ def _too_small_to_skip_upload(data: AnyStrSrc, min_size_for_remote_check: int) -
45
53
 
46
54
  class UploadDecision(ty.NamedTuple):
47
55
  upload_required: bool
48
- content_settings: ty.Optional[ContentSettings]
56
+ metadata: ty.Dict[str, str]
49
57
 
50
58
 
51
- class Properties(ty.Protocol):
52
- name: str
53
- content_settings: ContentSettings
59
+ def metadata_for_upload() -> ty.Dict[str, str]:
60
+ return {"upload_wrapper_sw": "thds.adls", "upload_hostname": hostname.friendly()}
61
+
54
62
 
63
+ def _co_upload_decision_unless_file_present_with_matching_checksum(
64
+ data: hashes.AnyStrSrc, min_size_for_remote_check: int
65
+ ) -> ty.Generator[bool, ty.Optional[PropertiesP], UploadDecision]:
66
+ local_hash = _try_default_hash(data)
67
+ if not local_hash:
68
+ return UploadDecision(True, metadata_for_upload())
55
69
 
56
- def _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
57
- data: AnyStrSrc, min_size_for_remote_check: int
58
- ) -> ty.Generator[bool, ty.Optional[Properties], UploadDecision]:
59
- local_content_settings = _get_checksum_content_settings(data)
60
- if not local_content_settings:
61
- return UploadDecision(True, None)
70
+ hash_meta = hashes.metadata_hash_dict(local_hash)
71
+ metadata = dict(metadata_for_upload(), **hash_meta)
62
72
  if _too_small_to_skip_upload(data, min_size_for_remote_check):
63
73
  logger.debug("Too small to bother with an early call - let's just upload...")
64
- return UploadDecision(True, local_content_settings)
74
+ return UploadDecision(True, metadata)
75
+
65
76
  remote_properties = yield True
66
77
  if not remote_properties:
67
78
  logger.debug("No remote properties could be fetched so an upload is required")
68
- return UploadDecision(True, local_content_settings)
69
- if remote_properties.content_settings.content_md5 == local_content_settings.content_md5:
70
- logger.info(f"Remote file {remote_properties.name} already exists and has matching checksum")
71
- return UploadDecision(False, local_content_settings)
72
- logger.debug("Remote file exists but MD5 does not match - upload required.")
73
- return UploadDecision(True, local_content_settings)
79
+ return UploadDecision(True, metadata)
80
+
81
+ remote_hashes = hashes.extract_hashes_from_props(remote_properties)
82
+ for algo in remote_hashes:
83
+ mkey = hashes.metadata_hash_b64_key(algo)
84
+ if mkey in hash_meta and hashing.b64(remote_hashes[algo].bytes) == hash_meta[mkey]:
85
+ logger.info(f"Remote file {remote_properties.name} already exists and has matching checksum")
86
+ return UploadDecision(False, metadata)
87
+
88
+ print(remote_hashes, hash_meta)
89
+ logger.debug("Remote file exists but hash does not match - upload required.")
90
+ return UploadDecision(True, metadata)
74
91
 
75
92
 
76
93
  doc = """
77
94
  Returns False for upload_required if the file is large and the remote
78
95
  exists and has a known, matching checksum.
79
96
 
80
- Returns ContentSettings if an MD5 checksum can be calculated.
97
+ Returns a metadata dict that should be added to any upload.
81
98
  """
82
99
 
83
100
 
84
- async def async_upload_decision_and_settings(
85
- get_properties: ty.Callable[[], ty.Awaitable[Properties]],
86
- data: AnyStrSrc,
101
+ async def async_upload_decision_and_metadata(
102
+ get_properties: ty.Callable[[], ty.Awaitable[PropertiesP]],
103
+ data: hashes.AnyStrSrc,
87
104
  min_size_for_remote_check: int = _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES,
88
105
  ) -> UploadDecision:
89
106
  try:
90
- co = _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
107
+ co = _co_upload_decision_unless_file_present_with_matching_checksum(
91
108
  data, min_size_for_remote_check
92
109
  )
93
110
  while True:
@@ -100,13 +117,13 @@ async def async_upload_decision_and_settings(
100
117
  return stop.value
101
118
 
102
119
 
103
- def upload_decision_and_settings(
104
- get_properties: ty.Callable[[], Properties],
105
- data: AnyStrSrc,
120
+ def upload_decision_and_metadata(
121
+ get_properties: ty.Callable[[], PropertiesP],
122
+ data: hashes.AnyStrSrc,
106
123
  min_size_for_remote_check: int = _SKIP_ALREADY_UPLOADED_CHECK_IF_MORE_THAN_BYTES,
107
124
  ) -> UploadDecision:
108
125
  try:
109
- co = _co_content_settings_for_upload_unless_file_present_with_matching_checksum(
126
+ co = _co_upload_decision_unless_file_present_with_matching_checksum(
110
127
  data, min_size_for_remote_check
111
128
  )
112
129
  while True:
@@ -119,9 +136,5 @@ def upload_decision_and_settings(
119
136
  return stop.value
120
137
 
121
138
 
122
- async_upload_decision_and_settings.__doc__ = doc
123
- upload_decision_and_settings.__doc__ = doc
124
-
125
-
126
- def metadata_for_upload() -> ty.Dict[str, str]:
127
- return {"upload_wrapper_sw": "thds.adls", "upload_hostname": hostname.friendly()}
139
+ async_upload_decision_and_metadata.__doc__ = doc
140
+ upload_decision_and_metadata.__doc__ = doc
@@ -1 +1 @@
1
- from . import download # noqa
1
+ from . import download, upload # noqa: F401
@@ -7,161 +7,121 @@
7
7
  # very end of the download), so for local users who don't have huge bandwidth, it's likely
8
8
  # a better user experience to disable this globally.
9
9
  import asyncio
10
- import json
11
- import os
12
10
  import subprocess
13
11
  import typing as ty
14
- import urllib.parse
15
- from contextlib import contextmanager
12
+ from contextlib import nullcontext
13
+ from dataclasses import dataclass
16
14
  from pathlib import Path
17
15
 
18
16
  from azure.storage.filedatalake import DataLakeFileClient
19
17
 
20
- from thds.core import cache, config, log
18
+ from thds.core import config, log
21
19
 
22
- from .. import _progress, conf, uri
20
+ from .. import conf
21
+ from . import login, progress, system_resources
23
22
 
24
- DONT_USE_AZCOPY = config.item("dont_use", default=True, parse=config.tobool)
25
-
26
- _AZCOPY_LOGIN_WORKLOAD_IDENTITY = "azcopy login --login-type workload".split()
27
- _AZCOPY_LOGIN_LOCAL_STATUS = "azcopy login status".split()
28
- # device login is an interactive process involving a web browser,
29
- # which is not acceptable for large scale automation.
30
- # So instead of logging in, we check to see if you _are_ logged in,
31
- # and if you are, we try using azcopy in the future.
23
+ DONT_USE_AZCOPY = config.item("dont_use", default=False, parse=config.tobool)
24
+ MIN_FILE_SIZE = config.item("min_file_size", default=20 * 10**6, parse=int) # 20 MB
32
25
 
33
26
  logger = log.getLogger(__name__)
34
27
 
35
28
 
36
- class DownloadRequest(ty.NamedTuple):
37
- """Use one or the other, but not both, to write the results."""
38
-
39
- writer: ty.IO[bytes]
29
+ @dataclass
30
+ class DownloadRequest:
40
31
  temp_path: Path
32
+ size_bytes: int
41
33
 
42
34
 
43
- @cache.locking # only run this once per process.
44
- def _good_azcopy_login() -> bool:
45
- if DONT_USE_AZCOPY():
46
- return False
47
-
48
- try:
49
- subprocess.run(_AZCOPY_LOGIN_WORKLOAD_IDENTITY, check=True, capture_output=True)
50
- logger.info("Will use azcopy for downloads in this process...")
51
- return True
52
-
53
- except (subprocess.CalledProcessError, FileNotFoundError):
54
- pass
55
- try:
56
- subprocess.run(_AZCOPY_LOGIN_LOCAL_STATUS, check=True)
57
- logger.info("Will use azcopy for downloads in this process...", dl=None)
58
- return True
59
- except FileNotFoundError:
60
- logger.info("azcopy is not installed or not on your PATH, so we cannot speed up downloads")
61
- except subprocess.CalledProcessError as cpe:
62
- logger.warning(
63
- "You are not logged in with azcopy, so we cannot speed up downloads."
64
- f" Run `azcopy login` to fix this. Return code was {cpe.returncode}"
65
- )
66
- return False
67
-
68
-
69
- def _azcopy_download_command(dl_file_client: DataLakeFileClient, path: Path) -> ty.List[str]:
70
- return ["azcopy", "copy", dl_file_client.url, str(path), "--output-type=json"]
71
-
35
+ @dataclass
36
+ class SdkDownloadRequest(DownloadRequest):
37
+ """Use one or the other, but not both, to write the results."""
72
38
 
73
- class AzCopyMessage(ty.TypedDict):
74
- TotalBytesEnumerated: str
75
- TotalBytesTransferred: str
39
+ writer: ty.IO[bytes]
76
40
 
77
41
 
78
- class AzCopyJsonLine(ty.TypedDict):
79
- MessageType: str
80
- MessageContent: AzCopyMessage
42
+ def _is_big_enough_for_azcopy(size_bytes: int) -> bool:
43
+ return size_bytes >= MIN_FILE_SIZE()
81
44
 
82
45
 
83
- def _parse_azcopy_json_output(line: str) -> AzCopyJsonLine:
84
- outer_msg = json.loads(line)
85
- return AzCopyJsonLine(
86
- MessageType=outer_msg["MessageType"],
87
- MessageContent=json.loads(outer_msg["MessageContent"]),
46
+ def should_use_azcopy(file_size_bytes: int) -> bool:
47
+ return (
48
+ _is_big_enough_for_azcopy(file_size_bytes)
49
+ and not DONT_USE_AZCOPY()
50
+ and login.good_azcopy_login()
88
51
  )
89
52
 
90
53
 
91
- @contextmanager
92
- def _track_azcopy_progress(http_url: str) -> ty.Iterator[ty.Callable[[str], None]]:
93
- """Context manager that tracks progress from AzCopy JSON lines. This works for both async and sync impls."""
94
- tracker = _progress.get_global_download_tracker()
95
- adls_uri = urllib.parse.unquote(str(uri.parse_uri(http_url)))
96
-
97
- def track(line: str):
98
- if not line:
99
- return
100
-
101
- try:
102
- prog = _parse_azcopy_json_output(line)
103
- if prog["MessageType"] == "Progress":
104
- tracker(adls_uri, total_written=int(prog["MessageContent"]["TotalBytesTransferred"]))
105
- except json.JSONDecodeError:
106
- pass
107
-
108
- yield track
109
-
110
-
111
- def _restrict_mem() -> dict:
112
- return dict(os.environ, AZCOPY_BUFFER_GB="0.3")
54
+ def _azcopy_download_command(dl_file_client: DataLakeFileClient, path: Path) -> ty.List[str]:
55
+ # turns out azcopy checks md5 by default - but we we do our own checking, sometimes with faster methods,
56
+ # and their checking _dramatically_ slows downloads on capable machines, so we disable it.
57
+ return ["azcopy", "copy", dl_file_client.url, str(path), "--output-type=json", "--check-md5=NoCheck"]
113
58
 
114
59
 
115
60
  def sync_fastpath(
116
61
  dl_file_client: DataLakeFileClient,
117
62
  download_request: DownloadRequest,
118
63
  ) -> None:
119
- if _good_azcopy_login():
64
+ if not isinstance(download_request, SdkDownloadRequest):
65
+ logger.debug("Downloading %s using azcopy", dl_file_client.url)
120
66
  try:
121
67
  # Run the copy
122
68
  process = subprocess.Popen(
123
69
  _azcopy_download_command(dl_file_client, download_request.temp_path),
124
70
  stdout=subprocess.PIPE,
125
- stderr=subprocess.PIPE,
71
+ stderr=subprocess.STDOUT,
126
72
  text=True,
127
- env=_restrict_mem(),
73
+ env=system_resources.restrict_usage(),
128
74
  )
129
75
  assert process.stdout
130
- with _track_azcopy_progress(dl_file_client.url) as track:
76
+ with progress.azcopy_tracker(dl_file_client.url, download_request.size_bytes) as track:
131
77
  for line in process.stdout:
132
78
  track(line)
79
+
80
+ process.wait()
81
+ if process.returncode != 0:
82
+ raise subprocess.SubprocessError(f"AzCopy failed with return code {process.returncode}")
83
+ assert (
84
+ download_request.temp_path.exists()
85
+ ), f"AzCopy did not create the file at {download_request.temp_path}"
133
86
  return # success
134
87
 
135
- except (subprocess.SubprocessError, FileNotFoundError):
88
+ except (subprocess.CalledProcessError, FileNotFoundError):
136
89
  logger.warning("Falling back to Python SDK for download")
137
90
 
138
- dl_file_client.download_file(
139
- max_concurrency=conf.DOWNLOAD_FILE_MAX_CONCURRENCY(),
140
- connection_timeout=conf.CONNECTION_TIMEOUT(),
141
- ).readinto(download_request.writer)
91
+ logger.debug("Downloading %s using Python SDK", dl_file_client.url)
92
+ if hasattr(download_request, "writer"):
93
+ writer_cm = nullcontext(download_request.writer)
94
+ else:
95
+ writer_cm = open(download_request.temp_path, "wb") # type: ignore[assignment]
96
+ with writer_cm as writer:
97
+ dl_file_client.download_file(
98
+ max_concurrency=conf.DOWNLOAD_FILE_MAX_CONCURRENCY(),
99
+ connection_timeout=conf.CONNECTION_TIMEOUT(),
100
+ ).readinto(writer)
142
101
 
143
102
 
144
103
  async def async_fastpath(
145
104
  dl_file_client: DataLakeFileClient,
146
105
  download_request: DownloadRequest,
147
106
  ) -> None:
148
- # technically it would be 'better' to do this login in an async subproces,
107
+ # technically it would be 'better' to do this login in an async subprocess,
149
108
  # but it involves a lot of boilerplate, _and_ we have no nice way to cache
150
109
  # the value, which is going to be computed one per process and never again.
151
110
  # So we'll just block the async loop for a couple of seconds one time...
152
- if _good_azcopy_login():
111
+ if not isinstance(download_request, SdkDownloadRequest):
112
+ logger.debug("Downloading %s using azcopy", dl_file_client.url)
153
113
  try:
154
114
  # Run the copy
155
115
  copy_proc = await asyncio.create_subprocess_exec(
156
116
  *_azcopy_download_command(dl_file_client, download_request.temp_path),
157
117
  stdout=asyncio.subprocess.PIPE,
158
- stderr=asyncio.subprocess.PIPE,
159
- env=_restrict_mem(),
118
+ stderr=asyncio.subprocess.STDOUT,
119
+ env=system_resources.restrict_usage(),
160
120
  )
161
121
  assert copy_proc.stdout
162
122
 
163
123
  # Feed lines to the tracker asynchronously
164
- with _track_azcopy_progress(dl_file_client.url) as track:
124
+ with progress.azcopy_tracker(dl_file_client.url, download_request.size_bytes) as track:
165
125
  while True:
166
126
  line = await copy_proc.stdout.readline()
167
127
  if not line: # EOF
@@ -178,9 +138,15 @@ async def async_fastpath(
178
138
  except (subprocess.SubprocessError, FileNotFoundError):
179
139
  logger.warning("Falling back to Python SDK for download")
180
140
 
181
- reader = await dl_file_client.download_file( # type: ignore[misc]
182
- # TODO - check above type ignore
183
- max_concurrency=conf.DOWNLOAD_FILE_MAX_CONCURRENCY(),
184
- connection_timeout=conf.CONNECTION_TIMEOUT(),
185
- )
186
- await reader.readinto(download_request.writer)
141
+ logger.debug("Downloading %s using Async Python SDK", dl_file_client.url)
142
+ if hasattr(download_request, "writer"):
143
+ writer_cm = nullcontext(download_request.writer)
144
+ else:
145
+ writer_cm = open(download_request.temp_path, "wb") # type: ignore[assignment]
146
+ with writer_cm as writer:
147
+ reader = await dl_file_client.download_file( # type: ignore[misc]
148
+ # TODO - check above type ignore
149
+ max_concurrency=conf.DOWNLOAD_FILE_MAX_CONCURRENCY(),
150
+ connection_timeout=conf.CONNECTION_TIMEOUT(),
151
+ )
152
+ await reader.readinto(writer)
@@ -0,0 +1,39 @@
1
+ import subprocess
2
+
3
+ from thds.core import cache, log, scope
4
+
5
+ _AZCOPY_LOGIN_WORKLOAD_IDENTITY = "azcopy login --login-type workload".split()
6
+ _AZCOPY_LOGIN_LOCAL_STATUS = "azcopy login status".split()
7
+ # device login is an interactive process involving a web browser,
8
+ # which is not acceptable for large scale automation.
9
+ # So instead of logging in, we check to see if you _are_ logged in,
10
+ # and if you are, we try using azcopy in the future.
11
+ logger = log.getLogger(__name__)
12
+
13
+
14
+ @cache.locking # only run this once per process.
15
+ @scope.bound
16
+ def good_azcopy_login() -> bool:
17
+ scope.enter(log.logger_context(dl=None))
18
+ try:
19
+ subprocess.run(_AZCOPY_LOGIN_WORKLOAD_IDENTITY, check=True, capture_output=True)
20
+ logger.info("Azcopy login with workload identity, so we can use it for large file transfers")
21
+ return True
22
+
23
+ except (subprocess.CalledProcessError, FileNotFoundError):
24
+ pass
25
+ try:
26
+ subprocess.run(_AZCOPY_LOGIN_LOCAL_STATUS, check=True)
27
+ logger.info("Azcopy login with local token, so we can use it for large file transfers")
28
+ return True
29
+
30
+ except FileNotFoundError:
31
+ logger.info(
32
+ "azcopy is not installed or not on your PATH, so we cannot speed up large file transfers"
33
+ )
34
+ except subprocess.CalledProcessError as cpe:
35
+ logger.warning(
36
+ "You are not logged in with azcopy, so we cannot speed up large file transfers."
37
+ f" Run `azcopy login` to fix this. Return code was {cpe.returncode}"
38
+ )
39
+ return False
@@ -0,0 +1,49 @@
1
+ import json
2
+ import typing as ty
3
+ import urllib.parse
4
+ from contextlib import contextmanager
5
+
6
+ from .. import _progress, uri
7
+
8
+
9
+ class AzCopyMessage(ty.TypedDict):
10
+ TotalBytesEnumerated: str
11
+ TotalBytesTransferred: str
12
+
13
+
14
+ class AzCopyJsonLine(ty.TypedDict):
15
+ MessageType: str
16
+ MessageContent: AzCopyMessage
17
+
18
+
19
+ def _parse_azcopy_json_output(line: str) -> AzCopyJsonLine:
20
+ outer_msg = json.loads(line)
21
+ return AzCopyJsonLine(
22
+ MessageType=outer_msg["MessageType"],
23
+ MessageContent=json.loads(outer_msg["MessageContent"]),
24
+ )
25
+
26
+
27
+ @contextmanager
28
+ def azcopy_tracker(http_url: str, size_bytes: int) -> ty.Iterator[ty.Callable[[str], None]]:
29
+ """Context manager that tracks progress from AzCopy JSON lines. This works for both async and sync impls."""
30
+ tracker = _progress.get_global_download_tracker()
31
+ adls_uri = urllib.parse.unquote(str(uri.parse_uri(http_url)))
32
+ if size_bytes:
33
+ tracker.add(adls_uri, total=size_bytes)
34
+
35
+ def track(line: str):
36
+ if not size_bytes:
37
+ return # no size, no progress
38
+
39
+ if not line:
40
+ return
41
+
42
+ try:
43
+ prog = _parse_azcopy_json_output(line)
44
+ if prog["MessageType"] == "Progress":
45
+ tracker(adls_uri, total_written=int(prog["MessageContent"]["TotalBytesTransferred"]))
46
+ except json.JSONDecodeError:
47
+ pass
48
+
49
+ yield track
@@ -0,0 +1,26 @@
1
+ import os
2
+ from functools import lru_cache
3
+
4
+ from thds.core import cpus, log
5
+
6
+ logger = log.getLogger(__name__)
7
+
8
+
9
+ @lru_cache
10
+ def restrict_usage() -> dict:
11
+ num_cpus = cpus.available_cpu_count()
12
+
13
+ env = dict(os.environ)
14
+ if "AZCOPY_BUFFER_GB" not in os.environ:
15
+ likely_mem_gb_available = num_cpus * 4 # assume 4 GB per CPU core is available
16
+ # o3 suggested 15% of the total available memory...
17
+ env["AZCOPY_BUFFER_GB"] = str(likely_mem_gb_available * 0.15)
18
+ if "AZCOPY_CONCURRENCY" not in os.environ:
19
+ env["AZCOPY_CONCURRENCY"] = str(int(num_cpus * 2))
20
+
21
+ logger.info(
22
+ "AZCOPY_BUFFER_GB == %s and AZCOPY_CONCURRENCY == %s",
23
+ env["AZCOPY_BUFFER_GB"],
24
+ env["AZCOPY_CONCURRENCY"],
25
+ )
26
+ return env
@@ -0,0 +1,95 @@
1
+ import subprocess
2
+ import typing as ty
3
+ from pathlib import Path
4
+
5
+ from thds.core import config
6
+
7
+ from .. import uri
8
+ from . import login, progress, system_resources
9
+
10
+ DONT_USE_AZCOPY = config.item("dont_use", default=False, parse=config.tobool)
11
+ MIN_FILE_SIZE = config.item("min_file_size", default=20 * 10**6, parse=int) # 20 MB
12
+
13
+
14
+ def build_azcopy_upload_command(
15
+ source_path: Path,
16
+ dest: uri.UriIsh,
17
+ *,
18
+ content_type: str = "",
19
+ metadata: ty.Mapping[str, str] = dict(), # noqa: B006
20
+ overwrite: bool = True,
21
+ ) -> list[str]:
22
+ """
23
+ Build azcopy upload command as a list of strings.
24
+
25
+ Args:
26
+ source_path: Path to local file to upload
27
+ dest_url: Full Azure blob URL (e.g., https://account.blob.core.windows.net/container/blob)
28
+ content_type: MIME content type
29
+ metadata: Mapping of metadata key-value pairs
30
+ overwrite: Whether to overwrite existing blob
31
+
32
+ Returns:
33
+ List of strings suitable for subprocess.run()
34
+ """
35
+
36
+ cmd = ["azcopy", "copy", str(source_path), uri.to_blob_windows_url(dest)]
37
+
38
+ if overwrite:
39
+ cmd.append("--overwrite=true")
40
+
41
+ if content_type:
42
+ cmd.append(f"--content-type={content_type}")
43
+
44
+ if metadata:
45
+ # Format metadata as key1=value1;key2=value2
46
+ metadata_str = ";".join(f"{k}={v}" for k, v in metadata.items())
47
+ cmd.append(f"--metadata={metadata_str}")
48
+
49
+ cmd.append("--output-type=json") # for progress tracking
50
+
51
+ return cmd
52
+
53
+
54
+ def _is_big_enough_for_azcopy(size_bytes: int) -> bool:
55
+ """
56
+ Determine if a file is big enough to warrant using azcopy for upload.
57
+
58
+ Args:
59
+ size_bytes: Size of the file in bytes
60
+
61
+ Returns:
62
+ True if the file is big enough, False otherwise
63
+ """
64
+ return size_bytes >= MIN_FILE_SIZE()
65
+
66
+
67
+ def should_use_azcopy(file_size_bytes: int) -> bool:
68
+ return (
69
+ _is_big_enough_for_azcopy(file_size_bytes)
70
+ and not DONT_USE_AZCOPY()
71
+ and login.good_azcopy_login()
72
+ )
73
+
74
+
75
+ def run(
76
+ cmd: ty.Sequence[str],
77
+ dest: uri.UriIsh,
78
+ size_bytes: int,
79
+ ) -> None:
80
+ # Run the copy
81
+ process = subprocess.Popen(
82
+ cmd,
83
+ stdout=subprocess.PIPE,
84
+ stderr=subprocess.STDOUT,
85
+ text=True,
86
+ env=system_resources.restrict_usage(),
87
+ )
88
+ assert process.stdout
89
+ with progress.azcopy_tracker(uri.to_blob_windows_url(dest), size_bytes) as track:
90
+ for line in process.stdout:
91
+ track(line)
92
+
93
+ process.wait()
94
+ if process.returncode != 0:
95
+ raise subprocess.SubprocessError(f"AzCopy failed with return code {process.returncode}")