gpu-dev 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpu_dev_cli/disks.py ADDED
@@ -0,0 +1,523 @@
1
+ """
2
+ Disk management for GPU Dev CLI
3
+ Handles named persistent disks with snapshot-first workflow
4
+ """
5
+
6
+ import boto3
7
+ import re
8
+ from decimal import Decimal
9
+ from typing import List, Dict, Optional, Tuple
10
+ from datetime import datetime, timedelta, timezone
11
+ from .config import Config
12
+
13
+
14
+ def get_ec2_client(config: Config):
15
+ """Get boto3 EC2 client"""
16
+ return config.session.client('ec2', region_name=config.aws_region)
17
+
18
+
19
+ def get_s3_client(config: Config):
20
+ """Get boto3 S3 client"""
21
+ return config.session.client('s3', region_name=config.aws_region)
22
+
23
+
24
+ def get_dynamodb_resource(config: Config):
25
+ """Get boto3 DynamoDB resource"""
26
+ return config.session.resource('dynamodb', region_name=config.aws_region)
27
+
28
+
29
+ def get_disk_in_use_status(disk_name: str, user_id: str, config: Config) -> Tuple[bool, Optional[str]]:
30
+ """
31
+ Check if a disk is currently in use by any reservation.
32
+ Returns (is_in_use, reservation_id)
33
+
34
+ We check TWO sources to handle all race conditions:
35
+ 1. Disks table `in_use` field - set by Lambda when disk is attached, cleared after cleanup
36
+ 2. Reservations table - for in-progress reservations that haven't started disk setup yet
37
+
38
+ This prevents race conditions during both spinning up (queued/pending) and
39
+ winding down (cancelled but cleanup still running).
40
+ """
41
+ dynamodb = get_dynamodb_resource(config)
42
+
43
+ try:
44
+ # First check: disks table in_use field (most reliable for cleanup in progress)
45
+ disks_table_name = config.disks_table if hasattr(config, 'disks_table') else f"{config.queue_name.rsplit('-', 1)[0]}-disks"
46
+ disks_table = dynamodb.Table(disks_table_name)
47
+
48
+ try:
49
+ disk_response = disks_table.get_item(
50
+ Key={'user_id': user_id, 'disk_name': disk_name}
51
+ )
52
+ disk_item = disk_response.get('Item', {})
53
+
54
+ # Check if disk is marked as in_use in the disks table
55
+ if disk_item.get('in_use', False):
56
+ attached_reservation = disk_item.get('attached_to_reservation')
57
+ return True, attached_reservation
58
+ except Exception as disk_check_error:
59
+ # If disks table check fails, fall through to reservation check
60
+ pass
61
+
62
+ # Second check: reservations table for in-progress reservations
63
+ reservations_table = dynamodb.Table(config.reservations_table)
64
+
65
+ # Use UserIndex for efficient query (instead of scan with pagination)
66
+ # Check ALL in-progress statuses to prevent race conditions
67
+ response = reservations_table.query(
68
+ IndexName="UserIndex",
69
+ KeyConditionExpression="user_id = :user_id",
70
+ FilterExpression="disk_name = :disk_name AND #status IN (:active, :preparing, :queued, :pending)",
71
+ ExpressionAttributeNames={"#status": "status"},
72
+ ExpressionAttributeValues={
73
+ ":user_id": user_id,
74
+ ":disk_name": disk_name,
75
+ ":active": "active",
76
+ ":preparing": "preparing",
77
+ ":queued": "queued",
78
+ ":pending": "pending"
79
+ }
80
+ )
81
+
82
+ # Handle pagination
83
+ items = response.get("Items", [])
84
+ while "LastEvaluatedKey" in response:
85
+ response = reservations_table.query(
86
+ IndexName="UserIndex",
87
+ KeyConditionExpression="user_id = :user_id",
88
+ FilterExpression="disk_name = :disk_name AND #status IN (:active, :preparing, :queued, :pending)",
89
+ ExpressionAttributeNames={"#status": "status"},
90
+ ExpressionAttributeValues={
91
+ ":user_id": user_id,
92
+ ":disk_name": disk_name,
93
+ ":active": "active",
94
+ ":preparing": "preparing",
95
+ ":queued": "queued",
96
+ ":pending": "pending"
97
+ },
98
+ ExclusiveStartKey=response["LastEvaluatedKey"]
99
+ )
100
+ items.extend(response.get("Items", []))
101
+
102
+ if items:
103
+ reservation_id = items[0]["reservation_id"]
104
+ return True, reservation_id
105
+
106
+ # Special case: For "default" disk, also check for legacy reservations without disk_name field
107
+ # (reservations created before named disk migration)
108
+ # IMPORTANT: Only match legacy reservations that HAVE an ebs_volume_id
109
+ # (reservations without disk_name AND without ebs_volume_id are non-persistent, not "default" disk)
110
+ if disk_name == "default":
111
+ legacy_response = reservations_table.query(
112
+ IndexName="UserIndex",
113
+ KeyConditionExpression="user_id = :user_id",
114
+ FilterExpression="attribute_not_exists(disk_name) AND attribute_exists(ebs_volume_id) AND #status IN (:active, :preparing)",
115
+ ExpressionAttributeNames={"#status": "status"},
116
+ ExpressionAttributeValues={
117
+ ":user_id": user_id,
118
+ ":active": "active",
119
+ ":preparing": "preparing"
120
+ }
121
+ )
122
+
123
+ # Handle pagination for legacy query
124
+ legacy_items = legacy_response.get("Items", [])
125
+ while "LastEvaluatedKey" in legacy_response:
126
+ legacy_response = reservations_table.query(
127
+ IndexName="UserIndex",
128
+ KeyConditionExpression="user_id = :user_id",
129
+ FilterExpression="attribute_not_exists(disk_name) AND attribute_exists(ebs_volume_id) AND #status IN (:active, :preparing)",
130
+ ExpressionAttributeNames={"#status": "status"},
131
+ ExpressionAttributeValues={
132
+ ":user_id": user_id,
133
+ ":active": "active",
134
+ ":preparing": "preparing"
135
+ },
136
+ ExclusiveStartKey=legacy_response["LastEvaluatedKey"]
137
+ )
138
+ legacy_items.extend(legacy_response.get("Items", []))
139
+
140
+ if legacy_items:
141
+ reservation_id = legacy_items[0]["reservation_id"]
142
+ return True, reservation_id
143
+
144
+ return False, None
145
+
146
+ except Exception as e:
147
+ print(f"Warning: Could not query reservations: {e}")
148
+ return False, None
149
+
150
+
151
+ def list_disks(user_id: str, config: Config) -> List[Dict]:
152
+ """
153
+ List all disks for a user.
154
+ Returns list of disk info dicts with: name, size, last_used, created_at, snapshot_count, in_use, reservation_id
155
+ """
156
+ ec2_client = get_ec2_client(config)
157
+ dynamodb = get_dynamodb_resource(config)
158
+
159
+ # Query DynamoDB disks table for this user's disks (with pagination)
160
+ disks_table_name = config.disks_table if hasattr(config, 'disks_table') else f"{config.queue_name.rsplit('-', 1)[0]}-disks"
161
+ disks_table = dynamodb.Table(disks_table_name)
162
+
163
+ dynamodb_disks = []
164
+ response = disks_table.query(
165
+ KeyConditionExpression="user_id = :user_id",
166
+ ExpressionAttributeValues={":user_id": user_id}
167
+ )
168
+ dynamodb_disks.extend(response.get('Items', []))
169
+
170
+ # Handle pagination (get all disks if user has many)
171
+ while 'LastEvaluatedKey' in response:
172
+ response = disks_table.query(
173
+ KeyConditionExpression="user_id = :user_id",
174
+ ExpressionAttributeValues={":user_id": user_id},
175
+ ExclusiveStartKey=response['LastEvaluatedKey']
176
+ )
177
+ dynamodb_disks.extend(response.get('Items', []))
178
+
179
+ # Process DynamoDB data
180
+ disks = []
181
+ for disk_item in dynamodb_disks:
182
+ disk_name = disk_item['disk_name']
183
+
184
+ # Convert DynamoDB types (Decimal to int/float)
185
+ size_gb = int(disk_item.get('size_gb', 0)) if disk_item.get('size_gb') else 0
186
+ snapshot_count = int(disk_item.get('snapshot_count', 0)) if disk_item.get('snapshot_count') else 0
187
+ pending_snapshot_count = int(disk_item.get('pending_snapshot_count', 0)) if disk_item.get('pending_snapshot_count') else 0
188
+
189
+ # Parse datetime strings from DynamoDB
190
+ created_at_str = disk_item.get('created_at')
191
+ last_used_str = disk_item.get('last_used')
192
+
193
+ created_at = datetime.fromisoformat(created_at_str) if created_at_str else None
194
+ last_used = datetime.fromisoformat(last_used_str) if last_used_str else None
195
+
196
+ # Ensure all datetimes are timezone-aware (normalize any timezone-naive datetimes from older records)
197
+ if created_at and created_at.tzinfo is None:
198
+ created_at = created_at.replace(tzinfo=timezone.utc)
199
+ if last_used and last_used.tzinfo is None:
200
+ last_used = last_used.replace(tzinfo=timezone.utc)
201
+
202
+ # Get disk_size if available
203
+ disk_size = disk_item.get('disk_size')
204
+
205
+ # Get backup and deletion status from DynamoDB
206
+ is_backing_up = disk_item.get('is_backing_up', False)
207
+ is_deleted = disk_item.get('is_deleted', False)
208
+ delete_date = disk_item.get('delete_date')
209
+
210
+ # Check current in_use status (check dynamically from reservations table)
211
+ is_in_use, reservation_id = get_disk_in_use_status(disk_name, user_id, config)
212
+
213
+ disks.append({
214
+ 'name': disk_name,
215
+ 'size_gb': size_gb,
216
+ 'disk_size': disk_size,
217
+ 'created_at': created_at,
218
+ 'last_used': last_used,
219
+ 'snapshot_count': snapshot_count,
220
+ 'pending_snapshot_count': pending_snapshot_count,
221
+ 'in_use': is_in_use,
222
+ 'is_backing_up': is_backing_up,
223
+ 'reservation_id': reservation_id,
224
+ 'is_deleted': is_deleted,
225
+ 'delete_date': delete_date,
226
+ })
227
+
228
+ # Sort by last_used (most recent first)
229
+ disks.sort(key=lambda d: d['last_used'] or datetime.min.replace(tzinfo=timezone.utc), reverse=True)
230
+
231
+ return disks
232
+
233
+
234
+ def create_disk(disk_name: str, user_id: str, config: Config) -> Optional[str]:
235
+ """
236
+ Create a new disk by sending request to SQS queue.
237
+ Lambda will create the disk entry in DynamoDB.
238
+ Returns operation_id on success, None on failure.
239
+ """
240
+ import json
241
+ import uuid
242
+
243
+ # Check if disk already exists
244
+ existing_disks = list_disks(user_id, config)
245
+ if any(d['name'] == disk_name for d in existing_disks):
246
+ print(f"Error: Disk '{disk_name}' already exists")
247
+ return None
248
+
249
+ # Validate disk name (alphanumeric + hyphens + underscores)
250
+ if not re.match(r'^[a-zA-Z0-9_-]+$', disk_name):
251
+ print(f"Error: Disk name must contain only letters, numbers, hyphens, and underscores")
252
+ return None
253
+
254
+ # Generate operation ID for tracking
255
+ operation_id = str(uuid.uuid4())
256
+
257
+ # Send create request to SQS queue
258
+ try:
259
+ sqs_client = config.session.client('sqs', region_name=config.aws_region)
260
+ queue_url = config.get_queue_url()
261
+
262
+ # Create disk creation message
263
+ message = {
264
+ 'action': 'create_disk',
265
+ 'operation_id': operation_id,
266
+ 'user_id': user_id,
267
+ 'disk_name': disk_name,
268
+ 'requested_at': datetime.now(timezone.utc).isoformat()
269
+ }
270
+
271
+ sqs_client.send_message(
272
+ QueueUrl=queue_url,
273
+ MessageBody=json.dumps(message)
274
+ )
275
+
276
+ return operation_id
277
+
278
+ except Exception as e:
279
+ print(f"Error sending create request: {e}")
280
+ return None
281
+
282
+
283
+ def list_disk_content(disk_name: str, user_id: str, config: Config) -> Optional[str]:
284
+ """
285
+ Fetch and return the contents of the latest snapshot for a disk.
286
+ Returns contents string or None if not found.
287
+ """
288
+ s3_client = get_s3_client(config)
289
+ dynamodb = get_dynamodb_resource(config)
290
+
291
+ # Get disk info from DynamoDB to get latest snapshot S3 path
292
+ disks_table_name = config.disks_table if hasattr(config, 'disks_table') else f"{config.queue_name.rsplit('-', 1)[0]}-disks"
293
+ disks_table = dynamodb.Table(disks_table_name)
294
+
295
+ try:
296
+ response = disks_table.get_item(
297
+ Key={'user_id': user_id, 'disk_name': disk_name}
298
+ )
299
+
300
+ if 'Item' not in response:
301
+ print(f"Disk '{disk_name}' not found")
302
+ return None
303
+
304
+ disk_item = response['Item']
305
+ s3_path = disk_item.get('latest_snapshot_content_s3')
306
+
307
+ if not s3_path:
308
+ print(f"No snapshot contents available for disk '{disk_name}'")
309
+ print(f"This may be a newly created disk or a disk created before content tracking was added.")
310
+ return None
311
+
312
+ except Exception as e:
313
+ print(f"Error fetching disk info from DynamoDB: {e}")
314
+ return None
315
+
316
+ # Parse S3 path (s3://bucket/key)
317
+ if not s3_path.startswith('s3://'):
318
+ print(f"Invalid S3 path format: {s3_path}")
319
+ return None
320
+
321
+ path_parts = s3_path[5:].split('/', 1)
322
+ if len(path_parts) != 2:
323
+ print(f"Invalid S3 path format: {s3_path}")
324
+ return None
325
+
326
+ bucket_name, s3_key = path_parts
327
+
328
+ try:
329
+ # Fetch contents from S3
330
+ response = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
331
+ contents = response['Body'].read().decode('utf-8')
332
+ return contents
333
+ except s3_client.exceptions.NoSuchKey:
334
+ print(f"Contents file not found in S3: {s3_path}")
335
+ return None
336
+ except Exception as e:
337
+ print(f"Error fetching contents from S3: {e}")
338
+ return None
339
+
340
+
341
+ def delete_disk(disk_name: str, user_id: str, config: Config) -> Optional[str]:
342
+ """
343
+ Soft delete a disk by sending delete request to SQS queue.
344
+ Lambda will handle marking in DynamoDB and tagging snapshots.
345
+ Returns operation_id on success, None on failure.
346
+ """
347
+ import json
348
+ import uuid
349
+
350
+ # Check if disk exists
351
+ disks = list_disks(user_id, config)
352
+ disk = next((d for d in disks if d['name'] == disk_name), None)
353
+
354
+ if not disk:
355
+ print(f"Error: Disk '{disk_name}' not found")
356
+ return None
357
+
358
+ # Check if disk is in use
359
+ if disk['in_use']:
360
+ print(f"Error: Cannot delete disk '{disk_name}' - it is currently in use")
361
+ print(f"Reservation ID: {disk['reservation_id']}")
362
+ return None
363
+
364
+ # Calculate deletion date (30 days from now)
365
+ delete_date = datetime.now(timezone.utc) + timedelta(days=30)
366
+ delete_date_str = delete_date.strftime('%Y-%m-%d')
367
+
368
+ # Generate operation ID for tracking
369
+ operation_id = str(uuid.uuid4())
370
+
371
+ # Send delete request to SQS queue
372
+ try:
373
+ sqs_client = config.session.client('sqs', region_name=config.aws_region)
374
+ queue_url = config.get_queue_url()
375
+
376
+ # Create disk deletion message
377
+ message = {
378
+ 'action': 'delete_disk',
379
+ 'operation_id': operation_id,
380
+ 'user_id': user_id,
381
+ 'disk_name': disk_name,
382
+ 'delete_date': delete_date_str,
383
+ 'requested_at': datetime.now(timezone.utc).isoformat()
384
+ }
385
+
386
+ sqs_client.send_message(
387
+ QueueUrl=queue_url,
388
+ MessageBody=json.dumps(message)
389
+ )
390
+
391
+ return operation_id
392
+
393
+ except Exception as e:
394
+ print(f"Error sending delete request: {e}")
395
+ return None
396
+
397
+
398
+ def poll_disk_operation(
399
+ operation_type: str,
400
+ disk_name: str,
401
+ user_id: str,
402
+ config: Config,
403
+ timeout_seconds: int = 60
404
+ ) -> Tuple[bool, str]:
405
+ """
406
+ Poll DynamoDB for disk operation completion.
407
+
408
+ Args:
409
+ operation_type: 'create' or 'delete'
410
+ disk_name: Name of the disk
411
+ user_id: User ID
412
+ config: Config object
413
+ timeout_seconds: Max time to wait
414
+
415
+ Returns:
416
+ Tuple of (success, message)
417
+ """
418
+ import time
419
+
420
+ start_time = time.time()
421
+ poll_interval = 2 # seconds
422
+
423
+ while time.time() - start_time < timeout_seconds:
424
+ try:
425
+ disks = list_disks(user_id, config)
426
+ disk = next((d for d in disks if d['name'] == disk_name), None)
427
+
428
+ if operation_type == 'create':
429
+ # For create, we're waiting for the disk to appear
430
+ if disk is not None:
431
+ return True, f"Disk '{disk_name}' created successfully"
432
+
433
+ elif operation_type == 'delete':
434
+ # For delete, we're waiting for is_deleted to be True
435
+ if disk is None:
436
+ # Disk no longer in list (shouldn't happen with soft delete)
437
+ return True, f"Disk '{disk_name}' deleted successfully"
438
+ elif disk.get('is_deleted', False):
439
+ delete_date = disk.get('delete_date', 'in 30 days')
440
+ return True, f"Disk '{disk_name}' marked for deletion. Snapshots will be permanently deleted on {delete_date}"
441
+
442
+ time.sleep(poll_interval)
443
+
444
+ except Exception as e:
445
+ # Continue polling on errors
446
+ time.sleep(poll_interval)
447
+
448
+ # Timeout
449
+ if operation_type == 'create':
450
+ return False, f"Timed out waiting for disk '{disk_name}' to be created. It may still be processing."
451
+ else:
452
+ return False, f"Timed out waiting for disk '{disk_name}' deletion to complete. It may still be processing."
453
+
454
+
455
+ def rename_disk(old_name: str, new_name: str, user_id: str, config: Config) -> bool:
456
+ """
457
+ Rename a disk by updating disk_name tags on all its snapshots.
458
+ Returns True on success, False on failure.
459
+ """
460
+ ec2_client = get_ec2_client(config)
461
+
462
+ # Validate new disk name
463
+ if not re.match(r'^[a-zA-Z0-9_-]+$', new_name):
464
+ print(f"Error: Disk name must contain only letters, numbers, hyphens, and underscores")
465
+ return False
466
+
467
+ # Check if old disk exists
468
+ disks = list_disks(user_id, config)
469
+ old_disk = next((d for d in disks if d['name'] == old_name), None)
470
+
471
+ if not old_disk:
472
+ print(f"Error: Disk '{old_name}' not found")
473
+ return False
474
+
475
+ # Check if new name already exists
476
+ if any(d['name'] == new_name for d in disks):
477
+ print(f"Error: Disk '{new_name}' already exists")
478
+ return False
479
+
480
+ # Check if disk is in use
481
+ if old_disk['in_use']:
482
+ print(f"Error: Cannot rename disk '{old_name}' - it is currently in use")
483
+ print(f"Reservation ID: {old_disk['reservation_id']}")
484
+ return False
485
+
486
+ print(f"Renaming disk '{old_name}' to '{new_name}'...")
487
+
488
+ try:
489
+ # Find all snapshots for this disk
490
+ response = ec2_client.describe_snapshots(
491
+ OwnerIds=["self"],
492
+ Filters=[
493
+ {"Name": "tag:gpu-dev-user", "Values": [user_id]},
494
+ {"Name": "tag:disk_name", "Values": [old_name]},
495
+ ]
496
+ )
497
+
498
+ snapshots = response.get('Snapshots', [])
499
+
500
+ if not snapshots:
501
+ print(f"Warning: No snapshots found for disk '{old_name}'")
502
+ return False
503
+
504
+ # Update disk_name tag on each snapshot
505
+ renamed_count = 0
506
+ for snapshot in snapshots:
507
+ snapshot_id = snapshot['SnapshotId']
508
+ try:
509
+ ec2_client.create_tags(
510
+ Resources=[snapshot_id],
511
+ Tags=[{"Key": "disk_name", "Value": new_name}]
512
+ )
513
+ print(f" ✓ Updated snapshot {snapshot_id}")
514
+ renamed_count += 1
515
+ except Exception as e:
516
+ print(f" ✗ Error updating snapshot {snapshot_id}: {e}")
517
+
518
+ print(f"✓ Successfully renamed disk to '{new_name}' ({renamed_count} snapshots updated)")
519
+ return True
520
+
521
+ except Exception as e:
522
+ print(f"Error renaming disk: {e}")
523
+ return False