dayhoff-tools 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dayhoff_tools/__init__.py +0 -0
- dayhoff_tools/chemistry/standardizer.py +297 -0
- dayhoff_tools/chemistry/utils.py +63 -0
- dayhoff_tools/cli/__init__.py +0 -0
- dayhoff_tools/cli/main.py +90 -0
- dayhoff_tools/cli/swarm_commands.py +156 -0
- dayhoff_tools/cli/utility_commands.py +244 -0
- dayhoff_tools/deployment/base.py +434 -0
- dayhoff_tools/deployment/deploy_aws.py +458 -0
- dayhoff_tools/deployment/deploy_gcp.py +176 -0
- dayhoff_tools/deployment/deploy_utils.py +781 -0
- dayhoff_tools/deployment/job_runner.py +153 -0
- dayhoff_tools/deployment/processors.py +125 -0
- dayhoff_tools/deployment/swarm.py +591 -0
- dayhoff_tools/embedders.py +893 -0
- dayhoff_tools/fasta.py +1082 -0
- dayhoff_tools/file_ops.py +261 -0
- dayhoff_tools/gcp.py +85 -0
- dayhoff_tools/h5.py +542 -0
- dayhoff_tools/kegg.py +37 -0
- dayhoff_tools/logs.py +27 -0
- dayhoff_tools/mmseqs.py +164 -0
- dayhoff_tools/sqlite.py +516 -0
- dayhoff_tools/structure.py +751 -0
- dayhoff_tools/uniprot.py +434 -0
- dayhoff_tools/warehouse.py +418 -0
- dayhoff_tools-1.0.0.dist-info/METADATA +122 -0
- dayhoff_tools-1.0.0.dist-info/RECORD +30 -0
- dayhoff_tools-1.0.0.dist-info/WHEEL +4 -0
- dayhoff_tools-1.0.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,458 @@
|
|
1
|
+
"""AWS deployment functionality for running jobs on AWS Batch."""
|
2
|
+
|
3
|
+
import base64
|
4
|
+
import datetime
|
5
|
+
import os
|
6
|
+
import re
|
7
|
+
import subprocess
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Optional
|
10
|
+
|
11
|
+
import boto3
|
12
|
+
import yaml
|
13
|
+
from botocore.exceptions import (
|
14
|
+
ClientError,
|
15
|
+
NoCredentialsError,
|
16
|
+
ProfileNotFound,
|
17
|
+
SSOTokenLoadError,
|
18
|
+
)
|
19
|
+
from dayhoff_tools.deployment.deploy_utils import docker_login, get_container_env_vars
|
20
|
+
|
21
|
+
|
22
|
+
def get_boto_session(config: dict[str, Any]) -> boto3.Session:
|
23
|
+
"""Creates a Boto3 session using the profile specified in the config.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
config: Dictionary containing the configuration loaded from YAML
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
A Boto3 session object
|
30
|
+
|
31
|
+
Raises:
|
32
|
+
RuntimeError: If the profile is not specified in the config or if credentials cannot be loaded
|
33
|
+
"""
|
34
|
+
aws_config = config.get("aws", {})
|
35
|
+
profile_name = aws_config.get("aws_profile")
|
36
|
+
region = aws_config.get("region", "us-east-1")
|
37
|
+
|
38
|
+
if not profile_name:
|
39
|
+
print(
|
40
|
+
"Warning: aws.aws_profile not specified in config. Using default AWS credential chain in ~/.aws/config."
|
41
|
+
)
|
42
|
+
return boto3.Session(region_name=region)
|
43
|
+
|
44
|
+
try:
|
45
|
+
print(f"Using AWS profile: {profile_name}")
|
46
|
+
session = boto3.Session(profile_name=profile_name, region_name=region)
|
47
|
+
sts = session.client("sts")
|
48
|
+
sts.get_caller_identity()
|
49
|
+
return session
|
50
|
+
except ProfileNotFound:
|
51
|
+
raise RuntimeError(
|
52
|
+
f"AWS profile '{profile_name}' not found in `~/.aws/config`."
|
53
|
+
)
|
54
|
+
except (NoCredentialsError, ClientError, SSOTokenLoadError) as e:
|
55
|
+
raise RuntimeError(
|
56
|
+
f"Could not load credentials for AWS profile '{profile_name}'. "
|
57
|
+
f"Ensure you are logged in via SSO ('aws sso login --profile {profile_name}') "
|
58
|
+
f"or have valid credentials configured. Original error: {e}"
|
59
|
+
) from e
|
60
|
+
|
61
|
+
|
62
|
+
def _extract_run_name_from_config(
|
63
|
+
config: dict[str, Any], _test_file_content: Optional[str] = None
|
64
|
+
) -> Optional[str]:
|
65
|
+
"""Extract run name from the config file referenced in job_command.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
config: Dictionary containing the configuration loaded from YAML
|
69
|
+
_test_file_content: Optional parameter for testing to override file content
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
Run name if found, None otherwise
|
73
|
+
"""
|
74
|
+
# Check if features and job_command exist in config
|
75
|
+
if "features" not in config:
|
76
|
+
return None
|
77
|
+
|
78
|
+
# Find job_command in features
|
79
|
+
job_command = None
|
80
|
+
for feature in config["features"]:
|
81
|
+
if isinstance(feature, dict) and "job_command" in feature:
|
82
|
+
job_command = feature["job_command"]
|
83
|
+
break
|
84
|
+
elif isinstance(feature, str) and feature.startswith("job_command:"):
|
85
|
+
job_command = feature.split(":", 1)[1].strip()
|
86
|
+
break
|
87
|
+
|
88
|
+
if not job_command:
|
89
|
+
return None
|
90
|
+
|
91
|
+
# Extract config file path using regex
|
92
|
+
config_match = re.search(r'--config=([^\s"]+)', job_command)
|
93
|
+
if not config_match:
|
94
|
+
return None
|
95
|
+
|
96
|
+
config_path = config_match.group(1)
|
97
|
+
|
98
|
+
# For testing, we can bypass file operations
|
99
|
+
if _test_file_content is not None:
|
100
|
+
run_config = yaml.safe_load(_test_file_content)
|
101
|
+
else:
|
102
|
+
# Resolve path relative to repo root
|
103
|
+
# Assuming we're in the repo root when this function is called
|
104
|
+
full_config_path = Path(config_path)
|
105
|
+
|
106
|
+
# Check if file exists
|
107
|
+
if not full_config_path.exists():
|
108
|
+
print(f"Warning: Config file {full_config_path} not found")
|
109
|
+
return None
|
110
|
+
|
111
|
+
try:
|
112
|
+
# Load the config file
|
113
|
+
with open(full_config_path, "r") as f:
|
114
|
+
run_config = yaml.safe_load(f)
|
115
|
+
except Exception as e:
|
116
|
+
print(f"Warning: Failed to extract run_name from {full_config_path}: {e}")
|
117
|
+
return None
|
118
|
+
|
119
|
+
# Extract run_name from wandb section
|
120
|
+
if (
|
121
|
+
run_config
|
122
|
+
and "init" in run_config
|
123
|
+
and "wandb" in run_config["init"]
|
124
|
+
and "run_name" in run_config["init"]["wandb"]
|
125
|
+
):
|
126
|
+
return run_config["init"]["wandb"]["run_name"]
|
127
|
+
|
128
|
+
return None
|
129
|
+
|
130
|
+
|
131
|
+
def _extract_job_name_from_uri(image_uri: str, config: dict[str, Any]) -> str:
|
132
|
+
"""Extract job name from image URI and config.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
image_uri: Full URI of the container image
|
136
|
+
config: Dictionary containing the configuration loaded from YAML
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
Job name in format: username__jobname__uniquehex
|
140
|
+
|
141
|
+
The job name components are:
|
142
|
+
- username: The LOCAL_USER environment variable or "unknown_user"
|
143
|
+
- jobname: Either from aws.job_name config, wandb run_name, or omitted
|
144
|
+
- uniquehex: A hexadecimal string (typically 7 characters)
|
145
|
+
representing seconds since January 1, 2020
|
146
|
+
|
147
|
+
If aws.job_name is specified in the config:
|
148
|
+
- If aws.job_name is "use_wandb_run_name", the run_name from the wandb config will be used
|
149
|
+
- Otherwise, the specified job_name will be used
|
150
|
+
|
151
|
+
If no job_name is specified or if "use_wandb_run_name" is specified but no run_name is found,
|
152
|
+
the format will be username__uniquehex (without the middle component).
|
153
|
+
"""
|
154
|
+
# Get username from environment
|
155
|
+
username = os.getenv("LOCAL_USER", "unknown_user")
|
156
|
+
|
157
|
+
# Generate a unique hex string based on seconds since 2020-01-01
|
158
|
+
epoch_2020 = datetime.datetime(2020, 1, 1).timestamp()
|
159
|
+
current_time = datetime.datetime.now().timestamp()
|
160
|
+
seconds_since_2020 = int(current_time - epoch_2020)
|
161
|
+
unique_hex = format(
|
162
|
+
seconds_since_2020, "x"
|
163
|
+
) # Simple hex representation, typically 7 chars
|
164
|
+
|
165
|
+
# Check if job_name is specified in AWS config
|
166
|
+
if "aws" in config and "job_name" in config["aws"]:
|
167
|
+
job_name = config["aws"]["job_name"]
|
168
|
+
|
169
|
+
# Special handling for "use_wandb_run_name"
|
170
|
+
if job_name == "use_wandb_run_name":
|
171
|
+
# Get run name from config
|
172
|
+
run_name = _extract_run_name_from_config(config)
|
173
|
+
if run_name:
|
174
|
+
# Return username__runname__uniquehex
|
175
|
+
return f"{username}__{run_name}__{unique_hex}"
|
176
|
+
else:
|
177
|
+
# Use the specified job_name but format it as username__jobname__uniquehex
|
178
|
+
return f"{username}__{job_name}__{unique_hex}"
|
179
|
+
|
180
|
+
# Default behavior if job_name is not specified or "use_wandb_run_name" with no run_name
|
181
|
+
# Get run name from config
|
182
|
+
run_name = _extract_run_name_from_config(config)
|
183
|
+
|
184
|
+
# Build job name based on available components
|
185
|
+
job_name_parts = []
|
186
|
+
|
187
|
+
# Add username (use default if None)
|
188
|
+
job_name_parts.append(username or "unknown_user")
|
189
|
+
|
190
|
+
# Add run name if available
|
191
|
+
if run_name:
|
192
|
+
job_name_parts.append(run_name)
|
193
|
+
|
194
|
+
# Add the unique hex string
|
195
|
+
job_name_parts.append(unique_hex)
|
196
|
+
|
197
|
+
# Join parts with double underscores
|
198
|
+
return "__".join(job_name_parts)
|
199
|
+
|
200
|
+
|
201
|
+
def push_image_to_ecr(image_uri: str, config: dict[str, Any]) -> None:
|
202
|
+
"""Push a Docker image to Amazon ECR.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
image_uri: Full URI of the image to push
|
206
|
+
config: Dictionary containing the configuration loaded from YAML
|
207
|
+
|
208
|
+
Raises:
|
209
|
+
subprocess.CalledProcessError: If ECR login or push fails
|
210
|
+
ClientError: If ECR authentication fails
|
211
|
+
"""
|
212
|
+
print("\nPushing image to ECR")
|
213
|
+
|
214
|
+
session = get_boto_session(config)
|
215
|
+
aws_config = config["aws"]
|
216
|
+
registry = aws_config["registry_uri"]
|
217
|
+
|
218
|
+
# Get ECR login token using the specific session client
|
219
|
+
ecr_client = session.client("ecr")
|
220
|
+
try:
|
221
|
+
token = ecr_client.get_authorization_token()
|
222
|
+
username, password = (
|
223
|
+
base64.b64decode(token["authorizationData"][0]["authorizationToken"])
|
224
|
+
.decode()
|
225
|
+
.split(":")
|
226
|
+
)
|
227
|
+
|
228
|
+
# Login to Docker registry
|
229
|
+
docker_login(registry, username, password)
|
230
|
+
|
231
|
+
# Push the image
|
232
|
+
print(f"\nPushing image: {image_uri}")
|
233
|
+
subprocess.run(["docker", "push", image_uri], check=True)
|
234
|
+
|
235
|
+
print(f"\nSuccessfully pushed image to ECR: {image_uri}")
|
236
|
+
|
237
|
+
except ClientError as e:
|
238
|
+
print(f"AWS ECR error using profile '{session.profile_name}': {str(e)}")
|
239
|
+
raise
|
240
|
+
except subprocess.CalledProcessError as e:
|
241
|
+
print(f"Docker command failed: {e.stderr.decode() if e.stderr else str(e)}")
|
242
|
+
raise
|
243
|
+
|
244
|
+
|
245
|
+
def create_or_update_job_definition(
|
246
|
+
image_uri: str,
|
247
|
+
config: dict[str, Any],
|
248
|
+
) -> str:
|
249
|
+
"""Create or update a job definition for the container.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
image_uri: Full URI of the container image
|
253
|
+
config: Dictionary containing the configuration loaded from YAML
|
254
|
+
|
255
|
+
Returns:
|
256
|
+
Name of the created/updated job definition
|
257
|
+
|
258
|
+
Raises:
|
259
|
+
ValueError: If job_role_arn is not specified in AWS configuration
|
260
|
+
"""
|
261
|
+
session = get_boto_session(config)
|
262
|
+
aws_config = config["aws"]
|
263
|
+
|
264
|
+
# Verify that job_role_arn is present
|
265
|
+
if "job_role_arn" not in aws_config:
|
266
|
+
raise ValueError(
|
267
|
+
"job_role_arn must be specified in AWS configuration. "
|
268
|
+
"This role is required for your container to access AWS resources after the "
|
269
|
+
"initial 15-minute temporary credential window."
|
270
|
+
)
|
271
|
+
|
272
|
+
batch = session.client("batch")
|
273
|
+
base_name = config["docker"]["base_name"]
|
274
|
+
job_def_name = f"job-def-{base_name}"
|
275
|
+
|
276
|
+
# Get compute specs from config
|
277
|
+
compute_specs = aws_config["batch_job"]["compute_specs"]
|
278
|
+
gpu_requirements = (
|
279
|
+
[{"type": "GPU", "value": str(compute_specs["gpus"])}]
|
280
|
+
if compute_specs.get("gpus", 0) > 0
|
281
|
+
else []
|
282
|
+
)
|
283
|
+
|
284
|
+
entrypoint_command = config["docker"].get("container_entrypoint")
|
285
|
+
if entrypoint_command is None:
|
286
|
+
raise ValueError("docker.container_entrypoint is required in configuration")
|
287
|
+
|
288
|
+
# Create linux parameters with devices
|
289
|
+
linux_params: dict[str, Any] = {
|
290
|
+
"devices": [
|
291
|
+
{
|
292
|
+
"hostPath": "/dev/nvidia0",
|
293
|
+
"containerPath": "/dev/nvidia0",
|
294
|
+
"permissions": ["READ", "WRITE"],
|
295
|
+
},
|
296
|
+
],
|
297
|
+
}
|
298
|
+
|
299
|
+
# Add shared memory configuration if specified in docker config
|
300
|
+
if "shared_memory" in config.get("docker", {}):
|
301
|
+
shared_mem = config["docker"]["shared_memory"]
|
302
|
+
# Convert to MiB (e.g., "16g" -> 16384 MiB)
|
303
|
+
if isinstance(shared_mem, str):
|
304
|
+
if shared_mem.endswith("g"):
|
305
|
+
# Convert GB to MiB (1G = 1024 MiB)
|
306
|
+
shared_memory_mib = int(float(shared_mem[:-1]) * 1024)
|
307
|
+
elif shared_mem.endswith("m"):
|
308
|
+
# Convert MB to MiB (approximate conversion)
|
309
|
+
shared_memory_mib = int(float(shared_mem[:-1]))
|
310
|
+
else:
|
311
|
+
# Assume the value is already in MiB
|
312
|
+
shared_memory_mib = int(float(shared_mem))
|
313
|
+
else:
|
314
|
+
# Assume the value is already in MiB if not a string
|
315
|
+
shared_memory_mib = int(shared_mem)
|
316
|
+
|
317
|
+
# Add shared memory size to linux parameters
|
318
|
+
linux_params["sharedMemorySize"] = shared_memory_mib
|
319
|
+
print(f"Setting shared memory size to {shared_memory_mib} MiB")
|
320
|
+
|
321
|
+
# Check if job definition already exists using the session client
|
322
|
+
try:
|
323
|
+
existing = batch.describe_job_definitions(
|
324
|
+
jobDefinitionName=job_def_name, status="ACTIVE"
|
325
|
+
)["jobDefinitions"]
|
326
|
+
|
327
|
+
if existing:
|
328
|
+
print(f"\nUpdating existing job definition: {job_def_name}")
|
329
|
+
else:
|
330
|
+
print(f"\nCreating new job definition: {job_def_name}")
|
331
|
+
|
332
|
+
except batch.exceptions.ClientError as e:
|
333
|
+
# Handle case where the error is specifically 'JobDefinitionNotFoundException'
|
334
|
+
# Boto3 typically includes error codes in the response
|
335
|
+
if (
|
336
|
+
e.response.get("Error", {}).get("Code") == "ClientError"
|
337
|
+
): # Simple check, might need refinement
|
338
|
+
print(f"\nCreating new job definition: {job_def_name}")
|
339
|
+
else:
|
340
|
+
# Re-raise unexpected client errors
|
341
|
+
raise
|
342
|
+
|
343
|
+
# Prepare job definition properties
|
344
|
+
job_definition_args = {
|
345
|
+
"jobDefinitionName": job_def_name,
|
346
|
+
"type": "container",
|
347
|
+
"containerProperties": {
|
348
|
+
"image": image_uri,
|
349
|
+
"vcpus": compute_specs["vcpus"],
|
350
|
+
"memory": compute_specs["memory"],
|
351
|
+
"resourceRequirements": gpu_requirements,
|
352
|
+
"executionRoleArn": aws_config["execution_role_arn"],
|
353
|
+
"jobRoleArn": aws_config["job_role_arn"],
|
354
|
+
"privileged": compute_specs.get("gpus", 0) > 0,
|
355
|
+
"command": entrypoint_command,
|
356
|
+
**({"linuxParameters": linux_params} if linux_params else {}),
|
357
|
+
},
|
358
|
+
"platformCapabilities": ["EC2"],
|
359
|
+
"timeout": {"attemptDurationSeconds": aws_config.get("timeout_seconds", 86400)},
|
360
|
+
}
|
361
|
+
|
362
|
+
# Register new revision using the session client
|
363
|
+
response = batch.register_job_definition(**job_definition_args)
|
364
|
+
|
365
|
+
return response["jobDefinitionName"]
|
366
|
+
|
367
|
+
|
368
|
+
def submit_aws_batch_job(
|
369
|
+
image_uri: str,
|
370
|
+
config: dict[str, Any],
|
371
|
+
) -> tuple[str, str]:
|
372
|
+
"""Submit a job to AWS Batch.
|
373
|
+
|
374
|
+
Args:
|
375
|
+
image_uri: Full URI of the container image
|
376
|
+
config: Dictionary containing the configuration loaded from YAML
|
377
|
+
|
378
|
+
Returns:
|
379
|
+
Tuple containing (job_id, job_name) of the submitted job
|
380
|
+
|
381
|
+
Raises:
|
382
|
+
ValueError: If job_role_arn is not present in AWS configuration
|
383
|
+
"""
|
384
|
+
session = get_boto_session(config)
|
385
|
+
aws_config = config["aws"]
|
386
|
+
region = session.region_name or aws_config["region"]
|
387
|
+
batch = session.client("batch")
|
388
|
+
|
389
|
+
# Generate job name (already includes unique hex string)
|
390
|
+
job_name = _extract_job_name_from_uri(image_uri, config)
|
391
|
+
print(f"\nGenerated job name: {job_name}")
|
392
|
+
|
393
|
+
# Log the job submission details
|
394
|
+
print("\nSubmitting job with configuration:")
|
395
|
+
print(f"Job Name: {job_name}")
|
396
|
+
print(f"Queue: {aws_config['job_queue']}")
|
397
|
+
print("Container Configuration:")
|
398
|
+
print(f"- Image: {image_uri}")
|
399
|
+
print(f"- vCPUs: {aws_config['batch_job']['compute_specs']['vcpus']}")
|
400
|
+
print(f"- Memory: {aws_config['batch_job']['compute_specs']['memory']} MiB")
|
401
|
+
print(f"- GPUs: {aws_config['batch_job']['compute_specs'].get('gpus', 0)}")
|
402
|
+
print(f"- Timeout: {aws_config.get('timeout_seconds', 86400)} seconds")
|
403
|
+
print(f"- Job Role: {aws_config['job_role_arn']}")
|
404
|
+
|
405
|
+
# Get all environment variables, including special ones like WANDB_API_KEY and GCP credentials
|
406
|
+
env_vars = get_container_env_vars(config)
|
407
|
+
|
408
|
+
print("Environment Variables:", list(env_vars.keys()))
|
409
|
+
|
410
|
+
# Create/Update Job Definition using the config (now implicitly uses the correct session)
|
411
|
+
job_definition = create_or_update_job_definition(image_uri, config)
|
412
|
+
print(f"\nUsing job definition: {job_definition}")
|
413
|
+
|
414
|
+
# Prepare job submission arguments
|
415
|
+
job_submit_args = {
|
416
|
+
"jobName": job_name,
|
417
|
+
"jobQueue": aws_config["job_queue"],
|
418
|
+
"jobDefinition": job_definition,
|
419
|
+
"containerOverrides": {
|
420
|
+
"environment": [
|
421
|
+
{"name": key, "value": str(value)} for key, value in env_vars.items()
|
422
|
+
],
|
423
|
+
},
|
424
|
+
}
|
425
|
+
|
426
|
+
# Add array job configuration if specified
|
427
|
+
if "array_size" in aws_config:
|
428
|
+
array_size = aws_config["array_size"]
|
429
|
+
if array_size > 1:
|
430
|
+
print(f"\nConfiguring as array job with {array_size} instances")
|
431
|
+
job_submit_args["arrayProperties"] = {"size": array_size}
|
432
|
+
|
433
|
+
# Configure retry strategy for array jobs
|
434
|
+
retry_attempts = aws_config.get("retry_attempts", 2)
|
435
|
+
print(f"Setting retry attempts to {retry_attempts}")
|
436
|
+
job_submit_args["retryStrategy"] = {"attempts": retry_attempts}
|
437
|
+
|
438
|
+
# Submit the job using the session client
|
439
|
+
response = batch.submit_job(**job_submit_args)
|
440
|
+
|
441
|
+
job_id = response["jobId"]
|
442
|
+
print(f"\nJob submitted with ID: {job_id}")
|
443
|
+
|
444
|
+
# Print instructions for monitoring
|
445
|
+
print("\nTo monitor your job:")
|
446
|
+
print(
|
447
|
+
f" 1. AWS Console: https://{region}.console.aws.amazon.com/batch/home?region={region}#jobs/detail/{job_id}"
|
448
|
+
)
|
449
|
+
print(f" 2. CloudWatch Logs: Check logs for job {job_name} (ID: {job_id})")
|
450
|
+
|
451
|
+
# For array jobs, provide additional monitoring info
|
452
|
+
if "array_size" in aws_config and aws_config["array_size"] > 1:
|
453
|
+
print(f" 3. This is an array job with {aws_config['array_size']} child jobs")
|
454
|
+
print(
|
455
|
+
f" Child jobs: https://{region}.console.aws.amazon.com/batch/home?region={region}#jobs/array-jobs/{job_id}"
|
456
|
+
)
|
457
|
+
|
458
|
+
return job_id, job_name
|
@@ -0,0 +1,176 @@
|
|
1
|
+
"""GCP-specific deployment functionality."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import subprocess
|
6
|
+
import tempfile
|
7
|
+
|
8
|
+
from dayhoff_tools.deployment.deploy_utils import get_container_env_vars
|
9
|
+
|
10
|
+
|
11
|
+
def check_job_exists(job_name: str, region: str) -> bool:
|
12
|
+
"""Check if a job with the given name already exists in GCP Batch.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
job_name: Name of the job to check
|
16
|
+
region: GCP region to check in
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
bool: True if the job exists, False otherwise
|
20
|
+
|
21
|
+
Note:
|
22
|
+
This uses gcloud batch jobs describe, which will return a non-zero
|
23
|
+
exit code if the job doesn't exist.
|
24
|
+
"""
|
25
|
+
try:
|
26
|
+
subprocess.run(
|
27
|
+
[
|
28
|
+
"gcloud",
|
29
|
+
"batch",
|
30
|
+
"jobs",
|
31
|
+
"describe",
|
32
|
+
job_name,
|
33
|
+
"--location",
|
34
|
+
region,
|
35
|
+
],
|
36
|
+
check=True,
|
37
|
+
capture_output=True, # Suppress output
|
38
|
+
)
|
39
|
+
return True
|
40
|
+
except subprocess.CalledProcessError:
|
41
|
+
return False
|
42
|
+
|
43
|
+
|
44
|
+
def create_batch_job_config(config: dict, image_uri: str) -> dict:
|
45
|
+
"""Create a GCP Batch job configuration from YAML config.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
config: Dictionary containing the configuration loaded from YAML
|
49
|
+
image_uri: URI of the Docker image to use
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
Dictionary containing GCP Batch job configuration
|
53
|
+
"""
|
54
|
+
gcp_config = config["gcp"]
|
55
|
+
|
56
|
+
# Start with the allocation and logs policies
|
57
|
+
batch_config = {
|
58
|
+
"allocationPolicy": gcp_config["allocation_policy"],
|
59
|
+
"logsPolicy": gcp_config["logs_policy"],
|
60
|
+
}
|
61
|
+
|
62
|
+
entrypoint_command = config["docker"].get("container_entrypoint")
|
63
|
+
if entrypoint_command is None:
|
64
|
+
raise ValueError("docker.container_entrypoint is required in configuration")
|
65
|
+
|
66
|
+
if not isinstance(entrypoint_command, list) or not all(
|
67
|
+
isinstance(x, str) for x in entrypoint_command
|
68
|
+
):
|
69
|
+
raise ValueError("docker.container_entrypoint must be a list of strings")
|
70
|
+
|
71
|
+
# Build the container configuration with bash entrypoint
|
72
|
+
container_config = {
|
73
|
+
"imageUri": image_uri,
|
74
|
+
"entrypoint": "/bin/bash",
|
75
|
+
"commands": ["-c", " ".join(entrypoint_command)],
|
76
|
+
}
|
77
|
+
|
78
|
+
# Add shared memory option if specified
|
79
|
+
if "shared_memory" in config.get("docker", {}):
|
80
|
+
container_config["options"] = f"--shm-size={config['docker']['shared_memory']}"
|
81
|
+
|
82
|
+
# Build the task group configuration
|
83
|
+
task_group = {
|
84
|
+
"taskCount": gcp_config["batch_job"]["taskCount"],
|
85
|
+
"parallelism": gcp_config["batch_job"]["parallelism"],
|
86
|
+
"taskSpec": {
|
87
|
+
"computeResource": gcp_config["batch_job"]["computeResource"],
|
88
|
+
"runnables": [{"container": container_config}],
|
89
|
+
},
|
90
|
+
}
|
91
|
+
|
92
|
+
# Get all environment variables, including special ones like WANDB_API_KEY and GCP credentials
|
93
|
+
env_vars = get_container_env_vars(config)
|
94
|
+
|
95
|
+
# Add environment variables if any exist
|
96
|
+
if env_vars:
|
97
|
+
task_group["taskSpec"]["runnables"][0]["environment"] = {"variables": env_vars}
|
98
|
+
|
99
|
+
# Add machine type and optional accelerators from instance config
|
100
|
+
instance_config = gcp_config["batch_job"]["instance"]
|
101
|
+
if "machineType" in instance_config:
|
102
|
+
# Add machine type to the allocation policy
|
103
|
+
if "policy" not in batch_config["allocationPolicy"]["instances"]:
|
104
|
+
batch_config["allocationPolicy"]["instances"]["policy"] = {}
|
105
|
+
batch_config["allocationPolicy"]["instances"]["policy"]["machineType"] = (
|
106
|
+
instance_config["machineType"]
|
107
|
+
)
|
108
|
+
|
109
|
+
# Add accelerators if present (optional)
|
110
|
+
if "accelerators" in instance_config:
|
111
|
+
batch_config["allocationPolicy"]["instances"]["policy"]["accelerators"] = (
|
112
|
+
instance_config["accelerators"]
|
113
|
+
)
|
114
|
+
|
115
|
+
# Add the task group to the configuration
|
116
|
+
batch_config["taskGroups"] = [task_group]
|
117
|
+
|
118
|
+
# Debug logging to verify configuration
|
119
|
+
print("\nGCP Batch Configuration:")
|
120
|
+
print("------------------------")
|
121
|
+
try:
|
122
|
+
policy = batch_config["allocationPolicy"]["instances"]["policy"]
|
123
|
+
print("Machine Type:", policy.get("machineType", "Not specified"))
|
124
|
+
print("Accelerators:", policy.get("accelerators", "Not specified"))
|
125
|
+
print("Environment Variables:", list(env_vars.keys()))
|
126
|
+
except KeyError as e:
|
127
|
+
print(f"Warning: Could not find {e} in configuration")
|
128
|
+
|
129
|
+
return batch_config
|
130
|
+
|
131
|
+
|
132
|
+
def submit_gcp_batch_job(config: dict, image_uri: str) -> None:
|
133
|
+
"""Submit a job to GCP Batch.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
config: Dictionary containing the configuration loaded from YAML
|
137
|
+
image_uri: URI of the Docker image to use
|
138
|
+
|
139
|
+
Raises:
|
140
|
+
ValueError: If a job with the same name already exists
|
141
|
+
"""
|
142
|
+
job_name = config["gcp"]["job_name"]
|
143
|
+
region = config["gcp"]["region"]
|
144
|
+
|
145
|
+
# Check if job already exists
|
146
|
+
if check_job_exists(job_name, region):
|
147
|
+
raise ValueError(
|
148
|
+
f"Job '{job_name}' already exists in region {region}. "
|
149
|
+
"Please choose a different job name or delete the existing job first."
|
150
|
+
)
|
151
|
+
|
152
|
+
# Create GCP Batch job configuration
|
153
|
+
batch_config = create_batch_job_config(config, image_uri)
|
154
|
+
|
155
|
+
# Write the configuration to a temporary file
|
156
|
+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
|
157
|
+
json.dump(batch_config, temp_file, indent=2)
|
158
|
+
temp_config_path = temp_file.name
|
159
|
+
|
160
|
+
try:
|
161
|
+
# Submit the job using gcloud
|
162
|
+
command = [
|
163
|
+
"gcloud",
|
164
|
+
"batch",
|
165
|
+
"jobs",
|
166
|
+
"submit",
|
167
|
+
job_name,
|
168
|
+
"--location",
|
169
|
+
region,
|
170
|
+
"--config",
|
171
|
+
temp_config_path,
|
172
|
+
]
|
173
|
+
subprocess.run(command, check=True)
|
174
|
+
finally:
|
175
|
+
# Clean up the temporary file
|
176
|
+
os.unlink(temp_config_path)
|