dh-cli 0.8.1__tar.gz → 0.8.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. {dh_cli-0.8.1 → dh_cli-0.8.3}/PKG-INFO +1 -1
  2. {dh_cli-0.8.1 → dh_cli-0.8.3}/pyproject.toml +1 -1
  3. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/finalize.py +51 -1
  4. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/commands.py +4 -1
  5. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/cost_report.py +45 -11
  6. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/test_submit_merge.py +5 -15
  7. dh_cli-0.8.3/tests/test_finalize_boltz_tar.py +257 -0
  8. {dh_cli-0.8.1 → dh_cli-0.8.3}/.gitignore +0 -0
  9. {dh_cli-0.8.1 → dh_cli-0.8.3}/LICENSE +0 -0
  10. {dh_cli-0.8.1 → dh_cli-0.8.3}/README.md +0 -0
  11. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/__init__.py +0 -0
  12. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/_identity.py +0 -0
  13. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/__init__.py +0 -0
  14. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/aws_batch.py +0 -0
  15. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/__init__.py +0 -0
  16. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/boltz.py +0 -0
  17. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/cancel.py +0 -0
  18. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/clean.py +0 -0
  19. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/embed_t5.py +0 -0
  20. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/list_jobs.py +0 -0
  21. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/local.py +0 -0
  22. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/logs.py +0 -0
  23. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/orca.py +0 -0
  24. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/protmpnn.py +0 -0
  25. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/protmpnn_to_boltz.py +0 -0
  26. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/retry.py +0 -0
  27. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/status.py +0 -0
  28. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/submit.py +0 -0
  29. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/train.py +0 -0
  30. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/commands/wait_for.py +0 -0
  31. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/fasta_utils.py +0 -0
  32. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/h5_utils.py +0 -0
  33. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/job_id.py +0 -0
  34. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/manifest.py +0 -0
  35. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/batch/s3_transport.py +0 -0
  36. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/__init__.py +0 -0
  37. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/bedrock/pricing.yaml +0 -0
  38. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/cloud_commands.py +0 -0
  39. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/codeartifact.py +0 -0
  40. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/__init__.py +0 -0
  41. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/api_client.py +0 -0
  42. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/auth.py +0 -0
  43. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/engine_commands.py +0 -0
  44. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/progress.py +0 -0
  45. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/ssh_config.py +0 -0
  46. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/engines_studios/studio_commands.py +0 -0
  47. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/github_commands.py +0 -0
  48. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/__init__.py +0 -0
  49. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/deploy.py +0 -0
  50. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/local.py +0 -0
  51. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/test.py +0 -0
  52. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/tf.py +0 -0
  53. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/hz/users.py +0 -0
  54. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/main.py +0 -0
  55. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/utility_commands.py +0 -0
  56. {dh_cli-0.8.1 → dh_cli-0.8.3}/src/dh_cli/warehouse.py +0 -0
  57. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/__init__.py +0 -0
  58. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/test_aws_batch_resources.py +0 -0
  59. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/batch/test_submit_cpu_only.py +0 -0
  60. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/conftest.py +0 -0
  61. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/A_cache_write.json +0 -0
  62. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/B_cache_read.json +0 -0
  63. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/C_plain.json +0 -0
  64. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/D_cursor_user.json +0 -0
  65. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/E_service_role.json +0 -0
  66. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/F_legacy_shared.json +0 -0
  67. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/fixtures/G_unknown_model.json +0 -0
  68. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_build_report.py +0 -0
  69. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_classify_arn.py +0 -0
  70. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cli_exit_codes.py +0 -0
  71. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cost_calc.py +0 -0
  72. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cost_command.py +0 -0
  73. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_cur_reconciliation.py +0 -0
  74. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_key_command.py +0 -0
  75. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_render_formats.py +0 -0
  76. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_resolve_base_model.py +0 -0
  77. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/bedrock/test_s3_walker.py +0 -0
  78. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/__init__.py +0 -0
  79. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/conftest.py +0 -0
  80. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_engine_role_cannot_read_github_pat.py +0 -0
  81. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_identity.py +0 -0
  82. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_login.py +0 -0
  83. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_login_error_paths.py +0 -0
  84. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_login_security.py +0 -0
  85. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_logout.py +0 -0
  86. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_rotate.py +0 -0
  87. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/github/test_status.py +0 -0
  88. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/hz/test_init.py +0 -0
  89. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/hz/test_suites.py +0 -0
  90. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/hz/test_users.py +0 -0
  91. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/test_cloud_gcp.py +0 -0
  92. {dh_cli-0.8.1 → dh_cli-0.8.3}/tests/test_finalize_protmpnn.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dh-cli
3
- Version: 0.8.1
3
+ Version: 0.8.3
4
4
  Summary: Dayhoff Labs developer CLI
5
5
  Author-email: Dayhoff Labs <dev@dayhofflabs.com>
6
6
  License: # PolyForm Noncommercial License 1.0.0
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "dh-cli"
7
- version = "0.8.1"
7
+ version = "0.8.3"
8
8
  description = "Dayhoff Labs developer CLI"
9
9
  requires-python = ">=3.11"
10
10
  readme = "README.md"
@@ -1,6 +1,7 @@
1
1
  """Finalize command for combining results and cleaning up."""
2
2
 
3
3
  import shutil
4
+ import tarfile
4
5
  import tempfile
5
6
  from pathlib import Path
6
7
 
@@ -94,7 +95,10 @@ def finalize(job_id, output, force, keep_intermediates, full_output, skip_dedup,
94
95
  s3_temp_dir = Path(tempfile.mkdtemp())
95
96
  s3_output_prefix = f"{manifest.s3_prefix}output/"
96
97
  click.echo("Downloading outputs from S3...")
97
- download_directory(s3_output_prefix, s3_temp_dir)
98
+ if manifest.pipeline == "boltz":
99
+ _download_boltz_s3_output(s3_output_prefix, s3_temp_dir)
100
+ else:
101
+ download_directory(s3_output_prefix, s3_temp_dir)
98
102
  output_dir = s3_temp_dir
99
103
  else:
100
104
  output_dir = job_dir / "output"
@@ -175,6 +179,52 @@ def finalize(job_id, output, force, keep_intermediates, full_output, skip_dedup,
175
179
  click.echo(f"Job directory preserved: {job_dir}")
176
180
 
177
181
 
182
+ def _download_boltz_s3_output(s3_output_prefix: str, local_dir: Path) -> None:
183
+ """Download Boltz tar outputs from S3 and extract into local_dir.
184
+
185
+ Workers produce one `boltz_results_<name>.tar` per prediction plus
186
+ `boltz_*.done` marker objects. This helper downloads only those keys
187
+ (ignoring anything else under the prefix), extracts each tar in place,
188
+ and removes the tar file afterwards so the resulting layout matches
189
+ what `_finalize_boltz` expects on Primordial.
190
+
191
+ Args:
192
+ s3_output_prefix: S3 URI prefix like s3://bucket/jobs/<id>/output/
193
+ local_dir: Local directory to extract into
194
+ """
195
+ from ..s3_transport import _get_client, parse_s3_uri
196
+
197
+ bucket, prefix_key = parse_s3_uri(s3_output_prefix)
198
+ client = _get_client()
199
+ local_dir.mkdir(parents=True, exist_ok=True)
200
+
201
+ paginator = client.get_paginator("list_objects_v2")
202
+ tar_count = 0
203
+ done_count = 0
204
+ for page in paginator.paginate(Bucket=bucket, Prefix=prefix_key):
205
+ for obj in page.get("Contents", []):
206
+ key = obj["Key"]
207
+ relative = key[len(prefix_key) :]
208
+ if not relative:
209
+ continue
210
+ basename = Path(relative).name
211
+ if basename.startswith("boltz_results_") and basename.endswith(".tar"):
212
+ tar_path = local_dir / basename
213
+ client.download_file(bucket, key, str(tar_path))
214
+ try:
215
+ with tarfile.open(tar_path, mode="r") as tf:
216
+ tf.extractall(local_dir)
217
+ finally:
218
+ tar_path.unlink(missing_ok=True)
219
+ tar_count += 1
220
+ elif basename.startswith("boltz_") and basename.endswith(".done"):
221
+ done_path = local_dir / basename
222
+ client.download_file(bucket, key, str(done_path))
223
+ done_count += 1
224
+
225
+ click.echo(f" Downloaded {tar_count} prediction tars, {done_count} done markers")
226
+
227
+
178
228
  def _check_completion(job_id: str, base_path: str, output_dir: Path | None = None) -> list[int]:
179
229
  """Check which chunks are incomplete (no .done marker).
180
230
 
@@ -304,8 +304,11 @@ def bedrock_cost(
304
304
  sys.exit(1)
305
305
 
306
306
  import boto3
307
+ from botocore.config import Config
307
308
 
308
- s3 = boto3.client("s3")
309
+ # Match the thread pool used by walk_logs so urllib3 doesn't block
310
+ # or warn when many parallel GETs are in flight.
311
+ s3 = boto3.client("s3", config=Config(max_pool_connections=32))
309
312
 
310
313
  my_handle: Optional[str] = None
311
314
  if me:
@@ -27,6 +27,7 @@ from __future__ import annotations
27
27
  import datetime as dt
28
28
  import gzip
29
29
  import json
30
+ from concurrent.futures import ThreadPoolExecutor
30
31
  from dataclasses import dataclass, field
31
32
  from pathlib import Path
32
33
  from typing import Any, Iterable, Iterator
@@ -443,14 +444,46 @@ def walk_logs(
443
444
  region: str,
444
445
  start: dt.date,
445
446
  end: dt.date,
447
+ max_workers: int = 32,
446
448
  ) -> Iterator[dict]:
449
+ """Yield every invocation record in `[start, end]` (inclusive, UTC days).
450
+
451
+ Object GETs are parallelised with a thread pool because each day's
452
+ prefix holds hundreds of tiny (~400-byte) gzipped objects and
453
+ per-request latency dominates wall time. Records within a single
454
+ object are yielded in their original NDJSON order; records *across*
455
+ objects may be reordered — downstream aggregation (`build_report`)
456
+ is order-insensitive.
457
+
458
+ `max_workers` caps in-flight S3 GETs per day. The caller's
459
+ `s3_client` should be configured with `max_pool_connections` >=
460
+ `max_workers` (see `botocore.config.Config`) to avoid urllib3
461
+ connection-pool contention.
462
+ """
447
463
  paginator = s3_client.get_paginator("list_objects_v2")
448
464
  seen_keys: set[str] = set()
465
+
466
+ def _fetch_and_parse(key: str) -> list[dict]:
467
+ body = s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
468
+ decompressed = gzip.decompress(body)
469
+ out: list[dict] = []
470
+ # Each object is one or more JSON records separated by
471
+ # newlines (NDJSON). Older Bedrock traffic produced
472
+ # one-record objects; multi-record objects appeared in
473
+ # our bucket on 2026-04-20. Parse line-by-line so both
474
+ # shapes work, and tolerate a trailing newline.
475
+ for line in decompressed.splitlines():
476
+ if not line.strip():
477
+ continue
478
+ out.append(json.loads(line))
479
+ return out
480
+
449
481
  for day in _iter_days(start, end):
450
482
  prefix = (
451
483
  f"invocation-logs/AWSLogs/{account}/BedrockModelInvocationLogs/"
452
484
  f"{region}/{day.year:04d}/{day.month:02d}/{day.day:02d}/"
453
485
  )
486
+ keys: list[str] = []
454
487
  for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
455
488
  for obj in page.get("Contents", []) or []:
456
489
  key = obj["Key"]
@@ -461,17 +494,18 @@ def walk_logs(
461
494
  if key in seen_keys:
462
495
  continue
463
496
  seen_keys.add(key)
464
- body = s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
465
- decompressed = gzip.decompress(body)
466
- # Each object is one or more JSON records separated by
467
- # newlines (NDJSON). Older Bedrock traffic produced
468
- # one-record objects; multi-record objects appeared in
469
- # our bucket on 2026-04-20. Parse line-by-line so both
470
- # shapes work, and tolerate a trailing newline.
471
- for line in decompressed.splitlines():
472
- if not line.strip():
473
- continue
474
- yield json.loads(line)
497
+ keys.append(key)
498
+ if not keys:
499
+ continue
500
+ # One pool per day bounds concurrent in-flight GETs and caps
501
+ # peak memory (at most ~max_workers decompressed objects held
502
+ # at once). ex.map preserves submission order, so the day's
503
+ # records stream out in a stable — though not chronological —
504
+ # order.
505
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
506
+ for records in ex.map(_fetch_and_parse, keys):
507
+ for rec in records:
508
+ yield rec
475
509
 
476
510
 
477
511
  def reconcile_with_cost_explorer(
@@ -91,9 +91,7 @@ class TestCliWinsWhenExplicitlyPassed:
91
91
  "cli_flag,kwarg,cli_default,yaml_alt,_",
92
92
  MERGE_FIELDS,
93
93
  )
94
- def test_cli_at_default_beats_yaml(
95
- self, cli_runner, tmp_path, cli_flag, kwarg, cli_default, yaml_alt, _
96
- ):
94
+ def test_cli_at_default_beats_yaml(self, cli_runner, tmp_path, cli_flag, kwarg, cli_default, yaml_alt, _):
97
95
  """CLI flag set to the Click default still wins over a different YAML value."""
98
96
  yaml_key = cli_flag.lstrip("-")
99
97
  config_path = tmp_path / "job.yaml"
@@ -112,9 +110,7 @@ class TestCliWinsWhenExplicitlyPassed:
112
110
  "cli_flag,kwarg,_,yaml_alt,cli_alt",
113
111
  MERGE_FIELDS,
114
112
  )
115
- def test_cli_at_non_default_beats_yaml(
116
- self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, cli_alt
117
- ):
113
+ def test_cli_at_non_default_beats_yaml(self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, cli_alt):
118
114
  """CLI flag at a non-default value also wins over YAML (regression check)."""
119
115
  yaml_key = cli_flag.lstrip("-")
120
116
  config_path = tmp_path / "job.yaml"
@@ -137,9 +133,7 @@ class TestYamlWinsWhenCliOmitted:
137
133
  "cli_flag,kwarg,_,yaml_alt,__",
138
134
  MERGE_FIELDS,
139
135
  )
140
- def test_yaml_wins_when_cli_not_passed(
141
- self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, __
142
- ):
136
+ def test_yaml_wins_when_cli_not_passed(self, cli_runner, tmp_path, cli_flag, kwarg, _, yaml_alt, __):
143
137
  yaml_key = cli_flag.lstrip("-")
144
138
  config_path = tmp_path / "job.yaml"
145
139
  config_path.write_text(yaml.dump({"command": "echo hi", yaml_key: yaml_alt}))
@@ -157,9 +151,7 @@ class TestDefaultWhenNeitherSet:
157
151
  "_,kwarg,cli_default,__,___",
158
152
  MERGE_FIELDS,
159
153
  )
160
- def test_click_default_applies(
161
- self, cli_runner, tmp_path, _, kwarg, cli_default, __, ___
162
- ):
154
+ def test_click_default_applies(self, cli_runner, tmp_path, _, kwarg, cli_default, __, ___):
163
155
  result, mock_client = _invoke(cli_runner, ["--command", "echo hi"], tmp_path)
164
156
  assert result.exit_code == 0, result.output
165
157
  call_kwargs = mock_client.submit_job.call_args[1]
@@ -218,9 +210,7 @@ class TestGpuRegressionFromReport:
218
210
 
219
211
  def test_explicit_gpus_one_beats_yaml_gpus_zero(self, cli_runner, tmp_path):
220
212
  config_path = tmp_path / "job.yaml"
221
- config_path.write_text(
222
- yaml.dump({"command": "echo hi", "queue": "t4-1x-spot", "gpus": 0})
223
- )
213
+ config_path.write_text(yaml.dump({"command": "echo hi", "queue": "t4-1x-spot", "gpus": 0}))
224
214
  result, mock_client = _invoke(
225
215
  cli_runner,
226
216
  ["-f", str(config_path), "--gpus", "1"],
@@ -0,0 +1,257 @@
1
+ """Tests for the Boltz S3 tar-aware finalize download path.
2
+
3
+ When a Boltz job was run in S3 mode, workers upload one
4
+ `boltz_results_<name>.tar` per prediction (plus per-worker done markers).
5
+ Finalize must download those tars, extract them into `boltz_results_*/`
6
+ directories matching the legacy on-disk layout, and leave the existing
7
+ `_finalize_boltz` logic untouched.
8
+
9
+ See plan: nutshell/plans/dma/05_2026/0512_boltz_s3_fanout_and_cross_az_dig.md.
10
+ """
11
+
12
+ import tarfile
13
+ import tempfile
14
+ from pathlib import Path
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import pytest
18
+
19
+ from dh_cli.batch.commands.finalize import _download_boltz_s3_output
20
+
21
+
22
+ @pytest.fixture
23
+ def temp_dir():
24
+ with tempfile.TemporaryDirectory() as tmpdir:
25
+ yield Path(tmpdir)
26
+
27
+
28
+ def _build_essential_tar(tar_path: Path, complex_name: str) -> None:
29
+ """Build a boltz_results_<name>.tar like the worker produces."""
30
+ with tempfile.TemporaryDirectory() as src:
31
+ src_p = Path(src)
32
+ pred_subdir = src_p / f"boltz_results_{complex_name}" / "predictions" / complex_name
33
+ pred_subdir.mkdir(parents=True)
34
+ (pred_subdir / f"{complex_name}_model_0.cif").write_text(f"CIF {complex_name}\n")
35
+ (pred_subdir / f"confidence_{complex_name}_model_0.json").write_text(
36
+ f'{{"cx":"{complex_name}"}}'
37
+ )
38
+ with tarfile.open(tar_path, mode="w") as tf:
39
+ root = src_p / f"boltz_results_{complex_name}"
40
+ for f in sorted(root.rglob("*")):
41
+ if f.is_file():
42
+ tf.add(f, arcname=f"boltz_results_{complex_name}/{f.relative_to(root)}")
43
+
44
+
45
+ def _make_mock_s3_client(objects: dict[str, bytes]) -> MagicMock:
46
+ """Build a boto3 S3 client mock backed by an in-memory bucket.
47
+
48
+ `objects` is a mapping of key -> bytes. The mock implements:
49
+ - list_objects_v2 (via a paginator)
50
+ - download_file (writes object bytes to the local path)
51
+ """
52
+ client = MagicMock()
53
+
54
+ def _paginate(Bucket, Prefix, **kwargs):
55
+ matching = [k for k in objects if k.startswith(Prefix)]
56
+ yield {"Contents": [{"Key": k} for k in sorted(matching)]}
57
+
58
+ paginator = MagicMock()
59
+ paginator.paginate.side_effect = _paginate
60
+ client.get_paginator.return_value = paginator
61
+
62
+ def _download_file(bucket, key, local_path, *args, **kwargs):
63
+ Path(local_path).parent.mkdir(parents=True, exist_ok=True)
64
+ Path(local_path).write_bytes(objects[key])
65
+
66
+ client.download_file.side_effect = _download_file
67
+ return client
68
+
69
+
70
+ class TestDownloadsAndExtracts:
71
+ def test_downloads_tars_and_extracts(self, temp_dir):
72
+ tar_a = temp_dir / "A.tar"
73
+ tar_b = temp_dir / "B.tar"
74
+ _build_essential_tar(tar_a, "A")
75
+ _build_essential_tar(tar_b, "B")
76
+
77
+ objects = {
78
+ "jobs/j/output/boltz_results_A.tar": tar_a.read_bytes(),
79
+ "jobs/j/output/boltz_results_B.tar": tar_b.read_bytes(),
80
+ "jobs/j/output/boltz_0.done": b"",
81
+ }
82
+ client = _make_mock_s3_client(objects)
83
+
84
+ dest = temp_dir / "extracted"
85
+ dest.mkdir()
86
+
87
+ with patch("dh_cli.batch.s3_transport._get_client", return_value=client):
88
+ _download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
89
+
90
+ assert (dest / "boltz_results_A" / "predictions" / "A" / "A_model_0.cif").read_text() == "CIF A\n"
91
+ assert (dest / "boltz_results_B" / "predictions" / "B" / "B_model_0.cif").read_text() == "CIF B\n"
92
+ assert (dest / "boltz_0.done").exists()
93
+
94
+ def test_local_tar_cleaned_up_after_extract(self, temp_dir):
95
+ tar_a = temp_dir / "A.tar"
96
+ _build_essential_tar(tar_a, "A")
97
+ objects = {"jobs/j/output/boltz_results_A.tar": tar_a.read_bytes()}
98
+ client = _make_mock_s3_client(objects)
99
+
100
+ dest = temp_dir / "extracted"
101
+ dest.mkdir()
102
+
103
+ with patch("dh_cli.batch.s3_transport._get_client", return_value=client):
104
+ _download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
105
+
106
+ leftover = list(dest.rglob("*.tar"))
107
+ assert leftover == []
108
+
109
+ def test_ignores_non_tar_non_done_keys(self, temp_dir):
110
+ """Extra objects in the output prefix that aren't tars or done markers are ignored."""
111
+ tar_a = temp_dir / "A.tar"
112
+ _build_essential_tar(tar_a, "A")
113
+ objects = {
114
+ "jobs/j/output/boltz_results_A.tar": tar_a.read_bytes(),
115
+ "jobs/j/output/stray.txt": b"ignore me",
116
+ "jobs/j/output/notes/readme.md": b"also ignored",
117
+ }
118
+ client = _make_mock_s3_client(objects)
119
+
120
+ dest = temp_dir / "extracted"
121
+ dest.mkdir()
122
+
123
+ with patch("dh_cli.batch.s3_transport._get_client", return_value=client):
124
+ _download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
125
+
126
+ assert (dest / "boltz_results_A" / "predictions" / "A" / "A_model_0.cif").exists()
127
+ assert not (dest / "stray.txt").exists()
128
+ assert not (dest / "notes").exists()
129
+
130
+
131
+ class TestRoundTrip:
132
+ def test_worker_to_finalize_bit_identical(self, temp_dir):
133
+ """Worker's sync_boltz_essential_to_s3 output -> finalize download -> same bytes."""
134
+ from dh_batch.s3_sync import sync_boltz_essential_to_s3
135
+
136
+ job_dir = temp_dir / "job"
137
+ output_dir = job_dir / "output"
138
+ pred_a = output_dir / "boltz_results_A" / "predictions" / "A"
139
+ pred_a.mkdir(parents=True)
140
+ (pred_a / "A_model_0.cif").write_text("CIF ROUNDTRIP\n")
141
+ (pred_a / "confidence_A_model_0.json").write_text('{"r":1}')
142
+ (output_dir / "boltz_0.done").write_text("ok")
143
+
144
+ uploaded: dict[str, bytes] = {}
145
+ worker_client = MagicMock()
146
+
147
+ def _upload_file(local_path, bucket, key, *args, **kwargs):
148
+ uploaded[key] = Path(local_path).read_bytes()
149
+
150
+ worker_client.upload_file.side_effect = _upload_file
151
+
152
+ with patch("dh_batch.s3_transport._get_client", return_value=worker_client):
153
+ sync_boltz_essential_to_s3(job_dir, "s3://bucket/jobs/j/")
154
+
155
+ finalize_client = _make_mock_s3_client(uploaded)
156
+ dest = temp_dir / "extracted"
157
+ dest.mkdir()
158
+
159
+ with patch("dh_cli.batch.s3_transport._get_client", return_value=finalize_client):
160
+ _download_boltz_s3_output("s3://bucket/jobs/j/output/", dest)
161
+
162
+ cif = dest / "boltz_results_A" / "predictions" / "A" / "A_model_0.cif"
163
+ conf = dest / "boltz_results_A" / "predictions" / "A" / "confidence_A_model_0.json"
164
+ done = dest / "boltz_0.done"
165
+
166
+ assert cif.read_text() == "CIF ROUNDTRIP\n"
167
+ assert conf.read_text() == '{"r":1}'
168
+ assert done.read_text() == "ok"
169
+
170
+
171
+ class TestFinalizeDispatches:
172
+ def test_boltz_s3_uses_tar_path(self, temp_dir):
173
+ """`dh batch finalize` for Boltz in S3 mode calls the tar-aware helper, not download_directory."""
174
+ from dh_cli.batch.commands.finalize import finalize as finalize_cmd
175
+ from dh_cli.batch.manifest import JobManifest, JobStatus
176
+
177
+ manifest = JobManifest(
178
+ job_id="test-boltz",
179
+ user="tester",
180
+ pipeline="boltz",
181
+ storage_mode="s3",
182
+ status=JobStatus.SUCCEEDED,
183
+ s3_prefix="s3://bucket/jobs/test-boltz/",
184
+ )
185
+
186
+ from click.testing import CliRunner
187
+
188
+ runner = CliRunner()
189
+
190
+ with patch("dh_cli.batch.commands.finalize.load_manifest", return_value=manifest):
191
+ with patch(
192
+ "dh_cli.batch.commands.finalize._download_boltz_s3_output"
193
+ ) as mock_tar_download:
194
+ with patch(
195
+ "dh_cli.batch.commands.finalize._check_completion", return_value=[]
196
+ ):
197
+ with patch("dh_cli.batch.commands.finalize._finalize_boltz") as mock_fb:
198
+ with patch("dh_cli.batch.commands.finalize.save_manifest_s3"):
199
+ result = runner.invoke(
200
+ finalize_cmd,
201
+ [
202
+ "test-boltz",
203
+ "--output",
204
+ str(temp_dir / "final"),
205
+ "--keep-intermediates",
206
+ ],
207
+ input="y\n",
208
+ )
209
+
210
+ assert result.exit_code == 0, result.output
211
+ assert mock_tar_download.called, "tar-aware download helper must be called for Boltz S3"
212
+ assert mock_fb.called
213
+
214
+ def test_non_boltz_s3_uses_download_directory(self, temp_dir):
215
+ """Non-Boltz pipelines (e.g. embed-t5) in S3 mode keep using download_directory."""
216
+ from dh_cli.batch.commands.finalize import finalize as finalize_cmd
217
+ from dh_cli.batch.manifest import JobManifest, JobStatus
218
+
219
+ manifest = JobManifest(
220
+ job_id="test-embed",
221
+ user="tester",
222
+ pipeline="embed-t5",
223
+ storage_mode="s3",
224
+ status=JobStatus.SUCCEEDED,
225
+ s3_prefix="s3://bucket/jobs/test-embed/",
226
+ )
227
+
228
+ from click.testing import CliRunner
229
+
230
+ runner = CliRunner()
231
+
232
+ with patch("dh_cli.batch.commands.finalize.load_manifest", return_value=manifest):
233
+ with patch(
234
+ "dh_cli.batch.commands.finalize._download_boltz_s3_output"
235
+ ) as mock_tar_download:
236
+ with patch("dh_cli.batch.s3_transport.download_directory") as mock_dd:
237
+ with patch(
238
+ "dh_cli.batch.commands.finalize._check_completion", return_value=[]
239
+ ):
240
+ with patch(
241
+ "dh_cli.batch.commands.finalize._finalize_embeddings"
242
+ ) as mock_fe:
243
+ with patch("dh_cli.batch.commands.finalize.save_manifest_s3"):
244
+ result = runner.invoke(
245
+ finalize_cmd,
246
+ [
247
+ "test-embed",
248
+ "--output",
249
+ str(temp_dir / "out.h5"),
250
+ "--keep-intermediates",
251
+ ],
252
+ )
253
+
254
+ assert result.exit_code == 0, result.output
255
+ assert mock_dd.called, "non-Boltz pipelines should keep using download_directory"
256
+ assert not mock_tar_download.called, "tar-aware helper must not fire for embed-t5"
257
+ assert mock_fe.called
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes