dayhoff-tools 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dayhoff_tools/__init__.py +0 -0
- dayhoff_tools/chemistry/standardizer.py +297 -0
- dayhoff_tools/chemistry/utils.py +63 -0
- dayhoff_tools/cli/__init__.py +0 -0
- dayhoff_tools/cli/main.py +90 -0
- dayhoff_tools/cli/swarm_commands.py +156 -0
- dayhoff_tools/cli/utility_commands.py +244 -0
- dayhoff_tools/deployment/base.py +434 -0
- dayhoff_tools/deployment/deploy_aws.py +458 -0
- dayhoff_tools/deployment/deploy_gcp.py +176 -0
- dayhoff_tools/deployment/deploy_utils.py +781 -0
- dayhoff_tools/deployment/job_runner.py +153 -0
- dayhoff_tools/deployment/processors.py +125 -0
- dayhoff_tools/deployment/swarm.py +591 -0
- dayhoff_tools/embedders.py +893 -0
- dayhoff_tools/fasta.py +1082 -0
- dayhoff_tools/file_ops.py +261 -0
- dayhoff_tools/gcp.py +85 -0
- dayhoff_tools/h5.py +542 -0
- dayhoff_tools/kegg.py +37 -0
- dayhoff_tools/logs.py +27 -0
- dayhoff_tools/mmseqs.py +164 -0
- dayhoff_tools/sqlite.py +516 -0
- dayhoff_tools/structure.py +751 -0
- dayhoff_tools/uniprot.py +434 -0
- dayhoff_tools/warehouse.py +418 -0
- dayhoff_tools-1.0.0.dist-info/METADATA +122 -0
- dayhoff_tools-1.0.0.dist-info/RECORD +30 -0
- dayhoff_tools-1.0.0.dist-info/WHEEL +4 -0
- dayhoff_tools-1.0.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,781 @@
|
|
1
|
+
"""Utility functions and classes for deployment and system monitoring.
|
2
|
+
|
3
|
+
This module provides utilities for:
|
4
|
+
1. System monitoring (GPU, CPU, memory)
|
5
|
+
2. Environment setup and authentication
|
6
|
+
3. DVC and repository configuration
|
7
|
+
4. Cloud instance metadata and identification
|
8
|
+
"""
|
9
|
+
|
10
|
+
import base64
|
11
|
+
import json
|
12
|
+
import logging
|
13
|
+
import os
|
14
|
+
import socket
|
15
|
+
import subprocess
|
16
|
+
import threading
|
17
|
+
import time
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from datetime import datetime
|
20
|
+
from pathlib import Path
|
21
|
+
from typing import Dict, Optional
|
22
|
+
|
23
|
+
import requests
|
24
|
+
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
def get_instance_metadata(timeout: int = 2) -> Dict[str, str]:
|
29
|
+
"""
|
30
|
+
Get instance metadata from various cloud providers.
|
31
|
+
|
32
|
+
This function attempts to retrieve metadata from GCP, AWS, and other cloud environments.
|
33
|
+
It tries different metadata endpoints and returns a consolidated dictionary
|
34
|
+
with information about the current cloud instance.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
timeout: Timeout in seconds for HTTP requests
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Dictionary containing instance metadata with keys:
|
41
|
+
- provider: Cloud provider name ('gcp', 'aws', 'azure', 'unknown')
|
42
|
+
- instance_id: Unique instance identifier
|
43
|
+
- instance_name: Name of the instance (if available)
|
44
|
+
- instance_type: Type/size of the instance (e.g., 't2.micro', 'n1-standard-1')
|
45
|
+
- region: Region where the instance is running
|
46
|
+
- zone: Availability zone
|
47
|
+
"""
|
48
|
+
metadata = {
|
49
|
+
"provider": "unknown",
|
50
|
+
"instance_id": "unknown",
|
51
|
+
"instance_name": "unknown",
|
52
|
+
"instance_type": "unknown",
|
53
|
+
"region": "unknown",
|
54
|
+
"zone": "unknown",
|
55
|
+
}
|
56
|
+
|
57
|
+
# Try GCP metadata
|
58
|
+
try:
|
59
|
+
headers = {"Metadata-Flavor": "Google"}
|
60
|
+
# Check if we're in GCP
|
61
|
+
response = requests.get(
|
62
|
+
"http://metadata.google.internal/computeMetadata/v1/instance/id",
|
63
|
+
headers=headers,
|
64
|
+
timeout=timeout,
|
65
|
+
)
|
66
|
+
if response.status_code == 200:
|
67
|
+
metadata["provider"] = "gcp"
|
68
|
+
metadata["instance_id"] = response.text
|
69
|
+
|
70
|
+
# Get instance name
|
71
|
+
response = requests.get(
|
72
|
+
"http://metadata.google.internal/computeMetadata/v1/instance/name",
|
73
|
+
headers=headers,
|
74
|
+
timeout=timeout,
|
75
|
+
)
|
76
|
+
if response.status_code == 200:
|
77
|
+
metadata["instance_name"] = response.text
|
78
|
+
|
79
|
+
# Get machine type
|
80
|
+
response = requests.get(
|
81
|
+
"http://metadata.google.internal/computeMetadata/v1/instance/machine-type",
|
82
|
+
headers=headers,
|
83
|
+
timeout=timeout,
|
84
|
+
)
|
85
|
+
if response.status_code == 200:
|
86
|
+
machine_type_path = response.text
|
87
|
+
metadata["instance_type"] = machine_type_path.split("/")[-1]
|
88
|
+
|
89
|
+
# Get zone
|
90
|
+
response = requests.get(
|
91
|
+
"http://metadata.google.internal/computeMetadata/v1/instance/zone",
|
92
|
+
headers=headers,
|
93
|
+
timeout=timeout,
|
94
|
+
)
|
95
|
+
if response.status_code == 200:
|
96
|
+
zone_path = response.text
|
97
|
+
metadata["zone"] = zone_path.split("/")[-1]
|
98
|
+
# Extract region from zone (e.g., us-central1-a -> us-central1)
|
99
|
+
if "-" in metadata["zone"]:
|
100
|
+
metadata["region"] = "-".join(metadata["zone"].split("-")[:-1])
|
101
|
+
|
102
|
+
return metadata
|
103
|
+
except Exception as e:
|
104
|
+
logger.debug(f"Not a GCP instance or metadata server not available: {e}")
|
105
|
+
|
106
|
+
# Try AWS metadata
|
107
|
+
try:
|
108
|
+
token_response = requests.put(
|
109
|
+
"http://169.254.169.254/latest/api/token",
|
110
|
+
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
|
111
|
+
timeout=timeout,
|
112
|
+
)
|
113
|
+
if token_response.status_code == 200:
|
114
|
+
token = token_response.text
|
115
|
+
headers = {"X-aws-ec2-metadata-token": token}
|
116
|
+
|
117
|
+
metadata["provider"] = "aws"
|
118
|
+
|
119
|
+
# Get instance ID
|
120
|
+
response = requests.get(
|
121
|
+
"http://169.254.169.254/latest/meta-data/instance-id",
|
122
|
+
headers=headers,
|
123
|
+
timeout=timeout,
|
124
|
+
)
|
125
|
+
if response.status_code == 200:
|
126
|
+
metadata["instance_id"] = response.text
|
127
|
+
|
128
|
+
# Get instance type
|
129
|
+
response = requests.get(
|
130
|
+
"http://169.254.169.254/latest/meta-data/instance-type",
|
131
|
+
headers=headers,
|
132
|
+
timeout=timeout,
|
133
|
+
)
|
134
|
+
if response.status_code == 200:
|
135
|
+
metadata["instance_type"] = response.text
|
136
|
+
|
137
|
+
# Get availability zone
|
138
|
+
response = requests.get(
|
139
|
+
"http://169.254.169.254/latest/meta-data/placement/availability-zone",
|
140
|
+
headers=headers,
|
141
|
+
timeout=timeout,
|
142
|
+
)
|
143
|
+
if response.status_code == 200:
|
144
|
+
metadata["zone"] = response.text
|
145
|
+
# Extract region from availability zone (e.g., us-east-1a -> us-east-1)
|
146
|
+
metadata["region"] = metadata["zone"][:-1]
|
147
|
+
|
148
|
+
# AWS doesn't provide an instance name directly,
|
149
|
+
# but we can use the hostname or instance-id as a fallback
|
150
|
+
try:
|
151
|
+
response = requests.get(
|
152
|
+
"http://169.254.169.254/latest/meta-data/hostname",
|
153
|
+
headers=headers,
|
154
|
+
timeout=timeout,
|
155
|
+
)
|
156
|
+
if response.status_code == 200:
|
157
|
+
metadata["instance_name"] = response.text
|
158
|
+
else:
|
159
|
+
metadata["instance_name"] = metadata["instance_id"]
|
160
|
+
except:
|
161
|
+
metadata["instance_name"] = metadata["instance_id"]
|
162
|
+
|
163
|
+
return metadata
|
164
|
+
except Exception as e:
|
165
|
+
logger.debug(f"Not an AWS EC2 instance or metadata server not available: {e}")
|
166
|
+
|
167
|
+
# Try Azure metadata (if needed in the future)
|
168
|
+
# ...
|
169
|
+
|
170
|
+
# As a fallback, try to get some basic info from the host
|
171
|
+
try:
|
172
|
+
metadata["instance_name"] = socket.gethostname()
|
173
|
+
# Check if we're running in a container environment
|
174
|
+
if os.path.exists("/.dockerenv") or os.path.exists("/run/.containerenv"):
|
175
|
+
metadata["provider"] = "container"
|
176
|
+
# Check batch environment variables
|
177
|
+
if os.getenv("BATCH_TASK_INDEX") is not None:
|
178
|
+
metadata["provider"] = "gcp-batch"
|
179
|
+
metadata["instance_name"] = f"batch-task-{os.getenv('BATCH_TASK_INDEX')}"
|
180
|
+
elif os.getenv("AWS_BATCH_JOB_ID") is not None:
|
181
|
+
metadata["provider"] = "aws-batch"
|
182
|
+
metadata["instance_name"] = f"aws-batch-{os.getenv('AWS_BATCH_JOB_ID')}"
|
183
|
+
except Exception as e:
|
184
|
+
logger.debug(f"Error getting hostname: {e}")
|
185
|
+
|
186
|
+
return metadata
|
187
|
+
|
188
|
+
|
189
|
+
def get_instance_name() -> str:
|
190
|
+
"""
|
191
|
+
Get the name of the current cloud instance or VM.
|
192
|
+
|
193
|
+
This is a cross-platform replacement for the old get_vm_name() function.
|
194
|
+
Works with GCP, AWS, and other environments.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
A string containing the instance name, hostname, or ID
|
198
|
+
"""
|
199
|
+
metadata = get_instance_metadata()
|
200
|
+
return metadata["instance_name"]
|
201
|
+
|
202
|
+
|
203
|
+
def get_instance_type() -> str:
|
204
|
+
"""
|
205
|
+
Get the machine type/size of the current cloud instance.
|
206
|
+
|
207
|
+
This is a cross-platform replacement for the old get_vm_type() function.
|
208
|
+
Works with GCP (e.g., n1-standard-1), AWS (e.g., t2.micro), and other environments.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
A string representing the instance type/size, or 'unknown' if not available
|
212
|
+
"""
|
213
|
+
metadata = get_instance_metadata()
|
214
|
+
return metadata["instance_type"]
|
215
|
+
|
216
|
+
|
217
|
+
def get_cloud_provider() -> str:
|
218
|
+
"""
|
219
|
+
Get the cloud provider where this code is running.
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
A string identifying the cloud provider ('gcp', 'aws', 'azure', 'container', 'unknown')
|
223
|
+
"""
|
224
|
+
metadata = get_instance_metadata()
|
225
|
+
return metadata["provider"]
|
226
|
+
|
227
|
+
|
228
|
+
def move_to_repo_root() -> None:
|
229
|
+
"""Move to the repository root directory.
|
230
|
+
|
231
|
+
Determines the repository root through multiple methods (in order):
|
232
|
+
1. Direct specification via REPO_ROOT environment variable
|
233
|
+
2. Standard marker files (.git, setup.py, pyproject.toml)
|
234
|
+
3. Container standard paths (/app if it exists and contains expected files)
|
235
|
+
4. NAME_OF_THIS_REPO environment variable (for VM environments)
|
236
|
+
|
237
|
+
Raises:
|
238
|
+
OSError: If repository root cannot be determined
|
239
|
+
"""
|
240
|
+
try:
|
241
|
+
# Check if REPO_ROOT is directly specified
|
242
|
+
if "REPO_ROOT" in os.environ:
|
243
|
+
root_path = Path(os.environ["REPO_ROOT"])
|
244
|
+
if root_path.exists():
|
245
|
+
logger.info(f"Using environment-specified REPO_ROOT: {root_path}")
|
246
|
+
os.chdir(root_path)
|
247
|
+
return
|
248
|
+
|
249
|
+
# Try to find repo root by looking for standard files
|
250
|
+
current = Path.cwd()
|
251
|
+
while current != current.parent:
|
252
|
+
if any(
|
253
|
+
(current / marker).exists()
|
254
|
+
for marker in [".git", "setup.py", "pyproject.toml"]
|
255
|
+
):
|
256
|
+
logger.info(f"Found repository root at: {current}")
|
257
|
+
os.chdir(current)
|
258
|
+
return
|
259
|
+
current = current.parent
|
260
|
+
|
261
|
+
# Check for container standard paths
|
262
|
+
container_paths = ["/app", "/workspace", "/code"]
|
263
|
+
for path in container_paths:
|
264
|
+
container_root = Path(path)
|
265
|
+
if container_root.exists() and any(
|
266
|
+
(container_root / subdir).exists()
|
267
|
+
for subdir in ["src", "swarm", "dayhoff_tools"]
|
268
|
+
):
|
269
|
+
logger.info(f"Using container standard path: {container_root}")
|
270
|
+
os.chdir(container_root)
|
271
|
+
return
|
272
|
+
|
273
|
+
# Fallback to environment variable if available
|
274
|
+
try:
|
275
|
+
name_of_this_repo = os.environ["NAME_OF_THIS_REPO"]
|
276
|
+
root_path = Path(f"/workspaces/{name_of_this_repo}")
|
277
|
+
if root_path.exists():
|
278
|
+
logger.info(f"Using workspace repository path: {root_path}")
|
279
|
+
os.chdir(root_path)
|
280
|
+
return
|
281
|
+
except KeyError as e:
|
282
|
+
logger.warning(f"NAME_OF_THIS_REPO environment variable not set: {e}")
|
283
|
+
|
284
|
+
# If we're already at what looks like a valid root, just stay here
|
285
|
+
if any(
|
286
|
+
Path.cwd().joinpath(marker).exists()
|
287
|
+
for marker in ["setup.py", "pyproject.toml", "src", "dayhoff_tools"]
|
288
|
+
):
|
289
|
+
logger.info(f"Current directory appears to be a valid root: {Path.cwd()}")
|
290
|
+
return
|
291
|
+
|
292
|
+
raise OSError("Could not determine repository root")
|
293
|
+
except Exception as e:
|
294
|
+
logger.error(f"ERROR: Could not move to repository root: {e}")
|
295
|
+
raise
|
296
|
+
|
297
|
+
|
298
|
+
def upload_folder_to_gcs(local_folder: str, bucket, gcs_folder: str):
|
299
|
+
"""
|
300
|
+
Upload all files from a local folder to a GCS folder
|
301
|
+
|
302
|
+
Args:
|
303
|
+
local_folder: Path to the local folder
|
304
|
+
bucket: GCS bucket object
|
305
|
+
gcs_folder: Destination folder path in GCS
|
306
|
+
"""
|
307
|
+
local_path = Path(local_folder)
|
308
|
+
|
309
|
+
for local_file in local_path.glob("**/*"):
|
310
|
+
if local_file.is_file():
|
311
|
+
# Construct the GCS path
|
312
|
+
relative_path = local_file.relative_to(local_path)
|
313
|
+
gcs_path = f"{gcs_folder.rstrip('/')}/{relative_path}"
|
314
|
+
|
315
|
+
# Upload the file
|
316
|
+
blob = bucket.blob(gcs_path)
|
317
|
+
blob.upload_from_filename(str(local_file))
|
318
|
+
logger.info(f"Uploaded {local_file} to {gcs_path}")
|
319
|
+
|
320
|
+
|
321
|
+
@dataclass
|
322
|
+
class SystemStats:
|
323
|
+
"""Container for system statistics."""
|
324
|
+
|
325
|
+
timestamp: str
|
326
|
+
vm_id: str
|
327
|
+
cpu_usage: float
|
328
|
+
mem_usage: float
|
329
|
+
gpu_usage: Optional[float]
|
330
|
+
disk_usage: float
|
331
|
+
|
332
|
+
def __str__(self) -> str:
|
333
|
+
"""Format system stats for logging."""
|
334
|
+
return (
|
335
|
+
f"VM:{self.vm_id} "
|
336
|
+
f"CPU:{self.cpu_usage:.1f}% "
|
337
|
+
f"MEM:{self.mem_usage:.1f}% "
|
338
|
+
f"DISK:{self.disk_usage:.1f}% "
|
339
|
+
f"GPU:{self.gpu_usage if self.gpu_usage is not None else 'N/A'}%"
|
340
|
+
)
|
341
|
+
|
342
|
+
|
343
|
+
class SystemMonitor:
|
344
|
+
"""Monitor system resources and GPU availability."""
|
345
|
+
|
346
|
+
def __init__(self, fail_without_gpu: bool = False):
|
347
|
+
"""Initialize system monitor.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
fail_without_gpu: Whether to terminate if GPU becomes unavailable.
|
351
|
+
"""
|
352
|
+
self.fail_without_gpu = fail_without_gpu
|
353
|
+
self.should_run = True
|
354
|
+
self._thread: Optional[threading.Thread] = None
|
355
|
+
|
356
|
+
def start(self) -> None:
|
357
|
+
"""Start system monitoring in a background thread."""
|
358
|
+
if self._thread is not None:
|
359
|
+
return
|
360
|
+
|
361
|
+
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
362
|
+
self._thread.start()
|
363
|
+
|
364
|
+
def stop(self) -> None:
|
365
|
+
"""Stop system monitoring."""
|
366
|
+
self.should_run = False
|
367
|
+
if self._thread is not None:
|
368
|
+
self._thread.join()
|
369
|
+
self._thread = None
|
370
|
+
|
371
|
+
def _monitor_loop(self) -> None:
|
372
|
+
"""Main monitoring loop."""
|
373
|
+
while self.should_run:
|
374
|
+
try:
|
375
|
+
if self.fail_without_gpu and not is_gpu_available():
|
376
|
+
logger.error(
|
377
|
+
f"[{self._get_vm_id()}] GPU became unavailable. Terminating process."
|
378
|
+
)
|
379
|
+
self._kill_wandb()
|
380
|
+
# Force exit the entire process
|
381
|
+
os._exit(1)
|
382
|
+
|
383
|
+
stats = self._get_system_stats()
|
384
|
+
logger.info(str(stats))
|
385
|
+
|
386
|
+
time.sleep(300) # Check every 5 minutes
|
387
|
+
except Exception as e:
|
388
|
+
logger.error(f"Error in monitoring loop: {e}")
|
389
|
+
time.sleep(300) # Continue monitoring even if there's an error
|
390
|
+
|
391
|
+
def _kill_wandb(self) -> None:
|
392
|
+
"""Kill wandb agent process."""
|
393
|
+
try:
|
394
|
+
subprocess.run(["pkill", "-f", "wandb agent"], check=True)
|
395
|
+
except subprocess.CalledProcessError:
|
396
|
+
pass # Process might not exist
|
397
|
+
|
398
|
+
def _get_vm_id(self) -> str:
|
399
|
+
"""Get VM identifier from GCP metadata server."""
|
400
|
+
try:
|
401
|
+
response = requests.get(
|
402
|
+
"http://metadata.google.internal/computeMetadata/v1/instance/id",
|
403
|
+
headers={"Metadata-Flavor": "Google"},
|
404
|
+
timeout=5,
|
405
|
+
)
|
406
|
+
return response.text
|
407
|
+
except Exception:
|
408
|
+
return "unknown"
|
409
|
+
|
410
|
+
def _get_system_stats(self) -> SystemStats:
|
411
|
+
"""Collect current system statistics."""
|
412
|
+
# Get timestamp
|
413
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
414
|
+
|
415
|
+
# Get VM ID
|
416
|
+
vm_id = self._get_vm_id()
|
417
|
+
|
418
|
+
# Get CPU usage
|
419
|
+
cpu_cmd = "top -bn1 | grep 'Cpu(s)' | awk '{print $2 + $4}'"
|
420
|
+
cpu_usage = float(subprocess.check_output(cpu_cmd, shell=True).decode().strip())
|
421
|
+
|
422
|
+
# Get memory usage
|
423
|
+
mem_cmd = "free | grep Mem | awk '{print $3/$2 * 100.0}'"
|
424
|
+
mem_usage = float(subprocess.check_output(mem_cmd, shell=True).decode().strip())
|
425
|
+
|
426
|
+
# Get GPU usage if available
|
427
|
+
gpu_usage = None
|
428
|
+
if is_gpu_available():
|
429
|
+
try:
|
430
|
+
gpu_cmd = "nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits"
|
431
|
+
gpu_usage = float(
|
432
|
+
subprocess.check_output(gpu_cmd, shell=True).decode().strip()
|
433
|
+
)
|
434
|
+
except (subprocess.CalledProcessError, ValueError):
|
435
|
+
pass
|
436
|
+
|
437
|
+
# Add disk usage check
|
438
|
+
disk_cmd = "df -h / | awk 'NR==2 {print $5}' | sed 's/%//'"
|
439
|
+
disk_usage = float(
|
440
|
+
subprocess.check_output(disk_cmd, shell=True).decode().strip()
|
441
|
+
)
|
442
|
+
|
443
|
+
return SystemStats(
|
444
|
+
timestamp, vm_id, cpu_usage, mem_usage, gpu_usage, disk_usage
|
445
|
+
)
|
446
|
+
|
447
|
+
|
448
|
+
def is_gpu_available() -> bool:
|
449
|
+
"""Check if NVIDIA GPU is available.
|
450
|
+
|
451
|
+
Returns:
|
452
|
+
bool: True if GPU is available and functioning
|
453
|
+
"""
|
454
|
+
try:
|
455
|
+
subprocess.run(["nvidia-smi"], check=True, capture_output=True)
|
456
|
+
return True
|
457
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
458
|
+
return False
|
459
|
+
|
460
|
+
|
461
|
+
def get_required_env_var(name: str) -> str:
|
462
|
+
"""Get an environment variable or raise an error if it's not set"""
|
463
|
+
value = os.getenv(name)
|
464
|
+
if value is None:
|
465
|
+
raise ValueError(f"Required environment variable {name} is not set")
|
466
|
+
return value
|
467
|
+
|
468
|
+
|
469
|
+
def authenticate_gcp() -> None:
|
470
|
+
"""Authenticate with Google Cloud Platform.
|
471
|
+
|
472
|
+
Uses GOOGLE_APPLICATION_CREDENTIALS_BASE64 from environment.
|
473
|
+
Skips if no credentials are available.
|
474
|
+
"""
|
475
|
+
credentials_base64 = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_BASE64")
|
476
|
+
if not credentials_base64:
|
477
|
+
logger.info("No GCP credentials provided, skipping authentication")
|
478
|
+
return
|
479
|
+
|
480
|
+
logger.info("Authenticating with Google Cloud")
|
481
|
+
|
482
|
+
# Decode and save credentials
|
483
|
+
credentials = base64.b64decode(credentials_base64).decode("utf-8")
|
484
|
+
with open("workerbee.json", "w") as f:
|
485
|
+
f.write(credentials)
|
486
|
+
|
487
|
+
# Activate service account (suppress survey output)
|
488
|
+
subprocess.run(
|
489
|
+
["gcloud", "auth", "activate-service-account", "--key-file=workerbee.json"],
|
490
|
+
check=True,
|
491
|
+
capture_output=True,
|
492
|
+
)
|
493
|
+
|
494
|
+
# Configure project
|
495
|
+
subprocess.run(
|
496
|
+
["gcloud", "config", "set", "project", "enzyme-discovery"],
|
497
|
+
check=True,
|
498
|
+
capture_output=True,
|
499
|
+
)
|
500
|
+
logger.info("Set project to enzyme-discovery")
|
501
|
+
|
502
|
+
# Verify configuration
|
503
|
+
subprocess.run(
|
504
|
+
["gcloud", "config", "get-value", "project"],
|
505
|
+
check=True,
|
506
|
+
capture_output=True,
|
507
|
+
)
|
508
|
+
|
509
|
+
# Get and print active service account
|
510
|
+
result = subprocess.run(
|
511
|
+
["gcloud", "auth", "list", "--filter=status:ACTIVE", "--format=value(account)"],
|
512
|
+
check=True,
|
513
|
+
capture_output=True,
|
514
|
+
text=True,
|
515
|
+
)
|
516
|
+
logger.info(f"Activated service account credentials for: {result.stdout.strip()}")
|
517
|
+
|
518
|
+
# Set explicit credentials path if it exists
|
519
|
+
if os.path.exists("workerbee.json"):
|
520
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath("workerbee.json")
|
521
|
+
logger.info(
|
522
|
+
f"Set GOOGLE_APPLICATION_CREDENTIALS to {os.environ['GOOGLE_APPLICATION_CREDENTIALS']}"
|
523
|
+
)
|
524
|
+
|
525
|
+
|
526
|
+
def setup_dvc() -> None:
|
527
|
+
"""Initialize and configure DVC with GitHub remote.
|
528
|
+
|
529
|
+
Only runs if USE_DVC="true" in environment.
|
530
|
+
Requires GCP authentication to be already set up.
|
531
|
+
"""
|
532
|
+
if os.getenv("USE_DVC", "").lower() != "true":
|
533
|
+
logger.info("DVC not enabled, skipping setup")
|
534
|
+
return
|
535
|
+
|
536
|
+
if not os.path.exists("workerbee.json"):
|
537
|
+
logger.info("GCP credentials not found, skipping DVC setup")
|
538
|
+
return
|
539
|
+
|
540
|
+
logger.info("Initializing DVC")
|
541
|
+
|
542
|
+
# Initialize DVC without git
|
543
|
+
subprocess.run(["dvc", "init", "--no-scm"], check=True, capture_output=True)
|
544
|
+
|
545
|
+
# Get GitHub PAT from GCP secrets
|
546
|
+
warehouse_pat = subprocess.run(
|
547
|
+
[
|
548
|
+
"gcloud",
|
549
|
+
"secrets",
|
550
|
+
"versions",
|
551
|
+
"access",
|
552
|
+
"latest",
|
553
|
+
"--secret=warehouse-read-pat",
|
554
|
+
],
|
555
|
+
check=True,
|
556
|
+
capture_output=True,
|
557
|
+
text=True,
|
558
|
+
).stdout.strip()
|
559
|
+
|
560
|
+
# Configure DVC remote
|
561
|
+
subprocess.run(
|
562
|
+
[
|
563
|
+
"dvc",
|
564
|
+
"remote",
|
565
|
+
"add",
|
566
|
+
"--default",
|
567
|
+
"warehouse",
|
568
|
+
"https://github.com/dayhofflabs/warehouse.git",
|
569
|
+
],
|
570
|
+
check=True,
|
571
|
+
)
|
572
|
+
subprocess.run(
|
573
|
+
["dvc", "remote", "modify", "warehouse", "auth", "basic"], check=True
|
574
|
+
)
|
575
|
+
subprocess.run(
|
576
|
+
[
|
577
|
+
"dvc",
|
578
|
+
"remote",
|
579
|
+
"modify",
|
580
|
+
"warehouse",
|
581
|
+
"--local",
|
582
|
+
"user",
|
583
|
+
"DanielMartinAlarcon",
|
584
|
+
],
|
585
|
+
check=True,
|
586
|
+
)
|
587
|
+
subprocess.run(
|
588
|
+
["dvc", "remote", "modify", "warehouse", "--local", "password", warehouse_pat],
|
589
|
+
check=True,
|
590
|
+
)
|
591
|
+
|
592
|
+
# Setup GitHub HTTPS access
|
593
|
+
git_config_path = Path.home() / ".git-credentials"
|
594
|
+
git_config_path.write_text(
|
595
|
+
f"https://DanielMartinAlarcon:{warehouse_pat}@github.com\n"
|
596
|
+
)
|
597
|
+
subprocess.run(
|
598
|
+
["git", "config", "--global", "credential.helper", "store"], check=True
|
599
|
+
)
|
600
|
+
|
601
|
+
|
602
|
+
def setup_rxnfp() -> None:
|
603
|
+
"""Clone rxnfp library.
|
604
|
+
|
605
|
+
Only runs if USE_RXNFP="true" in environment.
|
606
|
+
"""
|
607
|
+
if os.getenv("USE_RXNFP", "").lower() != "true":
|
608
|
+
logger.info("RXNFP not enabled, skipping setup")
|
609
|
+
return
|
610
|
+
|
611
|
+
logger.info("Cloning rxnfp library...")
|
612
|
+
subprocess.run(
|
613
|
+
["git", "clone", "https://github.com/rxn4chemistry/rxnfp.git"], check=True
|
614
|
+
)
|
615
|
+
|
616
|
+
|
617
|
+
def docker_login(registry: str, username: str, password: str) -> None:
|
618
|
+
"""Login to a Docker registry using provided credentials.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
registry: Registry URI to login to
|
622
|
+
username: Username for registry authentication
|
623
|
+
password: Password for registry authentication
|
624
|
+
|
625
|
+
Raises:
|
626
|
+
subprocess.CalledProcessError: If Docker login fails
|
627
|
+
"""
|
628
|
+
# Create .docker directory if it doesn't exist
|
629
|
+
docker_config_dir = os.path.expanduser("~/.docker")
|
630
|
+
os.makedirs(docker_config_dir, exist_ok=True)
|
631
|
+
|
632
|
+
# Write a minimal config file that disables credential helpers
|
633
|
+
with open(os.path.join(docker_config_dir, "config.json"), "w") as f:
|
634
|
+
json.dump({"auths": {}, "credsStore": ""}, f)
|
635
|
+
|
636
|
+
# Login to Docker using the credentials
|
637
|
+
login_process = subprocess.run(
|
638
|
+
[
|
639
|
+
"docker",
|
640
|
+
"login",
|
641
|
+
"--username",
|
642
|
+
username,
|
643
|
+
"--password-stdin",
|
644
|
+
registry,
|
645
|
+
],
|
646
|
+
input=password.encode(),
|
647
|
+
capture_output=True,
|
648
|
+
check=True,
|
649
|
+
)
|
650
|
+
|
651
|
+
if login_process.stderr:
|
652
|
+
logger.warning(f"Docker login warning: {login_process.stderr.decode()}")
|
653
|
+
|
654
|
+
|
655
|
+
def get_container_env_vars(config: dict) -> dict:
|
656
|
+
"""Get all environment variables for the container.
|
657
|
+
|
658
|
+
This function collects environment variables from multiple sources:
|
659
|
+
1. Feature flags from config.features
|
660
|
+
2. Feature-specific variables (e.g. WANDB_API_KEY, GCP credentials)
|
661
|
+
3. Additional variables specified in config.env_vars
|
662
|
+
|
663
|
+
Args:
|
664
|
+
config: Configuration dictionary
|
665
|
+
|
666
|
+
Returns:
|
667
|
+
Dictionary of environment variables
|
668
|
+
|
669
|
+
Raises:
|
670
|
+
FileNotFoundError: If GCP credentials file is not found when use_gcp_auth is True
|
671
|
+
"""
|
672
|
+
env_vars = {}
|
673
|
+
|
674
|
+
# Process features section to set feature flags
|
675
|
+
features = config.get("features")
|
676
|
+
# Handle case where features key exists but has no value (None)
|
677
|
+
if features is None:
|
678
|
+
features = []
|
679
|
+
|
680
|
+
# Handle boolean features that don't have associated credentials
|
681
|
+
bool_features = {
|
682
|
+
"use_dvc": "USE_DVC",
|
683
|
+
"use_rxnfp": "USE_RXNFP",
|
684
|
+
"fail_without_gpu": "FAIL_WITHOUT_GPU",
|
685
|
+
}
|
686
|
+
|
687
|
+
# Initialize all features to "false"
|
688
|
+
for env_var in bool_features.values():
|
689
|
+
env_vars[env_var] = "false"
|
690
|
+
|
691
|
+
# Set enabled features to "true"
|
692
|
+
for feature in features:
|
693
|
+
if isinstance(feature, str) and feature in bool_features:
|
694
|
+
env_vars[bool_features[feature]] = "true"
|
695
|
+
elif isinstance(feature, dict):
|
696
|
+
# Handle job_command if present
|
697
|
+
if "job_command" in feature:
|
698
|
+
env_vars["JOB_COMMAND"] = feature["job_command"]
|
699
|
+
|
700
|
+
# Add feature-specific variables when their features are enabled
|
701
|
+
features_set = {f if isinstance(f, str) else next(iter(f)) for f in features}
|
702
|
+
|
703
|
+
if "use_wandb" in features_set:
|
704
|
+
wandb_key = os.getenv("WANDB_API_KEY", "WANDB_API_KEY_NOT_SET")
|
705
|
+
env_vars["WANDB_API_KEY"] = wandb_key
|
706
|
+
logger.info(f"Loading WANDB_API_KEY into container: {wandb_key}")
|
707
|
+
|
708
|
+
if "use_gcp_auth" in features_set:
|
709
|
+
key_file = ".config/workerbee.json"
|
710
|
+
if not os.path.exists(key_file):
|
711
|
+
raise FileNotFoundError(
|
712
|
+
f"GCP credentials file not found: {key_file}. Required when use_gcp_auth=True"
|
713
|
+
)
|
714
|
+
|
715
|
+
# base64-encode the workerbee service account key file
|
716
|
+
creds = subprocess.run(
|
717
|
+
f"base64 {key_file} | tr -d '\n'",
|
718
|
+
shell=True,
|
719
|
+
check=True,
|
720
|
+
capture_output=True,
|
721
|
+
text=True,
|
722
|
+
).stdout
|
723
|
+
env_vars["GOOGLE_APPLICATION_CREDENTIALS_BASE64"] = creds
|
724
|
+
logger.info(f"Loaded GCP credentials into container: {key_file}")
|
725
|
+
|
726
|
+
# Get environment variables from config
|
727
|
+
config_env_vars = config.get("env_vars")
|
728
|
+
# Handle case where env_vars key exists but has no value (None)
|
729
|
+
if config_env_vars is None:
|
730
|
+
config_env_vars = {}
|
731
|
+
|
732
|
+
# Add them to the envvars made here
|
733
|
+
env_vars.update(config_env_vars)
|
734
|
+
|
735
|
+
return env_vars
|
736
|
+
|
737
|
+
|
738
|
+
def determine_worker_count(logger=None) -> int:
|
739
|
+
"""Determine optimal worker count based on CPU cores and environment.
|
740
|
+
|
741
|
+
Uses different strategies depending on the detected environment:
|
742
|
+
- For batch environments (GCP Batch, AWS Batch), uses (CPU_COUNT - 1)
|
743
|
+
- For development environments, uses (CPU_COUNT // 2)
|
744
|
+
- Always ensures at least 1 worker is returned
|
745
|
+
|
746
|
+
Args:
|
747
|
+
logger: Optional logger to output the decision (defaults to None)
|
748
|
+
|
749
|
+
Returns:
|
750
|
+
int: Recommended number of worker processes
|
751
|
+
"""
|
752
|
+
import multiprocessing
|
753
|
+
import os
|
754
|
+
|
755
|
+
# Detect CPU cores
|
756
|
+
cpu_count = multiprocessing.cpu_count()
|
757
|
+
|
758
|
+
# Detect if running in a cloud batch environment:
|
759
|
+
# - GCP Batch sets BATCH_TASK_INDEX
|
760
|
+
# - AWS Batch sets AWS_BATCH_JOB_ID
|
761
|
+
is_batch_env = (
|
762
|
+
os.getenv("BATCH_TASK_INDEX") is not None # GCP Batch
|
763
|
+
or os.getenv("AWS_BATCH_JOB_ID") is not None # AWS Batch
|
764
|
+
)
|
765
|
+
|
766
|
+
if is_batch_env:
|
767
|
+
# In batch environment, use most cores
|
768
|
+
num_workers = max(1, cpu_count - 1)
|
769
|
+
if logger:
|
770
|
+
logger.info(
|
771
|
+
f"Batch environment detected. Using {num_workers} of {cpu_count} cores"
|
772
|
+
)
|
773
|
+
else:
|
774
|
+
# In dev environment, be more conservative
|
775
|
+
num_workers = max(1, cpu_count // 2)
|
776
|
+
if logger:
|
777
|
+
logger.info(
|
778
|
+
f"Development environment detected. Using {num_workers} of {cpu_count} cores"
|
779
|
+
)
|
780
|
+
|
781
|
+
return num_workers
|