rslearn 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rslearn/train/dataset.py CHANGED
@@ -496,53 +496,6 @@ class SplitConfig:
496
496
  overlap_ratio: deprecated, use overlap_pixels instead
497
497
  load_all_patches: deprecated, use load_all_crops instead
498
498
  """
499
- # Handle deprecated load_all_patches parameter
500
- if load_all_patches is not None:
501
- warnings.warn(
502
- "load_all_patches is deprecated, use load_all_crops instead",
503
- FutureWarning,
504
- stacklevel=2,
505
- )
506
- if load_all_crops is not None:
507
- raise ValueError(
508
- "Cannot specify both load_all_patches and load_all_crops"
509
- )
510
- load_all_crops = load_all_patches
511
- # Handle deprecated patch_size parameter
512
- if patch_size is not None:
513
- warnings.warn(
514
- "patch_size is deprecated, use crop_size instead",
515
- FutureWarning,
516
- stacklevel=2,
517
- )
518
- if crop_size is not None:
519
- raise ValueError("Cannot specify both patch_size and crop_size")
520
- crop_size = patch_size
521
-
522
- # Normalize crop_size to tuple[int, int] | None
523
- self.crop_size: tuple[int, int] | None = None
524
- if crop_size is not None:
525
- if isinstance(crop_size, int):
526
- self.crop_size = (crop_size, crop_size)
527
- else:
528
- self.crop_size = crop_size
529
-
530
- # Handle deprecated overlap_ratio parameter
531
- if overlap_ratio is not None:
532
- warnings.warn(
533
- "overlap_ratio is deprecated, use overlap_pixels instead",
534
- FutureWarning,
535
- stacklevel=2,
536
- )
537
- if overlap_pixels is not None:
538
- raise ValueError("Cannot specify both overlap_ratio and overlap_pixels")
539
- if self.crop_size is None:
540
- raise ValueError("overlap_ratio requires crop_size to be set")
541
- overlap_pixels = round(self.crop_size[0] * overlap_ratio)
542
-
543
- if overlap_pixels is not None and overlap_pixels < 0:
544
- raise ValueError("overlap_pixels must be non-negative")
545
-
546
499
  self.groups = groups
547
500
  self.names = names
548
501
  self.tags = tags
@@ -555,13 +508,22 @@ class SplitConfig:
555
508
  output_layer_name_skip_inference_if_exists
556
509
  )
557
510
 
558
- # Note that load_all_crops is handled by the RslearnDataModule rather than the
559
- # ModelDataset.
560
- self.load_all_crops = load_all_crops
561
- self.overlap_pixels = overlap_pixels
511
+ # These have deprecated equivalents -- we store both raw values since we don't
512
+ # have a complete picture until the final merged SplitConfig is computed. We
513
+ # raise deprecation warnings in merge_and_validate and we disambiguate them in
514
+ # get_ functions (so the variables should never be accessed directly).
515
+ self._crop_size = crop_size
516
+ self._patch_size = patch_size
517
+ self._overlap_pixels = overlap_pixels
518
+ self._overlap_ratio = overlap_ratio
519
+ self._load_all_crops = load_all_crops
520
+ self._load_all_patches = load_all_patches
562
521
 
563
- def update(self, other: "SplitConfig") -> "SplitConfig":
564
- """Override settings in this SplitConfig with those in another.
522
+ def _merge(self, other: "SplitConfig") -> "SplitConfig":
523
+ """Merge settings from another SplitConfig into this one.
524
+
525
+ Args:
526
+ other: the config to merge in (its non-None values override self's)
565
527
 
566
528
  Returns:
567
529
  the resulting SplitConfig combining the settings.
@@ -574,9 +536,12 @@ class SplitConfig:
574
536
  num_patches=self.num_patches,
575
537
  transforms=self.transforms,
576
538
  sampler=self.sampler,
577
- crop_size=self.crop_size,
578
- overlap_pixels=self.overlap_pixels,
579
- load_all_crops=self.load_all_crops,
539
+ crop_size=self._crop_size,
540
+ patch_size=self._patch_size,
541
+ overlap_pixels=self._overlap_pixels,
542
+ overlap_ratio=self._overlap_ratio,
543
+ load_all_crops=self._load_all_crops,
544
+ load_all_patches=self._load_all_patches,
580
545
  skip_targets=self.skip_targets,
581
546
  output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
582
547
  )
@@ -594,12 +559,18 @@ class SplitConfig:
594
559
  result.transforms = other.transforms
595
560
  if other.sampler:
596
561
  result.sampler = other.sampler
597
- if other.crop_size:
598
- result.crop_size = other.crop_size
599
- if other.overlap_pixels is not None:
600
- result.overlap_pixels = other.overlap_pixels
601
- if other.load_all_crops is not None:
602
- result.load_all_crops = other.load_all_crops
562
+ if other._crop_size is not None:
563
+ result._crop_size = other._crop_size
564
+ if other._patch_size is not None:
565
+ result._patch_size = other._patch_size
566
+ if other._overlap_pixels is not None:
567
+ result._overlap_pixels = other._overlap_pixels
568
+ if other._overlap_ratio is not None:
569
+ result._overlap_ratio = other._overlap_ratio
570
+ if other._load_all_crops is not None:
571
+ result._load_all_crops = other._load_all_crops
572
+ if other._load_all_patches is not None:
573
+ result._load_all_patches = other._load_all_patches
603
574
  if other.skip_targets is not None:
604
575
  result.skip_targets = other.skip_targets
605
576
  if other.output_layer_name_skip_inference_if_exists is not None:
@@ -608,17 +579,90 @@ class SplitConfig:
608
579
  )
609
580
  return result
610
581
 
582
+ @staticmethod
583
+ def merge_and_validate(configs: list["SplitConfig"]) -> "SplitConfig":
584
+ """Merge a list of SplitConfigs and validate the result.
585
+
586
+ Args:
587
+ configs: list of SplitConfig to merge. Later configs override earlier ones.
588
+
589
+ Returns:
590
+ the merged and validated SplitConfig.
591
+ """
592
+ if not configs:
593
+ return SplitConfig()
594
+
595
+ result = configs[0]
596
+ for config in configs[1:]:
597
+ result = result._merge(config)
598
+
599
+ # Emit deprecation warnings
600
+ if result._patch_size is not None:
601
+ warnings.warn(
602
+ "patch_size is deprecated, use crop_size instead",
603
+ FutureWarning,
604
+ stacklevel=2,
605
+ )
606
+ if result._overlap_ratio is not None:
607
+ warnings.warn(
608
+ "overlap_ratio is deprecated, use overlap_pixels instead",
609
+ FutureWarning,
610
+ stacklevel=2,
611
+ )
612
+ if result._load_all_patches is not None:
613
+ warnings.warn(
614
+ "load_all_patches is deprecated, use load_all_crops instead",
615
+ FutureWarning,
616
+ stacklevel=2,
617
+ )
618
+
619
+ # Check for conflicting parameters
620
+ if result._crop_size is not None and result._patch_size is not None:
621
+ raise ValueError("Cannot specify both crop_size and patch_size")
622
+ if result._overlap_pixels is not None and result._overlap_ratio is not None:
623
+ raise ValueError("Cannot specify both overlap_pixels and overlap_ratio")
624
+ if result._load_all_crops is not None and result._load_all_patches is not None:
625
+ raise ValueError("Cannot specify both load_all_crops and load_all_patches")
626
+
627
+ # Validate overlap_pixels is non-negative
628
+ if result._overlap_pixels is not None and result._overlap_pixels < 0:
629
+ raise ValueError("overlap_pixels must be non-negative")
630
+
631
+ # overlap_pixels requires load_all_crops.
632
+ if result.get_overlap_pixels() > 0 and not result.get_load_all_crops():
633
+ raise ValueError(
634
+ "overlap_pixels requires load_all_crops to be True since (overlap is only used during sliding window inference"
635
+ )
636
+
637
+ return result
638
+
611
639
  def get_crop_size(self) -> tuple[int, int] | None:
612
- """Get crop size as tuple."""
613
- return self.crop_size
640
+ """Get crop size as tuple, handling deprecated patch_size."""
641
+ size = self._crop_size if self._crop_size is not None else self._patch_size
642
+ if size is None:
643
+ return None
644
+ if isinstance(size, int):
645
+ return (size, size)
646
+ return size
614
647
 
615
648
  def get_overlap_pixels(self) -> int:
616
- """Get the overlap pixels (default 0)."""
617
- return self.overlap_pixels if self.overlap_pixels is not None else 0
649
+ """Get the overlap pixels (default 0), handling deprecated overlap_ratio."""
650
+ if self._overlap_pixels is not None:
651
+ return self._overlap_pixels
652
+ if self._overlap_ratio is not None:
653
+ crop_size = self.get_crop_size()
654
+ if crop_size is None:
655
+ raise ValueError("overlap_ratio requires crop_size to be set")
656
+ return round(crop_size[0] * self._overlap_ratio)
657
+ return 0
618
658
 
619
659
  def get_load_all_crops(self) -> bool:
620
- """Returns whether loading all patches is enabled (default False)."""
621
- return True if self.load_all_crops is True else False
660
+ """Returns whether loading all crops is enabled (default False)."""
661
+ if self._load_all_crops is not None:
662
+ return self._load_all_crops
663
+ if self._load_all_patches is not None:
664
+ return self._load_all_patches
665
+ return False
622
666
 
623
667
  def get_skip_targets(self) -> bool:
624
668
  """Returns whether skip_targets is enabled (default False)."""
@@ -697,7 +741,7 @@ class ModelDataset(torch.utils.data.Dataset):
697
741
  task: Task,
698
742
  workers: int,
699
743
  name: str | None = None,
700
- fix_patch_pick: bool = False,
744
+ fix_crop_pick: bool = False,
701
745
  index_mode: IndexMode = IndexMode.OFF,
702
746
  ) -> None:
703
747
  """Instantiate a new ModelDataset.
@@ -709,7 +753,7 @@ class ModelDataset(torch.utils.data.Dataset):
709
753
  task: the task to train on
710
754
  workers: number of workers to use for initializing the dataset
711
755
  name: name of the dataset
712
- fix_patch_pick: if True, fix the patch pick to be the same every time
756
+ fix_crop_pick: if True, fix the crop pick to be the same every time
713
757
  for a given window. Useful for testing (default: False)
714
758
  index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
715
759
  """
@@ -718,14 +762,14 @@ class ModelDataset(torch.utils.data.Dataset):
718
762
  self.inputs = inputs
719
763
  self.task = task
720
764
  self.name = name
721
- self.fix_patch_pick = fix_patch_pick
765
+ self.fix_crop_pick = fix_crop_pick
722
766
  if split_config.transforms:
723
767
  self.transforms = Sequential(*split_config.transforms)
724
768
  else:
725
769
  self.transforms = rslearn.train.transforms.transform.Identity()
726
770
 
727
771
  # Get normalized crop size from the SplitConfig.
728
- # But if load all patches is enabled, this is handled by AllCropsDataset, so
772
+ # But if load_all_crops is enabled, this is handled by AllCropsDataset, so
729
773
  # here we instead load the entire windows.
730
774
  if split_config.get_load_all_crops():
731
775
  self.crop_size = None
@@ -952,7 +996,7 @@ class ModelDataset(torch.utils.data.Dataset):
952
996
  """Get a list of examples in the dataset.
953
997
 
954
998
  If load_all_crops is False, this is a list of Windows. Otherwise, this is a
955
- list of (window, crop_bounds, (crop_idx, # patches)) tuples.
999
+ list of (window, crop_bounds, (crop_idx, # crops)) tuples.
956
1000
  """
957
1001
  if self.dataset_examples is None:
958
1002
  logger.debug(
@@ -985,7 +1029,7 @@ class ModelDataset(torch.utils.data.Dataset):
985
1029
  """
986
1030
  dataset_examples = self.get_dataset_examples()
987
1031
  example = dataset_examples[idx]
988
- rng = random.Random(idx if self.fix_patch_pick else None)
1032
+ rng = random.Random(idx if self.fix_crop_pick else None)
989
1033
 
990
1034
  # Select bounds to read.
991
1035
  if self.crop_size:
@@ -6,12 +6,14 @@ from typing import Any
6
6
 
7
7
  import lightning as L
8
8
  import torch
9
+ import wandb
9
10
  from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
10
11
  from PIL import Image
11
12
  from upath import UPath
12
13
 
13
14
  from rslearn.log_utils import get_logger
14
15
 
16
+ from .metrics import NonScalarMetricOutput
15
17
  from .model_context import ModelContext, ModelOutput
16
18
  from .optimizer import AdamW, OptimizerFactory
17
19
  from .scheduler import PlateauScheduler, SchedulerFactory
@@ -210,15 +212,53 @@ class RslearnLightningModule(L.LightningModule):
210
212
  # Fail silently for single-dataset case, which is okay
211
213
  pass
212
214
 
215
+ def _log_non_scalar_metric(self, name: str, value: NonScalarMetricOutput) -> None:
216
+ """Log a non-scalar metric to wandb.
217
+
218
+ Args:
219
+ name: the metric name (e.g., "val_confusion_matrix")
220
+ value: the non-scalar metric output
221
+ """
222
+ # The non-scalar metrics are logging directly without Lightning
223
+ # So we need to skip logging during sanity check.
224
+ if self.trainer.sanity_checking:
225
+ return
226
+
227
+ # Wandb is required for logging non-scalar metrics.
228
+ if not wandb.run:
229
+ logger.warning(
230
+ f"Weights & Biases is not initialized, skipping logging of {name}"
231
+ )
232
+ return
233
+
234
+ value.log_to_wandb(name)
235
+
213
236
  def on_validation_epoch_end(self) -> None:
214
237
  """Compute and log validation metrics at epoch end.
215
238
 
216
239
  We manually compute and log metrics here (instead of passing the MetricCollection
217
240
  to log_dict) because MetricCollection.compute() properly flattens dict-returning
218
241
  metrics, while log_dict expects each metric to return a scalar tensor.
242
+
243
+ Non-scalar metrics (like confusion matrices) are logged separately using
244
+ logger-specific APIs.
219
245
  """
220
246
  metrics = self.val_metrics.compute()
221
- self.log_dict(metrics)
247
+
248
+ # Separate scalar and non-scalar metrics
249
+ scalar_metrics = {}
250
+ for k, v in metrics.items():
251
+ if isinstance(v, NonScalarMetricOutput):
252
+ self._log_non_scalar_metric(k, v)
253
+ elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
254
+ raise ValueError(
255
+ f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
256
+ "Wrap it in a NonScalarMetricOutput subclass."
257
+ )
258
+ else:
259
+ scalar_metrics[k] = v
260
+
261
+ self.log_dict(scalar_metrics)
222
262
  self.val_metrics.reset()
223
263
 
224
264
  def on_test_epoch_end(self) -> None:
@@ -227,14 +267,30 @@ class RslearnLightningModule(L.LightningModule):
227
267
  We manually compute and log metrics here (instead of passing the MetricCollection
228
268
  to log_dict) because MetricCollection.compute() properly flattens dict-returning
229
269
  metrics, while log_dict expects each metric to return a scalar tensor.
270
+
271
+ Non-scalar metrics (like confusion matrices) are logged separately.
230
272
  """
231
273
  metrics = self.test_metrics.compute()
232
- self.log_dict(metrics)
274
+
275
+ # Separate scalar and non-scalar metrics
276
+ scalar_metrics = {}
277
+ for k, v in metrics.items():
278
+ if isinstance(v, NonScalarMetricOutput):
279
+ self._log_non_scalar_metric(k, v)
280
+ elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
281
+ raise ValueError(
282
+ f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
283
+ "Wrap it in a NonScalarMetricOutput subclass."
284
+ )
285
+ else:
286
+ scalar_metrics[k] = v
287
+
288
+ self.log_dict(scalar_metrics)
233
289
  self.test_metrics.reset()
234
290
 
235
291
  if self.metrics_file:
236
292
  with open(self.metrics_file, "w") as f:
237
- metrics_dict = {k: v.item() for k, v in metrics.items()}
293
+ metrics_dict = {k: v.item() for k, v in scalar_metrics.items()}
238
294
  json.dump(metrics_dict, f, indent=4)
239
295
  logger.info(f"Saved metrics to {self.metrics_file}")
240
296
 
@@ -0,0 +1,162 @@
1
+ """Metric output classes for non-scalar metrics."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import wandb
8
+ from torchmetrics import Metric
9
+
10
+ from rslearn.log_utils import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class NonScalarMetricOutput(ABC):
17
+ """Base class for non-scalar metric outputs that need special logging.
18
+
19
+ Subclasses should implement the log_to_wandb method to define how the metric
20
+ should be logged (only supports logging to Weights & Biases).
21
+ """
22
+
23
+ @abstractmethod
24
+ def log_to_wandb(self, name: str) -> None:
25
+ """Log this metric to wandb.
26
+
27
+ Args:
28
+ name: the metric name
29
+ """
30
+ pass
31
+
32
+
33
+ @dataclass
34
+ class ConfusionMatrixOutput(NonScalarMetricOutput):
35
+ """Confusion matrix metric output.
36
+
37
+ Args:
38
+ confusion_matrix: confusion matrix of shape (num_classes, num_classes)
39
+ where cm[i, j] is the count of samples with true label i and predicted
40
+ label j.
41
+ class_names: optional list of class names for axis labels
42
+ """
43
+
44
+ confusion_matrix: torch.Tensor
45
+ class_names: list[str] | None = None
46
+
47
+ def _expand_confusion_matrix(self) -> tuple[list[int], list[int]]:
48
+ """Expand confusion matrix to (preds, labels) pairs for wandb.
49
+
50
+ Returns:
51
+ Tuple of (preds, labels) as lists of integers.
52
+ """
53
+ cm = self.confusion_matrix.detach().cpu()
54
+
55
+ # Handle extra dimensions from distributed reduction
56
+ if cm.dim() > 2:
57
+ cm = cm.sum(dim=0)
58
+
59
+ total = cm.sum().item()
60
+ if total == 0:
61
+ return [], []
62
+
63
+ preds = []
64
+ labels = []
65
+ for true_label in range(cm.shape[0]):
66
+ for pred_label in range(cm.shape[1]):
67
+ count = cm[true_label, pred_label].item()
68
+ if count > 0:
69
+ preds.extend([pred_label] * int(count))
70
+ labels.extend([true_label] * int(count))
71
+
72
+ return preds, labels
73
+
74
+ def log_to_wandb(self, name: str) -> None:
75
+ """Log confusion matrix to wandb.
76
+
77
+ Args:
78
+ name: the metric name (e.g., "val_confusion_matrix")
79
+ """
80
+ preds, labels = self._expand_confusion_matrix()
81
+
82
+ if len(preds) == 0:
83
+ logger.warning(f"No samples to log for {name}")
84
+ return
85
+
86
+ num_classes = self.confusion_matrix.shape[0]
87
+ if self.class_names is None:
88
+ class_names = [str(i) for i in range(num_classes)]
89
+ else:
90
+ class_names = self.class_names
91
+
92
+ wandb.log(
93
+ {
94
+ name: wandb.plot.confusion_matrix(
95
+ preds=preds,
96
+ y_true=labels,
97
+ class_names=class_names,
98
+ title=name,
99
+ ),
100
+ },
101
+ )
102
+
103
+
104
+ class ConfusionMatrixMetric(Metric):
105
+ """Confusion matrix metric that works on flattened inputs.
106
+
107
+ Expects preds of shape (N, C) and labels of shape (N,).
108
+ Should be wrapped by ClassificationMetric or SegmentationMetric
109
+ which handle the task-specific preprocessing.
110
+
111
+ Args:
112
+ num_classes: number of classes
113
+ class_names: optional list of class names for labeling
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ num_classes: int,
119
+ class_names: list[str] | None = None,
120
+ ):
121
+ """Initialize a new ConfusionMatrixMetric.
122
+
123
+ Args:
124
+ num_classes: number of classes
125
+ class_names: optional list of class names for labeling
126
+ """
127
+ super().__init__()
128
+ self.num_classes = num_classes
129
+ self.class_names = class_names
130
+ self.add_state(
131
+ "confusion_matrix",
132
+ default=torch.zeros(num_classes, num_classes, dtype=torch.long),
133
+ dist_reduce_fx="sum",
134
+ )
135
+
136
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
137
+ """Update metric.
138
+
139
+ Args:
140
+ preds: predictions of shape (N, C) - probabilities
141
+ labels: ground truth of shape (N,) - class indices
142
+ """
143
+ if len(preds) == 0:
144
+ return
145
+
146
+ pred_classes = preds.argmax(dim=1) # (N,)
147
+
148
+ for true_label in range(self.num_classes):
149
+ for pred_label in range(self.num_classes):
150
+ count = ((labels == true_label) & (pred_classes == pred_label)).sum()
151
+ self.confusion_matrix[true_label, pred_label] += count
152
+
153
+ def compute(self) -> ConfusionMatrixOutput:
154
+ """Returns the confusion matrix wrapped in ConfusionMatrixOutput."""
155
+ return ConfusionMatrixOutput(
156
+ confusion_matrix=self.confusion_matrix,
157
+ class_names=self.class_names,
158
+ )
159
+
160
+ def reset(self) -> None:
161
+ """Reset metric."""
162
+ super().reset()
@@ -16,6 +16,7 @@ from torchmetrics.classification import (
16
16
  )
17
17
 
18
18
  from rslearn.models.component import FeatureVector, Predictor
19
+ from rslearn.train.metrics import ConfusionMatrixMetric
19
20
  from rslearn.train.model_context import (
20
21
  ModelContext,
21
22
  ModelOutput,
@@ -44,6 +45,7 @@ class ClassificationTask(BasicTask):
44
45
  f1_metric_kwargs: dict[str, Any] = {},
45
46
  positive_class: str | None = None,
46
47
  positive_class_threshold: float = 0.5,
48
+ enable_confusion_matrix: bool = False,
47
49
  **kwargs: Any,
48
50
  ):
49
51
  """Initialize a new ClassificationTask.
@@ -69,6 +71,8 @@ class ClassificationTask(BasicTask):
69
71
  positive_class: positive class name.
70
72
  positive_class_threshold: threshold for classifying the positive class in
71
73
  binary classification (default 0.5).
74
+ enable_confusion_matrix: whether to compute confusion matrix (default false).
75
+ If true, it requires wandb to be initialized for logging.
72
76
  kwargs: other arguments to pass to BasicTask
73
77
  """
74
78
  super().__init__(**kwargs)
@@ -84,6 +88,7 @@ class ClassificationTask(BasicTask):
84
88
  self.f1_metric_kwargs = f1_metric_kwargs
85
89
  self.positive_class = positive_class
86
90
  self.positive_class_threshold = positive_class_threshold
91
+ self.enable_confusion_matrix = enable_confusion_matrix
87
92
 
88
93
  if self.positive_class_threshold != 0.5:
89
94
  # Must be binary classification
@@ -278,6 +283,14 @@ class ClassificationTask(BasicTask):
278
283
  )
279
284
  metrics["f1"] = ClassificationMetric(MulticlassF1Score(**kwargs))
280
285
 
286
+ if self.enable_confusion_matrix:
287
+ metrics["confusion_matrix"] = ClassificationMetric(
288
+ ConfusionMatrixMetric(
289
+ num_classes=len(self.classes),
290
+ class_names=self.classes,
291
+ ),
292
+ )
293
+
281
294
  return MetricCollection(metrics)
282
295
 
283
296
 
@@ -149,22 +149,28 @@ class PerPixelRegressionHead(Predictor):
149
149
  """Head for per-pixel regression task."""
150
150
 
151
151
  def __init__(
152
- self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
152
+ self,
153
+ loss_mode: Literal["mse", "l1", "huber"] = "mse",
154
+ use_sigmoid: bool = False,
155
+ huber_delta: float = 1.0,
153
156
  ):
154
- """Initialize a new RegressionHead.
157
+ """Initialize a new PerPixelRegressionHead.
155
158
 
156
159
  Args:
157
- loss_mode: the loss function to use, either "mse" (default) or "l1".
160
+ loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
158
161
  use_sigmoid: whether to apply a sigmoid activation on the output. This
159
162
  requires targets to be between 0-1.
163
+ huber_delta: delta parameter for Huber loss (only used when
164
+ loss_mode="huber").
160
165
  """
161
166
  super().__init__()
162
167
 
163
- if loss_mode not in ["mse", "l1"]:
164
- raise ValueError("invalid loss mode")
168
+ if loss_mode not in ["mse", "l1", "huber"]:
169
+ raise ValueError(f"invalid loss mode {loss_mode}")
165
170
 
166
171
  self.loss_mode = loss_mode
167
172
  self.use_sigmoid = use_sigmoid
173
+ self.huber_delta = huber_delta
168
174
 
169
175
  def forward(
170
176
  self,
@@ -217,8 +223,15 @@ class PerPixelRegressionHead(Predictor):
217
223
  scores = torch.square(outputs - labels)
218
224
  elif self.loss_mode == "l1":
219
225
  scores = torch.abs(outputs - labels)
226
+ elif self.loss_mode == "huber":
227
+ scores = torch.nn.functional.huber_loss(
228
+ outputs,
229
+ labels,
230
+ reduction="none",
231
+ delta=self.huber_delta,
232
+ )
220
233
  else:
221
- assert False
234
+ raise ValueError(f"unknown loss mode {self.loss_mode}")
222
235
 
223
236
  # Compute average but only over valid pixels.
224
237
  mask_total = mask.sum()
@@ -196,18 +196,24 @@ class RegressionHead(Predictor):
196
196
  """Head for regression task."""
197
197
 
198
198
  def __init__(
199
- self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
199
+ self,
200
+ loss_mode: Literal["mse", "l1", "huber"] = "mse",
201
+ use_sigmoid: bool = False,
202
+ huber_delta: float = 1.0,
200
203
  ):
201
204
  """Initialize a new RegressionHead.
202
205
 
203
206
  Args:
204
- loss_mode: the loss function to use, either "mse" (default) or "l1".
207
+ loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
205
208
  use_sigmoid: whether to apply a sigmoid activation on the output. This
206
209
  requires targets to be between 0-1.
210
+ huber_delta: delta parameter for Huber loss (only used when
211
+ loss_mode="huber").
207
212
  """
208
213
  super().__init__()
209
214
  self.loss_mode = loss_mode
210
215
  self.use_sigmoid = use_sigmoid
216
+ self.huber_delta = huber_delta
211
217
 
212
218
  def forward(
213
219
  self,
@@ -251,6 +257,16 @@ class RegressionHead(Predictor):
251
257
  losses["regress"] = torch.mean(torch.square(outputs - labels) * mask)
252
258
  elif self.loss_mode == "l1":
253
259
  losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
260
+ elif self.loss_mode == "huber":
261
+ losses["regress"] = torch.mean(
262
+ torch.nn.functional.huber_loss(
263
+ outputs,
264
+ labels,
265
+ reduction="none",
266
+ delta=self.huber_delta,
267
+ )
268
+ * mask
269
+ )
254
270
  else:
255
271
  raise ValueError(f"unknown loss mode {self.loss_mode}")
256
272