dayhoff-tools 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. dayhoff_tools/batch/__init__.py +8 -0
  2. dayhoff_tools/batch/workers/__init__.py +12 -0
  3. dayhoff_tools/batch/workers/base.py +150 -0
  4. dayhoff_tools/batch/workers/boltz.py +407 -0
  5. dayhoff_tools/batch/workers/embed_t5.py +92 -0
  6. dayhoff_tools/cli/batch/__init__.py +85 -0
  7. dayhoff_tools/cli/batch/aws_batch.py +401 -0
  8. dayhoff_tools/cli/batch/commands/__init__.py +25 -0
  9. dayhoff_tools/cli/batch/commands/boltz.py +362 -0
  10. dayhoff_tools/cli/batch/commands/cancel.py +82 -0
  11. dayhoff_tools/cli/batch/commands/embed_t5.py +303 -0
  12. dayhoff_tools/cli/batch/commands/finalize.py +206 -0
  13. dayhoff_tools/cli/batch/commands/list_jobs.py +78 -0
  14. dayhoff_tools/cli/batch/commands/local.py +95 -0
  15. dayhoff_tools/cli/batch/commands/logs.py +142 -0
  16. dayhoff_tools/cli/batch/commands/retry.py +142 -0
  17. dayhoff_tools/cli/batch/commands/status.py +214 -0
  18. dayhoff_tools/cli/batch/commands/submit.py +215 -0
  19. dayhoff_tools/cli/batch/job_id.py +151 -0
  20. dayhoff_tools/cli/batch/manifest.py +293 -0
  21. dayhoff_tools/cli/engines_studios/engine-studio-cli.md +26 -21
  22. dayhoff_tools/cli/engines_studios/engine_commands.py +16 -89
  23. dayhoff_tools/cli/engines_studios/ssh_config.py +96 -0
  24. dayhoff_tools/cli/engines_studios/studio_commands.py +13 -2
  25. dayhoff_tools/cli/main.py +51 -10
  26. {dayhoff_tools-1.14.0.dist-info → dayhoff_tools-1.14.2.dist-info}/METADATA +6 -1
  27. {dayhoff_tools-1.14.0.dist-info → dayhoff_tools-1.14.2.dist-info}/RECORD +29 -8
  28. {dayhoff_tools-1.14.0.dist-info → dayhoff_tools-1.14.2.dist-info}/WHEEL +0 -0
  29. {dayhoff_tools-1.14.0.dist-info → dayhoff_tools-1.14.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,85 @@
1
+ """Batch job management CLI for AWS Batch.
2
+
3
+ This module provides a Click-based CLI for submitting and managing batch jobs
4
+ on AWS Batch, with support for:
5
+ - High-level pipelines (embed-t5, boltz)
6
+ - Generic job submission
7
+ - Job lifecycle management (status, logs, cancel, retry, finalize)
8
+ - Local debugging and shell access
9
+ """
10
+
11
+ import click
12
+
13
+ from .commands.boltz import boltz
14
+ from .commands.cancel import cancel
15
+ from .commands.embed_t5 import embed_t5
16
+ from .commands.finalize import finalize
17
+ from .commands.list_jobs import list_jobs
18
+ from .commands.local import local
19
+ from .commands.logs import logs
20
+ from .commands.retry import retry
21
+ from .commands.status import status
22
+ from .commands.submit import submit
23
+
24
+
25
+ @click.group()
26
+ def batch_cli():
27
+ """Manage batch jobs on AWS Batch.
28
+
29
+ \b
30
+ Job Management:
31
+ submit Submit a custom job from config file
32
+ status Show job status
33
+ cancel Cancel a running job
34
+ logs View job logs
35
+ retry Retry failed chunks
36
+ finalize Combine results and clean up
37
+ local Run a chunk locally for debugging
38
+ list List recent jobs
39
+
40
+ \b
41
+ Embedding Pipelines:
42
+ embed-t5 Generate T5 protein embeddings
43
+
44
+ \b
45
+ Structure Prediction:
46
+ boltz Predict protein structures with Boltz
47
+
48
+ \b
49
+ Examples:
50
+ # Submit an embedding job
51
+ dh batch embed-t5 /primordial/proteins.fasta --workers 50
52
+
53
+ # Submit a structure prediction job
54
+ dh batch boltz /primordial/complexes/ --workers 100
55
+
56
+ # Check job status
57
+ dh batch status dma-embed-20260109-a3f2
58
+
59
+ # View logs for a failed chunk
60
+ dh batch logs dma-embed-20260109-a3f2 --index 27
61
+
62
+ # Retry failed chunks
63
+ dh batch retry dma-embed-20260109-a3f2
64
+
65
+ # Finalize and combine results
66
+ dh batch finalize dma-embed-20260109-a3f2 --output /primordial/embeddings.h5
67
+ """
68
+ pass
69
+
70
+
71
+ # Register job management commands
72
+ batch_cli.add_command(submit)
73
+ batch_cli.add_command(status)
74
+ batch_cli.add_command(cancel)
75
+ batch_cli.add_command(logs)
76
+ batch_cli.add_command(retry)
77
+ batch_cli.add_command(finalize)
78
+ batch_cli.add_command(local)
79
+ batch_cli.add_command(list_jobs, name="list")
80
+
81
+ # Register pipeline commands
82
+ batch_cli.add_command(embed_t5, name="embed-t5")
83
+ batch_cli.add_command(boltz)
84
+
85
+ __all__ = ["batch_cli"]
@@ -0,0 +1,401 @@
1
+ """AWS Batch client wrapper for job submission and management."""
2
+
3
+ import logging
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import boto3
9
+ from botocore.exceptions import ClientError
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class BatchError(Exception):
15
+ """Error interacting with AWS Batch."""
16
+
17
+ pass
18
+
19
+
20
+ @dataclass
21
+ class ArrayJobStatus:
22
+ """Aggregated status for an array job."""
23
+
24
+ total: int
25
+ pending: int
26
+ runnable: int
27
+ starting: int
28
+ running: int
29
+ succeeded: int
30
+ failed: int
31
+
32
+ @property
33
+ def completed(self) -> int:
34
+ return self.succeeded + self.failed
35
+
36
+ @property
37
+ def in_progress(self) -> int:
38
+ return self.pending + self.runnable + self.starting + self.running
39
+
40
+ @property
41
+ def is_complete(self) -> bool:
42
+ return self.completed == self.total
43
+
44
+ @property
45
+ def success_rate(self) -> float:
46
+ if self.completed == 0:
47
+ return 0.0
48
+ return self.succeeded / self.completed
49
+
50
+
51
+ class BatchClient:
52
+ """Client for interacting with AWS Batch."""
53
+
54
+ def __init__(self, region: str = "us-east-1"):
55
+ """Initialize the Batch client.
56
+
57
+ Args:
58
+ region: AWS region
59
+ """
60
+ self.batch = boto3.client("batch", region_name=region)
61
+ self.logs = boto3.client("logs", region_name=region)
62
+ self.region = region
63
+
64
+ def submit_job(
65
+ self,
66
+ job_name: str,
67
+ job_definition: str,
68
+ job_queue: str,
69
+ array_size: int | None = None,
70
+ environment: dict[str, str] | None = None,
71
+ parameters: dict[str, str] | None = None,
72
+ timeout_seconds: int | None = None,
73
+ retry_attempts: int = 3,
74
+ depends_on: list[dict] | None = None,
75
+ ) -> str:
76
+ """Submit a job to AWS Batch.
77
+
78
+ Args:
79
+ job_name: Name for the job
80
+ job_definition: Job definition name or ARN
81
+ job_queue: Queue to submit to
82
+ array_size: Size of array job (None for single job)
83
+ environment: Environment variables
84
+ parameters: Job parameters
85
+ timeout_seconds: Job timeout in seconds
86
+ retry_attempts: Number of retry attempts
87
+ depends_on: Job dependencies
88
+
89
+ Returns:
90
+ AWS Batch job ID
91
+
92
+ Raises:
93
+ BatchError: If submission fails
94
+ """
95
+ try:
96
+ submit_args: dict[str, Any] = {
97
+ "jobName": job_name,
98
+ "jobDefinition": job_definition,
99
+ "jobQueue": job_queue,
100
+ "retryStrategy": {"attempts": retry_attempts},
101
+ }
102
+
103
+ if array_size and array_size > 1:
104
+ submit_args["arrayProperties"] = {"size": array_size}
105
+
106
+ if environment:
107
+ submit_args["containerOverrides"] = {
108
+ "environment": [
109
+ {"name": k, "value": v} for k, v in environment.items()
110
+ ]
111
+ }
112
+
113
+ if parameters:
114
+ submit_args["parameters"] = parameters
115
+
116
+ if timeout_seconds:
117
+ submit_args["timeout"] = {"attemptDurationSeconds": timeout_seconds}
118
+
119
+ if depends_on:
120
+ submit_args["dependsOn"] = depends_on
121
+
122
+ response = self.batch.submit_job(**submit_args)
123
+ job_id = response["jobId"]
124
+ logger.info(f"Submitted job {job_name} with ID {job_id}")
125
+ return job_id
126
+
127
+ except ClientError as e:
128
+ raise BatchError(f"Failed to submit job: {e}")
129
+
130
+ def submit_array_job_with_indices(
131
+ self,
132
+ job_name: str,
133
+ job_definition: str,
134
+ job_queue: str,
135
+ indices: list[int],
136
+ environment: dict[str, str] | None = None,
137
+ timeout_seconds: int | None = None,
138
+ retry_attempts: int = 3,
139
+ ) -> str:
140
+ """Submit an array job for specific indices only.
141
+
142
+ Used for retrying specific failed chunks.
143
+
144
+ Args:
145
+ job_name: Name for the job
146
+ job_definition: Job definition name or ARN
147
+ job_queue: Queue to submit to
148
+ indices: Specific array indices to run
149
+ environment: Environment variables
150
+ timeout_seconds: Job timeout in seconds
151
+ retry_attempts: Number of retry attempts
152
+
153
+ Returns:
154
+ AWS Batch job ID
155
+ """
156
+ # For small number of indices, submit individual jobs
157
+ # For larger numbers, we could use array job with index selection
158
+ # AWS Batch doesn't natively support sparse array indices, so we use a workaround
159
+
160
+ if len(indices) == 1:
161
+ # Single job, set index via environment
162
+ env = environment.copy() if environment else {}
163
+ env["AWS_BATCH_JOB_ARRAY_INDEX"] = str(indices[0])
164
+ return self.submit_job(
165
+ job_name=job_name,
166
+ job_definition=job_definition,
167
+ job_queue=job_queue,
168
+ array_size=None,
169
+ environment=env,
170
+ timeout_seconds=timeout_seconds,
171
+ retry_attempts=retry_attempts,
172
+ )
173
+ else:
174
+ # For multiple indices, we pass them as a comma-separated list
175
+ # The worker will pick its index from this list based on array index
176
+ env = environment.copy() if environment else {}
177
+ env["BATCH_RETRY_INDICES"] = ",".join(str(i) for i in indices)
178
+ return self.submit_job(
179
+ job_name=job_name,
180
+ job_definition=job_definition,
181
+ job_queue=job_queue,
182
+ array_size=len(indices),
183
+ environment=env,
184
+ timeout_seconds=timeout_seconds,
185
+ retry_attempts=retry_attempts,
186
+ )
187
+
188
+ def describe_job(self, job_id: str) -> dict:
189
+ """Get details for a specific job.
190
+
191
+ Args:
192
+ job_id: AWS Batch job ID
193
+
194
+ Returns:
195
+ Job details dictionary
196
+
197
+ Raises:
198
+ BatchError: If job not found
199
+ """
200
+ try:
201
+ response = self.batch.describe_jobs(jobs=[job_id])
202
+ if not response.get("jobs"):
203
+ raise BatchError(f"Job not found: {job_id}")
204
+ return response["jobs"][0]
205
+ except ClientError as e:
206
+ raise BatchError(f"Failed to describe job: {e}")
207
+
208
+ def get_array_job_status(self, job_id: str) -> ArrayJobStatus:
209
+ """Get aggregated status for an array job.
210
+
211
+ Args:
212
+ job_id: AWS Batch job ID (parent array job)
213
+
214
+ Returns:
215
+ ArrayJobStatus with counts for each status
216
+ """
217
+ job = self.describe_job(job_id)
218
+
219
+ if "arrayProperties" not in job:
220
+ # Single job, not an array
221
+ status = job.get("status", "UNKNOWN")
222
+ return ArrayJobStatus(
223
+ total=1,
224
+ pending=1 if status == "PENDING" else 0,
225
+ runnable=1 if status == "RUNNABLE" else 0,
226
+ starting=1 if status == "STARTING" else 0,
227
+ running=1 if status == "RUNNING" else 0,
228
+ succeeded=1 if status == "SUCCEEDED" else 0,
229
+ failed=1 if status == "FAILED" else 0,
230
+ )
231
+
232
+ # Get status summary from array properties
233
+ status_summary = job.get("arrayProperties", {}).get("statusSummary", {})
234
+
235
+ return ArrayJobStatus(
236
+ total=job.get("arrayProperties", {}).get("size", 0),
237
+ pending=status_summary.get("PENDING", 0),
238
+ runnable=status_summary.get("RUNNABLE", 0),
239
+ starting=status_summary.get("STARTING", 0),
240
+ running=status_summary.get("RUNNING", 0),
241
+ succeeded=status_summary.get("SUCCEEDED", 0),
242
+ failed=status_summary.get("FAILED", 0),
243
+ )
244
+
245
+ def get_failed_indices(self, job_id: str) -> list[int]:
246
+ """Get the array indices that failed for an array job.
247
+
248
+ Args:
249
+ job_id: AWS Batch job ID (parent array job)
250
+
251
+ Returns:
252
+ List of failed array indices
253
+ """
254
+ failed_indices = []
255
+
256
+ # List child jobs with FAILED status
257
+ try:
258
+ paginator = self.batch.get_paginator("list_jobs")
259
+ for page in paginator.paginate(
260
+ arrayJobId=job_id, jobStatus="FAILED"
261
+ ):
262
+ for job_summary in page.get("jobSummaryList", []):
263
+ # Extract array index from job ID (format: jobId:index)
264
+ child_id = job_summary.get("jobId", "")
265
+ if ":" in child_id:
266
+ index = int(child_id.split(":")[-1])
267
+ failed_indices.append(index)
268
+ except ClientError as e:
269
+ logger.warning(f"Failed to list child jobs: {e}")
270
+
271
+ return sorted(failed_indices)
272
+
273
+ def cancel_job(self, job_id: str, reason: str = "Cancelled by user") -> None:
274
+ """Cancel a job.
275
+
276
+ Args:
277
+ job_id: AWS Batch job ID
278
+ reason: Cancellation reason
279
+
280
+ Raises:
281
+ BatchError: If cancellation fails
282
+ """
283
+ try:
284
+ self.batch.cancel_job(jobId=job_id, reason=reason)
285
+ logger.info(f"Cancelled job {job_id}")
286
+ except ClientError as e:
287
+ raise BatchError(f"Failed to cancel job: {e}")
288
+
289
+ def terminate_job(self, job_id: str, reason: str = "Terminated by user") -> None:
290
+ """Terminate a running job.
291
+
292
+ Args:
293
+ job_id: AWS Batch job ID
294
+ reason: Termination reason
295
+
296
+ Raises:
297
+ BatchError: If termination fails
298
+ """
299
+ try:
300
+ self.batch.terminate_job(jobId=job_id, reason=reason)
301
+ logger.info(f"Terminated job {job_id}")
302
+ except ClientError as e:
303
+ raise BatchError(f"Failed to terminate job: {e}")
304
+
305
+ def get_log_stream_name(self, job_id: str) -> str | None:
306
+ """Get the CloudWatch log stream name for a job.
307
+
308
+ Args:
309
+ job_id: AWS Batch job ID
310
+
311
+ Returns:
312
+ Log stream name, or None if not available
313
+ """
314
+ try:
315
+ job = self.describe_job(job_id)
316
+ container = job.get("container", {})
317
+ return container.get("logStreamName")
318
+ except BatchError:
319
+ return None
320
+
321
+ def get_logs(
322
+ self,
323
+ job_id: str,
324
+ log_group: str = "/aws/batch/job",
325
+ tail: int = 100,
326
+ start_time: int | None = None,
327
+ follow: bool = False,
328
+ ) -> list[str]:
329
+ """Get CloudWatch logs for a job.
330
+
331
+ Args:
332
+ job_id: AWS Batch job ID
333
+ log_group: CloudWatch log group name
334
+ tail: Number of lines to return (from end)
335
+ start_time: Start time in milliseconds since epoch
336
+ follow: If True, continue polling for new logs
337
+
338
+ Returns:
339
+ List of log messages
340
+ """
341
+ log_stream = self.get_log_stream_name(job_id)
342
+ if not log_stream:
343
+ return [f"No logs available for job {job_id}"]
344
+
345
+ messages = []
346
+
347
+ try:
348
+ kwargs: dict[str, Any] = {
349
+ "logGroupName": log_group,
350
+ "logStreamName": log_stream,
351
+ "limit": tail,
352
+ "startFromHead": False,
353
+ }
354
+
355
+ if start_time:
356
+ kwargs["startTime"] = start_time
357
+
358
+ response = self.logs.get_log_events(**kwargs)
359
+
360
+ for event in response.get("events", []):
361
+ timestamp = event.get("timestamp", 0)
362
+ message = event.get("message", "")
363
+ # Format timestamp
364
+ dt = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp / 1000))
365
+ messages.append(f"[{dt}] {message}")
366
+
367
+ except ClientError as e:
368
+ messages.append(f"Error fetching logs: {e}")
369
+
370
+ return messages
371
+
372
+ def wait_for_job(
373
+ self, job_id: str, poll_interval: int = 30, timeout: int = 86400
374
+ ) -> str:
375
+ """Wait for a job to complete.
376
+
377
+ Args:
378
+ job_id: AWS Batch job ID
379
+ poll_interval: Seconds between status checks
380
+ timeout: Maximum seconds to wait
381
+
382
+ Returns:
383
+ Final job status
384
+
385
+ Raises:
386
+ BatchError: If timeout exceeded
387
+ """
388
+ start_time = time.time()
389
+
390
+ while True:
391
+ if time.time() - start_time > timeout:
392
+ raise BatchError(f"Timeout waiting for job {job_id}")
393
+
394
+ job = self.describe_job(job_id)
395
+ status = job.get("status")
396
+
397
+ if status in ("SUCCEEDED", "FAILED"):
398
+ return status
399
+
400
+ logger.info(f"Job {job_id} status: {status}")
401
+ time.sleep(poll_interval)
@@ -0,0 +1,25 @@
1
+ """Batch CLI commands."""
2
+
3
+ from .boltz import boltz
4
+ from .cancel import cancel
5
+ from .embed_t5 import embed_t5
6
+ from .finalize import finalize
7
+ from .list_jobs import list_jobs
8
+ from .local import local
9
+ from .logs import logs
10
+ from .retry import retry
11
+ from .status import status
12
+ from .submit import submit
13
+
14
+ __all__ = [
15
+ "boltz",
16
+ "cancel",
17
+ "embed_t5",
18
+ "finalize",
19
+ "list_jobs",
20
+ "local",
21
+ "logs",
22
+ "retry",
23
+ "status",
24
+ "submit",
25
+ ]