dh-cli 0.8.1__tar.gz → 0.8.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dh_cli-0.8.1 → dh_cli-0.8.3}/PKG-INFO +1 -1
- {dh_cli-0.8.1 → dh_cli-0.8.3}/pyproject.toml +1 -1
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/finalize.py +51 -1
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/commands.py +4 -1
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/cost_report.py +45 -11
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/test_submit_merge.py +5 -15
- dh_cli-0.8.3/tests/test_finalize_boltz_tar.py +257 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/.gitignore +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/LICENSE +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/README.md +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/_identity.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/aws_batch.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/boltz.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/cancel.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/clean.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/embed_t5.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/list_jobs.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/local.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/logs.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/orca.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/protmpnn.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/protmpnn_to_boltz.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/retry.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/status.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/submit.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/train.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/wait_for.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/fasta_utils.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/h5_utils.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/job_id.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/manifest.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/s3_transport.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/pricing.yaml +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/cloud_commands.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/codeartifact.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/api_client.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/auth.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/engine_commands.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/progress.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/ssh_config.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/studio_commands.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/github_commands.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/deploy.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/local.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/test.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/tf.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/users.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/main.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/utility_commands.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/warehouse.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/test_aws_batch_resources.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/test_submit_cpu_only.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/conftest.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/A_cache_write.json +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/B_cache_read.json +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/C_plain.json +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/D_cursor_user.json +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/E_service_role.json +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/F_legacy_shared.json +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/G_unknown_model.json +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_build_report.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_classify_arn.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cli_exit_codes.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cost_calc.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cost_command.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cur_reconciliation.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_key_command.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_render_formats.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_resolve_base_model.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_s3_walker.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/__init__.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/conftest.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_engine_role_cannot_read_github_pat.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_identity.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_login.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_login_error_paths.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_login_security.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_logout.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_rotate.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_status.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/hz/test_init.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/hz/test_suites.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/hz/test_users.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/test_cloud_gcp.py +0 -0
- {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/test_finalize_protmpnn.py +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Finalize command for combining results and cleaning up."""
|
|
2
2
|
|
|
3
3
|
import shutil
|
|
4
|
+
import tarfile
|
|
4
5
|
import tempfile
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
|
|
@@ -94,7 +95,10 @@ def finalize(job_id, output, force, keep_intermediates, full_output, skip_dedup,
|
|
|
94
95
|
s3_temp_dir = Path(tempfile.mkdtemp())
|
|
95
96
|
s3_output_prefix = f"{manifest.s3_prefix}output/"
|
|
96
97
|
click.echo("Downloading outputs from S3...")
|
|
97
|
-
|
|
98
|
+
if manifest.pipeline == "boltz":
|
|
99
|
+
_download_boltz_s3_output(s3_output_prefix, s3_temp_dir)
|
|
100
|
+
else:
|
|
101
|
+
download_directory(s3_output_prefix, s3_temp_dir)
|
|
98
102
|
output_dir = s3_temp_dir
|
|
99
103
|
else:
|
|
100
104
|
output_dir = job_dir / "output"
|
|
@@ -175,6 +179,52 @@ def finalize(job_id, output, force, keep_intermediates, full_output, skip_dedup,
|
|
|
175
179
|
click.echo(f"Job directory preserved: {job_dir}")
|
|
176
180
|
|
|
177
181
|
|
|
182
|
+
def _download_boltz_s3_output(s3_output_prefix: str, local_dir: Path) -> None:
|
|
183
|
+
"""Download Boltz tar outputs from S3 and extract into local_dir.
|
|
184
|
+
|
|
185
|
+
Workers produce one `boltz_results_<name>.tar` per prediction plus
|
|
186
|
+
`boltz_*.done` marker objects. This helper downloads only those keys
|
|
187
|
+
(ignoring anything else under the prefix), extracts each tar in place,
|
|
188
|
+
and removes the tar file afterwards so the resulting layout matches
|
|
189
|
+
what `_finalize_boltz` expects on Primordial.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
s3_output_prefix: S3 URI prefix like s3://bucket/jobs/<id>/output/
|
|
193
|
+
local_dir: Local directory to extract into
|
|
194
|
+
"""
|
|
195
|
+
from ..s3_transport import _get_client, parse_s3_uri
|
|
196
|
+
|
|
197
|
+
bucket, prefix_key = parse_s3_uri(s3_output_prefix)
|
|
198
|
+
client = _get_client()
|
|
199
|
+
local_dir.mkdir(parents=True, exist_ok=True)
|
|
200
|
+
|
|
201
|
+
paginator = client.get_paginator("list_objects_v2")
|
|
202
|
+
tar_count = 0
|
|
203
|
+
done_count = 0
|
|
204
|
+
for page in paginator.paginate(Bucket=bucket, Prefix=prefix_key):
|
|
205
|
+
for obj in page.get("Contents", []):
|
|
206
|
+
key = obj["Key"]
|
|
207
|
+
relative = key[len(prefix_key) :]
|
|
208
|
+
if not relative:
|
|
209
|
+
continue
|
|
210
|
+
basename = Path(relative).name
|
|
211
|
+
if basename.startswith("boltz_results_") and basename.endswith(".tar"):
|
|
212
|
+
tar_path = local_dir / basename
|
|
213
|
+
client.download_file(bucket, key, str(tar_path))
|
|
214
|
+
try:
|
|
215
|
+
with tarfile.open(tar_path, mode="r") as tf:
|
|
216
|
+
tf.extractall(local_dir)
|
|
217
|
+
finally:
|
|
218
|
+
tar_path.unlink(missing_ok=True)
|
|
219
|
+
tar_count += 1
|
|
220
|
+
elif basename.startswith("boltz_") and basename.endswith(".done"):
|
|
221
|
+
done_path = local_dir / basename
|
|
222
|
+
client.download_file(bucket, key, str(done_path))
|
|
223
|
+
done_count += 1
|
|
224
|
+
|
|
225
|
+
click.echo(f" Downloaded {tar_count} prediction tars, {done_count} done markers")
|
|
226
|
+
|
|
227
|
+
|
|
178
228
|
def _check_completion(job_id: str, base_path: str, output_dir: Path | None = None) -> list[int]:
|
|
179
229
|
"""Check which chunks are incomplete (no .done marker).
|
|
180
230
|
|
|
@@ -304,8 +304,11 @@ def bedrock_cost(
|
|
|
304
304
|
sys.exit(1)
|
|
305
305
|
|
|
306
306
|
import boto3
|
|
307
|
+
from botocore.config import Config
|
|
307
308
|
|
|
308
|
-
|
|
309
|
+
# Match the thread pool used by walk_logs so urllib3 doesn't block
|
|
310
|
+
# or warn when many parallel GETs are in flight.
|
|
311
|
+
s3 = boto3.client("s3", config=Config(max_pool_connections=32))
|
|
309
312
|
|
|
310
313
|
my_handle: Optional[str] = None
|
|
311
314
|
if me:
|
|
@@ -27,6 +27,7 @@ from __future__ import annotations
|
|
|
27
27
|
import datetime as dt
|
|
28
28
|
import gzip
|
|
29
29
|
import json
|
|
30
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
30
31
|
from dataclasses import dataclass, field
|
|
31
32
|
from pathlib import Path
|
|
32
33
|
from typing import Any, Iterable, Iterator
|
|
@@ -443,14 +444,46 @@ def walk_logs(
|
|
|
443
444
|
region: str,
|
|
444
445
|
start: dt.date,
|
|
445
446
|
end: dt.date,
|
|
447
|
+
max_workers: int = 32,
|
|
446
448
|
) -> Iterator[dict]:
|
|
449
|
+
"""Yield every invocation record in `[start, end]` (inclusive, UTC days).
|
|
450
|
+
|
|
451
|
+
Object GETs are parallelised with a thread pool because each day's
|
|
452
|
+
prefix holds hundreds of tiny (~400-byte) gzipped objects and
|
|
453
|
+
per-request latency dominates wall time. Records within a single
|
|
454
|
+
object are yielded in their original NDJSON order; records *across*
|
|
455
|
+
objects may be reordered — downstream aggregation (`build_report`)
|
|
456
|
+
is order-insensitive.
|
|
457
|
+
|
|
458
|
+
`max_workers` caps in-flight S3 GETs per day. The caller's
|
|
459
|
+
`s3_client` should be configured with `max_pool_connections` >=
|
|
460
|
+
`max_workers` (see `botocore.config.Config`) to avoid urllib3
|
|
461
|
+
connection-pool contention.
|
|
462
|
+
"""
|
|
447
463
|
paginator = s3_client.get_paginator("list_objects_v2")
|
|
448
464
|
seen_keys: set[str] = set()
|
|
465
|
+
|
|
466
|
+
def _fetch_and_parse(key: str) -> list[dict]:
|
|
467
|
+
body = s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
|
|
468
|
+
decompressed = gzip.decompress(body)
|
|
469
|
+
out: list[dict] = []
|
|
470
|
+
# Each object is one or more JSON records separated by
|
|
471
|
+
# newlines (NDJSON). Older Bedrock traffic produced
|
|
472
|
+
# one-record objects; multi-record objects appeared in
|
|
473
|
+
# our bucket on 2026-04-20. Parse line-by-line so both
|
|
474
|
+
# shapes work, and tolerate a trailing newline.
|
|
475
|
+
for line in decompressed.splitlines():
|
|
476
|
+
if not line.strip():
|
|
477
|
+
continue
|
|
478
|
+
out.append(json.loads(line))
|
|
479
|
+
return out
|
|
480
|
+
|
|
449
481
|
for day in _iter_days(start, end):
|
|
450
482
|
prefix = (
|
|
451
483
|
f"invocation-logs/AWSLogs/{account}/BedrockModelInvocationLogs/"
|
|
452
484
|
f"{region}/{day.year:04d}/{day.month:02d}/{day.day:02d}/"
|
|
453
485
|
)
|
|
486
|
+
keys: list[str] = []
|
|
454
487
|
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
|
455
488
|
for obj in page.get("Contents", []) or []:
|
|
456
489
|
key = obj["Key"]
|
|
@@ -461,17 +494,18 @@ def walk_logs(
|
|
|
461
494
|
if key in seen_keys:
|
|
462
495
|
continue
|
|
463
496
|
seen_keys.add(key)
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
497
|
+
keys.append(key)
|
|
498
|
+
if not keys:
|
|
499
|
+
continue
|
|
500
|
+
# One pool per day bounds concurrent in-flight GETs and caps
|
|
501
|
+
# peak memory (at most ~max_workers decompressed objects held
|
|
502
|
+
# at once). ex.map preserves submission order, so the day's
|
|
503
|
+
# records stream out in a stable — though not chronological —
|
|
504
|
+
# order.
|
|
505
|
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
|
506
|
+
for records in ex.map(_fetch_and_parse, keys):
|
|
507
|
+
for rec in records:
|
|
508
|
+
yield rec
|
|
475
509
|
|
|
476
510
|
|
|
477
511
|
def reconcile_with_cost_explorer(
|
|
@@ -91,9 +91,7 @@ class TestCliWinsWhenExplicitlyPassed:
|
|
|
91
91
|
"cli_flag,kwarg,cli_default,yaml_alt,_",
|
|
92
92
|
MERGE_FIELDS,
|
|
93
93
|
)
|
|
94
|
-
def test_cli_at_default_beats_yaml(
|
|
95
|
-
self, cli_runner, tmp_path, cli_flag, kwarg, cli_default, yaml_alt, _
|
|
96
|
-
):
|
|
94
|
+
def test_cli_at_default_beats_yaml(self, cli_runner, tmp_path, cli_flag, kwarg, cli_default, yaml_alt, _):
|
|
97
95
|
"""CLI flag set to the Click default still wins over a different YAML value."""
|
|
98
96
|
yaml_key = cli_flag.lstrip("-")
|
|
99
97
|
config_path = tmp_path / "job.yaml"
|
|
@@ -112,9 +110,7 @@ class TestCliWinsWhenExplicitlyPassed:
|
|
|
112
110
|
"cli_flag,kwarg,_,yaml_alt,cli_alt",
|
|
113
111
|
MERGE_FIELDS,
|
|
114
112
|
)
|
|
115
|
-
def test_cli_at_non_default_beats_yaml(
|
|
116
|
-
self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, cli_alt
|
|
117
|
-
):
|
|
113
|
+
def test_cli_at_non_default_beats_yaml(self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, cli_alt):
|
|
118
114
|
"""CLI flag at a non-default value also wins over YAML (regression check)."""
|
|
119
115
|
yaml_key = cli_flag.lstrip("-")
|
|
120
116
|
config_path = tmp_path / "job.yaml"
|
|
@@ -137,9 +133,7 @@ class TestYamlWinsWhenCliOmitted:
|
|
|
137
133
|
"cli_flag,kwarg,_,yaml_alt,__",
|
|
138
134
|
MERGE_FIELDS,
|
|
139
135
|
)
|
|
140
|
-
def test_yaml_wins_when_cli_not_passed(
|
|
141
|
-
self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, __
|
|
142
|
-
):
|
|
136
|
+
def test_yaml_wins_when_cli_not_passed(self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, __):
|
|
143
137
|
yaml_key = cli_flag.lstrip("-")
|
|
144
138
|
config_path = tmp_path / "job.yaml"
|
|
145
139
|
config_path.write_text(yaml.dump({"command": "echo hi", yaml_key: yaml_alt}))
|
|
@@ -157,9 +151,7 @@ class TestDefaultWhenNeitherSet:
|
|
|
157
151
|
"_,kwarg,cli_default,__,___",
|
|
158
152
|
MERGE_FIELDS,
|
|
159
153
|
)
|
|
160
|
-
def test_click_default_applies(
|
|
161
|
-
self, cli_runner, tmp_path, _, kwarg, cli_default, __, ___
|
|
162
|
-
):
|
|
154
|
+
def test_click_default_applies(self, cli_runner, tmp_path, _, kwarg, cli_default, __, ___):
|
|
163
155
|
result, mock_client = _invoke(cli_runner, ["--command", "echo hi"], tmp_path)
|
|
164
156
|
assert result.exit_code == 0, result.output
|
|
165
157
|
call_kwargs = mock_client.submit_job.call_args[1]
|
|
@@ -218,9 +210,7 @@ class TestGpuRegressionFromReport:
|
|
|
218
210
|
|
|
219
211
|
def test_explicit_gpus_one_beats_yaml_gpus_zero(self, cli_runner, tmp_path):
|
|
220
212
|
config_path = tmp_path / "job.yaml"
|
|
221
|
-
config_path.write_text(
|
|
222
|
-
yaml.dump({"command": "echo hi", "queue": "t4-1x-spot", "gpus": 0})
|
|
223
|
-
)
|
|
213
|
+
config_path.write_text(yaml.dump({"command": "echo hi", "queue": "t4-1x-spot", "gpus": 0}))
|
|
224
214
|
result, mock_client = _invoke(
|
|
225
215
|
cli_runner,
|
|
226
216
|
["-f", str(config_path), "--gpus", "1"],
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Tests for the Boltz S3 tar-aware finalize download path.
|
|
2
|
+
|
|
3
|
+
When a Boltz job was run in S3 mode, workers upload one
|
|
4
|
+
`boltz_results_<name>.tar` per prediction (plus per-worker done markers).
|
|
5
|
+
Finalize must download those tars, extract them into `boltz_results_*/`
|
|
6
|
+
directories matching the legacy on-disk layout, and leave the existing
|
|
7
|
+
`_finalize_boltz` logic untouched.
|
|
8
|
+
|
|
9
|
+
See plan: nutshell/plans/dma/05_2026/0512_boltz_s3_fanout_and_cross_az_dig.md.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import tarfile
|
|
13
|
+
import tempfile
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import pytest
|
|
18
|
+
|
|
19
|
+
from dh_cli.batch.commands.finalize import _download_boltz_s3_output
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def temp_dir():
|
|
24
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
25
|
+
yield Path(tmpdir)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _build_essential_tar(tar_path: Path, complex_name: str) -> None:
|
|
29
|
+
"""Build a boltz_results_<name>.tar like the worker produces."""
|
|
30
|
+
with tempfile.TemporaryDirectory() as src:
|
|
31
|
+
src_p = Path(src)
|
|
32
|
+
pred_subdir = src_p / f"boltz_results_{complex_name}" / "predictions" / complex_name
|
|
33
|
+
pred_subdir.mkdir(parents=True)
|
|
34
|
+
(pred_subdir / f"{complex_name}_model_0.cif").write_text(f"CIF {complex_name}\n")
|
|
35
|
+
(pred_subdir / f"confidence_{complex_name}_model_0.json").write_text(
|
|
36
|
+
f'{{"cx":"{complex_name}"}}'
|
|
37
|
+
)
|
|
38
|
+
with tarfile.open(tar_path, mode="w") as tf:
|
|
39
|
+
root = src_p / f"boltz_results_{complex_name}"
|
|
40
|
+
for f in sorted(root.rglob("*")):
|
|
41
|
+
if f.is_file():
|
|
42
|
+
tf.add(f, arcname=f"boltz_results_{complex_name}/{f.relative_to(root)}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _make_mock_s3_client(objects: dict[str, bytes]) -> MagicMock:
|
|
46
|
+
"""Build a boto3 S3 client mock backed by an in-memory bucket.
|
|
47
|
+
|
|
48
|
+
`objects` is a mapping of key -> bytes. The mock implements:
|
|
49
|
+
- list_objects_v2 (via a paginator)
|
|
50
|
+
- download_file (writes object bytes to the local path)
|
|
51
|
+
"""
|
|
52
|
+
client = MagicMock()
|
|
53
|
+
|
|
54
|
+
def _paginate(Bucket, Prefix, **kwargs):
|
|
55
|
+
matching = [k for k in objects if k.startswith(Prefix)]
|
|
56
|
+
yield {"Contents": [{"Key": k} for k in sorted(matching)]}
|
|
57
|
+
|
|
58
|
+
paginator = MagicMock()
|
|
59
|
+
paginator.paginate.side_effect = _paginate
|
|
60
|
+
client.get_paginator.return_value = paginator
|
|
61
|
+
|
|
62
|
+
def _download_file(bucket, key, local_path, *args, **kwargs):
|
|
63
|
+
Path(local_path).parent.mkdir(parents=True, exist_ok=True)
|
|
64
|
+
Path(local_path).write_bytes(objects[key])
|
|
65
|
+
|
|
66
|
+
client.download_file.side_effect = _download_file
|
|
67
|
+
return client
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TestDownloadsAndExtracts:
|
|
71
|
+
def test_downloads_tars_and_extracts(self, temp_dir):
|
|
72
|
+
tar_a = temp_dir / "A.tar"
|
|
73
|
+
tar_b = temp_dir / "B.tar"
|
|
74
|
+
_build_essential_tar(tar_a, "A")
|
|
75
|
+
_build_essential_tar(tar_b, "B")
|
|
76
|
+
|
|
77
|
+
objects = {
|
|
78
|
+
"jobs/j/output/boltz_results_A.tar": tar_a.read_bytes(),
|
|
79
|
+
"jobs/j/output/boltz_results_B.tar": tar_b.read_bytes(),
|
|
80
|
+
"jobs/j/output/boltz_0.done": b"",
|
|
81
|
+
}
|
|
82
|
+
client = _make_mock_s3_client(objects)
|
|
83
|
+
|
|
84
|
+
dest = temp_dir / "extracted"
|
|
85
|
+
dest.mkdir()
|
|
86
|
+
|
|
87
|
+
with patch("dh_cli.batch.s3_transport._get_client", return_value=client):
|
|
88
|
+
_download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
|
|
89
|
+
|
|
90
|
+
assert (dest / "boltz_results_A" / "predictions" / "A" / "A_model_0.cif").read_text() == "CIF A\n"
|
|
91
|
+
assert (dest / "boltz_results_B" / "predictions" / "B" / "B_model_0.cif").read_text() == "CIF B\n"
|
|
92
|
+
assert (dest / "boltz_0.done").exists()
|
|
93
|
+
|
|
94
|
+
def test_local_tar_cleaned_up_after_extract(self, temp_dir):
|
|
95
|
+
tar_a = temp_dir / "A.tar"
|
|
96
|
+
_build_essential_tar(tar_a, "A")
|
|
97
|
+
objects = {"jobs/j/output/boltz_results_A.tar": tar_a.read_bytes()}
|
|
98
|
+
client = _make_mock_s3_client(objects)
|
|
99
|
+
|
|
100
|
+
dest = temp_dir / "extracted"
|
|
101
|
+
dest.mkdir()
|
|
102
|
+
|
|
103
|
+
with patch("dh_cli.batch.s3_transport._get_client", return_value=client):
|
|
104
|
+
_download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
|
|
105
|
+
|
|
106
|
+
leftover = list(dest.rglob("*.tar"))
|
|
107
|
+
assert leftover == []
|
|
108
|
+
|
|
109
|
+
def test_ignores_non_tar_non_done_keys(self, temp_dir):
|
|
110
|
+
"""Extra objects in the output prefix that aren't tars or done markers are ignored."""
|
|
111
|
+
tar_a = temp_dir / "A.tar"
|
|
112
|
+
_build_essential_tar(tar_a, "A")
|
|
113
|
+
objects = {
|
|
114
|
+
"jobs/j/output/boltz_results_A.tar": tar_a.read_bytes(),
|
|
115
|
+
"jobs/j/output/stray.txt": b"ignore me",
|
|
116
|
+
"jobs/j/output/notes/readme.md": b"also ignored",
|
|
117
|
+
}
|
|
118
|
+
client = _make_mock_s3_client(objects)
|
|
119
|
+
|
|
120
|
+
dest = temp_dir / "extracted"
|
|
121
|
+
dest.mkdir()
|
|
122
|
+
|
|
123
|
+
with patch("dh_cli.batch.s3_transport._get_client", return_value=client):
|
|
124
|
+
_download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
|
|
125
|
+
|
|
126
|
+
assert (dest / "boltz_results_A" / "predictions" / "A" / "A_model_0.cif").exists()
|
|
127
|
+
assert not (dest / "stray.txt").exists()
|
|
128
|
+
assert not (dest / "notes").exists()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class TestRoundTrip:
|
|
132
|
+
def test_worker_to_finalize_bit_identical(self, temp_dir):
|
|
133
|
+
"""Worker's sync_boltz_essential_to_s3 output -> finalize download -> same bytes."""
|
|
134
|
+
from dh_batch.s3_sync import sync_boltz_essential_to_s3
|
|
135
|
+
|
|
136
|
+
job_dir = temp_dir / "job"
|
|
137
|
+
output_dir = job_dir / "output"
|
|
138
|
+
pred_a = output_dir / "boltz_results_A" / "predictions" / "A"
|
|
139
|
+
pred_a.mkdir(parents=True)
|
|
140
|
+
(pred_a / "A_model_0.cif").write_text("CIF ROUNDTRIP\n")
|
|
141
|
+
(pred_a / "confidence_A_model_0.json").write_text('{"r":1}')
|
|
142
|
+
(output_dir / "boltz_0.done").write_text("ok")
|
|
143
|
+
|
|
144
|
+
uploaded: dict[str, bytes] = {}
|
|
145
|
+
worker_client = MagicMock()
|
|
146
|
+
|
|
147
|
+
def _upload_file(local_path, bucket, key, *args, **kwargs):
|
|
148
|
+
uploaded[key] = Path(local_path).read_bytes()
|
|
149
|
+
|
|
150
|
+
worker_client.upload_file.side_effect = _upload_file
|
|
151
|
+
|
|
152
|
+
with patch("dh_batch.s3_transport._get_client", return_value=worker_client):
|
|
153
|
+
sync_boltz_essential_to_s3(job_dir, "s3://bucket/jobs/j/")
|
|
154
|
+
|
|
155
|
+
finalize_client = _make_mock_s3_client(uploaded)
|
|
156
|
+
dest = temp_dir / "extracted"
|
|
157
|
+
dest.mkdir()
|
|
158
|
+
|
|
159
|
+
with patch("dh_cli.batch.s3_transport._get_client", return_value=finalize_client):
|
|
160
|
+
_download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
|
|
161
|
+
|
|
162
|
+
cif = dest / "boltz_results_A" / "predictions" / "A" / "A_model_0.cif"
|
|
163
|
+
conf = dest / "boltz_results_A" / "predictions" / "A" / "confidence_A_model_0.json"
|
|
164
|
+
done = dest / "boltz_0.done"
|
|
165
|
+
|
|
166
|
+
assert cif.read_text() == "CIF ROUNDTRIP\n"
|
|
167
|
+
assert conf.read_text() == '{"r":1}'
|
|
168
|
+
assert done.read_text() == "ok"
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class TestFinalizeDispatches:
|
|
172
|
+
def test_boltz_s3_uses_tar_path(self, temp_dir):
|
|
173
|
+
"""`dh batch finalize` for Boltz in S3 mode calls the tar-aware helper, not download_directory."""
|
|
174
|
+
from dh_cli.batch.commands.finalize import finalize as finalize_cmd
|
|
175
|
+
from dh_cli.batch.manifest import JobManifest, JobStatus
|
|
176
|
+
|
|
177
|
+
manifest = JobManifest(
|
|
178
|
+
job_id="test-boltz",
|
|
179
|
+
user="tester",
|
|
180
|
+
pipeline="boltz",
|
|
181
|
+
storage_mode="s3",
|
|
182
|
+
status=JobStatus.SUCCEEDED,
|
|
183
|
+
s3_prefix="s3://bucket/jobs/test-boltz/",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
from click.testing import CliRunner
|
|
187
|
+
|
|
188
|
+
runner = CliRunner()
|
|
189
|
+
|
|
190
|
+
with patch("dh_cli.batch.commands.finalize.load_manifest", return_value=manifest):
|
|
191
|
+
with patch(
|
|
192
|
+
"dh_cli.batch.commands.finalize._download_boltz_s3_output"
|
|
193
|
+
) as mock_tar_download:
|
|
194
|
+
with patch(
|
|
195
|
+
"dh_cli.batch.commands.finalize._check_completion", return_value=[]
|
|
196
|
+
):
|
|
197
|
+
with patch("dh_cli.batch.commands.finalize._finalize_boltz") as mock_fb:
|
|
198
|
+
with patch("dh_cli.batch.commands.finalize.save_manifest_s3"):
|
|
199
|
+
result = runner.invoke(
|
|
200
|
+
finalize_cmd,
|
|
201
|
+
[
|
|
202
|
+
"test-boltz",
|
|
203
|
+
"--output",
|
|
204
|
+
str(temp_dir / "final"),
|
|
205
|
+
"--keep-intermediates",
|
|
206
|
+
],
|
|
207
|
+
input="y\n",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
assert result.exit_code == 0, result.output
|
|
211
|
+
assert mock_tar_download.called, "tar-aware download helper must be called for Boltz S3"
|
|
212
|
+
assert mock_fb.called
|
|
213
|
+
|
|
214
|
+
def test_non_boltz_s3_uses_download_directory(self, temp_dir):
|
|
215
|
+
"""Non-Boltz pipelines (e.g. embed-t5) in S3 mode keep using download_directory."""
|
|
216
|
+
from dh_cli.batch.commands.finalize import finalize as finalize_cmd
|
|
217
|
+
from dh_cli.batch.manifest import JobManifest, JobStatus
|
|
218
|
+
|
|
219
|
+
manifest = JobManifest(
|
|
220
|
+
job_id="test-embed",
|
|
221
|
+
user="tester",
|
|
222
|
+
pipeline="embed-t5",
|
|
223
|
+
storage_mode="s3",
|
|
224
|
+
status=JobStatus.SUCCEEDED,
|
|
225
|
+
s3_prefix="s3://bucket/jobs/test-embed/",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
from click.testing import CliRunner
|
|
229
|
+
|
|
230
|
+
runner = CliRunner()
|
|
231
|
+
|
|
232
|
+
with patch("dh_cli.batch.commands.finalize.load_manifest", return_value=manifest):
|
|
233
|
+
with patch(
|
|
234
|
+
"dh_cli.batch.commands.finalize._download_boltz_s3_output"
|
|
235
|
+
) as mock_tar_download:
|
|
236
|
+
with patch("dh_cli.batch.s3_transport.download_directory") as mock_dd:
|
|
237
|
+
with patch(
|
|
238
|
+
"dh_cli.batch.commands.finalize._check_completion", return_value=[]
|
|
239
|
+
):
|
|
240
|
+
with patch(
|
|
241
|
+
"dh_cli.batch.commands.finalize._finalize_embeddings"
|
|
242
|
+
) as mock_fe:
|
|
243
|
+
with patch("dh_cli.batch.commands.finalize.save_manifest_s3"):
|
|
244
|
+
result = runner.invoke(
|
|
245
|
+
finalize_cmd,
|
|
246
|
+
[
|
|
247
|
+
"test-embed",
|
|
248
|
+
"--output",
|
|
249
|
+
str(temp_dir / "out.h5"),
|
|
250
|
+
"--keep-intermediates",
|
|
251
|
+
],
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
assert result.exit_code == 0, result.output
|
|
255
|
+
assert mock_dd.called, "non-Boltz pipelines should keep using download_directory"
|
|
256
|
+
assert not mock_tar_download.called, "tar-aware helper must not fire for embed-t5"
|
|
257
|
+
assert mock_fe.called
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|