dayhoff-tools 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,781 @@
1
+ """Utility functions and classes for deployment and system monitoring.
2
+
3
+ This module provides utilities for:
4
+ 1. System monitoring (GPU, CPU, memory)
5
+ 2. Environment setup and authentication
6
+ 3. DVC and repository configuration
7
+ 4. Cloud instance metadata and identification
8
+ """
9
+
10
+ import base64
11
+ import json
12
+ import logging
13
+ import os
14
+ import socket
15
+ import subprocess
16
+ import threading
17
+ import time
18
+ from dataclasses import dataclass
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+ from typing import Dict, Optional
22
+
23
+ import requests
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def get_instance_metadata(timeout: int = 2) -> Dict[str, str]:
29
+ """
30
+ Get instance metadata from various cloud providers.
31
+
32
+ This function attempts to retrieve metadata from GCP, AWS, and other cloud environments.
33
+ It tries different metadata endpoints and returns a consolidated dictionary
34
+ with information about the current cloud instance.
35
+
36
+ Args:
37
+ timeout: Timeout in seconds for HTTP requests
38
+
39
+ Returns:
40
+ Dictionary containing instance metadata with keys:
41
+ - provider: Cloud provider name ('gcp', 'aws', 'azure', 'unknown')
42
+ - instance_id: Unique instance identifier
43
+ - instance_name: Name of the instance (if available)
44
+ - instance_type: Type/size of the instance (e.g., 't2.micro', 'n1-standard-1')
45
+ - region: Region where the instance is running
46
+ - zone: Availability zone
47
+ """
48
+ metadata = {
49
+ "provider": "unknown",
50
+ "instance_id": "unknown",
51
+ "instance_name": "unknown",
52
+ "instance_type": "unknown",
53
+ "region": "unknown",
54
+ "zone": "unknown",
55
+ }
56
+
57
+ # Try GCP metadata
58
+ try:
59
+ headers = {"Metadata-Flavor": "Google"}
60
+ # Check if we're in GCP
61
+ response = requests.get(
62
+ "http://metadata.google.internal/computeMetadata/v1/instance/id",
63
+ headers=headers,
64
+ timeout=timeout,
65
+ )
66
+ if response.status_code == 200:
67
+ metadata["provider"] = "gcp"
68
+ metadata["instance_id"] = response.text
69
+
70
+ # Get instance name
71
+ response = requests.get(
72
+ "http://metadata.google.internal/computeMetadata/v1/instance/name",
73
+ headers=headers,
74
+ timeout=timeout,
75
+ )
76
+ if response.status_code == 200:
77
+ metadata["instance_name"] = response.text
78
+
79
+ # Get machine type
80
+ response = requests.get(
81
+ "http://metadata.google.internal/computeMetadata/v1/instance/machine-type",
82
+ headers=headers,
83
+ timeout=timeout,
84
+ )
85
+ if response.status_code == 200:
86
+ machine_type_path = response.text
87
+ metadata["instance_type"] = machine_type_path.split("/")[-1]
88
+
89
+ # Get zone
90
+ response = requests.get(
91
+ "http://metadata.google.internal/computeMetadata/v1/instance/zone",
92
+ headers=headers,
93
+ timeout=timeout,
94
+ )
95
+ if response.status_code == 200:
96
+ zone_path = response.text
97
+ metadata["zone"] = zone_path.split("/")[-1]
98
+ # Extract region from zone (e.g., us-central1-a -> us-central1)
99
+ if "-" in metadata["zone"]:
100
+ metadata["region"] = "-".join(metadata["zone"].split("-")[:-1])
101
+
102
+ return metadata
103
+ except Exception as e:
104
+ logger.debug(f"Not a GCP instance or metadata server not available: {e}")
105
+
106
+ # Try AWS metadata
107
+ try:
108
+ token_response = requests.put(
109
+ "http://169.254.169.254/latest/api/token",
110
+ headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
111
+ timeout=timeout,
112
+ )
113
+ if token_response.status_code == 200:
114
+ token = token_response.text
115
+ headers = {"X-aws-ec2-metadata-token": token}
116
+
117
+ metadata["provider"] = "aws"
118
+
119
+ # Get instance ID
120
+ response = requests.get(
121
+ "http://169.254.169.254/latest/meta-data/instance-id",
122
+ headers=headers,
123
+ timeout=timeout,
124
+ )
125
+ if response.status_code == 200:
126
+ metadata["instance_id"] = response.text
127
+
128
+ # Get instance type
129
+ response = requests.get(
130
+ "http://169.254.169.254/latest/meta-data/instance-type",
131
+ headers=headers,
132
+ timeout=timeout,
133
+ )
134
+ if response.status_code == 200:
135
+ metadata["instance_type"] = response.text
136
+
137
+ # Get availability zone
138
+ response = requests.get(
139
+ "http://169.254.169.254/latest/meta-data/placement/availability-zone",
140
+ headers=headers,
141
+ timeout=timeout,
142
+ )
143
+ if response.status_code == 200:
144
+ metadata["zone"] = response.text
145
+ # Extract region from availability zone (e.g., us-east-1a -> us-east-1)
146
+ metadata["region"] = metadata["zone"][:-1]
147
+
148
+ # AWS doesn't provide an instance name directly,
149
+ # but we can use the hostname or instance-id as a fallback
150
+ try:
151
+ response = requests.get(
152
+ "http://169.254.169.254/latest/meta-data/hostname",
153
+ headers=headers,
154
+ timeout=timeout,
155
+ )
156
+ if response.status_code == 200:
157
+ metadata["instance_name"] = response.text
158
+ else:
159
+ metadata["instance_name"] = metadata["instance_id"]
160
+ except:
161
+ metadata["instance_name"] = metadata["instance_id"]
162
+
163
+ return metadata
164
+ except Exception as e:
165
+ logger.debug(f"Not an AWS EC2 instance or metadata server not available: {e}")
166
+
167
+ # Try Azure metadata (if needed in the future)
168
+ # ...
169
+
170
+ # As a fallback, try to get some basic info from the host
171
+ try:
172
+ metadata["instance_name"] = socket.gethostname()
173
+ # Check if we're running in a container environment
174
+ if os.path.exists("/.dockerenv") or os.path.exists("/run/.containerenv"):
175
+ metadata["provider"] = "container"
176
+ # Check batch environment variables
177
+ if os.getenv("BATCH_TASK_INDEX") is not None:
178
+ metadata["provider"] = "gcp-batch"
179
+ metadata["instance_name"] = f"batch-task-{os.getenv('BATCH_TASK_INDEX')}"
180
+ elif os.getenv("AWS_BATCH_JOB_ID") is not None:
181
+ metadata["provider"] = "aws-batch"
182
+ metadata["instance_name"] = f"aws-batch-{os.getenv('AWS_BATCH_JOB_ID')}"
183
+ except Exception as e:
184
+ logger.debug(f"Error getting hostname: {e}")
185
+
186
+ return metadata
187
+
188
+
189
+ def get_instance_name() -> str:
190
+ """
191
+ Get the name of the current cloud instance or VM.
192
+
193
+ This is a cross-platform replacement for the old get_vm_name() function.
194
+ Works with GCP, AWS, and other environments.
195
+
196
+ Returns:
197
+ A string containing the instance name, hostname, or ID
198
+ """
199
+ metadata = get_instance_metadata()
200
+ return metadata["instance_name"]
201
+
202
+
203
+ def get_instance_type() -> str:
204
+ """
205
+ Get the machine type/size of the current cloud instance.
206
+
207
+ This is a cross-platform replacement for the old get_vm_type() function.
208
+ Works with GCP (e.g., n1-standard-1), AWS (e.g., t2.micro), and other environments.
209
+
210
+ Returns:
211
+ A string representing the instance type/size, or 'unknown' if not available
212
+ """
213
+ metadata = get_instance_metadata()
214
+ return metadata["instance_type"]
215
+
216
+
217
+ def get_cloud_provider() -> str:
218
+ """
219
+ Get the cloud provider where this code is running.
220
+
221
+ Returns:
222
+ A string identifying the cloud provider ('gcp', 'aws', 'azure', 'container', 'unknown')
223
+ """
224
+ metadata = get_instance_metadata()
225
+ return metadata["provider"]
226
+
227
+
228
+ def move_to_repo_root() -> None:
229
+ """Move to the repository root directory.
230
+
231
+ Determines the repository root through multiple methods (in order):
232
+ 1. Direct specification via REPO_ROOT environment variable
233
+ 2. Standard marker files (.git, setup.py, pyproject.toml)
234
+ 3. Container standard paths (/app if it exists and contains expected files)
235
+ 4. NAME_OF_THIS_REPO environment variable (for VM environments)
236
+
237
+ Raises:
238
+ OSError: If repository root cannot be determined
239
+ """
240
+ try:
241
+ # Check if REPO_ROOT is directly specified
242
+ if "REPO_ROOT" in os.environ:
243
+ root_path = Path(os.environ["REPO_ROOT"])
244
+ if root_path.exists():
245
+ logger.info(f"Using environment-specified REPO_ROOT: {root_path}")
246
+ os.chdir(root_path)
247
+ return
248
+
249
+ # Try to find repo root by looking for standard files
250
+ current = Path.cwd()
251
+ while current != current.parent:
252
+ if any(
253
+ (current / marker).exists()
254
+ for marker in [".git", "setup.py", "pyproject.toml"]
255
+ ):
256
+ logger.info(f"Found repository root at: {current}")
257
+ os.chdir(current)
258
+ return
259
+ current = current.parent
260
+
261
+ # Check for container standard paths
262
+ container_paths = ["/app", "/workspace", "/code"]
263
+ for path in container_paths:
264
+ container_root = Path(path)
265
+ if container_root.exists() and any(
266
+ (container_root / subdir).exists()
267
+ for subdir in ["src", "swarm", "dayhoff_tools"]
268
+ ):
269
+ logger.info(f"Using container standard path: {container_root}")
270
+ os.chdir(container_root)
271
+ return
272
+
273
+ # Fallback to environment variable if available
274
+ try:
275
+ name_of_this_repo = os.environ["NAME_OF_THIS_REPO"]
276
+ root_path = Path(f"/workspaces/{name_of_this_repo}")
277
+ if root_path.exists():
278
+ logger.info(f"Using workspace repository path: {root_path}")
279
+ os.chdir(root_path)
280
+ return
281
+ except KeyError as e:
282
+ logger.warning(f"NAME_OF_THIS_REPO environment variable not set: {e}")
283
+
284
+ # If we're already at what looks like a valid root, just stay here
285
+ if any(
286
+ Path.cwd().joinpath(marker).exists()
287
+ for marker in ["setup.py", "pyproject.toml", "src", "dayhoff_tools"]
288
+ ):
289
+ logger.info(f"Current directory appears to be a valid root: {Path.cwd()}")
290
+ return
291
+
292
+ raise OSError("Could not determine repository root")
293
+ except Exception as e:
294
+ logger.error(f"ERROR: Could not move to repository root: {e}")
295
+ raise
296
+
297
+
298
+ def upload_folder_to_gcs(local_folder: str, bucket, gcs_folder: str):
299
+ """
300
+ Upload all files from a local folder to a GCS folder
301
+
302
+ Args:
303
+ local_folder: Path to the local folder
304
+ bucket: GCS bucket object
305
+ gcs_folder: Destination folder path in GCS
306
+ """
307
+ local_path = Path(local_folder)
308
+
309
+ for local_file in local_path.glob("**/*"):
310
+ if local_file.is_file():
311
+ # Construct the GCS path
312
+ relative_path = local_file.relative_to(local_path)
313
+ gcs_path = f"{gcs_folder.rstrip('/')}/{relative_path}"
314
+
315
+ # Upload the file
316
+ blob = bucket.blob(gcs_path)
317
+ blob.upload_from_filename(str(local_file))
318
+ logger.info(f"Uploaded {local_file} to {gcs_path}")
319
+
320
+
321
+ @dataclass
322
+ class SystemStats:
323
+ """Container for system statistics."""
324
+
325
+ timestamp: str
326
+ vm_id: str
327
+ cpu_usage: float
328
+ mem_usage: float
329
+ gpu_usage: Optional[float]
330
+ disk_usage: float
331
+
332
+ def __str__(self) -> str:
333
+ """Format system stats for logging."""
334
+ return (
335
+ f"VM:{self.vm_id} "
336
+ f"CPU:{self.cpu_usage:.1f}% "
337
+ f"MEM:{self.mem_usage:.1f}% "
338
+ f"DISK:{self.disk_usage:.1f}% "
339
+ f"GPU:{self.gpu_usage if self.gpu_usage is not None else 'N/A'}%"
340
+ )
341
+
342
+
343
+ class SystemMonitor:
344
+ """Monitor system resources and GPU availability."""
345
+
346
+ def __init__(self, fail_without_gpu: bool = False):
347
+ """Initialize system monitor.
348
+
349
+ Args:
350
+ fail_without_gpu: Whether to terminate if GPU becomes unavailable.
351
+ """
352
+ self.fail_without_gpu = fail_without_gpu
353
+ self.should_run = True
354
+ self._thread: Optional[threading.Thread] = None
355
+
356
+ def start(self) -> None:
357
+ """Start system monitoring in a background thread."""
358
+ if self._thread is not None:
359
+ return
360
+
361
+ self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
362
+ self._thread.start()
363
+
364
+ def stop(self) -> None:
365
+ """Stop system monitoring."""
366
+ self.should_run = False
367
+ if self._thread is not None:
368
+ self._thread.join()
369
+ self._thread = None
370
+
371
+ def _monitor_loop(self) -> None:
372
+ """Main monitoring loop."""
373
+ while self.should_run:
374
+ try:
375
+ if self.fail_without_gpu and not is_gpu_available():
376
+ logger.error(
377
+ f"[{self._get_vm_id()}] GPU became unavailable. Terminating process."
378
+ )
379
+ self._kill_wandb()
380
+ # Force exit the entire process
381
+ os._exit(1)
382
+
383
+ stats = self._get_system_stats()
384
+ logger.info(str(stats))
385
+
386
+ time.sleep(300) # Check every 5 minutes
387
+ except Exception as e:
388
+ logger.error(f"Error in monitoring loop: {e}")
389
+ time.sleep(300) # Continue monitoring even if there's an error
390
+
391
+ def _kill_wandb(self) -> None:
392
+ """Kill wandb agent process."""
393
+ try:
394
+ subprocess.run(["pkill", "-f", "wandb agent"], check=True)
395
+ except subprocess.CalledProcessError:
396
+ pass # Process might not exist
397
+
398
+ def _get_vm_id(self) -> str:
399
+ """Get VM identifier from GCP metadata server."""
400
+ try:
401
+ response = requests.get(
402
+ "http://metadata.google.internal/computeMetadata/v1/instance/id",
403
+ headers={"Metadata-Flavor": "Google"},
404
+ timeout=5,
405
+ )
406
+ return response.text
407
+ except Exception:
408
+ return "unknown"
409
+
410
+ def _get_system_stats(self) -> SystemStats:
411
+ """Collect current system statistics."""
412
+ # Get timestamp
413
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
414
+
415
+ # Get VM ID
416
+ vm_id = self._get_vm_id()
417
+
418
+ # Get CPU usage
419
+ cpu_cmd = "top -bn1 | grep 'Cpu(s)' | awk '{print $2 + $4}'"
420
+ cpu_usage = float(subprocess.check_output(cpu_cmd, shell=True).decode().strip())
421
+
422
+ # Get memory usage
423
+ mem_cmd = "free | grep Mem | awk '{print $3/$2 * 100.0}'"
424
+ mem_usage = float(subprocess.check_output(mem_cmd, shell=True).decode().strip())
425
+
426
+ # Get GPU usage if available
427
+ gpu_usage = None
428
+ if is_gpu_available():
429
+ try:
430
+ gpu_cmd = "nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits"
431
+ gpu_usage = float(
432
+ subprocess.check_output(gpu_cmd, shell=True).decode().strip()
433
+ )
434
+ except (subprocess.CalledProcessError, ValueError):
435
+ pass
436
+
437
+ # Add disk usage check
438
+ disk_cmd = "df -h / | awk 'NR==2 {print $5}' | sed 's/%//'"
439
+ disk_usage = float(
440
+ subprocess.check_output(disk_cmd, shell=True).decode().strip()
441
+ )
442
+
443
+ return SystemStats(
444
+ timestamp, vm_id, cpu_usage, mem_usage, gpu_usage, disk_usage
445
+ )
446
+
447
+
448
+ def is_gpu_available() -> bool:
449
+ """Check if NVIDIA GPU is available.
450
+
451
+ Returns:
452
+ bool: True if GPU is available and functioning
453
+ """
454
+ try:
455
+ subprocess.run(["nvidia-smi"], check=True, capture_output=True)
456
+ return True
457
+ except (subprocess.CalledProcessError, FileNotFoundError):
458
+ return False
459
+
460
+
461
+ def get_required_env_var(name: str) -> str:
462
+ """Get an environment variable or raise an error if it's not set"""
463
+ value = os.getenv(name)
464
+ if value is None:
465
+ raise ValueError(f"Required environment variable {name} is not set")
466
+ return value
467
+
468
+
469
+ def authenticate_gcp() -> None:
470
+ """Authenticate with Google Cloud Platform.
471
+
472
+ Uses GOOGLE_APPLICATION_CREDENTIALS_BASE64 from environment.
473
+ Skips if no credentials are available.
474
+ """
475
+ credentials_base64 = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_BASE64")
476
+ if not credentials_base64:
477
+ logger.info("No GCP credentials provided, skipping authentication")
478
+ return
479
+
480
+ logger.info("Authenticating with Google Cloud")
481
+
482
+ # Decode and save credentials
483
+ credentials = base64.b64decode(credentials_base64).decode("utf-8")
484
+ with open("workerbee.json", "w") as f:
485
+ f.write(credentials)
486
+
487
+ # Activate service account (suppress survey output)
488
+ subprocess.run(
489
+ ["gcloud", "auth", "activate-service-account", "--key-file=workerbee.json"],
490
+ check=True,
491
+ capture_output=True,
492
+ )
493
+
494
+ # Configure project
495
+ subprocess.run(
496
+ ["gcloud", "config", "set", "project", "enzyme-discovery"],
497
+ check=True,
498
+ capture_output=True,
499
+ )
500
+ logger.info("Set project to enzyme-discovery")
501
+
502
+ # Verify configuration
503
+ subprocess.run(
504
+ ["gcloud", "config", "get-value", "project"],
505
+ check=True,
506
+ capture_output=True,
507
+ )
508
+
509
+ # Get and print active service account
510
+ result = subprocess.run(
511
+ ["gcloud", "auth", "list", "--filter=status:ACTIVE", "--format=value(account)"],
512
+ check=True,
513
+ capture_output=True,
514
+ text=True,
515
+ )
516
+ logger.info(f"Activated service account credentials for: {result.stdout.strip()}")
517
+
518
+ # Set explicit credentials path if it exists
519
+ if os.path.exists("workerbee.json"):
520
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath("workerbee.json")
521
+ logger.info(
522
+ f"Set GOOGLE_APPLICATION_CREDENTIALS to {os.environ['GOOGLE_APPLICATION_CREDENTIALS']}"
523
+ )
524
+
525
+
526
+ def setup_dvc() -> None:
527
+ """Initialize and configure DVC with GitHub remote.
528
+
529
+ Only runs if USE_DVC="true" in environment.
530
+ Requires GCP authentication to be already set up.
531
+ """
532
+ if os.getenv("USE_DVC", "").lower() != "true":
533
+ logger.info("DVC not enabled, skipping setup")
534
+ return
535
+
536
+ if not os.path.exists("workerbee.json"):
537
+ logger.info("GCP credentials not found, skipping DVC setup")
538
+ return
539
+
540
+ logger.info("Initializing DVC")
541
+
542
+ # Initialize DVC without git
543
+ subprocess.run(["dvc", "init", "--no-scm"], check=True, capture_output=True)
544
+
545
+ # Get GitHub PAT from GCP secrets
546
+ warehouse_pat = subprocess.run(
547
+ [
548
+ "gcloud",
549
+ "secrets",
550
+ "versions",
551
+ "access",
552
+ "latest",
553
+ "--secret=warehouse-read-pat",
554
+ ],
555
+ check=True,
556
+ capture_output=True,
557
+ text=True,
558
+ ).stdout.strip()
559
+
560
+ # Configure DVC remote
561
+ subprocess.run(
562
+ [
563
+ "dvc",
564
+ "remote",
565
+ "add",
566
+ "--default",
567
+ "warehouse",
568
+ "https://github.com/dayhofflabs/warehouse.git",
569
+ ],
570
+ check=True,
571
+ )
572
+ subprocess.run(
573
+ ["dvc", "remote", "modify", "warehouse", "auth", "basic"], check=True
574
+ )
575
+ subprocess.run(
576
+ [
577
+ "dvc",
578
+ "remote",
579
+ "modify",
580
+ "warehouse",
581
+ "--local",
582
+ "user",
583
+ "DanielMartinAlarcon",
584
+ ],
585
+ check=True,
586
+ )
587
+ subprocess.run(
588
+ ["dvc", "remote", "modify", "warehouse", "--local", "password", warehouse_pat],
589
+ check=True,
590
+ )
591
+
592
+ # Setup GitHub HTTPS access
593
+ git_config_path = Path.home() / ".git-credentials"
594
+ git_config_path.write_text(
595
+ f"https://DanielMartinAlarcon:{warehouse_pat}@github.com\n"
596
+ )
597
+ subprocess.run(
598
+ ["git", "config", "--global", "credential.helper", "store"], check=True
599
+ )
600
+
601
+
602
+ def setup_rxnfp() -> None:
603
+ """Clone rxnfp library.
604
+
605
+ Only runs if USE_RXNFP="true" in environment.
606
+ """
607
+ if os.getenv("USE_RXNFP", "").lower() != "true":
608
+ logger.info("RXNFP not enabled, skipping setup")
609
+ return
610
+
611
+ logger.info("Cloning rxnfp library...")
612
+ subprocess.run(
613
+ ["git", "clone", "https://github.com/rxn4chemistry/rxnfp.git"], check=True
614
+ )
615
+
616
+
617
+ def docker_login(registry: str, username: str, password: str) -> None:
618
+ """Login to a Docker registry using provided credentials.
619
+
620
+ Args:
621
+ registry: Registry URI to login to
622
+ username: Username for registry authentication
623
+ password: Password for registry authentication
624
+
625
+ Raises:
626
+ subprocess.CalledProcessError: If Docker login fails
627
+ """
628
+ # Create .docker directory if it doesn't exist
629
+ docker_config_dir = os.path.expanduser("~/.docker")
630
+ os.makedirs(docker_config_dir, exist_ok=True)
631
+
632
+ # Write a minimal config file that disables credential helpers
633
+ with open(os.path.join(docker_config_dir, "config.json"), "w") as f:
634
+ json.dump({"auths": {}, "credsStore": ""}, f)
635
+
636
+ # Login to Docker using the credentials
637
+ login_process = subprocess.run(
638
+ [
639
+ "docker",
640
+ "login",
641
+ "--username",
642
+ username,
643
+ "--password-stdin",
644
+ registry,
645
+ ],
646
+ input=password.encode(),
647
+ capture_output=True,
648
+ check=True,
649
+ )
650
+
651
+ if login_process.stderr:
652
+ logger.warning(f"Docker login warning: {login_process.stderr.decode()}")
653
+
654
+
655
+ def get_container_env_vars(config: dict) -> dict:
656
+ """Get all environment variables for the container.
657
+
658
+ This function collects environment variables from multiple sources:
659
+ 1. Feature flags from config.features
660
+ 2. Feature-specific variables (e.g. WANDB_API_KEY, GCP credentials)
661
+ 3. Additional variables specified in config.env_vars
662
+
663
+ Args:
664
+ config: Configuration dictionary
665
+
666
+ Returns:
667
+ Dictionary of environment variables
668
+
669
+ Raises:
670
+ FileNotFoundError: If GCP credentials file is not found when use_gcp_auth is True
671
+ """
672
+ env_vars = {}
673
+
674
+ # Process features section to set feature flags
675
+ features = config.get("features")
676
+ # Handle case where features key exists but has no value (None)
677
+ if features is None:
678
+ features = []
679
+
680
+ # Handle boolean features that don't have associated credentials
681
+ bool_features = {
682
+ "use_dvc": "USE_DVC",
683
+ "use_rxnfp": "USE_RXNFP",
684
+ "fail_without_gpu": "FAIL_WITHOUT_GPU",
685
+ }
686
+
687
+ # Initialize all features to "false"
688
+ for env_var in bool_features.values():
689
+ env_vars[env_var] = "false"
690
+
691
+ # Set enabled features to "true"
692
+ for feature in features:
693
+ if isinstance(feature, str) and feature in bool_features:
694
+ env_vars[bool_features[feature]] = "true"
695
+ elif isinstance(feature, dict):
696
+ # Handle job_command if present
697
+ if "job_command" in feature:
698
+ env_vars["JOB_COMMAND"] = feature["job_command"]
699
+
700
+ # Add feature-specific variables when their features are enabled
701
+ features_set = {f if isinstance(f, str) else next(iter(f)) for f in features}
702
+
703
+ if "use_wandb" in features_set:
704
+ wandb_key = os.getenv("WANDB_API_KEY", "WANDB_API_KEY_NOT_SET")
705
+ env_vars["WANDB_API_KEY"] = wandb_key
706
+ logger.info(f"Loading WANDB_API_KEY into container: {wandb_key}")
707
+
708
+ if "use_gcp_auth" in features_set:
709
+ key_file = ".config/workerbee.json"
710
+ if not os.path.exists(key_file):
711
+ raise FileNotFoundError(
712
+ f"GCP credentials file not found: {key_file}. Required when use_gcp_auth=True"
713
+ )
714
+
715
+ # base64-encode the workerbee service account key file
716
+ creds = subprocess.run(
717
+ f"base64 {key_file} | tr -d '\n'",
718
+ shell=True,
719
+ check=True,
720
+ capture_output=True,
721
+ text=True,
722
+ ).stdout
723
+ env_vars["GOOGLE_APPLICATION_CREDENTIALS_BASE64"] = creds
724
+ logger.info(f"Loaded GCP credentials into container: {key_file}")
725
+
726
+ # Get environment variables from config
727
+ config_env_vars = config.get("env_vars")
728
+ # Handle case where env_vars key exists but has no value (None)
729
+ if config_env_vars is None:
730
+ config_env_vars = {}
731
+
732
+ # Add them to the envvars made here
733
+ env_vars.update(config_env_vars)
734
+
735
+ return env_vars
736
+
737
+
738
+ def determine_worker_count(logger=None) -> int:
739
+ """Determine optimal worker count based on CPU cores and environment.
740
+
741
+ Uses different strategies depending on the detected environment:
742
+ - For batch environments (GCP Batch, AWS Batch), uses (CPU_COUNT - 1)
743
+ - For development environments, uses (CPU_COUNT // 2)
744
+ - Always ensures at least 1 worker is returned
745
+
746
+ Args:
747
+ logger: Optional logger to output the decision (defaults to None)
748
+
749
+ Returns:
750
+ int: Recommended number of worker processes
751
+ """
752
+ import multiprocessing
753
+ import os
754
+
755
+ # Detect CPU cores
756
+ cpu_count = multiprocessing.cpu_count()
757
+
758
+ # Detect if running in a cloud batch environment:
759
+ # - GCP Batch sets BATCH_TASK_INDEX
760
+ # - AWS Batch sets AWS_BATCH_JOB_ID
761
+ is_batch_env = (
762
+ os.getenv("BATCH_TASK_INDEX") is not None # GCP Batch
763
+ or os.getenv("AWS_BATCH_JOB_ID") is not None # AWS Batch
764
+ )
765
+
766
+ if is_batch_env:
767
+ # In batch environment, use most cores
768
+ num_workers = max(1, cpu_count - 1)
769
+ if logger:
770
+ logger.info(
771
+ f"Batch environment detected. Using {num_workers} of {cpu_count} cores"
772
+ )
773
+ else:
774
+ # In dev environment, be more conservative
775
+ num_workers = max(1, cpu_count // 2)
776
+ if logger:
777
+ logger.info(
778
+ f"Development environment detected. Using {num_workers} of {cpu_count} cores"
779
+ )
780
+
781
+ return num_workers