dayhoff-tools 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,458 @@
1
+ """AWS deployment functionality for running jobs on AWS Batch."""
2
+
3
+ import base64
4
+ import datetime
5
+ import os
6
+ import re
7
+ import subprocess
8
+ from pathlib import Path
9
+ from typing import Any, Optional
10
+
11
+ import boto3
12
+ import yaml
13
+ from botocore.exceptions import (
14
+ ClientError,
15
+ NoCredentialsError,
16
+ ProfileNotFound,
17
+ SSOTokenLoadError,
18
+ )
19
+ from dayhoff_tools.deployment.deploy_utils import docker_login, get_container_env_vars
20
+
21
+
22
+ def get_boto_session(config: dict[str, Any]) -> boto3.Session:
23
+ """Creates a Boto3 session using the profile specified in the config.
24
+
25
+ Args:
26
+ config: Dictionary containing the configuration loaded from YAML
27
+
28
+ Returns:
29
+ A Boto3 session object
30
+
31
+ Raises:
32
+ RuntimeError: If the profile is not specified in the config or if credentials cannot be loaded
33
+ """
34
+ aws_config = config.get("aws", {})
35
+ profile_name = aws_config.get("aws_profile")
36
+ region = aws_config.get("region", "us-east-1")
37
+
38
+ if not profile_name:
39
+ print(
40
+ "Warning: aws.aws_profile not specified in config. Using default AWS credential chain in ~/.aws/config."
41
+ )
42
+ return boto3.Session(region_name=region)
43
+
44
+ try:
45
+ print(f"Using AWS profile: {profile_name}")
46
+ session = boto3.Session(profile_name=profile_name, region_name=region)
47
+ sts = session.client("sts")
48
+ sts.get_caller_identity()
49
+ return session
50
+ except ProfileNotFound:
51
+ raise RuntimeError(
52
+ f"AWS profile '{profile_name}' not found in `~/.aws/config`."
53
+ )
54
+ except (NoCredentialsError, ClientError, SSOTokenLoadError) as e:
55
+ raise RuntimeError(
56
+ f"Could not load credentials for AWS profile '{profile_name}'. "
57
+ f"Ensure you are logged in via SSO ('aws sso login --profile {profile_name}') "
58
+ f"or have valid credentials configured. Original error: {e}"
59
+ ) from e
60
+
61
+
62
+ def _extract_run_name_from_config(
63
+ config: dict[str, Any], _test_file_content: Optional[str] = None
64
+ ) -> Optional[str]:
65
+ """Extract run name from the config file referenced in job_command.
66
+
67
+ Args:
68
+ config: Dictionary containing the configuration loaded from YAML
69
+ _test_file_content: Optional parameter for testing to override file content
70
+
71
+ Returns:
72
+ Run name if found, None otherwise
73
+ """
74
+ # Check if features and job_command exist in config
75
+ if "features" not in config:
76
+ return None
77
+
78
+ # Find job_command in features
79
+ job_command = None
80
+ for feature in config["features"]:
81
+ if isinstance(feature, dict) and "job_command" in feature:
82
+ job_command = feature["job_command"]
83
+ break
84
+ elif isinstance(feature, str) and feature.startswith("job_command:"):
85
+ job_command = feature.split(":", 1)[1].strip()
86
+ break
87
+
88
+ if not job_command:
89
+ return None
90
+
91
+ # Extract config file path using regex
92
+ config_match = re.search(r'--config=([^\s"]+)', job_command)
93
+ if not config_match:
94
+ return None
95
+
96
+ config_path = config_match.group(1)
97
+
98
+ # For testing, we can bypass file operations
99
+ if _test_file_content is not None:
100
+ run_config = yaml.safe_load(_test_file_content)
101
+ else:
102
+ # Resolve path relative to repo root
103
+ # Assuming we're in the repo root when this function is called
104
+ full_config_path = Path(config_path)
105
+
106
+ # Check if file exists
107
+ if not full_config_path.exists():
108
+ print(f"Warning: Config file {full_config_path} not found")
109
+ return None
110
+
111
+ try:
112
+ # Load the config file
113
+ with open(full_config_path, "r") as f:
114
+ run_config = yaml.safe_load(f)
115
+ except Exception as e:
116
+ print(f"Warning: Failed to extract run_name from {full_config_path}: {e}")
117
+ return None
118
+
119
+ # Extract run_name from wandb section
120
+ if (
121
+ run_config
122
+ and "init" in run_config
123
+ and "wandb" in run_config["init"]
124
+ and "run_name" in run_config["init"]["wandb"]
125
+ ):
126
+ return run_config["init"]["wandb"]["run_name"]
127
+
128
+ return None
129
+
130
+
131
+ def _extract_job_name_from_uri(image_uri: str, config: dict[str, Any]) -> str:
132
+ """Extract job name from image URI and config.
133
+
134
+ Args:
135
+ image_uri: Full URI of the container image
136
+ config: Dictionary containing the configuration loaded from YAML
137
+
138
+ Returns:
139
+ Job name in format: username__jobname__uniquehex
140
+
141
+ The job name components are:
142
+ - username: The LOCAL_USER environment variable or "unknown_user"
143
+ - jobname: Either from aws.job_name config, wandb run_name, or omitted
144
+ - uniquehex: A hexadecimal string (typically 7 characters)
145
+ representing seconds since January 1, 2020
146
+
147
+ If aws.job_name is specified in the config:
148
+ - If aws.job_name is "use_wandb_run_name", the run_name from the wandb config will be used
149
+ - Otherwise, the specified job_name will be used
150
+
151
+ If no job_name is specified or if "use_wandb_run_name" is specified but no run_name is found,
152
+ the format will be username__uniquehex (without the middle component).
153
+ """
154
+ # Get username from environment
155
+ username = os.getenv("LOCAL_USER", "unknown_user")
156
+
157
+ # Generate a unique hex string based on seconds since 2020-01-01
158
+ epoch_2020 = datetime.datetime(2020, 1, 1).timestamp()
159
+ current_time = datetime.datetime.now().timestamp()
160
+ seconds_since_2020 = int(current_time - epoch_2020)
161
+ unique_hex = format(
162
+ seconds_since_2020, "x"
163
+ ) # Simple hex representation, typically 7 chars
164
+
165
+ # Check if job_name is specified in AWS config
166
+ if "aws" in config and "job_name" in config["aws"]:
167
+ job_name = config["aws"]["job_name"]
168
+
169
+ # Special handling for "use_wandb_run_name"
170
+ if job_name == "use_wandb_run_name":
171
+ # Get run name from config
172
+ run_name = _extract_run_name_from_config(config)
173
+ if run_name:
174
+ # Return username__runname__uniquehex
175
+ return f"{username}__{run_name}__{unique_hex}"
176
+ else:
177
+ # Use the specified job_name but format it as username__jobname__uniquehex
178
+ return f"{username}__{job_name}__{unique_hex}"
179
+
180
+ # Default behavior if job_name is not specified or "use_wandb_run_name" with no run_name
181
+ # Get run name from config
182
+ run_name = _extract_run_name_from_config(config)
183
+
184
+ # Build job name based on available components
185
+ job_name_parts = []
186
+
187
+ # Add username (use default if None)
188
+ job_name_parts.append(username or "unknown_user")
189
+
190
+ # Add run name if available
191
+ if run_name:
192
+ job_name_parts.append(run_name)
193
+
194
+ # Add the unique hex string
195
+ job_name_parts.append(unique_hex)
196
+
197
+ # Join parts with double underscores
198
+ return "__".join(job_name_parts)
199
+
200
+
201
+ def push_image_to_ecr(image_uri: str, config: dict[str, Any]) -> None:
202
+ """Push a Docker image to Amazon ECR.
203
+
204
+ Args:
205
+ image_uri: Full URI of the image to push
206
+ config: Dictionary containing the configuration loaded from YAML
207
+
208
+ Raises:
209
+ subprocess.CalledProcessError: If ECR login or push fails
210
+ ClientError: If ECR authentication fails
211
+ """
212
+ print("\nPushing image to ECR")
213
+
214
+ session = get_boto_session(config)
215
+ aws_config = config["aws"]
216
+ registry = aws_config["registry_uri"]
217
+
218
+ # Get ECR login token using the specific session client
219
+ ecr_client = session.client("ecr")
220
+ try:
221
+ token = ecr_client.get_authorization_token()
222
+ username, password = (
223
+ base64.b64decode(token["authorizationData"][0]["authorizationToken"])
224
+ .decode()
225
+ .split(":")
226
+ )
227
+
228
+ # Login to Docker registry
229
+ docker_login(registry, username, password)
230
+
231
+ # Push the image
232
+ print(f"\nPushing image: {image_uri}")
233
+ subprocess.run(["docker", "push", image_uri], check=True)
234
+
235
+ print(f"\nSuccessfully pushed image to ECR: {image_uri}")
236
+
237
+ except ClientError as e:
238
+ print(f"AWS ECR error using profile '{session.profile_name}': {str(e)}")
239
+ raise
240
+ except subprocess.CalledProcessError as e:
241
+ print(f"Docker command failed: {e.stderr.decode() if e.stderr else str(e)}")
242
+ raise
243
+
244
+
245
+ def create_or_update_job_definition(
246
+ image_uri: str,
247
+ config: dict[str, Any],
248
+ ) -> str:
249
+ """Create or update a job definition for the container.
250
+
251
+ Args:
252
+ image_uri: Full URI of the container image
253
+ config: Dictionary containing the configuration loaded from YAML
254
+
255
+ Returns:
256
+ Name of the created/updated job definition
257
+
258
+ Raises:
259
+ ValueError: If job_role_arn is not specified in AWS configuration
260
+ """
261
+ session = get_boto_session(config)
262
+ aws_config = config["aws"]
263
+
264
+ # Verify that job_role_arn is present
265
+ if "job_role_arn" not in aws_config:
266
+ raise ValueError(
267
+ "job_role_arn must be specified in AWS configuration. "
268
+ "This role is required for your container to access AWS resources after the "
269
+ "initial 15-minute temporary credential window."
270
+ )
271
+
272
+ batch = session.client("batch")
273
+ base_name = config["docker"]["base_name"]
274
+ job_def_name = f"job-def-{base_name}"
275
+
276
+ # Get compute specs from config
277
+ compute_specs = aws_config["batch_job"]["compute_specs"]
278
+ gpu_requirements = (
279
+ [{"type": "GPU", "value": str(compute_specs["gpus"])}]
280
+ if compute_specs.get("gpus", 0) > 0
281
+ else []
282
+ )
283
+
284
+ entrypoint_command = config["docker"].get("container_entrypoint")
285
+ if entrypoint_command is None:
286
+ raise ValueError("docker.container_entrypoint is required in configuration")
287
+
288
+ # Create linux parameters with devices
289
+ linux_params: dict[str, Any] = {
290
+ "devices": [
291
+ {
292
+ "hostPath": "/dev/nvidia0",
293
+ "containerPath": "/dev/nvidia0",
294
+ "permissions": ["READ", "WRITE"],
295
+ },
296
+ ],
297
+ }
298
+
299
+ # Add shared memory configuration if specified in docker config
300
+ if "shared_memory" in config.get("docker", {}):
301
+ shared_mem = config["docker"]["shared_memory"]
302
+ # Convert to MiB (e.g., "16g" -> 16384 MiB)
303
+ if isinstance(shared_mem, str):
304
+ if shared_mem.endswith("g"):
305
+ # Convert GB to MiB (1G = 1024 MiB)
306
+ shared_memory_mib = int(float(shared_mem[:-1]) * 1024)
307
+ elif shared_mem.endswith("m"):
308
+ # Convert MB to MiB (approximate conversion)
309
+ shared_memory_mib = int(float(shared_mem[:-1]))
310
+ else:
311
+ # Assume the value is already in MiB
312
+ shared_memory_mib = int(float(shared_mem))
313
+ else:
314
+ # Assume the value is already in MiB if not a string
315
+ shared_memory_mib = int(shared_mem)
316
+
317
+ # Add shared memory size to linux parameters
318
+ linux_params["sharedMemorySize"] = shared_memory_mib
319
+ print(f"Setting shared memory size to {shared_memory_mib} MiB")
320
+
321
+ # Check if job definition already exists using the session client
322
+ try:
323
+ existing = batch.describe_job_definitions(
324
+ jobDefinitionName=job_def_name, status="ACTIVE"
325
+ )["jobDefinitions"]
326
+
327
+ if existing:
328
+ print(f"\nUpdating existing job definition: {job_def_name}")
329
+ else:
330
+ print(f"\nCreating new job definition: {job_def_name}")
331
+
332
+ except batch.exceptions.ClientError as e:
333
+ # Handle case where the error is specifically 'JobDefinitionNotFoundException'
334
+ # Boto3 typically includes error codes in the response
335
+ if (
336
+ e.response.get("Error", {}).get("Code") == "ClientError"
337
+ ): # Simple check, might need refinement
338
+ print(f"\nCreating new job definition: {job_def_name}")
339
+ else:
340
+ # Re-raise unexpected client errors
341
+ raise
342
+
343
+ # Prepare job definition properties
344
+ job_definition_args = {
345
+ "jobDefinitionName": job_def_name,
346
+ "type": "container",
347
+ "containerProperties": {
348
+ "image": image_uri,
349
+ "vcpus": compute_specs["vcpus"],
350
+ "memory": compute_specs["memory"],
351
+ "resourceRequirements": gpu_requirements,
352
+ "executionRoleArn": aws_config["execution_role_arn"],
353
+ "jobRoleArn": aws_config["job_role_arn"],
354
+ "privileged": compute_specs.get("gpus", 0) > 0,
355
+ "command": entrypoint_command,
356
+ **({"linuxParameters": linux_params} if linux_params else {}),
357
+ },
358
+ "platformCapabilities": ["EC2"],
359
+ "timeout": {"attemptDurationSeconds": aws_config.get("timeout_seconds", 86400)},
360
+ }
361
+
362
+ # Register new revision using the session client
363
+ response = batch.register_job_definition(**job_definition_args)
364
+
365
+ return response["jobDefinitionName"]
366
+
367
+
368
+ def submit_aws_batch_job(
369
+ image_uri: str,
370
+ config: dict[str, Any],
371
+ ) -> tuple[str, str]:
372
+ """Submit a job to AWS Batch.
373
+
374
+ Args:
375
+ image_uri: Full URI of the container image
376
+ config: Dictionary containing the configuration loaded from YAML
377
+
378
+ Returns:
379
+ Tuple containing (job_id, job_name) of the submitted job
380
+
381
+ Raises:
382
+ ValueError: If job_role_arn is not present in AWS configuration
383
+ """
384
+ session = get_boto_session(config)
385
+ aws_config = config["aws"]
386
+ region = session.region_name or aws_config["region"]
387
+ batch = session.client("batch")
388
+
389
+ # Generate job name (already includes unique hex string)
390
+ job_name = _extract_job_name_from_uri(image_uri, config)
391
+ print(f"\nGenerated job name: {job_name}")
392
+
393
+ # Log the job submission details
394
+ print("\nSubmitting job with configuration:")
395
+ print(f"Job Name: {job_name}")
396
+ print(f"Queue: {aws_config['job_queue']}")
397
+ print("Container Configuration:")
398
+ print(f"- Image: {image_uri}")
399
+ print(f"- vCPUs: {aws_config['batch_job']['compute_specs']['vcpus']}")
400
+ print(f"- Memory: {aws_config['batch_job']['compute_specs']['memory']} MiB")
401
+ print(f"- GPUs: {aws_config['batch_job']['compute_specs'].get('gpus', 0)}")
402
+ print(f"- Timeout: {aws_config.get('timeout_seconds', 86400)} seconds")
403
+ print(f"- Job Role: {aws_config['job_role_arn']}")
404
+
405
+ # Get all environment variables, including special ones like WANDB_API_KEY and GCP credentials
406
+ env_vars = get_container_env_vars(config)
407
+
408
+ print("Environment Variables:", list(env_vars.keys()))
409
+
410
+ # Create/Update Job Definition using the config (now implicitly uses the correct session)
411
+ job_definition = create_or_update_job_definition(image_uri, config)
412
+ print(f"\nUsing job definition: {job_definition}")
413
+
414
+ # Prepare job submission arguments
415
+ job_submit_args = {
416
+ "jobName": job_name,
417
+ "jobQueue": aws_config["job_queue"],
418
+ "jobDefinition": job_definition,
419
+ "containerOverrides": {
420
+ "environment": [
421
+ {"name": key, "value": str(value)} for key, value in env_vars.items()
422
+ ],
423
+ },
424
+ }
425
+
426
+ # Add array job configuration if specified
427
+ if "array_size" in aws_config:
428
+ array_size = aws_config["array_size"]
429
+ if array_size > 1:
430
+ print(f"\nConfiguring as array job with {array_size} instances")
431
+ job_submit_args["arrayProperties"] = {"size": array_size}
432
+
433
+ # Configure retry strategy for array jobs
434
+ retry_attempts = aws_config.get("retry_attempts", 2)
435
+ print(f"Setting retry attempts to {retry_attempts}")
436
+ job_submit_args["retryStrategy"] = {"attempts": retry_attempts}
437
+
438
+ # Submit the job using the session client
439
+ response = batch.submit_job(**job_submit_args)
440
+
441
+ job_id = response["jobId"]
442
+ print(f"\nJob submitted with ID: {job_id}")
443
+
444
+ # Print instructions for monitoring
445
+ print("\nTo monitor your job:")
446
+ print(
447
+ f" 1. AWS Console: https://{region}.console.aws.amazon.com/batch/home?region={region}#jobs/detail/{job_id}"
448
+ )
449
+ print(f" 2. CloudWatch Logs: Check logs for job {job_name} (ID: {job_id})")
450
+
451
+ # For array jobs, provide additional monitoring info
452
+ if "array_size" in aws_config and aws_config["array_size"] > 1:
453
+ print(f" 3. This is an array job with {aws_config['array_size']} child jobs")
454
+ print(
455
+ f" Child jobs: https://{region}.console.aws.amazon.com/batch/home?region={region}#jobs/array-jobs/{job_id}"
456
+ )
457
+
458
+ return job_id, job_name
@@ -0,0 +1,176 @@
1
+ """GCP-specific deployment functionality."""
2
+
3
+ import json
4
+ import os
5
+ import subprocess
6
+ import tempfile
7
+
8
+ from dayhoff_tools.deployment.deploy_utils import get_container_env_vars
9
+
10
+
11
+ def check_job_exists(job_name: str, region: str) -> bool:
12
+ """Check if a job with the given name already exists in GCP Batch.
13
+
14
+ Args:
15
+ job_name: Name of the job to check
16
+ region: GCP region to check in
17
+
18
+ Returns:
19
+ bool: True if the job exists, False otherwise
20
+
21
+ Note:
22
+ This uses gcloud batch jobs describe, which will return a non-zero
23
+ exit code if the job doesn't exist.
24
+ """
25
+ try:
26
+ subprocess.run(
27
+ [
28
+ "gcloud",
29
+ "batch",
30
+ "jobs",
31
+ "describe",
32
+ job_name,
33
+ "--location",
34
+ region,
35
+ ],
36
+ check=True,
37
+ capture_output=True, # Suppress output
38
+ )
39
+ return True
40
+ except subprocess.CalledProcessError:
41
+ return False
42
+
43
+
44
+ def create_batch_job_config(config: dict, image_uri: str) -> dict:
45
+ """Create a GCP Batch job configuration from YAML config.
46
+
47
+ Args:
48
+ config: Dictionary containing the configuration loaded from YAML
49
+ image_uri: URI of the Docker image to use
50
+
51
+ Returns:
52
+ Dictionary containing GCP Batch job configuration
53
+ """
54
+ gcp_config = config["gcp"]
55
+
56
+ # Start with the allocation and logs policies
57
+ batch_config = {
58
+ "allocationPolicy": gcp_config["allocation_policy"],
59
+ "logsPolicy": gcp_config["logs_policy"],
60
+ }
61
+
62
+ entrypoint_command = config["docker"].get("container_entrypoint")
63
+ if entrypoint_command is None:
64
+ raise ValueError("docker.container_entrypoint is required in configuration")
65
+
66
+ if not isinstance(entrypoint_command, list) or not all(
67
+ isinstance(x, str) for x in entrypoint_command
68
+ ):
69
+ raise ValueError("docker.container_entrypoint must be a list of strings")
70
+
71
+ # Build the container configuration with bash entrypoint
72
+ container_config = {
73
+ "imageUri": image_uri,
74
+ "entrypoint": "/bin/bash",
75
+ "commands": ["-c", " ".join(entrypoint_command)],
76
+ }
77
+
78
+ # Add shared memory option if specified
79
+ if "shared_memory" in config.get("docker", {}):
80
+ container_config["options"] = f"--shm-size={config['docker']['shared_memory']}"
81
+
82
+ # Build the task group configuration
83
+ task_group = {
84
+ "taskCount": gcp_config["batch_job"]["taskCount"],
85
+ "parallelism": gcp_config["batch_job"]["parallelism"],
86
+ "taskSpec": {
87
+ "computeResource": gcp_config["batch_job"]["computeResource"],
88
+ "runnables": [{"container": container_config}],
89
+ },
90
+ }
91
+
92
+ # Get all environment variables, including special ones like WANDB_API_KEY and GCP credentials
93
+ env_vars = get_container_env_vars(config)
94
+
95
+ # Add environment variables if any exist
96
+ if env_vars:
97
+ task_group["taskSpec"]["runnables"][0]["environment"] = {"variables": env_vars}
98
+
99
+ # Add machine type and optional accelerators from instance config
100
+ instance_config = gcp_config["batch_job"]["instance"]
101
+ if "machineType" in instance_config:
102
+ # Add machine type to the allocation policy
103
+ if "policy" not in batch_config["allocationPolicy"]["instances"]:
104
+ batch_config["allocationPolicy"]["instances"]["policy"] = {}
105
+ batch_config["allocationPolicy"]["instances"]["policy"]["machineType"] = (
106
+ instance_config["machineType"]
107
+ )
108
+
109
+ # Add accelerators if present (optional)
110
+ if "accelerators" in instance_config:
111
+ batch_config["allocationPolicy"]["instances"]["policy"]["accelerators"] = (
112
+ instance_config["accelerators"]
113
+ )
114
+
115
+ # Add the task group to the configuration
116
+ batch_config["taskGroups"] = [task_group]
117
+
118
+ # Debug logging to verify configuration
119
+ print("\nGCP Batch Configuration:")
120
+ print("------------------------")
121
+ try:
122
+ policy = batch_config["allocationPolicy"]["instances"]["policy"]
123
+ print("Machine Type:", policy.get("machineType", "Not specified"))
124
+ print("Accelerators:", policy.get("accelerators", "Not specified"))
125
+ print("Environment Variables:", list(env_vars.keys()))
126
+ except KeyError as e:
127
+ print(f"Warning: Could not find {e} in configuration")
128
+
129
+ return batch_config
130
+
131
+
132
+ def submit_gcp_batch_job(config: dict, image_uri: str) -> None:
133
+ """Submit a job to GCP Batch.
134
+
135
+ Args:
136
+ config: Dictionary containing the configuration loaded from YAML
137
+ image_uri: URI of the Docker image to use
138
+
139
+ Raises:
140
+ ValueError: If a job with the same name already exists
141
+ """
142
+ job_name = config["gcp"]["job_name"]
143
+ region = config["gcp"]["region"]
144
+
145
+ # Check if job already exists
146
+ if check_job_exists(job_name, region):
147
+ raise ValueError(
148
+ f"Job '{job_name}' already exists in region {region}. "
149
+ "Please choose a different job name or delete the existing job first."
150
+ )
151
+
152
+ # Create GCP Batch job configuration
153
+ batch_config = create_batch_job_config(config, image_uri)
154
+
155
+ # Write the configuration to a temporary file
156
+ with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
157
+ json.dump(batch_config, temp_file, indent=2)
158
+ temp_config_path = temp_file.name
159
+
160
+ try:
161
+ # Submit the job using gcloud
162
+ command = [
163
+ "gcloud",
164
+ "batch",
165
+ "jobs",
166
+ "submit",
167
+ job_name,
168
+ "--location",
169
+ region,
170
+ "--config",
171
+ temp_config_path,
172
+ ]
173
+ subprocess.run(command, check=True)
174
+ finally:
175
+ # Clean up the temporary file
176
+ os.unlink(temp_config_path)