cvic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cvic/__init__.py +0 -0
- cvic/common_cvic.py +466 -0
- cvic/cvic.py +848 -0
- cvic/tunic.py +1929 -0
- cvic-0.1.0.dist-info/METADATA +155 -0
- cvic-0.1.0.dist-info/RECORD +8 -0
- cvic-0.1.0.dist-info/WHEEL +4 -0
- cvic-0.1.0.dist-info/entry_points.txt +3 -0
cvic/__init__.py
ADDED
|
File without changes
|
cvic/common_cvic.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
"""common_cvic.py — Shared utilities for tunic.py and cvic.py."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import random
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import timm
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from torch.utils.data import DataLoader, Subset
|
|
14
|
+
from torchvision import datasets, transforms
|
|
15
|
+
from torchvision.transforms import RandAugment
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
except ImportError:
|
|
20
|
+
def tqdm(it, **kwargs):
|
|
21
|
+
return it
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import yaml
|
|
25
|
+
except ImportError:
|
|
26
|
+
yaml = None
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger("cvic")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ---------------------------------------------------------------------------
|
|
32
|
+
# Seeds / device
|
|
33
|
+
# ---------------------------------------------------------------------------
|
|
34
|
+
|
|
35
|
+
def set_seed(seed: int):
|
|
36
|
+
random.seed(seed)
|
|
37
|
+
np.random.seed(seed)
|
|
38
|
+
torch.manual_seed(seed)
|
|
39
|
+
if torch.cuda.is_available():
|
|
40
|
+
torch.cuda.manual_seed_all(seed)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_amp_dtype() -> torch.dtype:
|
|
44
|
+
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_device(device_str: str) -> torch.device:
|
|
48
|
+
if device_str == "auto":
|
|
49
|
+
if torch.cuda.is_available():
|
|
50
|
+
return torch.device("cuda")
|
|
51
|
+
if torch.backends.mps.is_available():
|
|
52
|
+
return torch.device("mps")
|
|
53
|
+
return torch.device("cpu")
|
|
54
|
+
return torch.device(device_str)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def format_duration(seconds: float) -> str:
|
|
58
|
+
"""Human-friendly duration: minutes once past 5 min, else seconds."""
|
|
59
|
+
return f"{seconds / 60:.1f}m" if seconds > 300 else f"{seconds:.1f}s"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ---------------------------------------------------------------------------
|
|
63
|
+
# Dataset helpers
|
|
64
|
+
# ---------------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
def validate_dataset_path(data_path: Path):
|
|
67
|
+
if not data_path.exists():
|
|
68
|
+
logger.error(f"Dataset path does not exist: {data_path}")
|
|
69
|
+
sys.exit(1)
|
|
70
|
+
if not (data_path / "train").exists() and not (data_path / "wds" / "train").exists():
|
|
71
|
+
logger.error(f"Expected a 'train/' or 'wds/train/' subdirectory in {data_path}")
|
|
72
|
+
sys.exit(1)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def make_stratified_split(dataset, val_fraction: float = 0.2, seed: int = 42):
|
|
76
|
+
from collections import defaultdict
|
|
77
|
+
class_to_indices = defaultdict(list)
|
|
78
|
+
for idx, (_, label) in enumerate(dataset.samples):
|
|
79
|
+
class_to_indices[label].append(idx)
|
|
80
|
+
|
|
81
|
+
train_indices, val_indices = [], []
|
|
82
|
+
rng = random.Random(seed)
|
|
83
|
+
for label, indices in class_to_indices.items():
|
|
84
|
+
indices = list(indices)
|
|
85
|
+
rng.shuffle(indices)
|
|
86
|
+
split = max(1, int(len(indices) * val_fraction))
|
|
87
|
+
val_indices.extend(indices[:split])
|
|
88
|
+
train_indices.extend(indices[split:])
|
|
89
|
+
|
|
90
|
+
return Subset(dataset, train_indices), Subset(dataset, val_indices)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# ---------------------------------------------------------------------------
|
|
94
|
+
# Transforms
|
|
95
|
+
# ---------------------------------------------------------------------------
|
|
96
|
+
|
|
97
|
+
def build_transforms(img_size: int, randaug_magnitude: int = 0, randaug_num_ops: int = 2, is_train: bool = True):
|
|
98
|
+
mean = [0.485, 0.456, 0.406]
|
|
99
|
+
std = [0.229, 0.224, 0.225]
|
|
100
|
+
if is_train:
|
|
101
|
+
base = [
|
|
102
|
+
transforms.RandomResizedCrop(img_size),
|
|
103
|
+
transforms.RandomHorizontalFlip(),
|
|
104
|
+
]
|
|
105
|
+
if randaug_magnitude > 0:
|
|
106
|
+
base.append(RandAugment(num_ops=randaug_num_ops, magnitude=randaug_magnitude))
|
|
107
|
+
base += [transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
108
|
+
return transforms.Compose(base)
|
|
109
|
+
else:
|
|
110
|
+
return transforms.Compose([
|
|
111
|
+
transforms.Resize(int(img_size * 256 / 224)),
|
|
112
|
+
transforms.CenterCrop(img_size),
|
|
113
|
+
transforms.ToTensor(),
|
|
114
|
+
transforms.Normalize(mean, std),
|
|
115
|
+
])
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# ---------------------------------------------------------------------------
|
|
119
|
+
# In-memory cached dataset (decode once, augment fresh every epoch)
|
|
120
|
+
# ---------------------------------------------------------------------------
|
|
121
|
+
|
|
122
|
+
class CachedImageDataset(torch.utils.data.Dataset):
|
|
123
|
+
"""In-memory dataset of pre-decoded PIL images with the transform applied lazily.
|
|
124
|
+
|
|
125
|
+
Images are decoded ONCE (the expensive JPEG/PNG decode + I/O) and held in RAM.
|
|
126
|
+
The transform runs inside ``__getitem__``, so any randomness it contains
|
|
127
|
+
(RandomResizedCrop, RandomHorizontalFlip, RandAugment) is re-sampled on *every*
|
|
128
|
+
access — i.e. each image gets a fresh augmentation every epoch, which is the
|
|
129
|
+
standard training regime. Pass a deterministic (``is_train=False``) transform for
|
|
130
|
+
validation/test.
|
|
131
|
+
|
|
132
|
+
Two views over the same images (e.g. an augmented train view and a clean val
|
|
133
|
+
view) can share the same ``images`` list at no extra memory cost.
|
|
134
|
+
|
|
135
|
+
NOTE: this is deliberately *not* a precompute-and-store-tensors cache. Caching the
|
|
136
|
+
post-transform tensor would freeze the augmentation across epochs (every epoch
|
|
137
|
+
sees the identical augmented image), which weakens regularization and makes a
|
|
138
|
+
method that does it incomparable to one that augments fresh per epoch.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self, images, labels, transform):
|
|
142
|
+
self.images = images
|
|
143
|
+
self.labels = labels
|
|
144
|
+
self.transform = transform
|
|
145
|
+
|
|
146
|
+
def __len__(self):
|
|
147
|
+
return len(self.images)
|
|
148
|
+
|
|
149
|
+
def __getitem__(self, idx):
|
|
150
|
+
return self.transform(self.images[idx]), self.labels[idx]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def load_wds_images(data_root: str, split: str = "train"):
|
|
154
|
+
"""Decode an entire WebDataset split into memory ONCE as decoded PIL images.
|
|
155
|
+
|
|
156
|
+
Returns ``(images, labels, num_classes)`` where ``images`` is a list of RGB
|
|
157
|
+
``PIL.Image`` and ``labels`` is a ``list[int]``. No transform is applied — wrap
|
|
158
|
+
the result in :class:`CachedImageDataset` so augmentation is re-sampled fresh
|
|
159
|
+
every epoch. Supports local paths and ``s3://`` (streamed via ``aws s3 cp``).
|
|
160
|
+
"""
|
|
161
|
+
try:
|
|
162
|
+
import webdataset as wds
|
|
163
|
+
except ImportError:
|
|
164
|
+
logger.error("webdataset not installed. Run: pip install webdataset")
|
|
165
|
+
sys.exit(1)
|
|
166
|
+
|
|
167
|
+
is_s3 = data_root.startswith("s3://")
|
|
168
|
+
if is_s3:
|
|
169
|
+
import subprocess
|
|
170
|
+
info_url = data_root.rstrip("/") + "/wds/dataset_info.json"
|
|
171
|
+
r = subprocess.run(["aws", "s3", "cp", info_url, "-"],
|
|
172
|
+
capture_output=True, text=True, check=True)
|
|
173
|
+
meta = json.loads(r.stdout)
|
|
174
|
+
else:
|
|
175
|
+
with open(Path(data_root) / "wds" / "dataset_info.json") as f:
|
|
176
|
+
meta = json.load(f)
|
|
177
|
+
|
|
178
|
+
classes = meta["classes"]
|
|
179
|
+
class_to_idx = {c: i for i, c in enumerate(classes)}
|
|
180
|
+
num_classes = len(classes)
|
|
181
|
+
|
|
182
|
+
available = list(meta["splits"].keys())
|
|
183
|
+
if split not in available:
|
|
184
|
+
logger.warning(f"Split '{split}' not found. Available: {available}. Using '{available[0]}'")
|
|
185
|
+
split = available[0]
|
|
186
|
+
|
|
187
|
+
def decode_cls(b):
|
|
188
|
+
s = b.decode().strip()
|
|
189
|
+
try:
|
|
190
|
+
idx = int(s)
|
|
191
|
+
if 0 <= idx < num_classes:
|
|
192
|
+
return idx
|
|
193
|
+
except ValueError:
|
|
194
|
+
pass
|
|
195
|
+
return class_to_idx[s]
|
|
196
|
+
|
|
197
|
+
# webdataset >= 1.0 folds basichandlers into imagehandler, which maps ".cls" ->
|
|
198
|
+
# int(data); intercept .cls to keep raw bytes so decode_cls handles string labels.
|
|
199
|
+
_img_decoder = wds.autodecode.imagehandler("pil")
|
|
200
|
+
|
|
201
|
+
def _decoder(key, data):
|
|
202
|
+
if key.endswith(".cls"):
|
|
203
|
+
return data
|
|
204
|
+
return _img_decoder(key, data)
|
|
205
|
+
|
|
206
|
+
n_shards = meta["splits"][split]["num_shards"]
|
|
207
|
+
if is_s3:
|
|
208
|
+
base = data_root.rstrip("/")
|
|
209
|
+
urls = [f"pipe:aws s3 cp {base}/wds/{split}/shard-{i:06d}.tar -" for i in range(n_shards)]
|
|
210
|
+
else:
|
|
211
|
+
d = Path(data_root) / "wds" / split
|
|
212
|
+
urls = [str(d / f"shard-{i:06d}.tar") for i in range(n_shards)]
|
|
213
|
+
|
|
214
|
+
dataset = (
|
|
215
|
+
wds.WebDataset(urls, shardshuffle=False, nodesplitter=wds.split_by_node, empty_check=False)
|
|
216
|
+
.decode(_decoder)
|
|
217
|
+
.to_tuple("png", "cls")
|
|
218
|
+
.map_tuple(lambda img: img.convert("RGB"), decode_cls)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
images, labels = [], []
|
|
222
|
+
for img, lbl in dataset:
|
|
223
|
+
images.append(img)
|
|
224
|
+
labels.append(lbl)
|
|
225
|
+
return images, labels, num_classes
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# ---------------------------------------------------------------------------
|
|
229
|
+
# Mixup / CutMix
|
|
230
|
+
# ---------------------------------------------------------------------------
|
|
231
|
+
|
|
232
|
+
class MixupCutmixCollator:
|
|
233
|
+
def __init__(self, mixup_alpha: float, cutmix_alpha: float, num_classes: int):
|
|
234
|
+
self.mixup_alpha = mixup_alpha
|
|
235
|
+
self.cutmix_alpha = cutmix_alpha
|
|
236
|
+
self.num_classes = num_classes
|
|
237
|
+
|
|
238
|
+
def __call__(self, batch):
|
|
239
|
+
images, labels = zip(*batch)
|
|
240
|
+
images = torch.stack(images)
|
|
241
|
+
labels = torch.tensor(labels, dtype=torch.long)
|
|
242
|
+
|
|
243
|
+
if self.mixup_alpha > 0 and self.cutmix_alpha > 0:
|
|
244
|
+
use_cutmix = random.random() > 0.5
|
|
245
|
+
elif self.cutmix_alpha > 0:
|
|
246
|
+
use_cutmix = True
|
|
247
|
+
elif self.mixup_alpha > 0:
|
|
248
|
+
use_cutmix = False
|
|
249
|
+
else:
|
|
250
|
+
return images, labels
|
|
251
|
+
|
|
252
|
+
if use_cutmix:
|
|
253
|
+
lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
|
|
254
|
+
images, labels_a, labels_b = self._cutmix(images, labels, lam)
|
|
255
|
+
else:
|
|
256
|
+
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
|
257
|
+
idx = torch.randperm(images.size(0))
|
|
258
|
+
images = lam * images + (1 - lam) * images[idx]
|
|
259
|
+
labels_a = labels
|
|
260
|
+
labels_b = labels[idx]
|
|
261
|
+
|
|
262
|
+
labels_a_oh = nn.functional.one_hot(labels_a, self.num_classes).float()
|
|
263
|
+
labels_b_oh = nn.functional.one_hot(labels_b, self.num_classes).float()
|
|
264
|
+
mixed_labels = lam * labels_a_oh + (1 - lam) * labels_b_oh
|
|
265
|
+
return images, mixed_labels
|
|
266
|
+
|
|
267
|
+
def _cutmix(self, images, labels, lam):
|
|
268
|
+
_, _, H, W = images.shape
|
|
269
|
+
cut_rat = np.sqrt(1.0 - lam)
|
|
270
|
+
cut_w = int(W * cut_rat)
|
|
271
|
+
cut_h = int(H * cut_rat)
|
|
272
|
+
cx = np.random.randint(W)
|
|
273
|
+
cy = np.random.randint(H)
|
|
274
|
+
x1 = np.clip(cx - cut_w // 2, 0, W)
|
|
275
|
+
x2 = np.clip(cx + cut_w // 2, 0, W)
|
|
276
|
+
y1 = np.clip(cy - cut_h // 2, 0, H)
|
|
277
|
+
y2 = np.clip(cy + cut_h // 2, 0, H)
|
|
278
|
+
idx = torch.randperm(images.size(0))
|
|
279
|
+
images = images.clone()
|
|
280
|
+
images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2]
|
|
281
|
+
lam = 1 - (x2 - x1) * (y2 - y1) / (W * H)
|
|
282
|
+
return images, labels, labels[idx]
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# ---------------------------------------------------------------------------
|
|
286
|
+
# Model
|
|
287
|
+
# ---------------------------------------------------------------------------
|
|
288
|
+
|
|
289
|
+
def create_model(model_name: str, num_classes: int, pretrained: bool, drop_rate: float) -> nn.Module:
|
|
290
|
+
try:
|
|
291
|
+
model = timm.create_model(model_name, pretrained=pretrained,
|
|
292
|
+
num_classes=num_classes, drop_rate=drop_rate)
|
|
293
|
+
except Exception as e:
|
|
294
|
+
logger.error(f"Failed to create model '{model_name}': {e}")
|
|
295
|
+
logger.error("Common alternatives: resnet50, efficientnet_b0, convnext_tiny, vit_small_patch16_224, mobilenetv3_large_100")
|
|
296
|
+
sys.exit(1)
|
|
297
|
+
return model
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def freeze_backbone(model: nn.Module):
|
|
301
|
+
head_keywords = {"head", "fc", "classifier"}
|
|
302
|
+
for name, param in model.named_parameters():
|
|
303
|
+
top = name.split(".")[0]
|
|
304
|
+
if top not in head_keywords and not any(kw in name for kw in head_keywords):
|
|
305
|
+
param.requires_grad = False
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def unfreeze_all(model: nn.Module):
|
|
309
|
+
for param in model.parameters():
|
|
310
|
+
param.requires_grad = True
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
# ---------------------------------------------------------------------------
|
|
314
|
+
# Optimizer / scheduler
|
|
315
|
+
# ---------------------------------------------------------------------------
|
|
316
|
+
|
|
317
|
+
def get_optimizer(model: nn.Module, optimizer_name: str, lr: float, weight_decay: float):
|
|
318
|
+
params = filter(lambda p: p.requires_grad, model.parameters())
|
|
319
|
+
if optimizer_name == "AdamW":
|
|
320
|
+
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
|
|
321
|
+
else:
|
|
322
|
+
return torch.optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def build_scheduler(optimizer, epochs: int, steps_per_epoch: int, warmup_epochs: int = 5,
|
|
326
|
+
start_step: int = 0):
|
|
327
|
+
warmup_steps = warmup_epochs * steps_per_epoch
|
|
328
|
+
total_steps = epochs * steps_per_epoch
|
|
329
|
+
|
|
330
|
+
def lr_lambda(step):
|
|
331
|
+
if step < warmup_steps:
|
|
332
|
+
return float(step) / max(1, warmup_steps)
|
|
333
|
+
progress = float(step - warmup_steps) / max(1, total_steps - warmup_steps)
|
|
334
|
+
return 0.5 * (1.0 + np.cos(np.pi * progress))
|
|
335
|
+
|
|
336
|
+
if start_step > 0:
|
|
337
|
+
for group in optimizer.param_groups:
|
|
338
|
+
group.setdefault("initial_lr", group["lr"])
|
|
339
|
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=start_step - 1)
|
|
340
|
+
|
|
341
|
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# ---------------------------------------------------------------------------
|
|
345
|
+
# Training / evaluation primitives
|
|
346
|
+
# ---------------------------------------------------------------------------
|
|
347
|
+
|
|
348
|
+
def train_one_epoch(model, loader, optimizer, scheduler, criterion, device,
|
|
349
|
+
use_soft_labels, trial_id="", epoch=0, epochs=0,
|
|
350
|
+
use_amp=False, show_progress=True):
|
|
351
|
+
import sys
|
|
352
|
+
import time
|
|
353
|
+
model.train()
|
|
354
|
+
use_amp = use_amp and device.type == "cuda"
|
|
355
|
+
amp_dtype = get_amp_dtype()
|
|
356
|
+
scaler = torch.amp.GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16)
|
|
357
|
+
|
|
358
|
+
# Accumulate metrics on-device so there is no GPU->CPU sync per batch; we sync
|
|
359
|
+
# only for the (throttled) progress display and once at epoch end. A per-batch
|
|
360
|
+
# .item() forces the CPU to block on each step, which can throttle a fast GPU.
|
|
361
|
+
total_loss = torch.zeros((), dtype=torch.float64, device=device) # float64 to match prior Python-float accumulation
|
|
362
|
+
correct = torch.zeros((), dtype=torch.long, device=device)
|
|
363
|
+
total = 0
|
|
364
|
+
|
|
365
|
+
epoch_str = f" epoch {epoch+1}/{epochs}" if epochs else ""
|
|
366
|
+
desc = f"trial {trial_id}{epoch_str}" if trial_id else f"train{epoch_str}"
|
|
367
|
+
# Refresh fast on a terminal, but coarsely when redirected to a file/pipe —
|
|
368
|
+
# otherwise log size scales with epoch wall-time (slower GPUs → huge logs).
|
|
369
|
+
_mininterval = 0.1 if sys.stderr.isatty() else 10.0
|
|
370
|
+
bar = tqdm(loader, leave=False, desc=desc, disable=not show_progress,
|
|
371
|
+
mininterval=_mininterval,
|
|
372
|
+
bar_format="{l_bar}{bar}| batch {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
|
|
373
|
+
_can_postfix = hasattr(bar, "set_postfix")
|
|
374
|
+
_last_postfix = 0.0
|
|
375
|
+
|
|
376
|
+
for images, labels in bar:
|
|
377
|
+
images = images.to(device)
|
|
378
|
+
optimizer.zero_grad()
|
|
379
|
+
with torch.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
|
|
380
|
+
outputs = model(images)
|
|
381
|
+
if use_soft_labels and labels.dim() == 2:
|
|
382
|
+
labels = labels.to(device)
|
|
383
|
+
loss = -(labels * nn.functional.log_softmax(outputs, dim=-1)).sum(dim=-1).mean()
|
|
384
|
+
else:
|
|
385
|
+
labels = labels.to(device)
|
|
386
|
+
loss = criterion(outputs, labels)
|
|
387
|
+
|
|
388
|
+
preds = outputs.argmax(dim=1)
|
|
389
|
+
target = labels.argmax(dim=1) if (use_soft_labels and labels.dim() == 2) else labels
|
|
390
|
+
|
|
391
|
+
scaler.scale(loss).backward()
|
|
392
|
+
scaler.unscale_(optimizer)
|
|
393
|
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
394
|
+
scaler.step(optimizer)
|
|
395
|
+
scaler.update()
|
|
396
|
+
scheduler.step()
|
|
397
|
+
|
|
398
|
+
# On-device accumulation — no .item(), so no per-batch sync.
|
|
399
|
+
bs = images.size(0)
|
|
400
|
+
total_loss += loss.detach() * bs
|
|
401
|
+
correct += (preds == target).sum()
|
|
402
|
+
total += bs
|
|
403
|
+
|
|
404
|
+
# Throttled live display: sync at most once per refresh interval.
|
|
405
|
+
if _can_postfix:
|
|
406
|
+
now = time.monotonic()
|
|
407
|
+
if now - _last_postfix >= _mininterval:
|
|
408
|
+
_last_postfix = now
|
|
409
|
+
bar.set_postfix(loss=f"{(total_loss / total).item():.4f}",
|
|
410
|
+
acc=f"{(correct / total).item():.4f}")
|
|
411
|
+
|
|
412
|
+
# Single sync for the returned epoch metrics.
|
|
413
|
+
return (total_loss / total).item(), (correct / total).item()
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def check_class_distribution(labels: np.ndarray, n_classes: int, class_names: list[str] | None = None) -> list[int]:
|
|
417
|
+
"""Print per-class sample counts to stderr. Return list of unscorable class indices.
|
|
418
|
+
|
|
419
|
+
Does NOT call sys.exit() — the caller decides whether to abort.
|
|
420
|
+
"""
|
|
421
|
+
n = len(labels)
|
|
422
|
+
bad = []
|
|
423
|
+
lines = [f"\nValidation set: {n} samples across {n_classes} classes:"]
|
|
424
|
+
for c in range(n_classes):
|
|
425
|
+
count = int((labels == c).sum())
|
|
426
|
+
if count == 0:
|
|
427
|
+
note = " <- no positives — AUROC undefined"
|
|
428
|
+
bad.append(c)
|
|
429
|
+
elif count == n:
|
|
430
|
+
note = " <- no negatives — AUROC undefined"
|
|
431
|
+
bad.append(c)
|
|
432
|
+
else:
|
|
433
|
+
note = ""
|
|
434
|
+
label = f"{class_names[c]}" if class_names else f"class {c:3d}"
|
|
435
|
+
lines.append(f" {label:30s}: {count:5d} samples{note}" if class_names else f" {label}: {count:5d} samples{note}")
|
|
436
|
+
if bad:
|
|
437
|
+
lines.append(f"\n{len(bad)} class(es) cannot be scored.")
|
|
438
|
+
lines.append("Increase --val-fraction so every class has both positive and negative examples.")
|
|
439
|
+
print("\n".join(lines), file=sys.stderr)
|
|
440
|
+
return bad
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _compute_auroc(probs: np.ndarray, labels: np.ndarray) -> float:
|
|
444
|
+
from sklearn.metrics import roc_auc_score
|
|
445
|
+
present = np.unique(labels)
|
|
446
|
+
try:
|
|
447
|
+
if probs.shape[1] == 2:
|
|
448
|
+
return roc_auc_score(labels, probs[:, 1])
|
|
449
|
+
if len(present) < 2:
|
|
450
|
+
return float("nan")
|
|
451
|
+
return roc_auc_score(labels, probs[:, present], multi_class="ovr",
|
|
452
|
+
average="macro", labels=present)
|
|
453
|
+
except ValueError:
|
|
454
|
+
return float("nan")
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
# ---------------------------------------------------------------------------
|
|
458
|
+
# Search space overrides
|
|
459
|
+
# ---------------------------------------------------------------------------
|
|
460
|
+
|
|
461
|
+
def load_search_space_overrides(path: str) -> dict:
|
|
462
|
+
if yaml is None:
|
|
463
|
+
logger.error("PyYAML is required for --search-space. Install with: pip install pyyaml")
|
|
464
|
+
sys.exit(1)
|
|
465
|
+
with open(path) as f:
|
|
466
|
+
return yaml.safe_load(f) or {}
|