dayhoff-tools 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,244 @@
1
+ """CLI commands common to all repos."""
2
+
3
+ import os
4
+ import subprocess
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import yaml
9
+
10
+
11
+ def test_github_actions_locally():
12
+ """Run the script test_pytest_in_github_actions_container.sh.sh."""
13
+ script_path = ".devcontainer/scripts/test_pytest_in_github_actions_container.sh"
14
+
15
+ try:
16
+ subprocess.check_call(["bash", script_path])
17
+ print("Script ran successfully!")
18
+ except subprocess.CalledProcessError as e:
19
+ print(f"Error occurred while running the script: {e}")
20
+
21
+
22
+ def rebuild_devcontainer_file():
23
+ """Run the script prepare_for_build.py."""
24
+ script_path = ".devcontainer/scripts/prepare_for_build.py"
25
+
26
+ try:
27
+ subprocess.check_call([sys.executable, script_path])
28
+ print("Script ran successfully!")
29
+ except subprocess.CalledProcessError as e:
30
+ print(f"Error occurred while running the script: {e}")
31
+
32
+
33
+ def get_ancestry(filepath: str) -> None:
34
+ """Take a .dvc file created from import, and generate an ancestry entry
35
+ that can be manually copied into other .dvc files."""
36
+ with open(filepath, "r") as file:
37
+ assert filepath.endswith(".dvc"), "ERROR: Not a .dvc file"
38
+ ancestor_content = yaml.safe_load(file)
39
+
40
+ error_msg = "Unexpected file structure. Are you sure this is a .dvc file generated from `dvc import`?"
41
+ assert "deps" in ancestor_content, error_msg
42
+
43
+ error_msg = "Please only reference data imported from main branches."
44
+ assert "rev" not in ancestor_content["deps"][0]["repo"], error_msg
45
+
46
+ ancestor_info = {
47
+ "name": os.path.basename(ancestor_content["outs"][0]["path"]),
48
+ "file_md5_hash": ancestor_content["outs"][0]["md5"],
49
+ "size": ancestor_content["outs"][0]["size"],
50
+ "repo_url": ancestor_content["deps"][0]["repo"]["url"],
51
+ "repo_path": ancestor_content["deps"][0]["path"],
52
+ "commit_hash": ancestor_content["deps"][0]["repo"]["rev_lock"],
53
+ }
54
+ print()
55
+ yaml.safe_dump(
56
+ [ancestor_info], sys.stdout, default_flow_style=False, sort_keys=False
57
+ )
58
+
59
+
60
+ def import_from_warehouse_typer() -> None:
61
+ """Import a file from warehouse.
62
+ This is a thin wrapper around `cli.utils.import_from_warehouse`,
63
+ with interactive prompts using questionary.
64
+ """
65
+ # Import only when the function is called
66
+ import questionary
67
+ from dayhoff_tools.warehouse import import_from_warehouse
68
+
69
+ # Ensure execution from root
70
+ cwd = Path(os.getcwd())
71
+ if cwd.parent.name != "workspaces" or str(cwd.parent.parent) != cwd.root:
72
+ raise Exception(
73
+ f"This command must be executed from the repo's root directory (/workspaces/reponame). Current directory: {cwd}"
74
+ )
75
+
76
+ # Use questionary for prompts instead of typer
77
+ warehouse_path = questionary.text("Warehouse path:").ask()
78
+
79
+ # Provide multiple-choice options for output folder
80
+ output_folder_choice = questionary.select(
81
+ "Output folder:",
82
+ choices=["data/imports", "same_as_warehouse", "Custom path..."],
83
+ ).ask()
84
+
85
+ # If custom path is selected, ask for the path
86
+ if output_folder_choice == "Custom path...":
87
+ output_folder = questionary.text("Enter custom output folder:").ask()
88
+ else:
89
+ output_folder = output_folder_choice
90
+
91
+ branch = questionary.text("Branch (default: main):", default="main").ask()
92
+
93
+ final_path = import_from_warehouse(
94
+ warehouse_path=warehouse_path,
95
+ output_folder=output_folder,
96
+ branch=branch,
97
+ )
98
+
99
+
100
+ def add_to_warehouse_typer() -> None:
101
+ """Add a new data file to warehouse, and expand its .dvc file with
102
+ metadata, including ancestor files."""
103
+ # Import only when the function is called
104
+ import questionary
105
+ from dayhoff_tools.warehouse import add_to_warehouse
106
+
107
+ # Ensure execution from root
108
+ cwd = Path(os.getcwd())
109
+ if cwd.parent.name != "workspaces" or str(cwd.parent.parent) != cwd.root:
110
+ raise Exception(
111
+ f"This command must be executed from the repo's root directory (/workspaces/reponame). Current directory: {cwd}"
112
+ )
113
+
114
+ # Prompt for the data file path
115
+ warehouse_path = questionary.text("Data file to be registered:").ask()
116
+
117
+ # Prompt for the ancestor .dvc file paths
118
+ ancestor_dvc_paths = []
119
+ print("\nEnter the path of all ancestor .dvc files (or hit Enter to finish).")
120
+ print("These files must be generated by `dvc import` or `dh wimport`.")
121
+ while True:
122
+ ancestor_path = questionary.text("Ancestor path: ").ask()
123
+ if ancestor_path:
124
+ ancestor_dvc_paths.append(ancestor_path)
125
+ else:
126
+ print()
127
+ break
128
+
129
+ dvc_path = add_to_warehouse(
130
+ warehouse_path=warehouse_path,
131
+ ancestor_dvc_paths=ancestor_dvc_paths,
132
+ )
133
+
134
+
135
+ def delete_local_branch(branch_name: str, folder_path: str):
136
+ """Delete a local Git branch after fetching with pruning.
137
+
138
+ Args:
139
+ branch_name: Name of the branch to delete
140
+ folder_path: Path to the git repository folder
141
+ """
142
+ try:
143
+ # Store current working directory
144
+ original_dir = os.getcwd()
145
+
146
+ # Change to the specified directory
147
+ os.chdir(folder_path)
148
+ print(f"Changed to directory: {folder_path}")
149
+
150
+ # Delete the specified branch
151
+ delete_branch_cmd = ["git", "branch", "-D", branch_name]
152
+ subprocess.run(delete_branch_cmd, check=True)
153
+ print(f"Deleted branch: {branch_name}")
154
+
155
+ # Fetch changes from the remote repository and prune obsolete branches
156
+ fetch_prune_cmd = ["git", "fetch", "-p"]
157
+ subprocess.run(fetch_prune_cmd, check=True)
158
+ print("Fetched changes and pruned obsolete branches")
159
+
160
+ except subprocess.CalledProcessError as e:
161
+ print(f"Error occurred while running Git commands: {e}")
162
+ finally:
163
+ # Always return to the original directory
164
+ os.chdir(original_dir)
165
+
166
+
167
+ def build_and_upload_wheel():
168
+ """Build a Python wheel and upload to PyPI.
169
+
170
+ Automatically increments the patch version number in pyproject.toml before building.
171
+ For example: 1.2.3 -> 1.2.4
172
+
173
+ Expects the PyPI API token to be available in the PYPI_API_TOKEN environment variable.
174
+ """
175
+ pypi_token = os.environ.get("PYPI_API_TOKEN")
176
+ if not pypi_token:
177
+ print("Error: PYPI_API_TOKEN environment variable not set.")
178
+ print("Please set it with your PyPI API token before running this command.")
179
+ return
180
+
181
+ try:
182
+ # Read current version from pyproject.toml
183
+ with open("pyproject.toml", "r") as f:
184
+ content = f.read()
185
+
186
+ # Find version line using simple string search
187
+ version_line = [line for line in content.split("\n") if "version = " in line][0]
188
+ current_version = version_line.split('"')[1] # Extract version between quotes
189
+
190
+ # Increment patch version
191
+ major, minor, patch = current_version.split(".")
192
+ new_version = f"{major}.{minor}.{int(patch) + 1}"
193
+
194
+ # Update all pyproject files with new version
195
+ for pyproject_file in [
196
+ "pyproject.toml",
197
+ "pyproject_gcp.toml",
198
+ "pyproject_mac.toml",
199
+ ]:
200
+ try:
201
+ with open(pyproject_file, "r") as f:
202
+ content = f.read()
203
+ new_content = content.replace(
204
+ f'version = "{current_version}"', f'version = "{new_version}"'
205
+ )
206
+ with open(pyproject_file, "w") as f:
207
+ f.write(new_content)
208
+ print(
209
+ f"Version bumped from {current_version} to {new_version} in {pyproject_file}"
210
+ )
211
+ except FileNotFoundError:
212
+ print(f"Skipping {pyproject_file} - file not found")
213
+
214
+ # Disable keyring to avoid issues in containers/CI
215
+ print("Disabling Poetry keyring...")
216
+ subprocess.run(
217
+ ["poetry", "config", "keyring.enabled", "false"],
218
+ check=True,
219
+ capture_output=True,
220
+ )
221
+
222
+ # Configure Poetry with the API token
223
+ print("Configuring Poetry with PyPI token...")
224
+ subprocess.run(
225
+ ["poetry", "config", "pypi-token.pypi", pypi_token],
226
+ check=True,
227
+ capture_output=True, # Hide token from output
228
+ )
229
+
230
+ # Build and upload
231
+ print("Building and uploading wheel to PyPI...")
232
+ subprocess.run(
233
+ [
234
+ "poetry",
235
+ "publish",
236
+ "--build",
237
+ ],
238
+ check=True,
239
+ )
240
+
241
+ print(f"Successfully built and uploaded version {new_version} to PyPI")
242
+
243
+ except subprocess.CalledProcessError as e:
244
+ print(f"Error during build/upload: {e}")
@@ -0,0 +1,434 @@
1
+ """Base functionality for container deployment across cloud providers.
2
+
3
+ This module provides the core functionality for building and running containers,
4
+ which can then be used locally or deployed to various cloud providers.
5
+ """
6
+
7
+ import datetime
8
+ import hashlib
9
+ import os
10
+ import subprocess
11
+ from typing import List
12
+
13
+ import torch
14
+ import typer
15
+ import yaml
16
+ from dayhoff_tools.deployment.deploy_aws import push_image_to_ecr, submit_aws_batch_job
17
+ from dayhoff_tools.deployment.deploy_gcp import submit_gcp_batch_job
18
+ from dayhoff_tools.deployment.deploy_utils import (
19
+ get_container_env_vars,
20
+ move_to_repo_root,
21
+ )
22
+
23
+
24
+ def _generate_image_tag(versioning: str) -> str:
25
+ """Generate a Docker image tag based on versioning strategy.
26
+
27
+ The tag is generated based on the specified versioning strategy:
28
+ - For 'latest': Simply returns 'latest'
29
+ - For 'unique': Generates a tag combining timestamp and content hash
30
+ Format: YYYYMMDD_HHMMSS_<8_char_hash>
31
+ The hash is computed from all Python and shell files in src/
32
+ to detect code changes.
33
+
34
+ Args:
35
+ versioning: Either 'unique' or 'latest'
36
+
37
+ Returns:
38
+ Generated tag string
39
+ """
40
+ if versioning == "latest":
41
+ return "latest"
42
+
43
+ # Generate unique tag based on timestamp and content hash
44
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
45
+
46
+ # Create hash of relevant files to detect code changes
47
+ hasher = hashlib.sha256()
48
+ for root, _, files in os.walk("src"):
49
+ for file in sorted(files): # Sort for reproducibility
50
+ if file.endswith((".py", ".sh")): # Add other relevant extensions
51
+ filepath = os.path.join(root, file)
52
+ with open(filepath, "rb") as f:
53
+ hasher.update(f.read())
54
+
55
+ content_hash = hasher.hexdigest()[:8] # First 8 chars are sufficient
56
+ return f"{timestamp}_{content_hash}"
57
+
58
+
59
+ def _build_image_uri(config: dict) -> str:
60
+ """Build the full image URI from config.
61
+
62
+ The URI is constructed in one of two ways:
63
+ 1. If image_uri is provided in config:
64
+ - For 'unique' versioning: Use the provided URI as is
65
+ - For 'latest' versioning: Ensure the URI ends with :latest
66
+ 2. If image_uri is empty or not provided:
67
+ - Use the registry_uri and repository from cloud-specific config
68
+ - Combine with base_name and generated tag
69
+
70
+ Args:
71
+ config: Dictionary containing docker configuration
72
+
73
+ Returns:
74
+ Complete image URI
75
+
76
+ Raises:
77
+ ValueError: If cloud is not specified or invalid
78
+ ValueError: If registry_uri or repository is missing for the selected cloud
79
+ """
80
+ docker_config = config["docker"]
81
+
82
+ # Handle provided image URI
83
+ if docker_config.get("image_uri"):
84
+ uri = docker_config["image_uri"]
85
+ if docker_config["image_versioning"] == "latest" and not uri.endswith(
86
+ ":latest"
87
+ ):
88
+ uri = uri.split(":")[0] + ":latest"
89
+ return uri
90
+
91
+ # Get cloud provider from config
92
+ cloud = config.get("cloud")
93
+ if not cloud:
94
+ raise ValueError("cloud must be specified when image_uri is not provided")
95
+ if cloud not in ["aws", "gcp"]:
96
+ raise ValueError(f"Invalid cloud provider: {cloud}. Must be one of: aws, gcp")
97
+
98
+ # Get cloud-specific configuration
99
+ cloud_config = config.get(cloud, {})
100
+ registry_uri = cloud_config.get("registry_uri")
101
+ if not registry_uri:
102
+ raise ValueError(
103
+ f"{cloud}.registry_uri must be specified when image_uri is not provided"
104
+ )
105
+
106
+ repository = cloud_config.get("repository")
107
+ if not repository:
108
+ raise ValueError(
109
+ f"{cloud}.repository must be specified when image_uri is not provided"
110
+ )
111
+
112
+ # Ensure registry_uri doesn't end with slash
113
+ registry_uri = registry_uri.rstrip("/")
114
+
115
+ base_name = docker_config["base_name"]
116
+ tag = _generate_image_tag(docker_config["image_versioning"])
117
+
118
+ # For AWS ECR, the base_name becomes part of the tag
119
+ if cloud == "aws":
120
+ return f"{registry_uri}/{repository}:{base_name}-{tag}"
121
+
122
+ # For GCP, include base_name in path
123
+ return f"{registry_uri}/{repository}/{base_name}:{tag}"
124
+
125
+
126
+ def build_job_image(config: dict) -> str:
127
+ """Build a Docker image based on configuration.
128
+
129
+ This function handles the complete image building process:
130
+ 1. Ensures we're in the repo root
131
+ 2. Constructs the image URI based on config
132
+ 3. Builds the image using docker build
133
+
134
+ Args:
135
+ config: Dictionary containing the configuration loaded from YAML.
136
+ The docker.image_versioning field can be either:
137
+ - "unique": Generate a unique tag based on timestamp and content hash
138
+ - "latest": Use the :latest tag for reusability
139
+
140
+ If docker.image_uri is provided:
141
+ - For "unique" versioning: Use the provided URI as is
142
+ - For "latest" versioning: Ensure the URI ends with :latest
143
+
144
+ Returns:
145
+ str: The complete image URI with appropriate tag
146
+ """
147
+ move_to_repo_root()
148
+
149
+ # Get image URI
150
+ image_uri = _build_image_uri(config)
151
+ docker_config = config["docker"]
152
+
153
+ print("\nBuilding Docker image: ", image_uri)
154
+ print(f"Using Dockerfile: {docker_config['dockerfile']}")
155
+ print(f"Using shared memory: {docker_config['shared_memory']}\n")
156
+
157
+ # Build the image
158
+ build_image_command = [
159
+ "docker",
160
+ "build",
161
+ f"--shm-size={docker_config['shared_memory']}",
162
+ "-f",
163
+ docker_config["dockerfile"],
164
+ "-t",
165
+ image_uri,
166
+ ".", # Use the root of the repo as image context
167
+ ]
168
+ subprocess.run(build_image_command, check=True)
169
+
170
+ # Get and print image size
171
+ image_info = subprocess.check_output(
172
+ ["docker", "images", "--format", "{{.Size}}", image_uri], encoding="utf-8"
173
+ ).strip()
174
+ print(f"\nBuilt image size: {image_info}")
175
+
176
+ return image_uri
177
+
178
+
179
+ def _build_docker_run_command(
180
+ config: dict,
181
+ image_uri: str,
182
+ container_name: str,
183
+ env_vars: dict,
184
+ mode: str,
185
+ ) -> List[str]:
186
+ """Build the docker run command with all necessary options.
187
+
188
+ This function constructs the complete docker run command, including:
189
+ - Mode-specific options (--rm, -d, -it)
190
+ - GPU support if available
191
+ - Container name
192
+ - Environment variables
193
+ - Entrypoint and command
194
+ - Privileged mode if specified
195
+ - Volume mounts if specified
196
+
197
+ Args:
198
+ config: Configuration dictionary
199
+ image_uri: URI of the image to run
200
+ container_name: Name for the container
201
+ env_vars: Environment variables to pass to container
202
+ mode: Deployment mode (local, shell, batch)
203
+
204
+ Returns:
205
+ List of command parts ready for subprocess.run
206
+
207
+ Raises:
208
+ ValueError: If placeholder strings are found in volume paths during local mode.
209
+ """
210
+ command = [
211
+ "docker",
212
+ "run",
213
+ f"--shm-size={config['docker']['shared_memory']}",
214
+ ]
215
+
216
+ # Add mode-specific options
217
+ if mode == "local":
218
+ command += ["--rm", "-d"] # Remove container after exit, run detached
219
+ elif mode == "shell":
220
+ command += ["--rm", "-it"] # Remove container after exit, interactive TTY
221
+
222
+ # Add privileged mode if specified
223
+ if config["docker"].get("privileged", False):
224
+ print("Container will run in privileged mode")
225
+ command += ["--privileged"]
226
+
227
+ # Add volume mounts if specified
228
+ if "volumes" in config["docker"]:
229
+ # Check for placeholder strings in local mode before adding volumes
230
+ if mode == "local":
231
+ for volume in config["docker"]["volumes"]:
232
+ if "<YOUR_USERNAME>" in volume or "<YOUR_REPO_NAME>" in volume:
233
+ raise ValueError(
234
+ f"Placeholder string found in volume path: '{volume}'. "
235
+ "Please replace <YOUR_USERNAME> and <YOUR_REPO_NAME> in your "
236
+ "local YAML configuration file's 'volumes' section."
237
+ )
238
+ # Add validated volumes to the command
239
+ for volume in config["docker"]["volumes"]:
240
+ print(f"Adding volume mount: {volume}")
241
+ command += ["-v", volume]
242
+
243
+ # Add GPU support if available
244
+ if torch.cuda.is_available():
245
+ print("Container has access to GPU")
246
+ command += ["--gpus", "all"]
247
+ else:
248
+ print("Container has access to CPU only, no GPU")
249
+
250
+ # Add container name
251
+ command += ["--name", container_name]
252
+
253
+ # Add environment variables
254
+ for key, value in env_vars.items():
255
+ command += ["-e", f"{key}={value}"]
256
+
257
+ # Add image and command
258
+ if mode == "local":
259
+ # For detached mode, use bash -c to execute the command
260
+ entrypoint = config["docker"].get(
261
+ "container_entrypoint", ["python", "swarm/main.py"]
262
+ )
263
+ if isinstance(entrypoint, str):
264
+ entrypoint = [entrypoint]
265
+ cmd_str = " ".join(entrypoint)
266
+ command += ["--entrypoint", "/bin/bash", image_uri, "-c", cmd_str]
267
+ else:
268
+ # For shell mode, just use bash as entrypoint
269
+ command += ["--entrypoint", "/bin/bash", image_uri]
270
+
271
+ return command
272
+
273
+
274
+ def run_container(config: dict, image_uri: str, mode: str) -> None:
275
+ """Run a container based on deployment mode.
276
+
277
+ This function handles the complete container lifecycle:
278
+ 1. Generates a unique container name
279
+ 2. Collects all necessary environment variables
280
+ 3. Builds and executes the docker run command
281
+ 4. Handles container logs for detached mode
282
+
283
+ The container name is generated using:
284
+ - Username (from LOCAL_USER env var)
285
+ - Timestamp (YYYYMMDD_HHMMSS format)
286
+
287
+ Args:
288
+ config: Dictionary containing the configuration loaded from YAML
289
+ image_uri: URI of the Docker image to run
290
+ mode: Deployment mode (local, shell, batch)
291
+
292
+ Raises:
293
+ ValueError: If deployment mode is invalid
294
+ subprocess.CalledProcessError: If container fails to start or run
295
+ """
296
+ if mode not in ["local", "shell"]:
297
+ raise ValueError(
298
+ f"Invalid deployment mode: {mode}. Must be one of: local, shell"
299
+ )
300
+
301
+ # Generate unique container name
302
+ username = os.getenv("LOCAL_USER", "unknown_user")
303
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
304
+ container_name = f"{username}_job_{timestamp}"
305
+
306
+ # Get environment variables
307
+ env_vars = get_container_env_vars(config)
308
+
309
+ # Build and run the container
310
+ command = _build_docker_run_command(
311
+ config, image_uri, container_name, env_vars, mode
312
+ )
313
+
314
+ # Handle container execution based on mode
315
+ print(f"Running container in {mode} mode: {container_name}")
316
+
317
+ if mode == "shell":
318
+ # Simple case: Run interactively with TTY, command will block until container exits
319
+ subprocess.run(command, check=True)
320
+ elif mode == "local":
321
+ # Complex case: Run in background and handle logs
322
+ try:
323
+ # Start the container in detached mode
324
+ subprocess.run(command, check=True)
325
+
326
+ # Once container is started, immediately follow its logs
327
+ # This helps users see what's happening without having to run 'docker logs' manually
328
+ print("\nDetached container logs:")
329
+ log_command = ["docker", "logs", "-f", container_name]
330
+ subprocess.run(log_command, check=True)
331
+ except subprocess.CalledProcessError as e:
332
+ # If anything goes wrong (either during startup or while running),
333
+ # try to get any available logs to help with debugging
334
+ print("\nContainer failed. Attempting to retrieve logs:")
335
+ try:
336
+ # Don't use -f here as the container has likely already stopped
337
+ subprocess.run(["docker", "logs", container_name], check=True)
338
+ except subprocess.CalledProcessError:
339
+ # If we can't get logs, container probably failed to start
340
+ print(
341
+ "No logs available. Container didn't start or failed immediately."
342
+ )
343
+ print("Try running it in shell mode to get more information.")
344
+ raise e # Re-raise the original error
345
+
346
+
347
+ def deploy(
348
+ mode: str = typer.Argument(help="Deployment mode. Options: local, shell, batch"),
349
+ config_path: str = typer.Argument(help="Path to the YAML configuration file"),
350
+ ) -> None:
351
+ """Deploy a job based on configuration from a YAML file.
352
+
353
+ This is the main entry point for all deployments. It handles:
354
+ 1. Validating the deployment mode
355
+ 2. Loading and validating the configuration
356
+ 3. Building or using an existing image
357
+ 4. Running the container locally or delegating to cloud-specific batch deployment
358
+
359
+ Args:
360
+ mode: Deployment mode to use. Options: local, shell, batch
361
+ config_path: Path to the YAML configuration file
362
+
363
+ Raises:
364
+ ValueError: If deployment mode is invalid
365
+ ValueError: If cloud field is not specified or invalid for batch mode
366
+ """
367
+ # Validate mode
368
+ valid_modes = ["local", "shell", "batch"]
369
+ if mode not in valid_modes:
370
+ raise ValueError(
371
+ f"Invalid mode: {mode}. Must be one of: {', '.join(valid_modes)}"
372
+ )
373
+
374
+ # Load YAML configuration
375
+ with open(config_path, "r") as f:
376
+ config = yaml.safe_load(f)
377
+
378
+ # For batch mode, check cloud provider and credentials early
379
+ if mode == "batch":
380
+ cloud = config.get("cloud")
381
+ if not cloud:
382
+ raise ValueError(
383
+ "cloud field must be specified in configuration for batch mode"
384
+ )
385
+ if cloud not in ["aws", "gcp"]:
386
+ raise ValueError(f"Invalid cloud: {cloud}. Must be one of: aws, gcp")
387
+
388
+ # Check AWS credentials early if using AWS
389
+ if cloud == "aws":
390
+ print("\nVerifying AWS credentials...")
391
+ from dayhoff_tools.deployment.deploy_aws import get_boto_session
392
+
393
+ # This will validate credentials and throw an appropriate error if they're invalid
394
+ get_boto_session(config)
395
+ print("AWS credentials verified.")
396
+
397
+ # Track if we built a new image
398
+ had_image_uri = bool(config["docker"]["image_uri"])
399
+
400
+ # Build or use existing image
401
+ image_uri = build_job_image(config)
402
+
403
+ if mode in ["local", "shell"]:
404
+ run_container(config, image_uri, mode)
405
+ return
406
+
407
+ # Handle batch mode
408
+ cloud = config.get("cloud")
409
+ # We already validated cloud above, so no need to check again
410
+
411
+ # Push image if we built it
412
+ if not had_image_uri:
413
+ if cloud == "aws":
414
+ push_image_to_ecr(image_uri, config)
415
+ else: # cloud == "gcp"
416
+ print("\nPushing image to Artifact Registry")
417
+ registry = config["gcp"]["registry_uri"].split("/")[
418
+ 0
419
+ ] # e.g. "us-central1-docker.pkg.dev"
420
+ print(f"Configuring Docker authentication for {registry}")
421
+ subprocess.run(
422
+ ["gcloud", "auth", "configure-docker", registry, "--quiet"],
423
+ check=True,
424
+ )
425
+ subprocess.run(["docker", "push", image_uri], check=True)
426
+ print(f"Pushed image to Artifact Registry: {image_uri}")
427
+
428
+ # Submit batch job
429
+ if cloud == "aws":
430
+ job_id, job_name = submit_aws_batch_job(image_uri, config)
431
+ print(f"\nSubmitted AWS Batch job '{job_name}' with ID: {job_id}")
432
+ else: # cloud == "gcp"
433
+ submit_gcp_batch_job(config, image_uri)
434
+ print(f"\nSubmitted GCP Batch job: {config['gcp']['job_name']}")