sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""This module has the LightningModule classes for all model types."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional, Union, Dict, Any
|
|
3
|
+
from typing import Optional, Union, Dict, Any, List
|
|
4
4
|
import time
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
@@ -33,6 +33,7 @@ from sleap_nn.inference.bottomup import (
|
|
|
33
33
|
)
|
|
34
34
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
35
35
|
from sleap_nn.architectures.model import Model
|
|
36
|
+
from sleap_nn.data.normalization import normalize_on_gpu
|
|
36
37
|
from sleap_nn.training.losses import compute_ohkm_loss
|
|
37
38
|
from loguru import logger
|
|
38
39
|
from sleap_nn.training.utils import (
|
|
@@ -40,14 +41,26 @@ from sleap_nn.training.utils import (
|
|
|
40
41
|
plot_confmaps,
|
|
41
42
|
plot_img,
|
|
42
43
|
plot_peaks,
|
|
44
|
+
VisualizationData,
|
|
43
45
|
)
|
|
46
|
+
import matplotlib
|
|
47
|
+
|
|
48
|
+
matplotlib.use(
|
|
49
|
+
"Agg"
|
|
50
|
+
) # Use non-interactive backend to avoid tkinter issues on Windows CI
|
|
44
51
|
import matplotlib.pyplot as plt
|
|
45
52
|
from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
|
|
46
53
|
from sleap_nn.config.trainer_config import (
|
|
54
|
+
CosineAnnealingWarmupConfig,
|
|
55
|
+
LinearWarmupLinearDecayConfig,
|
|
47
56
|
LRSchedulerConfig,
|
|
48
57
|
ReduceLROnPlateauConfig,
|
|
49
58
|
StepLRConfig,
|
|
50
59
|
)
|
|
60
|
+
from sleap_nn.training.schedulers import (
|
|
61
|
+
LinearWarmupCosineAnnealingLR,
|
|
62
|
+
LinearWarmupLinearDecayLR,
|
|
63
|
+
)
|
|
51
64
|
from sleap_nn.config.get_config import get_backbone_config
|
|
52
65
|
from sleap_nn.legacy_models import (
|
|
53
66
|
load_legacy_model_weights,
|
|
@@ -177,6 +190,15 @@ class LightningModel(L.LightningModule):
|
|
|
177
190
|
self.val_loss = {}
|
|
178
191
|
self.learning_rate = {}
|
|
179
192
|
|
|
193
|
+
# For epoch-averaged loss tracking
|
|
194
|
+
self._epoch_loss_sum = 0.0
|
|
195
|
+
self._epoch_loss_count = 0
|
|
196
|
+
|
|
197
|
+
# For epoch-end evaluation
|
|
198
|
+
self.val_predictions: List[Dict] = []
|
|
199
|
+
self.val_ground_truth: List[Dict] = []
|
|
200
|
+
self._collect_val_predictions: bool = False
|
|
201
|
+
|
|
180
202
|
# Initialization for encoder and decoder stacks.
|
|
181
203
|
if self.init_weights == "xavier":
|
|
182
204
|
self.model.apply(xavier_init_weights)
|
|
@@ -213,7 +235,9 @@ class LightningModel(L.LightningModule):
|
|
|
213
235
|
elif self.pretrained_backbone_weights.endswith(".h5"):
|
|
214
236
|
# load from sleap model weights
|
|
215
237
|
load_legacy_model_weights(
|
|
216
|
-
self.model.backbone,
|
|
238
|
+
self.model.backbone,
|
|
239
|
+
self.pretrained_backbone_weights,
|
|
240
|
+
component="backbone",
|
|
217
241
|
)
|
|
218
242
|
|
|
219
243
|
else:
|
|
@@ -242,7 +266,9 @@ class LightningModel(L.LightningModule):
|
|
|
242
266
|
elif self.pretrained_head_weights.endswith(".h5"):
|
|
243
267
|
# load from sleap model weights
|
|
244
268
|
load_legacy_model_weights(
|
|
245
|
-
self.model.head_layers,
|
|
269
|
+
self.model.head_layers,
|
|
270
|
+
self.pretrained_head_weights,
|
|
271
|
+
component="head",
|
|
246
272
|
)
|
|
247
273
|
|
|
248
274
|
else:
|
|
@@ -298,34 +324,82 @@ class LightningModel(L.LightningModule):
|
|
|
298
324
|
def on_train_epoch_start(self):
|
|
299
325
|
"""Configure the train timer at the beginning of each epoch."""
|
|
300
326
|
self.train_start_time = time.time()
|
|
327
|
+
# Reset epoch loss tracking
|
|
328
|
+
self._epoch_loss_sum = 0.0
|
|
329
|
+
self._epoch_loss_count = 0
|
|
330
|
+
|
|
331
|
+
def _accumulate_loss(self, loss: torch.Tensor):
|
|
332
|
+
"""Accumulate loss for epoch-averaged logging. Call this in training_step."""
|
|
333
|
+
self._epoch_loss_sum += loss.detach().item()
|
|
334
|
+
self._epoch_loss_count += 1
|
|
301
335
|
|
|
302
336
|
def on_train_epoch_end(self):
|
|
303
337
|
"""Configure the train timer at the end of every epoch."""
|
|
304
338
|
train_time = time.time() - self.train_start_time
|
|
305
339
|
self.log(
|
|
306
|
-
"
|
|
340
|
+
"train/time",
|
|
307
341
|
train_time,
|
|
308
342
|
prog_bar=False,
|
|
309
343
|
on_step=False,
|
|
310
344
|
on_epoch=True,
|
|
311
|
-
logger=True,
|
|
312
345
|
sync_dist=True,
|
|
313
346
|
)
|
|
347
|
+
# Log epoch explicitly for custom x-axis support in wandb
|
|
348
|
+
self.log(
|
|
349
|
+
"epoch",
|
|
350
|
+
float(self.current_epoch),
|
|
351
|
+
on_step=False,
|
|
352
|
+
on_epoch=True,
|
|
353
|
+
sync_dist=True,
|
|
354
|
+
)
|
|
355
|
+
# Log epoch-averaged training loss
|
|
356
|
+
if self._epoch_loss_count > 0:
|
|
357
|
+
avg_loss = self._epoch_loss_sum / self._epoch_loss_count
|
|
358
|
+
self.log(
|
|
359
|
+
"train/loss",
|
|
360
|
+
avg_loss,
|
|
361
|
+
prog_bar=False,
|
|
362
|
+
on_step=False,
|
|
363
|
+
on_epoch=True,
|
|
364
|
+
sync_dist=True,
|
|
365
|
+
)
|
|
366
|
+
# Log current learning rate (useful for monitoring LR schedulers)
|
|
367
|
+
if self.trainer.optimizers:
|
|
368
|
+
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
|
|
369
|
+
self.log(
|
|
370
|
+
"train/lr",
|
|
371
|
+
lr,
|
|
372
|
+
prog_bar=False,
|
|
373
|
+
on_step=False,
|
|
374
|
+
on_epoch=True,
|
|
375
|
+
sync_dist=True,
|
|
376
|
+
)
|
|
314
377
|
|
|
315
378
|
def on_validation_epoch_start(self):
|
|
316
379
|
"""Configure the val timer at the beginning of each epoch."""
|
|
317
380
|
self.val_start_time = time.time()
|
|
381
|
+
# Clear accumulated predictions for new epoch
|
|
382
|
+
self.val_predictions = []
|
|
383
|
+
self.val_ground_truth = []
|
|
318
384
|
|
|
319
385
|
def on_validation_epoch_end(self):
|
|
320
386
|
"""Configure the val timer at the end of every epoch."""
|
|
321
387
|
val_time = time.time() - self.val_start_time
|
|
322
388
|
self.log(
|
|
323
|
-
"
|
|
389
|
+
"val/time",
|
|
324
390
|
val_time,
|
|
325
391
|
prog_bar=False,
|
|
326
392
|
on_step=False,
|
|
327
393
|
on_epoch=True,
|
|
328
|
-
|
|
394
|
+
sync_dist=True,
|
|
395
|
+
)
|
|
396
|
+
# Log epoch explicitly so val/* metrics can use it as x-axis in wandb
|
|
397
|
+
# (mirrors what on_train_epoch_end does for train/* metrics)
|
|
398
|
+
self.log(
|
|
399
|
+
"epoch",
|
|
400
|
+
float(self.current_epoch),
|
|
401
|
+
on_step=False,
|
|
402
|
+
on_epoch=True,
|
|
329
403
|
sync_dist=True,
|
|
330
404
|
)
|
|
331
405
|
|
|
@@ -362,13 +436,51 @@ class LightningModel(L.LightningModule):
|
|
|
362
436
|
lr_scheduler_cfg.step_lr = StepLRConfig()
|
|
363
437
|
elif self.lr_scheduler == "reduce_lr_on_plateau":
|
|
364
438
|
lr_scheduler_cfg.reduce_lr_on_plateau = ReduceLROnPlateauConfig()
|
|
439
|
+
elif self.lr_scheduler == "cosine_annealing_warmup":
|
|
440
|
+
lr_scheduler_cfg.cosine_annealing_warmup = CosineAnnealingWarmupConfig()
|
|
441
|
+
elif self.lr_scheduler == "linear_warmup_linear_decay":
|
|
442
|
+
lr_scheduler_cfg.linear_warmup_linear_decay = (
|
|
443
|
+
LinearWarmupLinearDecayConfig()
|
|
444
|
+
)
|
|
365
445
|
|
|
366
446
|
elif isinstance(self.lr_scheduler, dict):
|
|
367
447
|
lr_scheduler_cfg = self.lr_scheduler
|
|
368
448
|
|
|
369
449
|
for k, v in self.lr_scheduler.items():
|
|
370
450
|
if v is not None:
|
|
371
|
-
if k == "
|
|
451
|
+
if k == "cosine_annealing_warmup":
|
|
452
|
+
cfg = self.lr_scheduler.cosine_annealing_warmup
|
|
453
|
+
# Use trainer's max_epochs if not specified in config
|
|
454
|
+
max_epochs = (
|
|
455
|
+
cfg.max_epochs
|
|
456
|
+
if cfg.max_epochs is not None
|
|
457
|
+
else self.trainer.max_epochs
|
|
458
|
+
)
|
|
459
|
+
scheduler = LinearWarmupCosineAnnealingLR(
|
|
460
|
+
optimizer=optimizer,
|
|
461
|
+
warmup_epochs=cfg.warmup_epochs,
|
|
462
|
+
max_epochs=max_epochs,
|
|
463
|
+
warmup_start_lr=cfg.warmup_start_lr,
|
|
464
|
+
eta_min=cfg.eta_min,
|
|
465
|
+
)
|
|
466
|
+
break
|
|
467
|
+
elif k == "linear_warmup_linear_decay":
|
|
468
|
+
cfg = self.lr_scheduler.linear_warmup_linear_decay
|
|
469
|
+
# Use trainer's max_epochs if not specified in config
|
|
470
|
+
max_epochs = (
|
|
471
|
+
cfg.max_epochs
|
|
472
|
+
if cfg.max_epochs is not None
|
|
473
|
+
else self.trainer.max_epochs
|
|
474
|
+
)
|
|
475
|
+
scheduler = LinearWarmupLinearDecayLR(
|
|
476
|
+
optimizer=optimizer,
|
|
477
|
+
warmup_epochs=cfg.warmup_epochs,
|
|
478
|
+
max_epochs=max_epochs,
|
|
479
|
+
warmup_start_lr=cfg.warmup_start_lr,
|
|
480
|
+
end_lr=cfg.end_lr,
|
|
481
|
+
)
|
|
482
|
+
break
|
|
483
|
+
elif k == "step_lr":
|
|
372
484
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
|
373
485
|
optimizer=optimizer,
|
|
374
486
|
step_size=self.lr_scheduler.step_lr.step_size,
|
|
@@ -396,7 +508,7 @@ class LightningModel(L.LightningModule):
|
|
|
396
508
|
"optimizer": optimizer,
|
|
397
509
|
"lr_scheduler": {
|
|
398
510
|
"scheduler": scheduler,
|
|
399
|
-
"monitor": "
|
|
511
|
+
"monitor": "val/loss",
|
|
400
512
|
},
|
|
401
513
|
}
|
|
402
514
|
|
|
@@ -493,8 +605,15 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
493
605
|
)
|
|
494
606
|
self.node_names = self.head_configs.single_instance.confmaps.part_names
|
|
495
607
|
|
|
496
|
-
def
|
|
497
|
-
"""
|
|
608
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
609
|
+
"""Extract visualization data from a sample.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
sample: A sample dictionary from the data pipeline.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
VisualizationData containing image, confmaps, peaks, etc.
|
|
616
|
+
"""
|
|
498
617
|
ex = sample.copy()
|
|
499
618
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
500
619
|
for k, v in ex.items():
|
|
@@ -502,27 +621,41 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
502
621
|
ex[k] = v.to(device=self.device)
|
|
503
622
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
504
623
|
output = self.single_instance_inf_layer(ex)[0]
|
|
624
|
+
|
|
505
625
|
peaks = output["pred_instance_peaks"].cpu().numpy()
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
626
|
+
peak_values = output["pred_peak_values"].cpu().numpy()
|
|
627
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
509
628
|
gt_instances = ex["instances"][0].cpu().numpy()
|
|
510
|
-
confmaps = (
|
|
511
|
-
|
|
512
|
-
|
|
629
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
630
|
+
|
|
631
|
+
return VisualizationData(
|
|
632
|
+
image=img,
|
|
633
|
+
pred_confmaps=confmaps,
|
|
634
|
+
pred_peaks=peaks,
|
|
635
|
+
pred_peak_values=peak_values,
|
|
636
|
+
gt_instances=gt_instances,
|
|
637
|
+
node_names=list(self.node_names) if self.node_names else [],
|
|
638
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
639
|
+
is_paired=True,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
def visualize_example(self, sample):
|
|
643
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
644
|
+
data = self.get_visualization_data(sample)
|
|
513
645
|
scale = 1.0
|
|
514
|
-
if
|
|
646
|
+
if data.image.shape[0] < 512:
|
|
515
647
|
scale = 2.0
|
|
516
|
-
if
|
|
648
|
+
if data.image.shape[0] < 256:
|
|
517
649
|
scale = 4.0
|
|
518
|
-
fig = plot_img(
|
|
519
|
-
plot_confmaps(
|
|
520
|
-
plot_peaks(gt_instances,
|
|
650
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
651
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
652
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
521
653
|
return fig
|
|
522
654
|
|
|
523
655
|
def forward(self, img):
|
|
524
656
|
"""Forward pass of the model."""
|
|
525
657
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
658
|
+
img = normalize_on_gpu(img)
|
|
526
659
|
return self.model(img)["SingleInstanceConfmapsHead"]
|
|
527
660
|
|
|
528
661
|
def training_step(self, batch, batch_idx):
|
|
@@ -531,6 +664,7 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
531
664
|
torch.squeeze(batch["image"], dim=1),
|
|
532
665
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
533
666
|
)
|
|
667
|
+
X = normalize_on_gpu(X)
|
|
534
668
|
|
|
535
669
|
y_preds = self.model(X)["SingleInstanceConfmapsHead"]
|
|
536
670
|
|
|
@@ -554,23 +688,24 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
554
688
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
555
689
|
for node_idx, name in enumerate(self.node_names):
|
|
556
690
|
self.log(
|
|
557
|
-
f"{name}",
|
|
691
|
+
f"train/confmaps/{name}",
|
|
558
692
|
channel_wise_loss[node_idx],
|
|
559
|
-
prog_bar=
|
|
560
|
-
on_step=
|
|
693
|
+
prog_bar=False,
|
|
694
|
+
on_step=False,
|
|
561
695
|
on_epoch=True,
|
|
562
|
-
logger=True,
|
|
563
696
|
sync_dist=True,
|
|
564
697
|
)
|
|
698
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
565
699
|
self.log(
|
|
566
|
-
"
|
|
700
|
+
"loss",
|
|
567
701
|
loss,
|
|
568
702
|
prog_bar=True,
|
|
569
703
|
on_step=True,
|
|
570
|
-
on_epoch=
|
|
571
|
-
logger=True,
|
|
704
|
+
on_epoch=False,
|
|
572
705
|
sync_dist=True,
|
|
573
706
|
)
|
|
707
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
708
|
+
self._accumulate_loss(loss)
|
|
574
709
|
return loss
|
|
575
710
|
|
|
576
711
|
def validation_step(self, batch, batch_idx):
|
|
@@ -579,6 +714,7 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
579
714
|
torch.squeeze(batch["image"], dim=1),
|
|
580
715
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
581
716
|
)
|
|
717
|
+
X = normalize_on_gpu(X)
|
|
582
718
|
|
|
583
719
|
y_preds = self.model(X)["SingleInstanceConfmapsHead"]
|
|
584
720
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
@@ -592,26 +728,60 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
592
728
|
loss_scale=self.loss_scale,
|
|
593
729
|
)
|
|
594
730
|
val_loss = val_loss + ohkm_loss
|
|
595
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
596
731
|
self.log(
|
|
597
|
-
"
|
|
598
|
-
lr,
|
|
599
|
-
prog_bar=True,
|
|
600
|
-
on_step=True,
|
|
601
|
-
on_epoch=True,
|
|
602
|
-
logger=True,
|
|
603
|
-
sync_dist=True,
|
|
604
|
-
)
|
|
605
|
-
self.log(
|
|
606
|
-
"val_loss",
|
|
732
|
+
"val/loss",
|
|
607
733
|
val_loss,
|
|
608
734
|
prog_bar=True,
|
|
609
|
-
on_step=
|
|
735
|
+
on_step=False,
|
|
610
736
|
on_epoch=True,
|
|
611
|
-
logger=True,
|
|
612
737
|
sync_dist=True,
|
|
613
738
|
)
|
|
614
739
|
|
|
740
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
741
|
+
if self._collect_val_predictions:
|
|
742
|
+
with torch.no_grad():
|
|
743
|
+
# Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
|
|
744
|
+
inference_batch = {k: v for k, v in batch.items()}
|
|
745
|
+
if inference_batch["image"].ndim == 5:
|
|
746
|
+
inference_batch["image"] = inference_batch["image"].squeeze(1)
|
|
747
|
+
inference_output = self.single_instance_inf_layer(inference_batch)
|
|
748
|
+
if isinstance(inference_output, list):
|
|
749
|
+
inference_output = inference_output[0]
|
|
750
|
+
|
|
751
|
+
batch_size = len(batch["frame_idx"])
|
|
752
|
+
for i in range(batch_size):
|
|
753
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
754
|
+
|
|
755
|
+
# Predictions are already in original image space (inference divides by eff_scale)
|
|
756
|
+
pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
757
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
758
|
+
|
|
759
|
+
# Transform GT from preprocessed to original image space
|
|
760
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
761
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
762
|
+
if gt_prep.ndim == 4:
|
|
763
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
764
|
+
gt_orig = gt_prep / eff
|
|
765
|
+
num_inst = batch["num_instances"][i].item()
|
|
766
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
767
|
+
|
|
768
|
+
self.val_predictions.append(
|
|
769
|
+
{
|
|
770
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
771
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
772
|
+
"pred_peaks": pred_peaks,
|
|
773
|
+
"pred_scores": pred_scores,
|
|
774
|
+
}
|
|
775
|
+
)
|
|
776
|
+
self.val_ground_truth.append(
|
|
777
|
+
{
|
|
778
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
779
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
780
|
+
"gt_instances": gt_orig,
|
|
781
|
+
"num_instances": num_inst,
|
|
782
|
+
}
|
|
783
|
+
)
|
|
784
|
+
|
|
615
785
|
|
|
616
786
|
class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
617
787
|
"""Lightning Module for TopDownCenteredInstance Model.
|
|
@@ -705,8 +875,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
705
875
|
|
|
706
876
|
self.node_names = self.head_configs.centered_instance.confmaps.part_names
|
|
707
877
|
|
|
708
|
-
def
|
|
709
|
-
"""
|
|
878
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
879
|
+
"""Extract visualization data from a sample."""
|
|
710
880
|
ex = sample.copy()
|
|
711
881
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
712
882
|
for k, v in ex.items():
|
|
@@ -714,27 +884,41 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
714
884
|
ex[k] = v.to(device=self.device)
|
|
715
885
|
ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
|
|
716
886
|
output = self.instance_peaks_inf_layer(ex)
|
|
887
|
+
|
|
717
888
|
peaks = output["pred_instance_peaks"].cpu().numpy()
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
889
|
+
peak_values = output["pred_peak_values"].cpu().numpy()
|
|
890
|
+
img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
721
891
|
gt_instances = ex["instance"].cpu().numpy()
|
|
722
|
-
confmaps = (
|
|
723
|
-
|
|
724
|
-
|
|
892
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
893
|
+
|
|
894
|
+
return VisualizationData(
|
|
895
|
+
image=img,
|
|
896
|
+
pred_confmaps=confmaps,
|
|
897
|
+
pred_peaks=peaks,
|
|
898
|
+
pred_peak_values=peak_values,
|
|
899
|
+
gt_instances=gt_instances,
|
|
900
|
+
node_names=list(self.node_names) if self.node_names else [],
|
|
901
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
902
|
+
is_paired=True,
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
def visualize_example(self, sample):
|
|
906
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
907
|
+
data = self.get_visualization_data(sample)
|
|
725
908
|
scale = 1.0
|
|
726
|
-
if
|
|
909
|
+
if data.image.shape[0] < 512:
|
|
727
910
|
scale = 2.0
|
|
728
|
-
if
|
|
911
|
+
if data.image.shape[0] < 256:
|
|
729
912
|
scale = 4.0
|
|
730
|
-
fig = plot_img(
|
|
731
|
-
plot_confmaps(
|
|
732
|
-
plot_peaks(gt_instances,
|
|
913
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
914
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
915
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
733
916
|
return fig
|
|
734
917
|
|
|
735
918
|
def forward(self, img):
|
|
736
919
|
"""Forward pass of the model."""
|
|
737
920
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
921
|
+
img = normalize_on_gpu(img)
|
|
738
922
|
return self.model(img)["CenteredInstanceConfmapsHead"]
|
|
739
923
|
|
|
740
924
|
def training_step(self, batch, batch_idx):
|
|
@@ -743,6 +927,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
743
927
|
torch.squeeze(batch["instance_image"], dim=1),
|
|
744
928
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
745
929
|
)
|
|
930
|
+
X = normalize_on_gpu(X)
|
|
746
931
|
|
|
747
932
|
y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
|
|
748
933
|
|
|
@@ -766,24 +951,25 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
766
951
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
767
952
|
for node_idx, name in enumerate(self.node_names):
|
|
768
953
|
self.log(
|
|
769
|
-
f"{name}",
|
|
954
|
+
f"train/confmaps/{name}",
|
|
770
955
|
channel_wise_loss[node_idx],
|
|
771
|
-
prog_bar=
|
|
772
|
-
on_step=
|
|
956
|
+
prog_bar=False,
|
|
957
|
+
on_step=False,
|
|
773
958
|
on_epoch=True,
|
|
774
|
-
logger=True,
|
|
775
959
|
sync_dist=True,
|
|
776
960
|
)
|
|
777
961
|
|
|
962
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
778
963
|
self.log(
|
|
779
|
-
"
|
|
964
|
+
"loss",
|
|
780
965
|
loss,
|
|
781
966
|
prog_bar=True,
|
|
782
967
|
on_step=True,
|
|
783
|
-
on_epoch=
|
|
784
|
-
logger=True,
|
|
968
|
+
on_epoch=False,
|
|
785
969
|
sync_dist=True,
|
|
786
970
|
)
|
|
971
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
972
|
+
self._accumulate_loss(loss)
|
|
787
973
|
return loss
|
|
788
974
|
|
|
789
975
|
def validation_step(self, batch, batch_idx):
|
|
@@ -792,6 +978,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
792
978
|
torch.squeeze(batch["instance_image"], dim=1),
|
|
793
979
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
794
980
|
)
|
|
981
|
+
X = normalize_on_gpu(X)
|
|
795
982
|
|
|
796
983
|
y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
|
|
797
984
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
@@ -805,26 +992,71 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
805
992
|
loss_scale=self.loss_scale,
|
|
806
993
|
)
|
|
807
994
|
val_loss = val_loss + ohkm_loss
|
|
808
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
809
|
-
self.log(
|
|
810
|
-
"learning_rate",
|
|
811
|
-
lr,
|
|
812
|
-
prog_bar=True,
|
|
813
|
-
on_step=True,
|
|
814
|
-
on_epoch=True,
|
|
815
|
-
logger=True,
|
|
816
|
-
sync_dist=True,
|
|
817
|
-
)
|
|
818
995
|
self.log(
|
|
819
|
-
"
|
|
996
|
+
"val/loss",
|
|
820
997
|
val_loss,
|
|
821
998
|
prog_bar=True,
|
|
822
|
-
on_step=
|
|
999
|
+
on_step=False,
|
|
823
1000
|
on_epoch=True,
|
|
824
|
-
logger=True,
|
|
825
1001
|
sync_dist=True,
|
|
826
1002
|
)
|
|
827
1003
|
|
|
1004
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1005
|
+
if self._collect_val_predictions:
|
|
1006
|
+
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
1007
|
+
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
1008
|
+
|
|
1009
|
+
with torch.no_grad():
|
|
1010
|
+
inference_output = self.instance_peaks_inf_layer(batch)
|
|
1011
|
+
|
|
1012
|
+
batch_size = len(batch["frame_idx"])
|
|
1013
|
+
for i in range(batch_size):
|
|
1014
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1015
|
+
|
|
1016
|
+
# Predictions from inference (crop-relative, original scale)
|
|
1017
|
+
pred_peaks_crop = (
|
|
1018
|
+
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
1019
|
+
)
|
|
1020
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
1021
|
+
|
|
1022
|
+
# Compute bbox offset in original space from SAVED prep bbox
|
|
1023
|
+
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
1024
|
+
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
1025
|
+
bbox_top_left_orig = (
|
|
1026
|
+
bbox_prep[0] / eff
|
|
1027
|
+
) # Top-left corner in original space
|
|
1028
|
+
|
|
1029
|
+
# Full image coordinates (original space)
|
|
1030
|
+
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
1031
|
+
|
|
1032
|
+
# GT transform: crop-relative preprocessed -> full image original
|
|
1033
|
+
gt_crop_prep = (
|
|
1034
|
+
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
1035
|
+
) # (n_nodes, 2)
|
|
1036
|
+
gt_crop_orig = gt_crop_prep / eff
|
|
1037
|
+
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
1038
|
+
|
|
1039
|
+
self.val_predictions.append(
|
|
1040
|
+
{
|
|
1041
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1042
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1043
|
+
"pred_peaks": pred_peaks_full.reshape(
|
|
1044
|
+
1, -1, 2
|
|
1045
|
+
), # (1, n_nodes, 2)
|
|
1046
|
+
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
1047
|
+
}
|
|
1048
|
+
)
|
|
1049
|
+
self.val_ground_truth.append(
|
|
1050
|
+
{
|
|
1051
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1052
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1053
|
+
"gt_instances": gt_full_orig.reshape(
|
|
1054
|
+
1, -1, 2
|
|
1055
|
+
), # (1, n_nodes, 2)
|
|
1056
|
+
"num_instances": 1,
|
|
1057
|
+
}
|
|
1058
|
+
)
|
|
1059
|
+
|
|
828
1060
|
|
|
829
1061
|
class CentroidLightningModule(LightningModel):
|
|
830
1062
|
"""Lightning Module for Centroid Model.
|
|
@@ -916,9 +1148,10 @@ class CentroidLightningModule(LightningModel):
|
|
|
916
1148
|
output_stride=self.head_configs.centroid.confmaps.output_stride,
|
|
917
1149
|
input_scale=1.0,
|
|
918
1150
|
)
|
|
1151
|
+
self.node_names = ["centroid"]
|
|
919
1152
|
|
|
920
|
-
def
|
|
921
|
-
"""
|
|
1153
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
1154
|
+
"""Extract visualization data from a sample."""
|
|
922
1155
|
ex = sample.copy()
|
|
923
1156
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
924
1157
|
for k, v in ex.items():
|
|
@@ -927,26 +1160,40 @@ class CentroidLightningModule(LightningModel):
|
|
|
927
1160
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
928
1161
|
gt_centroids = ex["centroids"].cpu().numpy()
|
|
929
1162
|
output = self.centroid_inf_layer(ex)
|
|
1163
|
+
|
|
930
1164
|
peaks = output["centroids"][0].cpu().numpy()
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
1165
|
+
centroid_vals = output["centroid_vals"][0].cpu().numpy()
|
|
1166
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1167
|
+
confmaps = output["pred_centroid_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
1168
|
+
|
|
1169
|
+
return VisualizationData(
|
|
1170
|
+
image=img,
|
|
1171
|
+
pred_confmaps=confmaps,
|
|
1172
|
+
pred_peaks=peaks,
|
|
1173
|
+
pred_peak_values=centroid_vals,
|
|
1174
|
+
gt_instances=gt_centroids,
|
|
1175
|
+
node_names=self.node_names,
|
|
1176
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
1177
|
+
is_paired=False,
|
|
1178
|
+
)
|
|
1179
|
+
|
|
1180
|
+
def visualize_example(self, sample):
|
|
1181
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
1182
|
+
data = self.get_visualization_data(sample)
|
|
937
1183
|
scale = 1.0
|
|
938
|
-
if
|
|
1184
|
+
if data.image.shape[0] < 512:
|
|
939
1185
|
scale = 2.0
|
|
940
|
-
if
|
|
1186
|
+
if data.image.shape[0] < 256:
|
|
941
1187
|
scale = 4.0
|
|
942
|
-
fig = plot_img(
|
|
943
|
-
plot_confmaps(
|
|
944
|
-
plot_peaks(
|
|
1188
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1189
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
1190
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
945
1191
|
return fig
|
|
946
1192
|
|
|
947
1193
|
def forward(self, img):
|
|
948
1194
|
"""Forward pass of the model."""
|
|
949
1195
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
1196
|
+
img = normalize_on_gpu(img)
|
|
950
1197
|
return self.model(img)["CentroidConfmapsHead"]
|
|
951
1198
|
|
|
952
1199
|
def training_step(self, batch, batch_idx):
|
|
@@ -955,18 +1202,21 @@ class CentroidLightningModule(LightningModel):
|
|
|
955
1202
|
torch.squeeze(batch["image"], dim=1),
|
|
956
1203
|
torch.squeeze(batch["centroids_confidence_maps"], dim=1),
|
|
957
1204
|
)
|
|
1205
|
+
X = normalize_on_gpu(X)
|
|
958
1206
|
|
|
959
1207
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
960
1208
|
loss = nn.MSELoss()(y_preds, y)
|
|
1209
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
961
1210
|
self.log(
|
|
962
|
-
"
|
|
1211
|
+
"loss",
|
|
963
1212
|
loss,
|
|
964
1213
|
prog_bar=True,
|
|
965
1214
|
on_step=True,
|
|
966
|
-
on_epoch=
|
|
967
|
-
logger=True,
|
|
1215
|
+
on_epoch=False,
|
|
968
1216
|
sync_dist=True,
|
|
969
1217
|
)
|
|
1218
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1219
|
+
self._accumulate_loss(loss)
|
|
970
1220
|
return loss
|
|
971
1221
|
|
|
972
1222
|
def validation_step(self, batch, batch_idx):
|
|
@@ -975,29 +1225,74 @@ class CentroidLightningModule(LightningModel):
|
|
|
975
1225
|
torch.squeeze(batch["image"], dim=1),
|
|
976
1226
|
torch.squeeze(batch["centroids_confidence_maps"], dim=1),
|
|
977
1227
|
)
|
|
1228
|
+
X = normalize_on_gpu(X)
|
|
978
1229
|
|
|
979
1230
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
980
1231
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
981
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
982
1232
|
self.log(
|
|
983
|
-
"
|
|
984
|
-
lr,
|
|
985
|
-
prog_bar=True,
|
|
986
|
-
on_step=True,
|
|
987
|
-
on_epoch=True,
|
|
988
|
-
logger=True,
|
|
989
|
-
sync_dist=True,
|
|
990
|
-
)
|
|
991
|
-
self.log(
|
|
992
|
-
"val_loss",
|
|
1233
|
+
"val/loss",
|
|
993
1234
|
val_loss,
|
|
994
1235
|
prog_bar=True,
|
|
995
|
-
on_step=
|
|
1236
|
+
on_step=False,
|
|
996
1237
|
on_epoch=True,
|
|
997
|
-
logger=True,
|
|
998
1238
|
sync_dist=True,
|
|
999
1239
|
)
|
|
1000
1240
|
|
|
1241
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1242
|
+
if self._collect_val_predictions:
|
|
1243
|
+
# Save GT centroids before inference (inference overwrites batch["centroids"])
|
|
1244
|
+
batch["gt_centroids"] = batch["centroids"].clone()
|
|
1245
|
+
|
|
1246
|
+
with torch.no_grad():
|
|
1247
|
+
inference_output = self.centroid_inf_layer(batch)
|
|
1248
|
+
|
|
1249
|
+
batch_size = len(batch["frame_idx"])
|
|
1250
|
+
for i in range(batch_size):
|
|
1251
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1252
|
+
|
|
1253
|
+
# Predictions are in original image space (inference divides by eff_scale)
|
|
1254
|
+
# centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
|
|
1255
|
+
pred_centroids = (
|
|
1256
|
+
inference_output["centroids"][i].squeeze(0).cpu().numpy()
|
|
1257
|
+
)
|
|
1258
|
+
pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
|
|
1259
|
+
|
|
1260
|
+
# Transform GT centroids from preprocessed to original image space
|
|
1261
|
+
# Use "gt_centroids" since inference overwrites "centroids" with predictions
|
|
1262
|
+
gt_centroids_prep = (
|
|
1263
|
+
batch["gt_centroids"][i].cpu().numpy()
|
|
1264
|
+
) # (n_samples=1, max_inst, 2)
|
|
1265
|
+
gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
|
|
1266
|
+
num_inst = batch["num_instances"][i].item()
|
|
1267
|
+
|
|
1268
|
+
# Filter to valid instances (non-NaN)
|
|
1269
|
+
valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
|
|
1270
|
+
pred_centroids = pred_centroids[valid_pred_mask]
|
|
1271
|
+
pred_vals = pred_vals[valid_pred_mask]
|
|
1272
|
+
|
|
1273
|
+
gt_centroids_valid = gt_centroids_orig[:num_inst]
|
|
1274
|
+
|
|
1275
|
+
self.val_predictions.append(
|
|
1276
|
+
{
|
|
1277
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1278
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1279
|
+
"pred_peaks": pred_centroids.reshape(
|
|
1280
|
+
-1, 1, 2
|
|
1281
|
+
), # (n_inst, 1, 2)
|
|
1282
|
+
"pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
|
|
1283
|
+
}
|
|
1284
|
+
)
|
|
1285
|
+
self.val_ground_truth.append(
|
|
1286
|
+
{
|
|
1287
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1288
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1289
|
+
"gt_instances": gt_centroids_valid.reshape(
|
|
1290
|
+
-1, 1, 2
|
|
1291
|
+
), # (n_inst, 1, 2)
|
|
1292
|
+
"num_instances": num_inst,
|
|
1293
|
+
}
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1001
1296
|
|
|
1002
1297
|
class BottomUpLightningModule(LightningModel):
|
|
1003
1298
|
"""Lightning Module for BottomUp Model.
|
|
@@ -1090,16 +1385,20 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1090
1385
|
self.bottomup_inf_layer = BottomUpInferenceModel(
|
|
1091
1386
|
torch_model=self.forward,
|
|
1092
1387
|
paf_scorer=paf_scorer,
|
|
1093
|
-
peak_threshold=0.
|
|
1388
|
+
peak_threshold=0.1, # Lower threshold for epoch-end eval during training
|
|
1094
1389
|
input_scale=1.0,
|
|
1095
1390
|
return_confmaps=True,
|
|
1096
1391
|
return_pafs=True,
|
|
1097
1392
|
cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
|
|
1098
1393
|
pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
|
|
1394
|
+
max_peaks_per_node=100, # Prevents combinatorial explosion in early training
|
|
1099
1395
|
)
|
|
1396
|
+
self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
|
|
1100
1397
|
|
|
1101
|
-
def
|
|
1102
|
-
|
|
1398
|
+
def get_visualization_data(
|
|
1399
|
+
self, sample, include_pafs: bool = False
|
|
1400
|
+
) -> VisualizationData:
|
|
1401
|
+
"""Extract visualization data from a sample."""
|
|
1103
1402
|
ex = sample.copy()
|
|
1104
1403
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
1105
1404
|
for k, v in ex.items():
|
|
@@ -1107,54 +1406,65 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1107
1406
|
ex[k] = v.to(device=self.device)
|
|
1108
1407
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1109
1408
|
output = self.bottomup_inf_layer(ex)[0]
|
|
1409
|
+
|
|
1110
1410
|
peaks = output["pred_instance_peaks"][0].cpu().numpy()
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1411
|
+
peak_values = output["pred_peak_values"][0].cpu().numpy()
|
|
1412
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1114
1413
|
gt_instances = ex["instances"][0].cpu().numpy()
|
|
1115
|
-
confmaps = (
|
|
1116
|
-
|
|
1117
|
-
|
|
1414
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
1415
|
+
|
|
1416
|
+
pred_pafs = None
|
|
1417
|
+
if include_pafs:
|
|
1418
|
+
pafs = output["pred_part_affinity_fields"].cpu().numpy()[0]
|
|
1419
|
+
pred_pafs = pafs # (h, w, 2*edges)
|
|
1420
|
+
|
|
1421
|
+
return VisualizationData(
|
|
1422
|
+
image=img,
|
|
1423
|
+
pred_confmaps=confmaps,
|
|
1424
|
+
pred_peaks=peaks,
|
|
1425
|
+
pred_peak_values=peak_values,
|
|
1426
|
+
gt_instances=gt_instances,
|
|
1427
|
+
node_names=self.node_names,
|
|
1428
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
1429
|
+
is_paired=False,
|
|
1430
|
+
pred_pafs=pred_pafs,
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
def visualize_example(self, sample):
|
|
1434
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
1435
|
+
data = self.get_visualization_data(sample)
|
|
1118
1436
|
scale = 1.0
|
|
1119
|
-
if
|
|
1437
|
+
if data.image.shape[0] < 512:
|
|
1120
1438
|
scale = 2.0
|
|
1121
|
-
if
|
|
1439
|
+
if data.image.shape[0] < 256:
|
|
1122
1440
|
scale = 4.0
|
|
1123
|
-
fig = plot_img(
|
|
1124
|
-
plot_confmaps(
|
|
1441
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1442
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
1125
1443
|
plt.xlim(plt.xlim())
|
|
1126
1444
|
plt.ylim(plt.ylim())
|
|
1127
|
-
plot_peaks(gt_instances,
|
|
1445
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
1128
1446
|
return fig
|
|
1129
1447
|
|
|
1130
1448
|
def visualize_pafs_example(self, sample):
|
|
1131
|
-
"""Visualize predictions during training (used with callbacks)."""
|
|
1132
|
-
|
|
1133
|
-
ex["eff_scale"] = torch.tensor([1.0])
|
|
1134
|
-
for k, v in ex.items():
|
|
1135
|
-
if isinstance(v, torch.Tensor):
|
|
1136
|
-
ex[k] = v.to(device=self.device)
|
|
1137
|
-
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1138
|
-
output = self.bottomup_inf_layer(ex)[0]
|
|
1139
|
-
img = (
|
|
1140
|
-
output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1141
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1142
|
-
pafs = output["pred_part_affinity_fields"].cpu().numpy()[0] # (h, w, 2*edges)
|
|
1449
|
+
"""Visualize PAF predictions during training (used with callbacks)."""
|
|
1450
|
+
data = self.get_visualization_data(sample, include_pafs=True)
|
|
1143
1451
|
scale = 1.0
|
|
1144
|
-
if
|
|
1452
|
+
if data.image.shape[0] < 512:
|
|
1145
1453
|
scale = 2.0
|
|
1146
|
-
if
|
|
1454
|
+
if data.image.shape[0] < 256:
|
|
1147
1455
|
scale = 4.0
|
|
1148
|
-
fig = plot_img(
|
|
1456
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1149
1457
|
|
|
1458
|
+
pafs = data.pred_pafs
|
|
1150
1459
|
pafs = pafs.reshape((pafs.shape[0], pafs.shape[1], -1, 2))
|
|
1151
1460
|
pafs_mag = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2)
|
|
1152
|
-
plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] /
|
|
1461
|
+
plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] / data.image.shape[0])
|
|
1153
1462
|
return fig
|
|
1154
1463
|
|
|
1155
1464
|
def forward(self, img):
|
|
1156
1465
|
"""Forward pass of the model."""
|
|
1157
1466
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
1467
|
+
img = normalize_on_gpu(img)
|
|
1158
1468
|
output = self.model(img)
|
|
1159
1469
|
return {
|
|
1160
1470
|
"MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
|
|
@@ -1166,6 +1476,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1166
1476
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1167
1477
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1168
1478
|
y_paf = batch["part_affinity_fields"]
|
|
1479
|
+
X = normalize_on_gpu(X)
|
|
1169
1480
|
preds = self.model(X)
|
|
1170
1481
|
pafs = preds["PartAffinityFieldsHead"]
|
|
1171
1482
|
confmaps = preds["MultiInstanceConfmapsHead"]
|
|
@@ -1198,13 +1509,29 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1198
1509
|
"PartAffinityFieldsHead": pafs_loss,
|
|
1199
1510
|
}
|
|
1200
1511
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1512
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1201
1513
|
self.log(
|
|
1202
|
-
"
|
|
1514
|
+
"loss",
|
|
1203
1515
|
loss,
|
|
1204
1516
|
prog_bar=True,
|
|
1205
1517
|
on_step=True,
|
|
1518
|
+
on_epoch=False,
|
|
1519
|
+
sync_dist=True,
|
|
1520
|
+
)
|
|
1521
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1522
|
+
self._accumulate_loss(loss)
|
|
1523
|
+
self.log(
|
|
1524
|
+
"train/confmaps_loss",
|
|
1525
|
+
confmap_loss,
|
|
1526
|
+
on_step=False,
|
|
1527
|
+
on_epoch=True,
|
|
1528
|
+
sync_dist=True,
|
|
1529
|
+
)
|
|
1530
|
+
self.log(
|
|
1531
|
+
"train/paf_loss",
|
|
1532
|
+
pafs_loss,
|
|
1533
|
+
on_step=False,
|
|
1206
1534
|
on_epoch=True,
|
|
1207
|
-
logger=True,
|
|
1208
1535
|
sync_dist=True,
|
|
1209
1536
|
)
|
|
1210
1537
|
return loss
|
|
@@ -1214,6 +1541,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1214
1541
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1215
1542
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1216
1543
|
y_paf = batch["part_affinity_fields"]
|
|
1544
|
+
X = normalize_on_gpu(X)
|
|
1217
1545
|
|
|
1218
1546
|
preds = self.model(X)
|
|
1219
1547
|
pafs = preds["PartAffinityFieldsHead"]
|
|
@@ -1248,25 +1576,75 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1248
1576
|
}
|
|
1249
1577
|
|
|
1250
1578
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1251
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
1252
1579
|
self.log(
|
|
1253
|
-
"
|
|
1254
|
-
|
|
1580
|
+
"val/loss",
|
|
1581
|
+
val_loss,
|
|
1255
1582
|
prog_bar=True,
|
|
1256
|
-
on_step=
|
|
1583
|
+
on_step=False,
|
|
1257
1584
|
on_epoch=True,
|
|
1258
|
-
logger=True,
|
|
1259
1585
|
sync_dist=True,
|
|
1260
1586
|
)
|
|
1261
1587
|
self.log(
|
|
1262
|
-
"
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
on_step=True,
|
|
1588
|
+
"val/confmaps_loss",
|
|
1589
|
+
confmap_loss,
|
|
1590
|
+
on_step=False,
|
|
1266
1591
|
on_epoch=True,
|
|
1267
|
-
logger=True,
|
|
1268
1592
|
sync_dist=True,
|
|
1269
1593
|
)
|
|
1594
|
+
self.log(
|
|
1595
|
+
"val/paf_loss",
|
|
1596
|
+
pafs_loss,
|
|
1597
|
+
on_step=False,
|
|
1598
|
+
on_epoch=True,
|
|
1599
|
+
sync_dist=True,
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1602
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1603
|
+
if self._collect_val_predictions:
|
|
1604
|
+
with torch.no_grad():
|
|
1605
|
+
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
1606
|
+
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
1607
|
+
inference_output = self.bottomup_inf_layer(batch)
|
|
1608
|
+
if isinstance(inference_output, list):
|
|
1609
|
+
inference_output = inference_output[0]
|
|
1610
|
+
|
|
1611
|
+
batch_size = len(batch["frame_idx"])
|
|
1612
|
+
for i in range(batch_size):
|
|
1613
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1614
|
+
|
|
1615
|
+
# Predictions are already in original space (variable number of instances)
|
|
1616
|
+
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
1617
|
+
pred_scores = inference_output["pred_peak_values"][i]
|
|
1618
|
+
if torch.is_tensor(pred_peaks):
|
|
1619
|
+
pred_peaks = pred_peaks.cpu().numpy()
|
|
1620
|
+
if torch.is_tensor(pred_scores):
|
|
1621
|
+
pred_scores = pred_scores.cpu().numpy()
|
|
1622
|
+
|
|
1623
|
+
# Transform GT to original space
|
|
1624
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
1625
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
1626
|
+
if gt_prep.ndim == 4:
|
|
1627
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
1628
|
+
gt_orig = gt_prep / eff
|
|
1629
|
+
num_inst = batch["num_instances"][i].item()
|
|
1630
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
1631
|
+
|
|
1632
|
+
self.val_predictions.append(
|
|
1633
|
+
{
|
|
1634
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1635
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1636
|
+
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
1637
|
+
"pred_scores": pred_scores,
|
|
1638
|
+
}
|
|
1639
|
+
)
|
|
1640
|
+
self.val_ground_truth.append(
|
|
1641
|
+
{
|
|
1642
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1643
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1644
|
+
"gt_instances": gt_orig, # Original space
|
|
1645
|
+
"num_instances": num_inst,
|
|
1646
|
+
}
|
|
1647
|
+
)
|
|
1270
1648
|
|
|
1271
1649
|
|
|
1272
1650
|
class BottomUpMultiClassLightningModule(LightningModel):
|
|
@@ -1361,9 +1739,14 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1361
1739
|
cms_output_stride=self.head_configs.multi_class_bottomup.confmaps.output_stride,
|
|
1362
1740
|
class_maps_output_stride=self.head_configs.multi_class_bottomup.class_maps.output_stride,
|
|
1363
1741
|
)
|
|
1742
|
+
self.node_names = list(
|
|
1743
|
+
self.head_configs.multi_class_bottomup.confmaps.part_names
|
|
1744
|
+
)
|
|
1364
1745
|
|
|
1365
|
-
def
|
|
1366
|
-
|
|
1746
|
+
def get_visualization_data(
|
|
1747
|
+
self, sample, include_class_maps: bool = False
|
|
1748
|
+
) -> VisualizationData:
|
|
1749
|
+
"""Extract visualization data from a sample."""
|
|
1367
1750
|
ex = sample.copy()
|
|
1368
1751
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
1369
1752
|
for k, v in ex.items():
|
|
@@ -1371,54 +1754,65 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1371
1754
|
ex[k] = v.to(device=self.device)
|
|
1372
1755
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1373
1756
|
output = self.bottomup_inf_layer(ex)[0]
|
|
1757
|
+
|
|
1374
1758
|
peaks = output["pred_instance_peaks"][0].cpu().numpy()
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1759
|
+
peak_values = output["pred_peak_values"][0].cpu().numpy()
|
|
1760
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1378
1761
|
gt_instances = ex["instances"][0].cpu().numpy()
|
|
1379
|
-
confmaps = (
|
|
1380
|
-
|
|
1381
|
-
|
|
1762
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
1763
|
+
|
|
1764
|
+
pred_class_maps = None
|
|
1765
|
+
if include_class_maps:
|
|
1766
|
+
pred_class_maps = (
|
|
1767
|
+
output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
|
|
1768
|
+
)
|
|
1769
|
+
|
|
1770
|
+
return VisualizationData(
|
|
1771
|
+
image=img,
|
|
1772
|
+
pred_confmaps=confmaps,
|
|
1773
|
+
pred_peaks=peaks,
|
|
1774
|
+
pred_peak_values=peak_values,
|
|
1775
|
+
gt_instances=gt_instances,
|
|
1776
|
+
node_names=self.node_names,
|
|
1777
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
1778
|
+
is_paired=False,
|
|
1779
|
+
pred_class_maps=pred_class_maps,
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
def visualize_example(self, sample):
|
|
1783
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
1784
|
+
data = self.get_visualization_data(sample)
|
|
1382
1785
|
scale = 1.0
|
|
1383
|
-
if
|
|
1786
|
+
if data.image.shape[0] < 512:
|
|
1384
1787
|
scale = 2.0
|
|
1385
|
-
if
|
|
1788
|
+
if data.image.shape[0] < 256:
|
|
1386
1789
|
scale = 4.0
|
|
1387
|
-
fig = plot_img(
|
|
1388
|
-
plot_confmaps(
|
|
1790
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1791
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
1389
1792
|
plt.xlim(plt.xlim())
|
|
1390
1793
|
plt.ylim(plt.ylim())
|
|
1391
|
-
plot_peaks(gt_instances,
|
|
1794
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
1392
1795
|
return fig
|
|
1393
1796
|
|
|
1394
1797
|
def visualize_class_maps_example(self, sample):
|
|
1395
|
-
"""Visualize predictions during training (used with callbacks)."""
|
|
1396
|
-
|
|
1397
|
-
ex["eff_scale"] = torch.tensor([1.0])
|
|
1398
|
-
for k, v in ex.items():
|
|
1399
|
-
if isinstance(v, torch.Tensor):
|
|
1400
|
-
ex[k] = v.to(device=self.device)
|
|
1401
|
-
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1402
|
-
output = self.bottomup_inf_layer(ex)[0]
|
|
1403
|
-
img = (
|
|
1404
|
-
output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1405
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1406
|
-
classmaps = (
|
|
1407
|
-
output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
|
|
1408
|
-
) # (n_classes, h, w)
|
|
1798
|
+
"""Visualize class map predictions during training (used with callbacks)."""
|
|
1799
|
+
data = self.get_visualization_data(sample, include_class_maps=True)
|
|
1409
1800
|
scale = 1.0
|
|
1410
|
-
if
|
|
1801
|
+
if data.image.shape[0] < 512:
|
|
1411
1802
|
scale = 2.0
|
|
1412
|
-
if
|
|
1803
|
+
if data.image.shape[0] < 256:
|
|
1413
1804
|
scale = 4.0
|
|
1414
|
-
fig = plot_img(
|
|
1415
|
-
|
|
1416
|
-
|
|
1805
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1806
|
+
plot_confmaps(
|
|
1807
|
+
data.pred_class_maps,
|
|
1808
|
+
output_scale=data.pred_class_maps.shape[0] / data.image.shape[0],
|
|
1809
|
+
)
|
|
1417
1810
|
return fig
|
|
1418
1811
|
|
|
1419
1812
|
def forward(self, img):
|
|
1420
1813
|
"""Forward pass of the model."""
|
|
1421
1814
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
1815
|
+
img = normalize_on_gpu(img)
|
|
1422
1816
|
output = self.model(img)
|
|
1423
1817
|
return {
|
|
1424
1818
|
"MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
|
|
@@ -1430,6 +1824,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1430
1824
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1431
1825
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1432
1826
|
y_classmap = torch.squeeze(batch["class_maps"], dim=1)
|
|
1827
|
+
X = normalize_on_gpu(X)
|
|
1433
1828
|
preds = self.model(X)
|
|
1434
1829
|
classmaps = preds["ClassMapsHead"]
|
|
1435
1830
|
confmaps = preds["MultiInstanceConfmapsHead"]
|
|
@@ -1453,15 +1848,84 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1453
1848
|
"ClassMapsHead": classmaps_loss,
|
|
1454
1849
|
}
|
|
1455
1850
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1851
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1456
1852
|
self.log(
|
|
1457
|
-
"
|
|
1853
|
+
"loss",
|
|
1458
1854
|
loss,
|
|
1459
1855
|
prog_bar=True,
|
|
1460
1856
|
on_step=True,
|
|
1857
|
+
on_epoch=False,
|
|
1858
|
+
sync_dist=True,
|
|
1859
|
+
)
|
|
1860
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1861
|
+
self._accumulate_loss(loss)
|
|
1862
|
+
self.log(
|
|
1863
|
+
"train/confmaps_loss",
|
|
1864
|
+
confmap_loss,
|
|
1865
|
+
on_step=False,
|
|
1866
|
+
on_epoch=True,
|
|
1867
|
+
sync_dist=True,
|
|
1868
|
+
)
|
|
1869
|
+
self.log(
|
|
1870
|
+
"train/classmap_loss",
|
|
1871
|
+
classmaps_loss,
|
|
1872
|
+
on_step=False,
|
|
1461
1873
|
on_epoch=True,
|
|
1462
|
-
logger=True,
|
|
1463
1874
|
sync_dist=True,
|
|
1464
1875
|
)
|
|
1876
|
+
|
|
1877
|
+
# Compute classification accuracy at GT keypoint locations
|
|
1878
|
+
with torch.no_grad():
|
|
1879
|
+
# Get output stride for class maps
|
|
1880
|
+
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1881
|
+
|
|
1882
|
+
# Get GT instances and sample class maps at those locations
|
|
1883
|
+
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1884
|
+
if instances.dim() == 5:
|
|
1885
|
+
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1886
|
+
num_instances = batch["num_instances"] # (batch,)
|
|
1887
|
+
|
|
1888
|
+
correct = 0
|
|
1889
|
+
total = 0
|
|
1890
|
+
for b in range(instances.shape[0]):
|
|
1891
|
+
n_inst = num_instances[b].item()
|
|
1892
|
+
for inst_idx in range(n_inst):
|
|
1893
|
+
for node_idx in range(instances.shape[2]):
|
|
1894
|
+
# Get keypoint location (in input image space)
|
|
1895
|
+
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
1896
|
+
if torch.isnan(kp).any():
|
|
1897
|
+
continue
|
|
1898
|
+
|
|
1899
|
+
# Convert to class map space
|
|
1900
|
+
x_cm = (
|
|
1901
|
+
(kp[0] / cms_stride)
|
|
1902
|
+
.long()
|
|
1903
|
+
.clamp(0, classmaps.shape[-1] - 1)
|
|
1904
|
+
)
|
|
1905
|
+
y_cm = (
|
|
1906
|
+
(kp[1] / cms_stride)
|
|
1907
|
+
.long()
|
|
1908
|
+
.clamp(0, classmaps.shape[-2] - 1)
|
|
1909
|
+
)
|
|
1910
|
+
|
|
1911
|
+
# Sample predicted and GT class at this location
|
|
1912
|
+
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
1913
|
+
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
1914
|
+
|
|
1915
|
+
if pred_class == gt_class:
|
|
1916
|
+
correct += 1
|
|
1917
|
+
total += 1
|
|
1918
|
+
|
|
1919
|
+
if total > 0:
|
|
1920
|
+
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
1921
|
+
self.log(
|
|
1922
|
+
"train/class_accuracy",
|
|
1923
|
+
class_accuracy,
|
|
1924
|
+
on_step=False,
|
|
1925
|
+
on_epoch=True,
|
|
1926
|
+
sync_dist=True,
|
|
1927
|
+
)
|
|
1928
|
+
|
|
1465
1929
|
return loss
|
|
1466
1930
|
|
|
1467
1931
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1469,6 +1933,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1469
1933
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1470
1934
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1471
1935
|
y_classmap = torch.squeeze(batch["class_maps"], dim=1)
|
|
1936
|
+
X = normalize_on_gpu(X)
|
|
1472
1937
|
|
|
1473
1938
|
preds = self.model(X)
|
|
1474
1939
|
classmaps = preds["ClassMapsHead"]
|
|
@@ -1494,26 +1959,128 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1494
1959
|
}
|
|
1495
1960
|
|
|
1496
1961
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1497
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
1498
1962
|
self.log(
|
|
1499
|
-
"
|
|
1500
|
-
|
|
1963
|
+
"val/loss",
|
|
1964
|
+
val_loss,
|
|
1501
1965
|
prog_bar=True,
|
|
1502
|
-
on_step=
|
|
1966
|
+
on_step=False,
|
|
1503
1967
|
on_epoch=True,
|
|
1504
|
-
logger=True,
|
|
1505
1968
|
sync_dist=True,
|
|
1506
1969
|
)
|
|
1507
1970
|
self.log(
|
|
1508
|
-
"
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1971
|
+
"val/confmaps_loss",
|
|
1972
|
+
confmap_loss,
|
|
1973
|
+
on_step=False,
|
|
1974
|
+
on_epoch=True,
|
|
1975
|
+
sync_dist=True,
|
|
1976
|
+
)
|
|
1977
|
+
self.log(
|
|
1978
|
+
"val/classmap_loss",
|
|
1979
|
+
classmaps_loss,
|
|
1980
|
+
on_step=False,
|
|
1512
1981
|
on_epoch=True,
|
|
1513
|
-
logger=True,
|
|
1514
1982
|
sync_dist=True,
|
|
1515
1983
|
)
|
|
1516
1984
|
|
|
1985
|
+
# Compute classification accuracy at GT keypoint locations
|
|
1986
|
+
with torch.no_grad():
|
|
1987
|
+
# Get output stride for class maps
|
|
1988
|
+
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1989
|
+
|
|
1990
|
+
# Get GT instances and sample class maps at those locations
|
|
1991
|
+
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1992
|
+
if instances.dim() == 5:
|
|
1993
|
+
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1994
|
+
num_instances = batch["num_instances"] # (batch,)
|
|
1995
|
+
|
|
1996
|
+
correct = 0
|
|
1997
|
+
total = 0
|
|
1998
|
+
for b in range(instances.shape[0]):
|
|
1999
|
+
n_inst = num_instances[b].item()
|
|
2000
|
+
for inst_idx in range(n_inst):
|
|
2001
|
+
for node_idx in range(instances.shape[2]):
|
|
2002
|
+
# Get keypoint location (in input image space)
|
|
2003
|
+
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
2004
|
+
if torch.isnan(kp).any():
|
|
2005
|
+
continue
|
|
2006
|
+
|
|
2007
|
+
# Convert to class map space
|
|
2008
|
+
x_cm = (
|
|
2009
|
+
(kp[0] / cms_stride)
|
|
2010
|
+
.long()
|
|
2011
|
+
.clamp(0, classmaps.shape[-1] - 1)
|
|
2012
|
+
)
|
|
2013
|
+
y_cm = (
|
|
2014
|
+
(kp[1] / cms_stride)
|
|
2015
|
+
.long()
|
|
2016
|
+
.clamp(0, classmaps.shape[-2] - 1)
|
|
2017
|
+
)
|
|
2018
|
+
|
|
2019
|
+
# Sample predicted and GT class at this location
|
|
2020
|
+
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
2021
|
+
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
2022
|
+
|
|
2023
|
+
if pred_class == gt_class:
|
|
2024
|
+
correct += 1
|
|
2025
|
+
total += 1
|
|
2026
|
+
|
|
2027
|
+
if total > 0:
|
|
2028
|
+
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
2029
|
+
self.log(
|
|
2030
|
+
"val/class_accuracy",
|
|
2031
|
+
class_accuracy,
|
|
2032
|
+
on_step=False,
|
|
2033
|
+
on_epoch=True,
|
|
2034
|
+
sync_dist=True,
|
|
2035
|
+
)
|
|
2036
|
+
|
|
2037
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
2038
|
+
if self._collect_val_predictions:
|
|
2039
|
+
with torch.no_grad():
|
|
2040
|
+
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
2041
|
+
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
2042
|
+
inference_output = self.bottomup_inf_layer(batch)
|
|
2043
|
+
if isinstance(inference_output, list):
|
|
2044
|
+
inference_output = inference_output[0]
|
|
2045
|
+
|
|
2046
|
+
batch_size = len(batch["frame_idx"])
|
|
2047
|
+
for i in range(batch_size):
|
|
2048
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2049
|
+
|
|
2050
|
+
# Predictions are already in original space (variable number of instances)
|
|
2051
|
+
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
2052
|
+
pred_scores = inference_output["pred_peak_values"][i]
|
|
2053
|
+
if torch.is_tensor(pred_peaks):
|
|
2054
|
+
pred_peaks = pred_peaks.cpu().numpy()
|
|
2055
|
+
if torch.is_tensor(pred_scores):
|
|
2056
|
+
pred_scores = pred_scores.cpu().numpy()
|
|
2057
|
+
|
|
2058
|
+
# Transform GT to original space
|
|
2059
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
2060
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
2061
|
+
if gt_prep.ndim == 4:
|
|
2062
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
2063
|
+
gt_orig = gt_prep / eff
|
|
2064
|
+
num_inst = batch["num_instances"][i].item()
|
|
2065
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
2066
|
+
|
|
2067
|
+
self.val_predictions.append(
|
|
2068
|
+
{
|
|
2069
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2070
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2071
|
+
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
2072
|
+
"pred_scores": pred_scores,
|
|
2073
|
+
}
|
|
2074
|
+
)
|
|
2075
|
+
self.val_ground_truth.append(
|
|
2076
|
+
{
|
|
2077
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2078
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2079
|
+
"gt_instances": gt_orig, # Original space
|
|
2080
|
+
"num_instances": num_inst,
|
|
2081
|
+
}
|
|
2082
|
+
)
|
|
2083
|
+
|
|
1517
2084
|
|
|
1518
2085
|
class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
1519
2086
|
"""Lightning Module for TopDownCenteredInstance ID Model.
|
|
@@ -1607,8 +2174,8 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1607
2174
|
|
|
1608
2175
|
self.node_names = self.head_configs.multi_class_topdown.confmaps.part_names
|
|
1609
2176
|
|
|
1610
|
-
def
|
|
1611
|
-
"""
|
|
2177
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
2178
|
+
"""Extract visualization data from a sample."""
|
|
1612
2179
|
ex = sample.copy()
|
|
1613
2180
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
1614
2181
|
for k, v in ex.items():
|
|
@@ -1616,27 +2183,41 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1616
2183
|
ex[k] = v.to(device=self.device)
|
|
1617
2184
|
ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
|
|
1618
2185
|
output = self.instance_peaks_inf_layer(ex)
|
|
2186
|
+
|
|
1619
2187
|
peaks = output["pred_instance_peaks"].cpu().numpy()
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
2188
|
+
peak_values = output["pred_peak_values"].cpu().numpy()
|
|
2189
|
+
img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1623
2190
|
gt_instances = ex["instance"].cpu().numpy()
|
|
1624
|
-
confmaps = (
|
|
1625
|
-
|
|
1626
|
-
|
|
2191
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
2192
|
+
|
|
2193
|
+
return VisualizationData(
|
|
2194
|
+
image=img,
|
|
2195
|
+
pred_confmaps=confmaps,
|
|
2196
|
+
pred_peaks=peaks,
|
|
2197
|
+
pred_peak_values=peak_values,
|
|
2198
|
+
gt_instances=gt_instances,
|
|
2199
|
+
node_names=list(self.node_names) if self.node_names else [],
|
|
2200
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
2201
|
+
is_paired=True,
|
|
2202
|
+
)
|
|
2203
|
+
|
|
2204
|
+
def visualize_example(self, sample):
|
|
2205
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
2206
|
+
data = self.get_visualization_data(sample)
|
|
1627
2207
|
scale = 1.0
|
|
1628
|
-
if
|
|
2208
|
+
if data.image.shape[0] < 512:
|
|
1629
2209
|
scale = 2.0
|
|
1630
|
-
if
|
|
2210
|
+
if data.image.shape[0] < 256:
|
|
1631
2211
|
scale = 4.0
|
|
1632
|
-
fig = plot_img(
|
|
1633
|
-
plot_confmaps(
|
|
1634
|
-
plot_peaks(gt_instances,
|
|
2212
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
2213
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
2214
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
1635
2215
|
return fig
|
|
1636
2216
|
|
|
1637
2217
|
def forward(self, img):
|
|
1638
2218
|
"""Forward pass of the model."""
|
|
1639
2219
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
2220
|
+
img = normalize_on_gpu(img)
|
|
1640
2221
|
output = self.model(img)
|
|
1641
2222
|
return {
|
|
1642
2223
|
"CenteredInstanceConfmapsHead": output["CenteredInstanceConfmapsHead"],
|
|
@@ -1648,6 +2229,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1648
2229
|
X = torch.squeeze(batch["instance_image"], dim=1)
|
|
1649
2230
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1650
2231
|
y_classvector = batch["class_vectors"]
|
|
2232
|
+
X = normalize_on_gpu(X)
|
|
1651
2233
|
preds = self.model(X)
|
|
1652
2234
|
classvector = preds["ClassVectorsHead"]
|
|
1653
2235
|
confmaps = preds["CenteredInstanceConfmapsHead"]
|
|
@@ -1679,22 +2261,50 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1679
2261
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
1680
2262
|
for node_idx, name in enumerate(self.node_names):
|
|
1681
2263
|
self.log(
|
|
1682
|
-
f"{name}",
|
|
2264
|
+
f"train/confmaps/{name}",
|
|
1683
2265
|
channel_wise_loss[node_idx],
|
|
1684
|
-
prog_bar=
|
|
1685
|
-
on_step=
|
|
2266
|
+
prog_bar=False,
|
|
2267
|
+
on_step=False,
|
|
1686
2268
|
on_epoch=True,
|
|
1687
|
-
logger=True,
|
|
1688
2269
|
sync_dist=True,
|
|
1689
2270
|
)
|
|
1690
2271
|
|
|
2272
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1691
2273
|
self.log(
|
|
1692
|
-
"
|
|
2274
|
+
"loss",
|
|
1693
2275
|
loss,
|
|
1694
2276
|
prog_bar=True,
|
|
1695
2277
|
on_step=True,
|
|
2278
|
+
on_epoch=False,
|
|
2279
|
+
sync_dist=True,
|
|
2280
|
+
)
|
|
2281
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
2282
|
+
self._accumulate_loss(loss)
|
|
2283
|
+
self.log(
|
|
2284
|
+
"train/confmaps_loss",
|
|
2285
|
+
confmap_loss,
|
|
2286
|
+
on_step=False,
|
|
2287
|
+
on_epoch=True,
|
|
2288
|
+
sync_dist=True,
|
|
2289
|
+
)
|
|
2290
|
+
self.log(
|
|
2291
|
+
"train/classvector_loss",
|
|
2292
|
+
classvector_loss,
|
|
2293
|
+
on_step=False,
|
|
2294
|
+
on_epoch=True,
|
|
2295
|
+
sync_dist=True,
|
|
2296
|
+
)
|
|
2297
|
+
|
|
2298
|
+
# Compute classification accuracy
|
|
2299
|
+
with torch.no_grad():
|
|
2300
|
+
pred_classes = torch.argmax(classvector, dim=1)
|
|
2301
|
+
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2302
|
+
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2303
|
+
self.log(
|
|
2304
|
+
"train/class_accuracy",
|
|
2305
|
+
class_accuracy,
|
|
2306
|
+
on_step=False,
|
|
1696
2307
|
on_epoch=True,
|
|
1697
|
-
logger=True,
|
|
1698
2308
|
sync_dist=True,
|
|
1699
2309
|
)
|
|
1700
2310
|
return loss
|
|
@@ -1704,6 +2314,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1704
2314
|
X = torch.squeeze(batch["instance_image"], dim=1)
|
|
1705
2315
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1706
2316
|
y_classvector = batch["class_vectors"]
|
|
2317
|
+
X = normalize_on_gpu(X)
|
|
1707
2318
|
preds = self.model(X)
|
|
1708
2319
|
classvector = preds["ClassVectorsHead"]
|
|
1709
2320
|
confmaps = preds["CenteredInstanceConfmapsHead"]
|
|
@@ -1727,23 +2338,94 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1727
2338
|
"ClassVectorsHead": classvector_loss,
|
|
1728
2339
|
}
|
|
1729
2340
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1730
|
-
|
|
1731
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
1732
2341
|
self.log(
|
|
1733
|
-
"
|
|
1734
|
-
|
|
2342
|
+
"val/loss",
|
|
2343
|
+
val_loss,
|
|
1735
2344
|
prog_bar=True,
|
|
1736
|
-
on_step=
|
|
2345
|
+
on_step=False,
|
|
1737
2346
|
on_epoch=True,
|
|
1738
|
-
logger=True,
|
|
1739
2347
|
sync_dist=True,
|
|
1740
2348
|
)
|
|
1741
2349
|
self.log(
|
|
1742
|
-
"
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
on_step=True,
|
|
2350
|
+
"val/confmaps_loss",
|
|
2351
|
+
confmap_loss,
|
|
2352
|
+
on_step=False,
|
|
1746
2353
|
on_epoch=True,
|
|
1747
|
-
logger=True,
|
|
1748
2354
|
sync_dist=True,
|
|
1749
2355
|
)
|
|
2356
|
+
self.log(
|
|
2357
|
+
"val/classvector_loss",
|
|
2358
|
+
classvector_loss,
|
|
2359
|
+
on_step=False,
|
|
2360
|
+
on_epoch=True,
|
|
2361
|
+
sync_dist=True,
|
|
2362
|
+
)
|
|
2363
|
+
|
|
2364
|
+
# Compute classification accuracy
|
|
2365
|
+
with torch.no_grad():
|
|
2366
|
+
pred_classes = torch.argmax(classvector, dim=1)
|
|
2367
|
+
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2368
|
+
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2369
|
+
self.log(
|
|
2370
|
+
"val/class_accuracy",
|
|
2371
|
+
class_accuracy,
|
|
2372
|
+
on_step=False,
|
|
2373
|
+
on_epoch=True,
|
|
2374
|
+
sync_dist=True,
|
|
2375
|
+
)
|
|
2376
|
+
|
|
2377
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
2378
|
+
if self._collect_val_predictions:
|
|
2379
|
+
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
2380
|
+
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
2381
|
+
|
|
2382
|
+
with torch.no_grad():
|
|
2383
|
+
inference_output = self.instance_peaks_inf_layer(batch)
|
|
2384
|
+
|
|
2385
|
+
batch_size = len(batch["frame_idx"])
|
|
2386
|
+
for i in range(batch_size):
|
|
2387
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2388
|
+
|
|
2389
|
+
# Predictions from inference (crop-relative, original scale)
|
|
2390
|
+
pred_peaks_crop = (
|
|
2391
|
+
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
2392
|
+
)
|
|
2393
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
2394
|
+
|
|
2395
|
+
# Compute bbox offset in original space from SAVED prep bbox
|
|
2396
|
+
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
2397
|
+
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
2398
|
+
bbox_top_left_orig = (
|
|
2399
|
+
bbox_prep[0] / eff
|
|
2400
|
+
) # Top-left corner in original space
|
|
2401
|
+
|
|
2402
|
+
# Full image coordinates (original space)
|
|
2403
|
+
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
2404
|
+
|
|
2405
|
+
# GT transform: crop-relative preprocessed -> full image original
|
|
2406
|
+
gt_crop_prep = (
|
|
2407
|
+
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
2408
|
+
) # (n_nodes, 2)
|
|
2409
|
+
gt_crop_orig = gt_crop_prep / eff
|
|
2410
|
+
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
2411
|
+
|
|
2412
|
+
self.val_predictions.append(
|
|
2413
|
+
{
|
|
2414
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2415
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2416
|
+
"pred_peaks": pred_peaks_full.reshape(
|
|
2417
|
+
1, -1, 2
|
|
2418
|
+
), # (1, n_nodes, 2)
|
|
2419
|
+
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
2420
|
+
}
|
|
2421
|
+
)
|
|
2422
|
+
self.val_ground_truth.append(
|
|
2423
|
+
{
|
|
2424
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2425
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2426
|
+
"gt_instances": gt_full_orig.reshape(
|
|
2427
|
+
1, -1, 2
|
|
2428
|
+
), # (1, n_nodes, 2)
|
|
2429
|
+
"num_instances": 1,
|
|
2430
|
+
}
|
|
2431
|
+
)
|