cvic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cvic/__init__.py ADDED
File without changes
cvic/common_cvic.py ADDED
@@ -0,0 +1,466 @@
1
+ """common_cvic.py — Shared utilities for tunic.py and cvic.py."""
2
+
3
+ import json
4
+ import logging
5
+ import random
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import timm
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import DataLoader, Subset
14
+ from torchvision import datasets, transforms
15
+ from torchvision.transforms import RandAugment
16
+
17
+ try:
18
+ from tqdm import tqdm
19
+ except ImportError:
20
+ def tqdm(it, **kwargs):
21
+ return it
22
+
23
+ try:
24
+ import yaml
25
+ except ImportError:
26
+ yaml = None
27
+
28
+ logger = logging.getLogger("cvic")
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Seeds / device
33
+ # ---------------------------------------------------------------------------
34
+
35
+ def set_seed(seed: int):
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ if torch.cuda.is_available():
40
+ torch.cuda.manual_seed_all(seed)
41
+
42
+
43
+ def get_amp_dtype() -> torch.dtype:
44
+ return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
45
+
46
+
47
+ def get_device(device_str: str) -> torch.device:
48
+ if device_str == "auto":
49
+ if torch.cuda.is_available():
50
+ return torch.device("cuda")
51
+ if torch.backends.mps.is_available():
52
+ return torch.device("mps")
53
+ return torch.device("cpu")
54
+ return torch.device(device_str)
55
+
56
+
57
+ def format_duration(seconds: float) -> str:
58
+ """Human-friendly duration: minutes once past 5 min, else seconds."""
59
+ return f"{seconds / 60:.1f}m" if seconds > 300 else f"{seconds:.1f}s"
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Dataset helpers
64
+ # ---------------------------------------------------------------------------
65
+
66
+ def validate_dataset_path(data_path: Path):
67
+ if not data_path.exists():
68
+ logger.error(f"Dataset path does not exist: {data_path}")
69
+ sys.exit(1)
70
+ if not (data_path / "train").exists() and not (data_path / "wds" / "train").exists():
71
+ logger.error(f"Expected a 'train/' or 'wds/train/' subdirectory in {data_path}")
72
+ sys.exit(1)
73
+
74
+
75
+ def make_stratified_split(dataset, val_fraction: float = 0.2, seed: int = 42):
76
+ from collections import defaultdict
77
+ class_to_indices = defaultdict(list)
78
+ for idx, (_, label) in enumerate(dataset.samples):
79
+ class_to_indices[label].append(idx)
80
+
81
+ train_indices, val_indices = [], []
82
+ rng = random.Random(seed)
83
+ for label, indices in class_to_indices.items():
84
+ indices = list(indices)
85
+ rng.shuffle(indices)
86
+ split = max(1, int(len(indices) * val_fraction))
87
+ val_indices.extend(indices[:split])
88
+ train_indices.extend(indices[split:])
89
+
90
+ return Subset(dataset, train_indices), Subset(dataset, val_indices)
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Transforms
95
+ # ---------------------------------------------------------------------------
96
+
97
+ def build_transforms(img_size: int, randaug_magnitude: int = 0, randaug_num_ops: int = 2, is_train: bool = True):
98
+ mean = [0.485, 0.456, 0.406]
99
+ std = [0.229, 0.224, 0.225]
100
+ if is_train:
101
+ base = [
102
+ transforms.RandomResizedCrop(img_size),
103
+ transforms.RandomHorizontalFlip(),
104
+ ]
105
+ if randaug_magnitude > 0:
106
+ base.append(RandAugment(num_ops=randaug_num_ops, magnitude=randaug_magnitude))
107
+ base += [transforms.ToTensor(), transforms.Normalize(mean, std)]
108
+ return transforms.Compose(base)
109
+ else:
110
+ return transforms.Compose([
111
+ transforms.Resize(int(img_size * 256 / 224)),
112
+ transforms.CenterCrop(img_size),
113
+ transforms.ToTensor(),
114
+ transforms.Normalize(mean, std),
115
+ ])
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # In-memory cached dataset (decode once, augment fresh every epoch)
120
+ # ---------------------------------------------------------------------------
121
+
122
+ class CachedImageDataset(torch.utils.data.Dataset):
123
+ """In-memory dataset of pre-decoded PIL images with the transform applied lazily.
124
+
125
+ Images are decoded ONCE (the expensive JPEG/PNG decode + I/O) and held in RAM.
126
+ The transform runs inside ``__getitem__``, so any randomness it contains
127
+ (RandomResizedCrop, RandomHorizontalFlip, RandAugment) is re-sampled on *every*
128
+ access — i.e. each image gets a fresh augmentation every epoch, which is the
129
+ standard training regime. Pass a deterministic (``is_train=False``) transform for
130
+ validation/test.
131
+
132
+ Two views over the same images (e.g. an augmented train view and a clean val
133
+ view) can share the same ``images`` list at no extra memory cost.
134
+
135
+ NOTE: this is deliberately *not* a precompute-and-store-tensors cache. Caching the
136
+ post-transform tensor would freeze the augmentation across epochs (every epoch
137
+ sees the identical augmented image), which weakens regularization and makes a
138
+ method that does it incomparable to one that augments fresh per epoch.
139
+ """
140
+
141
+ def __init__(self, images, labels, transform):
142
+ self.images = images
143
+ self.labels = labels
144
+ self.transform = transform
145
+
146
+ def __len__(self):
147
+ return len(self.images)
148
+
149
+ def __getitem__(self, idx):
150
+ return self.transform(self.images[idx]), self.labels[idx]
151
+
152
+
153
+ def load_wds_images(data_root: str, split: str = "train"):
154
+ """Decode an entire WebDataset split into memory ONCE as decoded PIL images.
155
+
156
+ Returns ``(images, labels, num_classes)`` where ``images`` is a list of RGB
157
+ ``PIL.Image`` and ``labels`` is a ``list[int]``. No transform is applied — wrap
158
+ the result in :class:`CachedImageDataset` so augmentation is re-sampled fresh
159
+ every epoch. Supports local paths and ``s3://`` (streamed via ``aws s3 cp``).
160
+ """
161
+ try:
162
+ import webdataset as wds
163
+ except ImportError:
164
+ logger.error("webdataset not installed. Run: pip install webdataset")
165
+ sys.exit(1)
166
+
167
+ is_s3 = data_root.startswith("s3://")
168
+ if is_s3:
169
+ import subprocess
170
+ info_url = data_root.rstrip("/") + "/wds/dataset_info.json"
171
+ r = subprocess.run(["aws", "s3", "cp", info_url, "-"],
172
+ capture_output=True, text=True, check=True)
173
+ meta = json.loads(r.stdout)
174
+ else:
175
+ with open(Path(data_root) / "wds" / "dataset_info.json") as f:
176
+ meta = json.load(f)
177
+
178
+ classes = meta["classes"]
179
+ class_to_idx = {c: i for i, c in enumerate(classes)}
180
+ num_classes = len(classes)
181
+
182
+ available = list(meta["splits"].keys())
183
+ if split not in available:
184
+ logger.warning(f"Split '{split}' not found. Available: {available}. Using '{available[0]}'")
185
+ split = available[0]
186
+
187
+ def decode_cls(b):
188
+ s = b.decode().strip()
189
+ try:
190
+ idx = int(s)
191
+ if 0 <= idx < num_classes:
192
+ return idx
193
+ except ValueError:
194
+ pass
195
+ return class_to_idx[s]
196
+
197
+ # webdataset >= 1.0 folds basichandlers into imagehandler, which maps ".cls" ->
198
+ # int(data); intercept .cls to keep raw bytes so decode_cls handles string labels.
199
+ _img_decoder = wds.autodecode.imagehandler("pil")
200
+
201
+ def _decoder(key, data):
202
+ if key.endswith(".cls"):
203
+ return data
204
+ return _img_decoder(key, data)
205
+
206
+ n_shards = meta["splits"][split]["num_shards"]
207
+ if is_s3:
208
+ base = data_root.rstrip("/")
209
+ urls = [f"pipe:aws s3 cp {base}/wds/{split}/shard-{i:06d}.tar -" for i in range(n_shards)]
210
+ else:
211
+ d = Path(data_root) / "wds" / split
212
+ urls = [str(d / f"shard-{i:06d}.tar") for i in range(n_shards)]
213
+
214
+ dataset = (
215
+ wds.WebDataset(urls, shardshuffle=False, nodesplitter=wds.split_by_node, empty_check=False)
216
+ .decode(_decoder)
217
+ .to_tuple("png", "cls")
218
+ .map_tuple(lambda img: img.convert("RGB"), decode_cls)
219
+ )
220
+
221
+ images, labels = [], []
222
+ for img, lbl in dataset:
223
+ images.append(img)
224
+ labels.append(lbl)
225
+ return images, labels, num_classes
226
+
227
+
228
+ # ---------------------------------------------------------------------------
229
+ # Mixup / CutMix
230
+ # ---------------------------------------------------------------------------
231
+
232
+ class MixupCutmixCollator:
233
+ def __init__(self, mixup_alpha: float, cutmix_alpha: float, num_classes: int):
234
+ self.mixup_alpha = mixup_alpha
235
+ self.cutmix_alpha = cutmix_alpha
236
+ self.num_classes = num_classes
237
+
238
+ def __call__(self, batch):
239
+ images, labels = zip(*batch)
240
+ images = torch.stack(images)
241
+ labels = torch.tensor(labels, dtype=torch.long)
242
+
243
+ if self.mixup_alpha > 0 and self.cutmix_alpha > 0:
244
+ use_cutmix = random.random() > 0.5
245
+ elif self.cutmix_alpha > 0:
246
+ use_cutmix = True
247
+ elif self.mixup_alpha > 0:
248
+ use_cutmix = False
249
+ else:
250
+ return images, labels
251
+
252
+ if use_cutmix:
253
+ lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
254
+ images, labels_a, labels_b = self._cutmix(images, labels, lam)
255
+ else:
256
+ lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
257
+ idx = torch.randperm(images.size(0))
258
+ images = lam * images + (1 - lam) * images[idx]
259
+ labels_a = labels
260
+ labels_b = labels[idx]
261
+
262
+ labels_a_oh = nn.functional.one_hot(labels_a, self.num_classes).float()
263
+ labels_b_oh = nn.functional.one_hot(labels_b, self.num_classes).float()
264
+ mixed_labels = lam * labels_a_oh + (1 - lam) * labels_b_oh
265
+ return images, mixed_labels
266
+
267
+ def _cutmix(self, images, labels, lam):
268
+ _, _, H, W = images.shape
269
+ cut_rat = np.sqrt(1.0 - lam)
270
+ cut_w = int(W * cut_rat)
271
+ cut_h = int(H * cut_rat)
272
+ cx = np.random.randint(W)
273
+ cy = np.random.randint(H)
274
+ x1 = np.clip(cx - cut_w // 2, 0, W)
275
+ x2 = np.clip(cx + cut_w // 2, 0, W)
276
+ y1 = np.clip(cy - cut_h // 2, 0, H)
277
+ y2 = np.clip(cy + cut_h // 2, 0, H)
278
+ idx = torch.randperm(images.size(0))
279
+ images = images.clone()
280
+ images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2]
281
+ lam = 1 - (x2 - x1) * (y2 - y1) / (W * H)
282
+ return images, labels, labels[idx]
283
+
284
+
285
+ # ---------------------------------------------------------------------------
286
+ # Model
287
+ # ---------------------------------------------------------------------------
288
+
289
+ def create_model(model_name: str, num_classes: int, pretrained: bool, drop_rate: float) -> nn.Module:
290
+ try:
291
+ model = timm.create_model(model_name, pretrained=pretrained,
292
+ num_classes=num_classes, drop_rate=drop_rate)
293
+ except Exception as e:
294
+ logger.error(f"Failed to create model '{model_name}': {e}")
295
+ logger.error("Common alternatives: resnet50, efficientnet_b0, convnext_tiny, vit_small_patch16_224, mobilenetv3_large_100")
296
+ sys.exit(1)
297
+ return model
298
+
299
+
300
+ def freeze_backbone(model: nn.Module):
301
+ head_keywords = {"head", "fc", "classifier"}
302
+ for name, param in model.named_parameters():
303
+ top = name.split(".")[0]
304
+ if top not in head_keywords and not any(kw in name for kw in head_keywords):
305
+ param.requires_grad = False
306
+
307
+
308
+ def unfreeze_all(model: nn.Module):
309
+ for param in model.parameters():
310
+ param.requires_grad = True
311
+
312
+
313
+ # ---------------------------------------------------------------------------
314
+ # Optimizer / scheduler
315
+ # ---------------------------------------------------------------------------
316
+
317
+ def get_optimizer(model: nn.Module, optimizer_name: str, lr: float, weight_decay: float):
318
+ params = filter(lambda p: p.requires_grad, model.parameters())
319
+ if optimizer_name == "AdamW":
320
+ return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
321
+ else:
322
+ return torch.optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
323
+
324
+
325
+ def build_scheduler(optimizer, epochs: int, steps_per_epoch: int, warmup_epochs: int = 5,
326
+ start_step: int = 0):
327
+ warmup_steps = warmup_epochs * steps_per_epoch
328
+ total_steps = epochs * steps_per_epoch
329
+
330
+ def lr_lambda(step):
331
+ if step < warmup_steps:
332
+ return float(step) / max(1, warmup_steps)
333
+ progress = float(step - warmup_steps) / max(1, total_steps - warmup_steps)
334
+ return 0.5 * (1.0 + np.cos(np.pi * progress))
335
+
336
+ if start_step > 0:
337
+ for group in optimizer.param_groups:
338
+ group.setdefault("initial_lr", group["lr"])
339
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=start_step - 1)
340
+
341
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
342
+
343
+
344
+ # ---------------------------------------------------------------------------
345
+ # Training / evaluation primitives
346
+ # ---------------------------------------------------------------------------
347
+
348
+ def train_one_epoch(model, loader, optimizer, scheduler, criterion, device,
349
+ use_soft_labels, trial_id="", epoch=0, epochs=0,
350
+ use_amp=False, show_progress=True):
351
+ import sys
352
+ import time
353
+ model.train()
354
+ use_amp = use_amp and device.type == "cuda"
355
+ amp_dtype = get_amp_dtype()
356
+ scaler = torch.amp.GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16)
357
+
358
+ # Accumulate metrics on-device so there is no GPU->CPU sync per batch; we sync
359
+ # only for the (throttled) progress display and once at epoch end. A per-batch
360
+ # .item() forces the CPU to block on each step, which can throttle a fast GPU.
361
+ total_loss = torch.zeros((), dtype=torch.float64, device=device) # float64 to match prior Python-float accumulation
362
+ correct = torch.zeros((), dtype=torch.long, device=device)
363
+ total = 0
364
+
365
+ epoch_str = f" epoch {epoch+1}/{epochs}" if epochs else ""
366
+ desc = f"trial {trial_id}{epoch_str}" if trial_id else f"train{epoch_str}"
367
+ # Refresh fast on a terminal, but coarsely when redirected to a file/pipe —
368
+ # otherwise log size scales with epoch wall-time (slower GPUs → huge logs).
369
+ _mininterval = 0.1 if sys.stderr.isatty() else 10.0
370
+ bar = tqdm(loader, leave=False, desc=desc, disable=not show_progress,
371
+ mininterval=_mininterval,
372
+ bar_format="{l_bar}{bar}| batch {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
373
+ _can_postfix = hasattr(bar, "set_postfix")
374
+ _last_postfix = 0.0
375
+
376
+ for images, labels in bar:
377
+ images = images.to(device)
378
+ optimizer.zero_grad()
379
+ with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
380
+ outputs = model(images)
381
+ if use_soft_labels and labels.dim() == 2:
382
+ labels = labels.to(device)
383
+ loss = -(labels * nn.functional.log_softmax(outputs, dim=-1)).sum(dim=-1).mean()
384
+ else:
385
+ labels = labels.to(device)
386
+ loss = criterion(outputs, labels)
387
+
388
+ preds = outputs.argmax(dim=1)
389
+ target = labels.argmax(dim=1) if (use_soft_labels and labels.dim() == 2) else labels
390
+
391
+ scaler.scale(loss).backward()
392
+ scaler.unscale_(optimizer)
393
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
394
+ scaler.step(optimizer)
395
+ scaler.update()
396
+ scheduler.step()
397
+
398
+ # On-device accumulation — no .item(), so no per-batch sync.
399
+ bs = images.size(0)
400
+ total_loss += loss.detach() * bs
401
+ correct += (preds == target).sum()
402
+ total += bs
403
+
404
+ # Throttled live display: sync at most once per refresh interval.
405
+ if _can_postfix:
406
+ now = time.monotonic()
407
+ if now - _last_postfix >= _mininterval:
408
+ _last_postfix = now
409
+ bar.set_postfix(loss=f"{(total_loss / total).item():.4f}",
410
+ acc=f"{(correct / total).item():.4f}")
411
+
412
+ # Single sync for the returned epoch metrics.
413
+ return (total_loss / total).item(), (correct / total).item()
414
+
415
+
416
+ def check_class_distribution(labels: np.ndarray, n_classes: int, class_names: list[str] | None = None) -> list[int]:
417
+ """Print per-class sample counts to stderr. Return list of unscorable class indices.
418
+
419
+ Does NOT call sys.exit() — the caller decides whether to abort.
420
+ """
421
+ n = len(labels)
422
+ bad = []
423
+ lines = [f"\nValidation set: {n} samples across {n_classes} classes:"]
424
+ for c in range(n_classes):
425
+ count = int((labels == c).sum())
426
+ if count == 0:
427
+ note = " <- no positives — AUROC undefined"
428
+ bad.append(c)
429
+ elif count == n:
430
+ note = " <- no negatives — AUROC undefined"
431
+ bad.append(c)
432
+ else:
433
+ note = ""
434
+ label = f"{class_names[c]}" if class_names else f"class {c:3d}"
435
+ lines.append(f" {label:30s}: {count:5d} samples{note}" if class_names else f" {label}: {count:5d} samples{note}")
436
+ if bad:
437
+ lines.append(f"\n{len(bad)} class(es) cannot be scored.")
438
+ lines.append("Increase --val-fraction so every class has both positive and negative examples.")
439
+ print("\n".join(lines), file=sys.stderr)
440
+ return bad
441
+
442
+
443
+ def _compute_auroc(probs: np.ndarray, labels: np.ndarray) -> float:
444
+ from sklearn.metrics import roc_auc_score
445
+ present = np.unique(labels)
446
+ try:
447
+ if probs.shape[1] == 2:
448
+ return roc_auc_score(labels, probs[:, 1])
449
+ if len(present) < 2:
450
+ return float("nan")
451
+ return roc_auc_score(labels, probs[:, present], multi_class="ovr",
452
+ average="macro", labels=present)
453
+ except ValueError:
454
+ return float("nan")
455
+
456
+
457
+ # ---------------------------------------------------------------------------
458
+ # Search space overrides
459
+ # ---------------------------------------------------------------------------
460
+
461
+ def load_search_space_overrides(path: str) -> dict:
462
+ if yaml is None:
463
+ logger.error("PyYAML is required for --search-space. Install with: pip install pyyaml")
464
+ sys.exit(1)
465
+ with open(path) as f:
466
+ return yaml.safe_load(f) or {}