aws-bootstrap-g4dn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,839 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ GPU Throughput Benchmark for AWS EC2 Spot Instances (T4 GPU)
4
+
5
+ Tests PyTorch GPU utilization with two benchmark modes:
6
+ 1. CNN on MNIST - lightweight, fast iteration
7
+ 2. Transformer on synthetic data - more compute-intensive
8
+
9
+ Reports: iterations/sec, samples/sec, GPU memory usage, and utilization metrics.
10
+
11
+ Supports multiple precision modes with automatic fallback:
12
+ - FP16 AMP (default for Turing/Ampere+)
13
+ - FP32 (fallback if AMP fails)
14
+ - TF32 (Ampere+ only)
15
+ """
16
+
17
+ from __future__ import annotations
18
+ import argparse
19
+ import os
20
+ import sys
21
+ import time
22
+ from contextlib import contextmanager
23
+ from dataclasses import dataclass
24
+ from enum import Enum
25
+ from typing import TYPE_CHECKING
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch.utils.data import DataLoader, TensorDataset
31
+ from torchvision import datasets, transforms
32
+ from tqdm import tqdm
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from collections.abc import Generator
37
+
38
+
39
+ # -----------------------------------------------------------------------------
40
+ # Diagnostic Functions
41
+ # -----------------------------------------------------------------------------
42
+
43
+
44
+ def run_cuda_diagnostics(device: torch.device) -> dict[str, bool]:
45
+ """
46
+ Run diagnostic tests to verify CUDA/cuBLAS functionality.
47
+ Returns dict of test_name -> passed.
48
+ """
49
+ results: dict[str, bool] = {}
50
+
51
+ if device.type != "cuda":
52
+ print(" Skipping CUDA diagnostics (CPU mode)")
53
+ return results
54
+
55
+ print("\n" + "-" * 40)
56
+ print("Running CUDA Diagnostics")
57
+ print("-" * 40)
58
+
59
+ # Test 1: Basic FP32 matmul
60
+ try:
61
+ a = torch.randn(256, 256, device=device)
62
+ b = torch.randn(256, 256, device=device)
63
+ _c = torch.mm(a, b)
64
+ torch.cuda.synchronize()
65
+ results["fp32_matmul"] = True
66
+ print(" ✓ FP32 matmul: PASSED")
67
+ except Exception as e:
68
+ results["fp32_matmul"] = False
69
+ print(f" ✗ FP32 matmul: FAILED - {e}")
70
+
71
+ # Test 2: FP16 matmul (no autocast)
72
+ try:
73
+ a = torch.randn(256, 256, device=device, dtype=torch.float16)
74
+ b = torch.randn(256, 256, device=device, dtype=torch.float16)
75
+ _c = torch.mm(a, b)
76
+ torch.cuda.synchronize()
77
+ results["fp16_matmul"] = True
78
+ print(" ✓ FP16 matmul: PASSED")
79
+ except Exception as e:
80
+ results["fp16_matmul"] = False
81
+ print(f" ✗ FP16 matmul: FAILED - {e}")
82
+
83
+ # Test 3: FP16 matmul with autocast
84
+ try:
85
+ a = torch.randn(256, 256, device=device)
86
+ b = torch.randn(256, 256, device=device)
87
+ with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
88
+ _c = torch.mm(a, b)
89
+ torch.cuda.synchronize()
90
+ results["fp16_autocast"] = True
91
+ print(" ✓ FP16 autocast matmul: PASSED")
92
+ except Exception as e:
93
+ results["fp16_autocast"] = False
94
+ print(f" ✗ FP16 autocast matmul: FAILED - {e}")
95
+
96
+ # Test 4: Linear layer with autocast (common GEMM pattern)
97
+ try:
98
+ linear = nn.Linear(512, 512).to(device)
99
+ x = torch.randn(64, 512, device=device)
100
+ with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
101
+ _y = linear(x)
102
+ torch.cuda.synchronize()
103
+ results["fp16_linear"] = True
104
+ print(" ✓ FP16 linear layer: PASSED")
105
+ except Exception as e:
106
+ results["fp16_linear"] = False
107
+ print(f" ✗ FP16 linear layer: FAILED - {e}")
108
+
109
+ # Test 5: Conv2d with autocast
110
+ try:
111
+ conv = nn.Conv2d(64, 128, 3, padding=1).to(device)
112
+ x = torch.randn(16, 64, 32, 32, device=device)
113
+ with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
114
+ _y = conv(x)
115
+ torch.cuda.synchronize()
116
+ results["fp16_conv2d"] = True
117
+ print(" ✓ FP16 conv2d: PASSED")
118
+ except Exception as e:
119
+ results["fp16_conv2d"] = False
120
+ print(f" ✗ FP16 conv2d: FAILED - {e}")
121
+
122
+ # Test 6: Batched matmul (transformer attention pattern)
123
+ try:
124
+ # Simulates attention: (batch, heads, seq, dim) @ (batch, heads, dim, seq)
125
+ a = torch.randn(8, 8, 128, 64, device=device, dtype=torch.float16)
126
+ b = torch.randn(8, 8, 64, 128, device=device, dtype=torch.float16)
127
+ _c = torch.matmul(a, b)
128
+ torch.cuda.synchronize()
129
+ results["fp16_batched_matmul"] = True
130
+ print(" ✓ FP16 batched matmul: PASSED")
131
+ except Exception as e:
132
+ results["fp16_batched_matmul"] = False
133
+ print(f" ✗ FP16 batched matmul: FAILED - {e}")
134
+
135
+ print("-" * 40)
136
+
137
+ # Summary
138
+ passed = sum(results.values())
139
+ total = len(results)
140
+ print(f"Diagnostics: {passed}/{total} tests passed")
141
+
142
+ if passed < total:
143
+ failed_tests = [k for k, v in results.items() if not v]
144
+ print(f"Failed tests: {', '.join(failed_tests)}")
145
+ print("\nRecommendation: Use --precision fp32 to bypass FP16 issues")
146
+
147
+ print("-" * 40 + "\n")
148
+ return results
149
+
150
+
151
+ class PrecisionMode(Enum):
152
+ """Supported precision modes for training."""
153
+
154
+ FP32 = "fp32"
155
+ FP16 = "fp16"
156
+ BF16 = "bf16"
157
+ TF32 = "tf32"
158
+
159
+
160
+ @dataclass(frozen=True)
161
+ class BenchmarkConfig:
162
+ """Configuration for benchmark runs."""
163
+
164
+ batch_size: int = 256
165
+ num_warmup_batches: int = 10
166
+ num_benchmark_batches: int = 100
167
+ num_workers: int = 4
168
+ pin_memory: bool = True
169
+ precision: PrecisionMode = PrecisionMode.FP16
170
+
171
+
172
+ @dataclass
173
+ class BenchmarkResult:
174
+ """Results from a benchmark run."""
175
+
176
+ model_name: str
177
+ total_samples: int
178
+ total_time_sec: float
179
+ peak_memory_mb: float
180
+ avg_batch_time_ms: float
181
+ precision_mode: str
182
+
183
+ @property
184
+ def samples_per_sec(self) -> float:
185
+ return self.total_samples / self.total_time_sec
186
+
187
+ @property
188
+ def batches_per_sec(self) -> float:
189
+ return 1000.0 / self.avg_batch_time_ms
190
+
191
+ def __str__(self) -> str:
192
+ return (
193
+ f"\n{'=' * 60}\n"
194
+ f"Benchmark Results: {self.model_name}\n"
195
+ f"{'=' * 60}\n"
196
+ f" Precision mode: {self.precision_mode}\n"
197
+ f" Total samples processed: {self.total_samples:,}\n"
198
+ f" Total time: {self.total_time_sec:.2f}s\n"
199
+ f" Throughput: {self.samples_per_sec:,.1f} samples/sec\n"
200
+ f" Throughput: {self.batches_per_sec:.1f} batches/sec\n"
201
+ f" Avg batch time: {self.avg_batch_time_ms:.2f}ms\n"
202
+ f" Peak GPU memory: {self.peak_memory_mb:.1f}MB\n"
203
+ f"{'=' * 60}\n"
204
+ )
205
+
206
+
207
+ # -----------------------------------------------------------------------------
208
+ # CNN Model for MNIST
209
+ # -----------------------------------------------------------------------------
210
+
211
+
212
+ class MNISTConvNet(nn.Module):
213
+ """
214
+ Simple but non-trivial CNN for MNIST.
215
+ ~1.2M parameters - enough to stress GPU without being excessive.
216
+ """
217
+
218
+ def __init__(self) -> None:
219
+ super().__init__()
220
+ self.features = nn.Sequential(
221
+ nn.Conv2d(1, 64, kernel_size=3, padding=1),
222
+ nn.BatchNorm2d(64),
223
+ nn.ReLU(inplace=True),
224
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
225
+ nn.BatchNorm2d(64),
226
+ nn.ReLU(inplace=True),
227
+ nn.MaxPool2d(2),
228
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
229
+ nn.BatchNorm2d(128),
230
+ nn.ReLU(inplace=True),
231
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
232
+ nn.BatchNorm2d(128),
233
+ nn.ReLU(inplace=True),
234
+ nn.MaxPool2d(2),
235
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
236
+ nn.BatchNorm2d(256),
237
+ nn.ReLU(inplace=True),
238
+ nn.AdaptiveAvgPool2d(1),
239
+ )
240
+ self.classifier = nn.Sequential(
241
+ nn.Flatten(),
242
+ nn.Linear(256, 256),
243
+ nn.ReLU(inplace=True),
244
+ nn.Dropout(0.5),
245
+ nn.Linear(256, 10),
246
+ )
247
+
248
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
249
+ x = self.features(x)
250
+ return self.classifier(x)
251
+
252
+
253
+ # -----------------------------------------------------------------------------
254
+ # Transformer Model (GPT-style decoder)
255
+ # -----------------------------------------------------------------------------
256
+
257
+
258
+ class TransformerBlock(nn.Module):
259
+ """Single transformer decoder block with pre-norm."""
260
+
261
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1) -> None:
262
+ super().__init__()
263
+ self.ln1 = nn.LayerNorm(d_model)
264
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
265
+ self.ln2 = nn.LayerNorm(d_model)
266
+ self.ff = nn.Sequential(
267
+ nn.Linear(d_model, d_ff),
268
+ nn.GELU(),
269
+ nn.Dropout(dropout),
270
+ nn.Linear(d_ff, d_model),
271
+ nn.Dropout(dropout),
272
+ )
273
+
274
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
275
+ # Pre-norm architecture
276
+ normed = self.ln1(x)
277
+ attn_out, _ = self.attn(normed, normed, normed, attn_mask=attn_mask, need_weights=False)
278
+ x = x + attn_out
279
+ x = x + self.ff(self.ln2(x))
280
+ return x
281
+
282
+
283
+ class MiniGPT(nn.Module):
284
+ """
285
+ Small GPT-style transformer for benchmarking.
286
+ ~25M parameters - representative of real workloads.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ vocab_size: int = 32000,
292
+ d_model: int = 512,
293
+ n_heads: int = 8,
294
+ n_layers: int = 6,
295
+ d_ff: int = 2048,
296
+ max_seq_len: int = 256,
297
+ dropout: float = 0.1,
298
+ ) -> None:
299
+ super().__init__()
300
+ self.d_model = d_model
301
+ self.max_seq_len = max_seq_len
302
+
303
+ self.token_emb = nn.Embedding(vocab_size, d_model)
304
+ self.pos_emb = nn.Embedding(max_seq_len, d_model)
305
+ self.dropout = nn.Dropout(dropout)
306
+
307
+ self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
308
+
309
+ self.ln_f = nn.LayerNorm(d_model)
310
+ self.head = nn.Linear(d_model, vocab_size, bias=False)
311
+
312
+ # Weight tying
313
+ self.head.weight = self.token_emb.weight
314
+
315
+ self._init_weights()
316
+
317
+ def _init_weights(self) -> None:
318
+ for module in self.modules():
319
+ if isinstance(module, nn.Linear):
320
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
321
+ if module.bias is not None:
322
+ nn.init.zeros_(module.bias)
323
+ elif isinstance(module, nn.Embedding):
324
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
325
+
326
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
327
+ batch_size, seq_len = x.shape
328
+ assert seq_len <= self.max_seq_len, f"Sequence length {seq_len} > max {self.max_seq_len}"
329
+
330
+ positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
331
+
332
+ x = self.dropout(self.token_emb(x) + self.pos_emb(positions))
333
+
334
+ # Causal mask for autoregressive modeling
335
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1)
336
+
337
+ for block in self.blocks:
338
+ x = block(x, attn_mask=causal_mask)
339
+
340
+ x = self.ln_f(x)
341
+ return self.head(x)
342
+
343
+
344
+ # -----------------------------------------------------------------------------
345
+ # Data Loading
346
+ # -----------------------------------------------------------------------------
347
+
348
+
349
+ def get_mnist_loader(config: BenchmarkConfig, device: torch.device) -> DataLoader:
350
+ """Load MNIST dataset with standard preprocessing."""
351
+ transform = transforms.Compose(
352
+ [
353
+ transforms.ToTensor(),
354
+ transforms.Normalize((0.1307,), (0.3081,)),
355
+ ]
356
+ )
357
+
358
+ dataset = datasets.MNIST(root="/tmp/data", train=True, download=True, transform=transform)
359
+
360
+ return DataLoader(
361
+ dataset,
362
+ batch_size=config.batch_size,
363
+ shuffle=True,
364
+ num_workers=config.num_workers,
365
+ pin_memory=config.pin_memory and device.type == "cuda",
366
+ persistent_workers=config.num_workers > 0,
367
+ )
368
+
369
+
370
+ def get_synthetic_text_loader(
371
+ config: BenchmarkConfig,
372
+ vocab_size: int = 32000,
373
+ seq_len: int = 256,
374
+ num_samples: int = 50000,
375
+ ) -> DataLoader:
376
+ """Generate synthetic token sequences for transformer benchmarking."""
377
+ # Random token IDs (simulates real tokenized text distribution)
378
+ data = torch.randint(0, vocab_size, (num_samples, seq_len))
379
+ # Labels are next-token shifted (standard LM objective)
380
+ labels = torch.randint(0, vocab_size, (num_samples, seq_len))
381
+
382
+ dataset = TensorDataset(data, labels)
383
+
384
+ return DataLoader(
385
+ dataset,
386
+ batch_size=config.batch_size,
387
+ shuffle=True,
388
+ num_workers=0, # Synthetic data is fast enough
389
+ pin_memory=True,
390
+ )
391
+
392
+
393
+ # -----------------------------------------------------------------------------
394
+ # Benchmark Runner
395
+ # -----------------------------------------------------------------------------
396
+
397
+
398
+ @contextmanager
399
+ def cuda_timer(device: torch.device) -> Generator[dict[str, float]]:
400
+ """Context manager for accurate CUDA timing using events."""
401
+ result: dict[str, float] = {}
402
+
403
+ if device.type == "cuda":
404
+ torch.cuda.synchronize(device)
405
+ start_event = torch.cuda.Event(enable_timing=True)
406
+ end_event = torch.cuda.Event(enable_timing=True)
407
+ start_event.record()
408
+ yield result
409
+ end_event.record()
410
+ torch.cuda.synchronize(device)
411
+ result["elapsed_ms"] = start_event.elapsed_time(end_event)
412
+ else:
413
+ start = time.perf_counter()
414
+ yield result
415
+ result["elapsed_ms"] = (time.perf_counter() - start) * 1000
416
+
417
+
418
+ def run_benchmark(
419
+ model: nn.Module,
420
+ loader: DataLoader,
421
+ config: BenchmarkConfig,
422
+ device: torch.device,
423
+ model_name: str,
424
+ precision: PrecisionMode,
425
+ is_lm: bool = False,
426
+ ) -> BenchmarkResult:
427
+ """
428
+ Run training benchmark with warmup phase.
429
+
430
+ Args:
431
+ model: PyTorch model to benchmark
432
+ loader: DataLoader providing batches
433
+ config: Benchmark configuration
434
+ device: Target device
435
+ model_name: Name for reporting
436
+ precision: Precision mode to use
437
+ is_lm: If True, use language modeling loss (ignore_index=-100)
438
+ """
439
+ model = model.to(device)
440
+ model.train()
441
+
442
+ # Configure precision-specific settings
443
+ use_amp = precision in (PrecisionMode.FP16, PrecisionMode.BF16)
444
+ amp_dtype = torch.float16 if precision == PrecisionMode.FP16 else torch.bfloat16
445
+
446
+ # GradScaler is only needed for FP16 (BF16 has sufficient dynamic range)
447
+ use_scaler = precision == PrecisionMode.FP16 and device.type == "cuda"
448
+ scaler = torch.amp.GradScaler("cuda", enabled=use_scaler)
449
+
450
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
451
+
452
+ if device.type == "cuda":
453
+ torch.cuda.reset_peak_memory_stats(device)
454
+
455
+ data_iter = iter(loader)
456
+ batch_times: list[float] = []
457
+
458
+ total_batches = config.num_warmup_batches + config.num_benchmark_batches
459
+
460
+ print(f"\nRunning {model_name} benchmark...")
461
+ print(f" Precision: {precision.value}")
462
+ print(f" AMP enabled: {use_amp}")
463
+ print(f" GradScaler enabled: {use_scaler}")
464
+ print(f" Warmup batches: {config.num_warmup_batches}")
465
+ print(f" Benchmark batches: {config.num_benchmark_batches}")
466
+ print(f" Batch size: {config.batch_size}")
467
+
468
+ pbar = tqdm(range(total_batches), desc=model_name, unit="batch")
469
+ for batch_idx in pbar:
470
+ # Get next batch, cycling if needed
471
+ try:
472
+ batch = next(data_iter)
473
+ except StopIteration:
474
+ data_iter = iter(loader)
475
+ batch = next(data_iter)
476
+
477
+ inputs = batch[0].to(device, non_blocking=True)
478
+ targets = batch[1].to(device, non_blocking=True)
479
+
480
+ with cuda_timer(device) as timer:
481
+ optimizer.zero_grad(set_to_none=True)
482
+
483
+ # Use autocast only when AMP is enabled
484
+ with torch.amp.autocast(
485
+ device_type=device.type,
486
+ dtype=amp_dtype,
487
+ enabled=use_amp,
488
+ ):
489
+ outputs = model(inputs)
490
+
491
+ if is_lm:
492
+ # Reshape for cross-entropy: (batch * seq_len, vocab_size)
493
+ loss = F.cross_entropy(
494
+ outputs.view(-1, outputs.size(-1)),
495
+ targets.view(-1),
496
+ )
497
+ else:
498
+ loss = F.cross_entropy(outputs, targets)
499
+
500
+ # Backward pass with optional gradient scaling
501
+ scaler.scale(loss).backward()
502
+ scaler.step(optimizer)
503
+ scaler.update()
504
+
505
+ # Only record times after warmup
506
+ is_benchmark = batch_idx >= config.num_warmup_batches
507
+ if is_benchmark:
508
+ batch_times.append(timer["elapsed_ms"])
509
+
510
+ # Update progress bar
511
+ phase = "bench" if is_benchmark else "warmup"
512
+ postfix: dict[str, str] = {"phase": phase, "loss": f"{loss.item():.4f}"}
513
+ if is_benchmark and batch_times:
514
+ sps = config.batch_size / (batch_times[-1] / 1000)
515
+ postfix["samples/s"] = f"{sps:,.0f}"
516
+ if device.type == "cuda":
517
+ mem_mb = torch.cuda.memory_allocated(device) / (1024**2)
518
+ postfix["gpu_mem"] = f"{mem_mb:.0f}MB"
519
+ pbar.set_postfix(postfix)
520
+
521
+ # Compute statistics
522
+ avg_batch_time = sum(batch_times) / len(batch_times)
523
+ total_time = sum(batch_times) / 1000 # Convert to seconds
524
+ total_samples = config.num_benchmark_batches * config.batch_size
525
+
526
+ peak_memory = 0.0
527
+ if device.type == "cuda":
528
+ peak_memory = torch.cuda.max_memory_allocated(device) / (1024 * 1024)
529
+
530
+ return BenchmarkResult(
531
+ model_name=model_name,
532
+ total_samples=total_samples,
533
+ total_time_sec=total_time,
534
+ peak_memory_mb=peak_memory,
535
+ avg_batch_time_ms=avg_batch_time,
536
+ precision_mode=precision.value,
537
+ )
538
+
539
+
540
+ # -----------------------------------------------------------------------------
541
+ # System Information and GPU Configuration
542
+ # -----------------------------------------------------------------------------
543
+
544
+
545
+ def get_gpu_architecture(device: torch.device) -> tuple[int, int]:
546
+ """Get GPU compute capability (major, minor)."""
547
+ if device.type != "cuda":
548
+ return (0, 0)
549
+ props = torch.cuda.get_device_properties(device)
550
+ return (props.major, props.minor)
551
+
552
+
553
+ def configure_precision(device: torch.device, requested: PrecisionMode) -> PrecisionMode:
554
+ """
555
+ Configure and validate precision mode based on GPU capabilities.
556
+
557
+ GPU Architecture Reference:
558
+ - Turing (T4): sm_75 - Supports FP16 tensor cores, NO native BF16, NO TF32
559
+ - Ampere (A100, A10, 3090): sm_80/86 - Supports FP16, BF16, TF32
560
+ - Hopper (H100): sm_90 - Full support for all modes
561
+
562
+ Returns the actual precision mode that will be used.
563
+ """
564
+ if device.type != "cuda":
565
+ print(" CPU mode: Using FP32")
566
+ return PrecisionMode.FP32
567
+
568
+ major, minor = get_gpu_architecture(device)
569
+ sm_version = major * 10 + minor
570
+
571
+ print(f" GPU compute capability: sm_{sm_version}")
572
+
573
+ # =========================================================================
574
+ # CRITICAL: Disable problematic cuBLAS features that can cause GEMM errors
575
+ # These settings improve stability on older architectures like Turing (T4)
576
+ # =========================================================================
577
+
578
+ # Disable TF32 on non-Ampere hardware (TF32 is Ampere+ only)
579
+ if sm_version < 80:
580
+ torch.backends.cuda.matmul.allow_tf32 = False
581
+ torch.backends.cudnn.allow_tf32 = False
582
+ print(" TF32 disabled (requires sm_80+)")
583
+
584
+ # Disable reduced precision reductions in FP16 GEMMs
585
+ # This can cause overflow/execution failures on some cuBLAS versions
586
+ # See: https://docs.pytorch.org/docs/stable/notes/cuda.html
587
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
588
+ print(" FP16 reduced precision reduction disabled (for stability)")
589
+
590
+ # Also disable BF16 reduced precision reduction for consistency
591
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
592
+ print(" BF16 reduced precision reduction disabled (for stability)")
593
+
594
+ # TF32 requires Ampere or newer (sm_80+)
595
+ if requested == PrecisionMode.TF32:
596
+ if sm_version >= 80:
597
+ torch.backends.cuda.matmul.allow_tf32 = True
598
+ torch.backends.cudnn.allow_tf32 = True
599
+ print(" TF32 mode enabled for matmul and cuDNN")
600
+ return PrecisionMode.TF32
601
+ else:
602
+ print(f" WARNING: TF32 requires sm_80+, but GPU is sm_{sm_version}")
603
+ print(" Falling back to FP16 AMP")
604
+ requested = PrecisionMode.FP16
605
+
606
+ # BF16 requires Ampere or newer (sm_80+) for efficient operation
607
+ if requested == PrecisionMode.BF16:
608
+ if sm_version >= 80 and torch.cuda.is_bf16_supported():
609
+ print(" BF16 mode enabled")
610
+ return PrecisionMode.BF16
611
+ else:
612
+ print(f" WARNING: BF16 not efficiently supported on sm_{sm_version}")
613
+ print(" Falling back to FP16 AMP")
614
+ requested = PrecisionMode.FP16
615
+
616
+ # FP16 works on Volta (sm_70) and newer
617
+ if requested == PrecisionMode.FP16:
618
+ if sm_version >= 70:
619
+ print(" FP16 AMP mode enabled")
620
+ return PrecisionMode.FP16
621
+ else:
622
+ print(f" WARNING: FP16 tensor cores require sm_70+, but GPU is sm_{sm_version}")
623
+ print(" Falling back to FP32")
624
+ return PrecisionMode.FP32
625
+
626
+ # FP32 always works
627
+ print(" FP32 mode (no mixed precision)")
628
+ return PrecisionMode.FP32
629
+
630
+
631
+ def print_system_info(requested_precision: PrecisionMode) -> tuple[torch.device, PrecisionMode]:
632
+ """Print system and CUDA information, return device and actual precision mode."""
633
+ print("\n" + "=" * 60)
634
+ print("System Information")
635
+ print("=" * 60)
636
+ print(f"PyTorch version: {torch.__version__}")
637
+ print(f"Python version: {sys.version.split()[0]}")
638
+
639
+ if torch.cuda.is_available():
640
+ device = torch.device("cuda")
641
+ print("CUDA available: Yes")
642
+ print(f"CUDA version: {torch.version.cuda}")
643
+
644
+ cudnn_version = torch.backends.cudnn.version()
645
+ if cudnn_version:
646
+ print(f"cuDNN version: {cudnn_version}")
647
+
648
+ print(f"Device count: {torch.cuda.device_count()}")
649
+
650
+ for i in range(torch.cuda.device_count()):
651
+ props = torch.cuda.get_device_properties(i)
652
+ print(f"\nGPU {i}: {props.name}")
653
+ print(f" Compute capability: {props.major}.{props.minor}")
654
+ print(f" Total memory: {props.total_memory / (1024**3):.1f}GB")
655
+ print(f" SM count: {props.multi_processor_count}")
656
+
657
+ print("\nPrecision Configuration:")
658
+ actual_precision = configure_precision(device, requested_precision)
659
+
660
+ # Set deterministic cuBLAS workspace config for stability
661
+ # This can help avoid sporadic GEMM failures
662
+ if "CUBLAS_WORKSPACE_CONFIG" not in os.environ:
663
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
664
+ print(" Set CUBLAS_WORKSPACE_CONFIG=:4096:8 for stability")
665
+
666
+ else:
667
+ device = torch.device("cpu")
668
+ actual_precision = PrecisionMode.FP32
669
+ print("CUDA available: No (running on CPU)")
670
+ print("WARNING: GPU benchmark results will not be representative!")
671
+
672
+ print("=" * 60)
673
+ return device, actual_precision
674
+
675
+
676
+ # -----------------------------------------------------------------------------
677
+ # Main Entry Point
678
+ # -----------------------------------------------------------------------------
679
+
680
+
681
+ def main() -> None:
682
+ parser = argparse.ArgumentParser(
683
+ description="GPU Throughput Benchmark",
684
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
685
+ )
686
+ parser.add_argument(
687
+ "--mode",
688
+ choices=["cnn", "transformer", "both"],
689
+ default="both",
690
+ help="Benchmark mode: cnn (MNIST), transformer (synthetic LM), or both",
691
+ )
692
+ parser.add_argument(
693
+ "--batch-size",
694
+ type=int,
695
+ default=256,
696
+ help="Batch size for CNN training",
697
+ )
698
+ parser.add_argument(
699
+ "--transformer-batch-size",
700
+ type=int,
701
+ default=32,
702
+ help="Batch size for transformer training (smaller due to large vocab logits)",
703
+ )
704
+ parser.add_argument(
705
+ "--warmup-batches",
706
+ type=int,
707
+ default=10,
708
+ help="Number of warmup batches (not timed)",
709
+ )
710
+ parser.add_argument(
711
+ "--benchmark-batches",
712
+ type=int,
713
+ default=100,
714
+ help="Number of batches to benchmark",
715
+ )
716
+ parser.add_argument(
717
+ "--precision",
718
+ choices=["fp32", "fp16", "bf16", "tf32"],
719
+ default="fp16",
720
+ help="Precision mode: fp32 (full), fp16 (AMP), bf16 (AMP), tf32 (Ampere+)",
721
+ )
722
+ parser.add_argument(
723
+ "--diagnose",
724
+ action="store_true",
725
+ help="Run CUDA/cuBLAS diagnostic tests before benchmarking",
726
+ )
727
+ args = parser.parse_args()
728
+
729
+ requested_precision = PrecisionMode(args.precision)
730
+ device, actual_precision = print_system_info(requested_precision)
731
+
732
+ # Run diagnostics if requested
733
+ if args.diagnose:
734
+ diag_results = run_cuda_diagnostics(device)
735
+ # If FP16 tests fail, suggest using FP32
736
+ fp16_tests = ["fp16_matmul", "fp16_autocast", "fp16_linear", "fp16_batched_matmul"]
737
+ fp16_failures = [t for t in fp16_tests if t in diag_results and not diag_results[t]]
738
+ if fp16_failures and actual_precision == PrecisionMode.FP16:
739
+ print("WARNING: FP16 diagnostic tests failed. Switching to FP32.")
740
+ actual_precision = PrecisionMode.FP32
741
+
742
+ config = BenchmarkConfig(
743
+ batch_size=args.batch_size,
744
+ num_warmup_batches=args.warmup_batches,
745
+ num_benchmark_batches=args.benchmark_batches,
746
+ precision=actual_precision,
747
+ )
748
+
749
+ results: list[BenchmarkResult] = []
750
+
751
+ # CNN Benchmark
752
+ if args.mode in ("cnn", "both"):
753
+ model = MNISTConvNet()
754
+ param_count = sum(p.numel() for p in model.parameters())
755
+ print(f"\nMNIST CNN parameters: {param_count:,}")
756
+
757
+ loader = get_mnist_loader(config, device)
758
+
759
+ try:
760
+ result = run_benchmark(model, loader, config, device, "MNIST CNN", actual_precision, is_lm=False)
761
+ results.append(result)
762
+ print(result)
763
+ except RuntimeError as e:
764
+ if "CUBLAS" in str(e) or "cuBLAS" in str(e):
765
+ print(f"\n*** cuBLAS error encountered with {actual_precision.value} ***")
766
+ print(f"Error: {e}")
767
+ print("\nRetrying with FP32 (no AMP)...")
768
+
769
+ # Cleanup and retry with FP32
770
+ del model
771
+ if device.type == "cuda":
772
+ torch.cuda.empty_cache()
773
+
774
+ model = MNISTConvNet()
775
+ result = run_benchmark(model, loader, config, device, "MNIST CNN", PrecisionMode.FP32, is_lm=False)
776
+ results.append(result)
777
+ print(result)
778
+ else:
779
+ raise
780
+
781
+ # Cleanup
782
+ del model, loader
783
+ if device.type == "cuda":
784
+ torch.cuda.empty_cache()
785
+
786
+ # Transformer Benchmark
787
+ if args.mode in ("transformer", "both"):
788
+ transformer_config = BenchmarkConfig(
789
+ batch_size=args.transformer_batch_size,
790
+ num_warmup_batches=args.warmup_batches,
791
+ num_benchmark_batches=args.benchmark_batches,
792
+ precision=actual_precision,
793
+ )
794
+
795
+ model = MiniGPT()
796
+ param_count = sum(p.numel() for p in model.parameters())
797
+ print(f"\nMiniGPT parameters: {param_count:,}")
798
+
799
+ loader = get_synthetic_text_loader(transformer_config)
800
+
801
+ try:
802
+ result = run_benchmark(
803
+ model, loader, transformer_config, device, "MiniGPT Transformer", actual_precision, is_lm=True
804
+ )
805
+ results.append(result)
806
+ print(result)
807
+ except RuntimeError as e:
808
+ if "CUBLAS" in str(e) or "cuBLAS" in str(e):
809
+ print(f"\n*** cuBLAS error encountered with {actual_precision.value} ***")
810
+ print(f"Error: {e}")
811
+ print("\nRetrying with FP32 (no AMP)...")
812
+
813
+ # Cleanup and retry with FP32
814
+ del model
815
+ if device.type == "cuda":
816
+ torch.cuda.empty_cache()
817
+
818
+ model = MiniGPT()
819
+ result = run_benchmark(
820
+ model, loader, transformer_config, device, "MiniGPT Transformer", PrecisionMode.FP32, is_lm=True
821
+ )
822
+ results.append(result)
823
+ print(result)
824
+ else:
825
+ raise
826
+
827
+ # Summary
828
+ if len(results) > 1:
829
+ print("\n" + "=" * 60)
830
+ print("BENCHMARK SUMMARY")
831
+ print("=" * 60)
832
+ for r in results:
833
+ print(f"{r.model_name} ({r.precision_mode}):")
834
+ print(f" {r.samples_per_sec:,.1f} samples/sec | {r.peak_memory_mb:.0f}MB peak")
835
+ print("=" * 60)
836
+
837
+
838
+ if __name__ == "__main__":
839
+ main()