dayhoff-tools 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,591 @@
1
+ """
2
+ Gets the data that the current worker needs.
3
+ Interacts with Firebase (where the data is listed)
4
+ and GCS (where it is hosted).
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ import signal
12
+ import threading
13
+ import time
14
+ from datetime import datetime
15
+ from typing import List, Set, Tuple
16
+ from zoneinfo import ZoneInfo
17
+
18
+ import firebase_admin
19
+ import requests
20
+ from dayhoff_tools.deployment.deploy_utils import get_instance_name, get_instance_type
21
+ from dayhoff_tools.deployment.processors import Processor
22
+ from firebase_admin import firestore
23
+ from google.cloud import storage
24
+ from google.cloud.firestore import transactional
25
+ from google.cloud.firestore_v1.base_query import FieldFilter
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # AWS IMDSv2 endpoints
30
+ AWS_IMDS_TOKEN_URL = "http://169.254.169.254/latest/api/token"
31
+ AWS_INSTANCE_ACTION_URL = "http://169.254.169.254/latest/meta-data/spot/instance-action"
32
+
33
+ # GCP preemption endpoints
34
+ GCP_PREEMPTION_URL = (
35
+ "http://metadata.google.internal/computeMetadata/v1/instance/preempted"
36
+ )
37
+ GCP_PREEMPTION_WARNING_URL = (
38
+ "http://metadata.google.internal/computeMetadata/v1/instance/maintenance-event"
39
+ )
40
+
41
+ # Shutdown signal received flag (shared between threads)
42
+ _shutdown_requested = threading.Event()
43
+
44
+
45
+ def initialize_firebase():
46
+ try:
47
+ # Attempts to get the default app, if it's already initialized.
48
+ firebase_admin.get_app()
49
+ except ValueError:
50
+ # If the default app has not been initialized, then initialize it.
51
+ firebase_admin.initialize_app()
52
+
53
+
54
+ def assert_sequential_files_in_gcs_folder(
55
+ bucket_name: str, folder_name: str, expected_number: int | None = None
56
+ ) -> Tuple[List[str], Set[str]]:
57
+ """
58
+ Check a GCS folder for sequential files that have names ending with '_x' where x is an int,
59
+ regardless of file extension.
60
+ Determine if any files are missing or duplicate.
61
+ The number of files is inferred from the name of the last one, unless provided.
62
+
63
+ Args:
64
+ bucket_name (str): Name of the GCS bucket
65
+ folder_name (str): Name of the folder within the bucket
66
+ expected_number (int | None): The expected number of files in the folder.
67
+
68
+ Returns:
69
+ Tuple[List[str], Set[str]]: Sorted list of missing files, set of duplicate files.
70
+ """
71
+ # Initialize the GCS client
72
+ client = storage.Client()
73
+ bucket = client.get_bucket(bucket_name)
74
+
75
+ # Ensure folder_name ends with a slash
76
+ folder_name = folder_name.rstrip("/") + "/"
77
+
78
+ # List all blobs in the folder
79
+ blobs = list(bucket.list_blobs(prefix=folder_name))
80
+
81
+ # Extract file numbers and names
82
+ file_numbers = {}
83
+ max_number = 0
84
+ for blob in blobs:
85
+ # Handle both actual Blob objects and Mock objects
86
+ if isinstance(blob, storage.Blob):
87
+ blob_name = blob.name
88
+ else:
89
+ blob_name = blob.name if isinstance(blob.name, str) else str(blob.name)
90
+
91
+ # Ensure blob_name is a string before using in re.search
92
+ if blob_name is None:
93
+ continue
94
+
95
+ # Make sure blob_name is a string
96
+ blob_name_str = str(blob_name)
97
+ match = re.search(r"_(\d+)(?:\.[^.]+)?$", blob_name_str)
98
+ if match:
99
+ number = int(match.group(1))
100
+ max_number = max(max_number, number)
101
+ if number in file_numbers:
102
+ file_numbers[number].append(blob_name_str)
103
+ else:
104
+ file_numbers[number] = [blob_name_str]
105
+
106
+ # Determine the expected number of files
107
+ if expected_number is None:
108
+ expected_number = max_number
109
+
110
+ # Find missing and duplicate files
111
+ missing_files = []
112
+ duplicate_files = set()
113
+
114
+ for i in range(1, expected_number + 1):
115
+ if i not in file_numbers:
116
+ missing_files.append(f"{folder_name}file_{i}")
117
+ elif len(file_numbers[i]) > 1:
118
+ duplicate_files.update(file_numbers[i])
119
+
120
+ print(f"Files expected: {expected_number}")
121
+ print(f"Files missing: {missing_files}")
122
+ print(f"Files duplicated: {duplicate_files}")
123
+
124
+ return missing_files, duplicate_files
125
+
126
+
127
+ def publish_cards(
128
+ names: List[str],
129
+ firestore_collection: str,
130
+ ):
131
+ """Publish cards to Firebase. Expects a list of filenames (not full paths),
132
+ which will each be published as a new document in the collection."""
133
+
134
+ initialize_firebase()
135
+ collection = firestore.client().collection(firestore_collection)
136
+
137
+ for name in names:
138
+ collection.document().set(
139
+ {
140
+ "status": "available",
141
+ "packet_filename": name,
142
+ "created": datetime.now(ZoneInfo("America/Los_Angeles")),
143
+ }
144
+ )
145
+ print(f"Creating card {name}")
146
+
147
+ print(f"Uploaded {len(names)} cards.")
148
+
149
+
150
+ @transactional
151
+ def _assign_card_in_transaction(
152
+ transaction,
153
+ query,
154
+ collection,
155
+ vm_name: str,
156
+ vm_type: str,
157
+ batch_worker: str,
158
+ ):
159
+ """Draw an `available` card and update it to `assigned`.
160
+ Do so in an atomic transaction.
161
+
162
+ This function can't be a class function because the decorator
163
+ expects a transaction (not self) as the first argument."""
164
+
165
+ # Exit the function if no documents are found
166
+ query_output = list(query.stream(transaction=transaction))
167
+ if not query_output:
168
+ logger.error("No cards are available.")
169
+ return None
170
+
171
+ card_id = query_output[0].id
172
+ card_reference = collection.document(card_id)
173
+ packet_filename = card_reference.get().get("packet_filename")
174
+
175
+ # Update the document within the transaction
176
+ now = datetime.now(ZoneInfo("America/Los_Angeles"))
177
+ transaction.update(
178
+ card_reference,
179
+ {
180
+ "status": "assigned",
181
+ "packet_filename": packet_filename,
182
+ "vm_name": vm_name,
183
+ "vm_type": vm_type,
184
+ "batch_index": batch_worker,
185
+ "first_update": now,
186
+ "last_update": now,
187
+ },
188
+ )
189
+ return card_reference
190
+
191
+
192
+ class FirestoreService:
193
+ """Handles all Firestore database operations"""
194
+
195
+ def __init__(self, collection_name, tachycardic=False):
196
+ initialize_firebase()
197
+ self.db = firestore.client()
198
+ self.collection = self.db.collection(collection_name)
199
+ self.vm_name = get_instance_name()
200
+ self.vm_type = get_instance_type()
201
+ self.batch_worker = os.getenv("BATCH_TASK_INDEX", "Not processed in Batch")
202
+ self.current_card = None
203
+ self.heartstopper = threading.Event()
204
+ self.tachycardic = tachycardic
205
+
206
+ def draw_card(self) -> str | None:
207
+ """Draw an `available` card from Firestore and update it to `assigned`.
208
+ Return the packet_filename contained in that card."""
209
+ # Start a transaction and use the wrapper function
210
+ transaction = self.db.transaction()
211
+ query = self.collection.where(
212
+ filter=FieldFilter("status", "==", "available")
213
+ ).limit(1)
214
+ self.current_card = _assign_card_in_transaction(
215
+ transaction,
216
+ query,
217
+ self.collection,
218
+ self.vm_name,
219
+ self.vm_type,
220
+ self.batch_worker,
221
+ )
222
+ if self.current_card is None:
223
+ return None
224
+ else:
225
+ packet_filename = str(self.current_card.get().get("packet_filename"))
226
+ logger.info("Card drawn: %s", packet_filename)
227
+ return packet_filename
228
+
229
+ def start_heart(self):
230
+ """Start the heartbeat thread."""
231
+ logger.info("Heartbeat started.")
232
+ self.heartstopper = threading.Event()
233
+ self.heartbeat_thread = threading.Thread(target=self._send_heartbeat)
234
+ self.heartbeat_thread.start()
235
+
236
+ def close_card(self):
237
+ """Update the current card to `processed`"""
238
+ if not self.current_card:
239
+ logger.info("No current card to close")
240
+ return
241
+
242
+ self.current_card.update(
243
+ {
244
+ "status": "processed",
245
+ "last_update": datetime.now(ZoneInfo("America/Los_Angeles")),
246
+ }
247
+ )
248
+
249
+ def stop_heart(self):
250
+ """Stop the heartbeat thread."""
251
+ if self.heartstopper is None:
252
+ logger.info("No current heartbeat to stop")
253
+ return
254
+ self.heartstopper.set()
255
+ self.heartbeat_thread.join()
256
+
257
+ def record_failure(self, failure: str):
258
+ """Record the cause of death for the current card."""
259
+ if not self.current_card:
260
+ logger.info("No current card for which to record failure")
261
+ return
262
+
263
+ self.current_card.update( # type: ignore
264
+ {
265
+ "status": "failed",
266
+ "last_update": datetime.now(ZoneInfo("America/Los_Angeles")),
267
+ "failure": failure,
268
+ }
269
+ )
270
+
271
+ def _send_heartbeat(self):
272
+ """Periodically update the 'last_update' field for the current card.
273
+ Listen for a heartstopper signal."""
274
+
275
+ if self.tachycardic:
276
+ update_interval = 2
277
+ sleep_interval = 1
278
+ else:
279
+ update_interval = 5 * 60 # 5 minutes expressed in seconds
280
+ sleep_interval = 5 # Check the heartstopper every 5 seconds
281
+
282
+ counter = 0
283
+ while (
284
+ self.current_card and self.heartstopper and not self.heartstopper.is_set()
285
+ ):
286
+ if counter >= update_interval:
287
+ try:
288
+ # Perform the database update
289
+ self.current_card.update(
290
+ {
291
+ "last_update": datetime.now(
292
+ ZoneInfo("America/Los_Angeles")
293
+ ),
294
+ }
295
+ )
296
+ logger.debug("Heartbeat updated.")
297
+ counter = 0 # Reset counter after update
298
+ except Exception as e:
299
+ logger.error(f"Error updating heartbeat: {e}")
300
+ break # Exit the loself.firestore if there's an error
301
+
302
+ time.sleep(sleep_interval) # Sleep for the short interval
303
+
304
+ # Increment counter by the number of seconds slept
305
+ counter += sleep_interval
306
+
307
+ logger.info("Heartbeat stopped.")
308
+
309
+ def release_card(self):
310
+ """Release the current card back to 'available' state.
311
+ Used when a spot instance is being terminated to allow another worker to pick it up.
312
+ """
313
+ if not self.current_card:
314
+ logger.info("No current card to release")
315
+ return
316
+
317
+ packet_filename = self.current_card.get().get("packet_filename")
318
+
319
+ # Create a new available card
320
+ self.db.collection(self.collection.id).document().set(
321
+ {
322
+ "status": "available",
323
+ "packet_filename": packet_filename,
324
+ "created": datetime.now(ZoneInfo("America/Los_Angeles")),
325
+ "released_from": self.vm_name,
326
+ }
327
+ )
328
+
329
+ # Update the current card to show it was released due to termination
330
+ self.current_card.update(
331
+ {
332
+ "status": "released",
333
+ "last_update": datetime.now(ZoneInfo("America/Los_Angeles")),
334
+ "release_reason": "spot_termination",
335
+ }
336
+ )
337
+
338
+ logger.info(
339
+ f"Card {packet_filename} released back to available pool due to spot termination"
340
+ )
341
+ self.current_card = None
342
+
343
+
344
+ class GCSService:
345
+ """Handles all GCS operations"""
346
+
347
+ def __init__(
348
+ self,
349
+ bucket_name: str,
350
+ gcs_input_folder: str,
351
+ gcs_output_folder: str,
352
+ ):
353
+ storage_client = storage.Client()
354
+ self.bucket = storage_client.bucket(bucket_name)
355
+ self.gcs_input_folder = gcs_input_folder
356
+ self.gcs_output_folder = gcs_output_folder
357
+
358
+ def download_data(self, packet_filename: str):
359
+ """Download the data from GCS."""
360
+ logger.info(f"Downloading {packet_filename}")
361
+
362
+ # Download packet with original name
363
+ source_blob = self.bucket.blob(f"{self.gcs_input_folder}/{packet_filename}")
364
+ source_blob.download_to_filename(packet_filename)
365
+
366
+ def upload_results(self, path_to_upload: str):
367
+ """Upload the results back to GCS.
368
+
369
+ Args:
370
+ path_to_upload: Path to the file or folder to upload
371
+ """
372
+ logger.info(f"Uploading {path_to_upload}")
373
+
374
+ if os.path.isfile(path_to_upload):
375
+ # Upload single file to GCS
376
+ gcs_path = f"{self.gcs_output_folder}/{path_to_upload}"
377
+ blob = self.bucket.blob(gcs_path)
378
+ blob.upload_from_filename(path_to_upload)
379
+ logger.info(f"Uploaded file {path_to_upload} to {gcs_path}")
380
+ elif os.path.isdir(path_to_upload):
381
+ # Upload entire folder to GCS
382
+ gcs_folder = f"{self.gcs_output_folder}/{os.path.basename(path_to_upload)}"
383
+ from dayhoff_tools.deployment.deploy_utils import upload_folder_to_gcs
384
+
385
+ upload_folder_to_gcs(path_to_upload, self.bucket, gcs_folder)
386
+ logger.info(f"Uploaded folder {path_to_upload} to {gcs_folder}")
387
+ else:
388
+ raise FileNotFoundError(
389
+ f"The path {path_to_upload} does not exist or is not accessible"
390
+ )
391
+
392
+
393
+ class Operator:
394
+ """Communicates with the Firestore and GCS services to constantly get
395
+ data, then assigns it to the processor. Also manages errors and
396
+ sending a constant heartbeat to Firestore."""
397
+
398
+ def __init__(
399
+ self,
400
+ firestore_service: FirestoreService,
401
+ gcs_service: GCSService,
402
+ processor: Processor,
403
+ ):
404
+ self.firestore = firestore_service
405
+ self.gcs = gcs_service
406
+ self.processor = processor
407
+ self.termination_checker = None
408
+ self._setup_shutdown_handlers()
409
+
410
+ def _setup_shutdown_handlers(self):
411
+ """Set up handlers for spot instance termination and system shutdown signals."""
412
+ # Register signal handlers for graceful shutdown
413
+ signal.signal(signal.SIGTERM, self._handle_shutdown_signal)
414
+ signal.signal(signal.SIGINT, self._handle_shutdown_signal)
415
+
416
+ # Start a thread to check for AWS spot instance termination
417
+ self.termination_checker = threading.Thread(target=self._check_for_termination)
418
+ self.termination_checker.daemon = True
419
+ self.termination_checker.start()
420
+
421
+ logger.info("Spot instance termination handlers initialized")
422
+
423
+ def _handle_shutdown_signal(self, signum, frame):
424
+ """Handle shutdown signals by releasing card and stopping gracefully."""
425
+ logger.warning(
426
+ f"Received shutdown signal {signum}. Preparing for graceful shutdown."
427
+ )
428
+ _shutdown_requested.set()
429
+
430
+ # Release the current card if there is one
431
+ if self.firestore.current_card:
432
+ logger.info("Releasing card due to shutdown signal")
433
+ self.firestore.release_card()
434
+ self.firestore.stop_heart()
435
+
436
+ def _check_for_termination(self):
437
+ """Periodically check for AWS and GCP instance termination notices.
438
+
439
+ - For AWS spot instances, uses IMDSv2 to check instance-action metadata
440
+ - For GCP preemptible VMs, checks both maintenance-event and preempted metadata
441
+ """
442
+ while not _shutdown_requested.is_set():
443
+ try:
444
+ # Check AWS spot termination using IMDSv2 (token-based auth)
445
+ token_response = requests.put(
446
+ AWS_IMDS_TOKEN_URL,
447
+ headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
448
+ timeout=2,
449
+ )
450
+
451
+ if token_response.status_code == 200:
452
+ # Use token to check instance-action
453
+ token = token_response.text
454
+ action_response = requests.get(
455
+ AWS_INSTANCE_ACTION_URL,
456
+ headers={"X-aws-ec2-metadata-token": token},
457
+ timeout=2,
458
+ )
459
+
460
+ # If 200 response and contains action, termination is imminent
461
+ if action_response.status_code == 200:
462
+ try:
463
+ action_data = json.loads(action_response.text)
464
+ action = action_data.get("action")
465
+ action_time = action_data.get("time")
466
+
467
+ logger.warning(
468
+ f"AWS Spot instance interruption notice received: "
469
+ f"action={action}, time={action_time}"
470
+ )
471
+
472
+ # Release the current card if there is one
473
+ if self.firestore.current_card:
474
+ logger.info(
475
+ f"Releasing card due to spot termination notice (action: {action})"
476
+ )
477
+ self.firestore.release_card()
478
+ self.firestore.stop_heart()
479
+
480
+ _shutdown_requested.set()
481
+ break
482
+ except json.JSONDecodeError:
483
+ logger.warning(
484
+ f"Failed to parse instance-action response: {action_response.text}"
485
+ )
486
+ except requests.RequestException:
487
+ # This is expected on GCP or non-spot instances
488
+ pass
489
+
490
+ try:
491
+ # Check GCP preemption warning (gives more lead time)
492
+ gcp_warning_response = requests.get(
493
+ GCP_PREEMPTION_WARNING_URL,
494
+ headers={"Metadata-Flavor": "Google"},
495
+ timeout=2,
496
+ )
497
+
498
+ # If we get a response containing TERMINATE, preemption is coming soon
499
+ if (
500
+ gcp_warning_response.status_code == 200
501
+ and "TERMINATE" in gcp_warning_response.text
502
+ ):
503
+ logger.warning(
504
+ f"GCP preemptible VM maintenance event detected: {gcp_warning_response.text}"
505
+ )
506
+
507
+ # Release the current card if there is one
508
+ if self.firestore.current_card:
509
+ logger.info("Releasing card due to GCP preemption warning")
510
+ self.firestore.release_card()
511
+ self.firestore.stop_heart()
512
+
513
+ _shutdown_requested.set()
514
+ break
515
+
516
+ # Check GCP actual preemption (fallback check)
517
+ gcp_response = requests.get(
518
+ GCP_PREEMPTION_URL, headers={"Metadata-Flavor": "Google"}, timeout=2
519
+ )
520
+
521
+ # If we get a 200 response with "TRUE", termination is already happening
522
+ if (
523
+ gcp_response.status_code == 200
524
+ and gcp_response.text.upper() == "TRUE"
525
+ ):
526
+ logger.warning("GCP preemptible VM is being terminated now")
527
+
528
+ # Release the current card if there is one
529
+ if self.firestore.current_card:
530
+ logger.info("Releasing card due to GCP preemption in progress")
531
+ self.firestore.release_card()
532
+ self.firestore.stop_heart()
533
+
534
+ _shutdown_requested.set()
535
+ break
536
+ except requests.RequestException:
537
+ # This is expected on AWS or non-preemptible instances
538
+ pass
539
+
540
+ # Check every 5 seconds as recommended by AWS
541
+ time.sleep(5)
542
+
543
+ def run(self):
544
+ while not _shutdown_requested.is_set():
545
+ try:
546
+ packet_filename = self.firestore.draw_card()
547
+ if packet_filename is None:
548
+ break
549
+ self.firestore.start_heart()
550
+ self.gcs.download_data(packet_filename)
551
+
552
+ # Check if termination was requested while downloading
553
+ if _shutdown_requested.is_set():
554
+ logger.warning("Shutdown requested during download, releasing card")
555
+ self.firestore.release_card()
556
+ self.firestore.stop_heart()
557
+ break
558
+
559
+ output_path = self.processor.run(input_file=packet_filename)
560
+
561
+ # Check if termination was requested during processing
562
+ if _shutdown_requested.is_set():
563
+ logger.warning(
564
+ "Shutdown requested during processing, releasing card"
565
+ )
566
+ self.firestore.release_card()
567
+ self.firestore.stop_heart()
568
+ break
569
+
570
+ self.gcs.upload_results(path_to_upload=output_path)
571
+ self.firestore.close_card()
572
+ self.firestore.stop_heart()
573
+
574
+ # Clean up the output based on whether it's a file or directory
575
+ if os.path.isfile(output_path):
576
+ os.remove(output_path)
577
+ elif os.path.isdir(output_path):
578
+ import shutil
579
+
580
+ shutil.rmtree(output_path)
581
+ except Exception as e:
582
+ logger.error(e)
583
+ self.firestore.record_failure(str(e))
584
+ self.firestore.stop_heart()
585
+ except KeyboardInterrupt:
586
+ logger.error("KeyboardInterrupt")
587
+ self.firestore.record_failure("KeyboardInterrupt")
588
+ self.firestore.stop_heart()
589
+ break
590
+
591
+ logger.info("Operator is done.")