rslearn 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/dataset/storage/file.py +16 -12
  30. rslearn/models/concatenate_features.py +6 -1
  31. rslearn/tile_stores/default.py +4 -2
  32. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  33. rslearn/train/data_module.py +36 -33
  34. rslearn/train/dataset.py +159 -68
  35. rslearn/train/lightning_module.py +60 -4
  36. rslearn/train/metrics.py +162 -0
  37. rslearn/train/model_context.py +3 -3
  38. rslearn/train/prediction_writer.py +69 -41
  39. rslearn/train/tasks/classification.py +14 -1
  40. rslearn/train/tasks/detection.py +5 -5
  41. rslearn/train/tasks/per_pixel_regression.py +19 -6
  42. rslearn/train/tasks/regression.py +19 -3
  43. rslearn/train/tasks/segmentation.py +17 -0
  44. rslearn/utils/__init__.py +2 -0
  45. rslearn/utils/fsspec.py +51 -1
  46. rslearn/utils/geometry.py +21 -0
  47. rslearn/utils/m2m_api.py +251 -0
  48. rslearn/utils/retry_session.py +43 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/METADATA +6 -3
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/RECORD +55 -50
  51. rslearn/data_sources/earthdata_srtm.py +0 -282
  52. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
  53. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
  54. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
  55. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
  56. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py CHANGED
@@ -8,6 +8,7 @@ import random
8
8
  import tempfile
9
9
  import time
10
10
  import uuid
11
+ import warnings
11
12
  from datetime import datetime
12
13
  from enum import StrEnum
13
14
  from typing import Any
@@ -456,11 +457,15 @@ class SplitConfig:
456
457
  num_patches: int | None = None,
457
458
  transforms: list[torch.nn.Module] | None = None,
458
459
  sampler: SamplerFactory | None = None,
460
+ crop_size: int | tuple[int, int] | None = None,
461
+ overlap_pixels: int | None = None,
462
+ load_all_crops: bool | None = None,
463
+ skip_targets: bool | None = None,
464
+ output_layer_name_skip_inference_if_exists: str | None = None,
465
+ # Deprecated parameters (for backwards compatibility)
459
466
  patch_size: int | tuple[int, int] | None = None,
460
467
  overlap_ratio: float | None = None,
461
468
  load_all_patches: bool | None = None,
462
- skip_targets: bool | None = None,
463
- output_layer_name_skip_inference_if_exists: str | None = None,
464
469
  ) -> None:
465
470
  """Initialize a new SplitConfig.
466
471
 
@@ -475,18 +480,21 @@ class SplitConfig:
475
480
  num_patches: limit this split to this many patches
476
481
  transforms: transforms to apply
477
482
  sampler: SamplerFactory for this split
478
- patch_size: an optional square size or (width, height) tuple. If set, read
483
+ crop_size: an optional square size or (width, height) tuple. If set, read
479
484
  crops of this size rather than entire windows.
480
- overlap_ratio: an optional float between 0 and 1. If set, read patches with
481
- this ratio of overlap.
482
- load_all_patches: with patch_size set, rather than sampling a random patch
483
- for each window, read all patches as separate sequential items in the
485
+ overlap_pixels: the number of pixels shared between adjacent crops during
486
+ sliding window inference.
487
+ load_all_crops: with crop_size set, rather than sampling a random crop
488
+ for each window, read all crops as separate sequential items in the
484
489
  dataset.
485
490
  skip_targets: whether to skip targets when loading inputs
486
491
  output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
487
492
  If set, windows that already
488
493
  have this layer completed will be skipped (useful for resuming
489
494
  partial inference runs).
495
+ patch_size: deprecated, use crop_size instead
496
+ overlap_ratio: deprecated, use overlap_pixels instead
497
+ load_all_patches: deprecated, use load_all_crops instead
490
498
  """
491
499
  self.groups = groups
492
500
  self.names = names
@@ -495,22 +503,27 @@ class SplitConfig:
495
503
  self.num_patches = num_patches
496
504
  self.transforms = transforms
497
505
  self.sampler = sampler
498
- self.patch_size = patch_size
499
506
  self.skip_targets = skip_targets
500
507
  self.output_layer_name_skip_inference_if_exists = (
501
508
  output_layer_name_skip_inference_if_exists
502
509
  )
503
510
 
504
- # Note that load_all_patches are handled by the RslearnDataModule rather than
505
- # the ModelDataset.
506
- self.load_all_patches = load_all_patches
507
- self.overlap_ratio = overlap_ratio
511
+ # These have deprecated equivalents -- we store both raw values since we don't
512
+ # have a complete picture until the final merged SplitConfig is computed. We
513
+ # raise deprecation warnings in merge_and_validate and we disambiguate them in
514
+ # get_ functions (so the variables should never be accessed directly).
515
+ self._crop_size = crop_size
516
+ self._patch_size = patch_size
517
+ self._overlap_pixels = overlap_pixels
518
+ self._overlap_ratio = overlap_ratio
519
+ self._load_all_crops = load_all_crops
520
+ self._load_all_patches = load_all_patches
508
521
 
509
- if self.overlap_ratio is not None and not (0 < self.overlap_ratio < 1):
510
- raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
522
+ def _merge(self, other: "SplitConfig") -> "SplitConfig":
523
+ """Merge settings from another SplitConfig into this one.
511
524
 
512
- def update(self, other: "SplitConfig") -> "SplitConfig":
513
- """Override settings in this SplitConfig with those in another.
525
+ Args:
526
+ other: the config to merge in (its non-None values override self's)
514
527
 
515
528
  Returns:
516
529
  the resulting SplitConfig combining the settings.
@@ -523,9 +536,12 @@ class SplitConfig:
523
536
  num_patches=self.num_patches,
524
537
  transforms=self.transforms,
525
538
  sampler=self.sampler,
526
- patch_size=self.patch_size,
527
- overlap_ratio=self.overlap_ratio,
528
- load_all_patches=self.load_all_patches,
539
+ crop_size=self._crop_size,
540
+ patch_size=self._patch_size,
541
+ overlap_pixels=self._overlap_pixels,
542
+ overlap_ratio=self._overlap_ratio,
543
+ load_all_crops=self._load_all_crops,
544
+ load_all_patches=self._load_all_patches,
529
545
  skip_targets=self.skip_targets,
530
546
  output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
531
547
  )
@@ -543,12 +559,18 @@ class SplitConfig:
543
559
  result.transforms = other.transforms
544
560
  if other.sampler:
545
561
  result.sampler = other.sampler
546
- if other.patch_size:
547
- result.patch_size = other.patch_size
548
- if other.overlap_ratio is not None:
549
- result.overlap_ratio = other.overlap_ratio
550
- if other.load_all_patches is not None:
551
- result.load_all_patches = other.load_all_patches
562
+ if other._crop_size is not None:
563
+ result._crop_size = other._crop_size
564
+ if other._patch_size is not None:
565
+ result._patch_size = other._patch_size
566
+ if other._overlap_pixels is not None:
567
+ result._overlap_pixels = other._overlap_pixels
568
+ if other._overlap_ratio is not None:
569
+ result._overlap_ratio = other._overlap_ratio
570
+ if other._load_all_crops is not None:
571
+ result._load_all_crops = other._load_all_crops
572
+ if other._load_all_patches is not None:
573
+ result._load_all_patches = other._load_all_patches
552
574
  if other.skip_targets is not None:
553
575
  result.skip_targets = other.skip_targets
554
576
  if other.output_layer_name_skip_inference_if_exists is not None:
@@ -557,21 +579,90 @@ class SplitConfig:
557
579
  )
558
580
  return result
559
581
 
560
- def get_patch_size(self) -> tuple[int, int] | None:
561
- """Get patch size normalized to int tuple."""
562
- if self.patch_size is None:
563
- return None
564
- if isinstance(self.patch_size, int):
565
- return (self.patch_size, self.patch_size)
566
- return self.patch_size
582
+ @staticmethod
583
+ def merge_and_validate(configs: list["SplitConfig"]) -> "SplitConfig":
584
+ """Merge a list of SplitConfigs and validate the result.
567
585
 
568
- def get_overlap_ratio(self) -> float:
569
- """Get the overlap ratio (default 0)."""
570
- return self.overlap_ratio if self.overlap_ratio is not None else 0.0
586
+ Args:
587
+ configs: list of SplitConfig to merge. Later configs override earlier ones.
571
588
 
572
- def get_load_all_patches(self) -> bool:
573
- """Returns whether loading all patches is enabled (default False)."""
574
- return True if self.load_all_patches is True else False
589
+ Returns:
590
+ the merged and validated SplitConfig.
591
+ """
592
+ if not configs:
593
+ return SplitConfig()
594
+
595
+ result = configs[0]
596
+ for config in configs[1:]:
597
+ result = result._merge(config)
598
+
599
+ # Emit deprecation warnings
600
+ if result._patch_size is not None:
601
+ warnings.warn(
602
+ "patch_size is deprecated, use crop_size instead",
603
+ FutureWarning,
604
+ stacklevel=2,
605
+ )
606
+ if result._overlap_ratio is not None:
607
+ warnings.warn(
608
+ "overlap_ratio is deprecated, use overlap_pixels instead",
609
+ FutureWarning,
610
+ stacklevel=2,
611
+ )
612
+ if result._load_all_patches is not None:
613
+ warnings.warn(
614
+ "load_all_patches is deprecated, use load_all_crops instead",
615
+ FutureWarning,
616
+ stacklevel=2,
617
+ )
618
+
619
+ # Check for conflicting parameters
620
+ if result._crop_size is not None and result._patch_size is not None:
621
+ raise ValueError("Cannot specify both crop_size and patch_size")
622
+ if result._overlap_pixels is not None and result._overlap_ratio is not None:
623
+ raise ValueError("Cannot specify both overlap_pixels and overlap_ratio")
624
+ if result._load_all_crops is not None and result._load_all_patches is not None:
625
+ raise ValueError("Cannot specify both load_all_crops and load_all_patches")
626
+
627
+ # Validate overlap_pixels is non-negative
628
+ if result._overlap_pixels is not None and result._overlap_pixels < 0:
629
+ raise ValueError("overlap_pixels must be non-negative")
630
+
631
+ # overlap_pixels requires load_all_crops.
632
+ if result.get_overlap_pixels() > 0 and not result.get_load_all_crops():
633
+ raise ValueError(
634
+ "overlap_pixels requires load_all_crops to be True since (overlap is only used during sliding window inference"
635
+ )
636
+
637
+ return result
638
+
639
+ def get_crop_size(self) -> tuple[int, int] | None:
640
+ """Get crop size as tuple, handling deprecated patch_size."""
641
+ size = self._crop_size if self._crop_size is not None else self._patch_size
642
+ if size is None:
643
+ return None
644
+ if isinstance(size, int):
645
+ return (size, size)
646
+ return size
647
+
648
+ def get_overlap_pixels(self) -> int:
649
+ """Get the overlap pixels (default 0), handling deprecated overlap_ratio."""
650
+ if self._overlap_pixels is not None:
651
+ return self._overlap_pixels
652
+ if self._overlap_ratio is not None:
653
+ crop_size = self.get_crop_size()
654
+ if crop_size is None:
655
+ raise ValueError("overlap_ratio requires crop_size to be set")
656
+ return round(crop_size[0] * self._overlap_ratio)
657
+ return 0
658
+
659
+ def get_load_all_crops(self) -> bool:
660
+ """Returns whether loading all crops is enabled (default False)."""
661
+ if self._load_all_crops is not None:
662
+ return self._load_all_crops
663
+ if self._load_all_patches is not None:
664
+ return self._load_all_patches
665
+ return False
575
666
 
576
667
  def get_skip_targets(self) -> bool:
577
668
  """Returns whether skip_targets is enabled (default False)."""
@@ -650,7 +741,7 @@ class ModelDataset(torch.utils.data.Dataset):
650
741
  task: Task,
651
742
  workers: int,
652
743
  name: str | None = None,
653
- fix_patch_pick: bool = False,
744
+ fix_crop_pick: bool = False,
654
745
  index_mode: IndexMode = IndexMode.OFF,
655
746
  ) -> None:
656
747
  """Instantiate a new ModelDataset.
@@ -662,7 +753,7 @@ class ModelDataset(torch.utils.data.Dataset):
662
753
  task: the task to train on
663
754
  workers: number of workers to use for initializing the dataset
664
755
  name: name of the dataset
665
- fix_patch_pick: if True, fix the patch pick to be the same every time
756
+ fix_crop_pick: if True, fix the crop pick to be the same every time
666
757
  for a given window. Useful for testing (default: False)
667
758
  index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
668
759
  """
@@ -671,19 +762,19 @@ class ModelDataset(torch.utils.data.Dataset):
671
762
  self.inputs = inputs
672
763
  self.task = task
673
764
  self.name = name
674
- self.fix_patch_pick = fix_patch_pick
765
+ self.fix_crop_pick = fix_crop_pick
675
766
  if split_config.transforms:
676
767
  self.transforms = Sequential(*split_config.transforms)
677
768
  else:
678
769
  self.transforms = rslearn.train.transforms.transform.Identity()
679
770
 
680
- # Get normalized patch size from the SplitConfig.
681
- # But if load all patches is enabled, this is handled by AllPatchesDataset, so
771
+ # Get normalized crop size from the SplitConfig.
772
+ # But if load_all_crops is enabled, this is handled by AllCropsDataset, so
682
773
  # here we instead load the entire windows.
683
- if split_config.get_load_all_patches():
684
- self.patch_size = None
774
+ if split_config.get_load_all_crops():
775
+ self.crop_size = None
685
776
  else:
686
- self.patch_size = split_config.get_patch_size()
777
+ self.crop_size = split_config.get_crop_size()
687
778
 
688
779
  # If targets are not needed, remove them from the inputs.
689
780
  if split_config.get_skip_targets():
@@ -904,8 +995,8 @@ class ModelDataset(torch.utils.data.Dataset):
904
995
  def get_dataset_examples(self) -> list[Window]:
905
996
  """Get a list of examples in the dataset.
906
997
 
907
- If load_all_patches is False, this is a list of Windows. Otherwise, this is a
908
- list of (window, patch_bounds, (patch_idx, # patches)) tuples.
998
+ If load_all_crops is False, this is a list of Windows. Otherwise, this is a
999
+ list of (window, crop_bounds, (crop_idx, # crops)) tuples.
909
1000
  """
910
1001
  if self.dataset_examples is None:
911
1002
  logger.debug(
@@ -938,37 +1029,37 @@ class ModelDataset(torch.utils.data.Dataset):
938
1029
  """
939
1030
  dataset_examples = self.get_dataset_examples()
940
1031
  example = dataset_examples[idx]
941
- rng = random.Random(idx if self.fix_patch_pick else None)
1032
+ rng = random.Random(idx if self.fix_crop_pick else None)
942
1033
 
943
1034
  # Select bounds to read.
944
- if self.patch_size:
1035
+ if self.crop_size:
945
1036
  window = example
946
1037
 
947
- def get_patch_range(n_patch: int, n_window: int) -> list[int]:
948
- if n_patch > n_window:
1038
+ def get_crop_range(n_crop: int, n_window: int) -> list[int]:
1039
+ if n_crop > n_window:
949
1040
  # Select arbitrary range containing the entire window.
950
- # Basically arbitrarily padding the window to get to patch size.
951
- start = rng.randint(n_window - n_patch, 0)
952
- return [start, start + n_patch]
1041
+ # Basically arbitrarily padding the window to get to crop size.
1042
+ start = rng.randint(n_window - n_crop, 0)
1043
+ return [start, start + n_crop]
953
1044
 
954
1045
  else:
955
- # Select arbitrary patch within the window.
956
- start = rng.randint(0, n_window - n_patch)
957
- return [start, start + n_patch]
1046
+ # Select arbitrary crop within the window.
1047
+ start = rng.randint(0, n_window - n_crop)
1048
+ return [start, start + n_crop]
958
1049
 
959
1050
  window_size = (
960
1051
  window.bounds[2] - window.bounds[0],
961
1052
  window.bounds[3] - window.bounds[1],
962
1053
  )
963
- patch_ranges = [
964
- get_patch_range(self.patch_size[0], window_size[0]),
965
- get_patch_range(self.patch_size[1], window_size[1]),
1054
+ crop_ranges = [
1055
+ get_crop_range(self.crop_size[0], window_size[0]),
1056
+ get_crop_range(self.crop_size[1], window_size[1]),
966
1057
  ]
967
1058
  bounds = (
968
- window.bounds[0] + patch_ranges[0][0],
969
- window.bounds[1] + patch_ranges[1][0],
970
- window.bounds[0] + patch_ranges[0][1],
971
- window.bounds[1] + patch_ranges[1][1],
1059
+ window.bounds[0] + crop_ranges[0][0],
1060
+ window.bounds[1] + crop_ranges[1][0],
1061
+ window.bounds[0] + crop_ranges[0][1],
1062
+ window.bounds[1] + crop_ranges[1][1],
972
1063
  )
973
1064
 
974
1065
  else:
@@ -990,9 +1081,9 @@ class ModelDataset(torch.utils.data.Dataset):
990
1081
  window_group=window.group,
991
1082
  window_name=window.name,
992
1083
  window_bounds=window.bounds,
993
- patch_bounds=bounds,
994
- patch_idx=0,
995
- num_patches_in_window=1,
1084
+ crop_bounds=bounds,
1085
+ crop_idx=0,
1086
+ num_crops_in_window=1,
996
1087
  time_range=window.time_range,
997
1088
  projection=window.projection,
998
1089
  dataset_source=self.name,
@@ -6,12 +6,14 @@ from typing import Any
6
6
 
7
7
  import lightning as L
8
8
  import torch
9
+ import wandb
9
10
  from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
10
11
  from PIL import Image
11
12
  from upath import UPath
12
13
 
13
14
  from rslearn.log_utils import get_logger
14
15
 
16
+ from .metrics import NonScalarMetricOutput
15
17
  from .model_context import ModelContext, ModelOutput
16
18
  from .optimizer import AdamW, OptimizerFactory
17
19
  from .scheduler import PlateauScheduler, SchedulerFactory
@@ -210,15 +212,53 @@ class RslearnLightningModule(L.LightningModule):
210
212
  # Fail silently for single-dataset case, which is okay
211
213
  pass
212
214
 
215
+ def _log_non_scalar_metric(self, name: str, value: NonScalarMetricOutput) -> None:
216
+ """Log a non-scalar metric to wandb.
217
+
218
+ Args:
219
+ name: the metric name (e.g., "val_confusion_matrix")
220
+ value: the non-scalar metric output
221
+ """
222
+ # The non-scalar metrics are logging directly without Lightning
223
+ # So we need to skip logging during sanity check.
224
+ if self.trainer.sanity_checking:
225
+ return
226
+
227
+ # Wandb is required for logging non-scalar metrics.
228
+ if not wandb.run:
229
+ logger.warning(
230
+ f"Weights & Biases is not initialized, skipping logging of {name}"
231
+ )
232
+ return
233
+
234
+ value.log_to_wandb(name)
235
+
213
236
  def on_validation_epoch_end(self) -> None:
214
237
  """Compute and log validation metrics at epoch end.
215
238
 
216
239
  We manually compute and log metrics here (instead of passing the MetricCollection
217
240
  to log_dict) because MetricCollection.compute() properly flattens dict-returning
218
241
  metrics, while log_dict expects each metric to return a scalar tensor.
242
+
243
+ Non-scalar metrics (like confusion matrices) are logged separately using
244
+ logger-specific APIs.
219
245
  """
220
246
  metrics = self.val_metrics.compute()
221
- self.log_dict(metrics)
247
+
248
+ # Separate scalar and non-scalar metrics
249
+ scalar_metrics = {}
250
+ for k, v in metrics.items():
251
+ if isinstance(v, NonScalarMetricOutput):
252
+ self._log_non_scalar_metric(k, v)
253
+ elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
254
+ raise ValueError(
255
+ f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
256
+ "Wrap it in a NonScalarMetricOutput subclass."
257
+ )
258
+ else:
259
+ scalar_metrics[k] = v
260
+
261
+ self.log_dict(scalar_metrics)
222
262
  self.val_metrics.reset()
223
263
 
224
264
  def on_test_epoch_end(self) -> None:
@@ -227,14 +267,30 @@ class RslearnLightningModule(L.LightningModule):
227
267
  We manually compute and log metrics here (instead of passing the MetricCollection
228
268
  to log_dict) because MetricCollection.compute() properly flattens dict-returning
229
269
  metrics, while log_dict expects each metric to return a scalar tensor.
270
+
271
+ Non-scalar metrics (like confusion matrices) are logged separately.
230
272
  """
231
273
  metrics = self.test_metrics.compute()
232
- self.log_dict(metrics)
274
+
275
+ # Separate scalar and non-scalar metrics
276
+ scalar_metrics = {}
277
+ for k, v in metrics.items():
278
+ if isinstance(v, NonScalarMetricOutput):
279
+ self._log_non_scalar_metric(k, v)
280
+ elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
281
+ raise ValueError(
282
+ f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
283
+ "Wrap it in a NonScalarMetricOutput subclass."
284
+ )
285
+ else:
286
+ scalar_metrics[k] = v
287
+
288
+ self.log_dict(scalar_metrics)
233
289
  self.test_metrics.reset()
234
290
 
235
291
  if self.metrics_file:
236
292
  with open(self.metrics_file, "w") as f:
237
- metrics_dict = {k: v.item() for k, v in metrics.items()}
293
+ metrics_dict = {k: v.item() for k, v in scalar_metrics.items()}
238
294
  json.dump(metrics_dict, f, indent=4)
239
295
  logger.info(f"Saved metrics to {self.metrics_file}")
240
296
 
@@ -365,7 +421,7 @@ class RslearnLightningModule(L.LightningModule):
365
421
  for image_suffix, image in images.items():
366
422
  out_fname = os.path.join(
367
423
  self.visualize_dir,
368
- f"{metadata.window_name}_{metadata.patch_bounds[0]}_{metadata.patch_bounds[1]}_{image_suffix}.png",
424
+ f"{metadata.window_name}_{metadata.crop_bounds[0]}_{metadata.crop_bounds[1]}_{image_suffix}.png",
369
425
  )
370
426
  Image.fromarray(image).save(out_fname)
371
427
 
@@ -0,0 +1,162 @@
1
+ """Metric output classes for non-scalar metrics."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import wandb
8
+ from torchmetrics import Metric
9
+
10
+ from rslearn.log_utils import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class NonScalarMetricOutput(ABC):
17
+ """Base class for non-scalar metric outputs that need special logging.
18
+
19
+ Subclasses should implement the log_to_wandb method to define how the metric
20
+ should be logged (only supports logging to Weights & Biases).
21
+ """
22
+
23
+ @abstractmethod
24
+ def log_to_wandb(self, name: str) -> None:
25
+ """Log this metric to wandb.
26
+
27
+ Args:
28
+ name: the metric name
29
+ """
30
+ pass
31
+
32
+
33
+ @dataclass
34
+ class ConfusionMatrixOutput(NonScalarMetricOutput):
35
+ """Confusion matrix metric output.
36
+
37
+ Args:
38
+ confusion_matrix: confusion matrix of shape (num_classes, num_classes)
39
+ where cm[i, j] is the count of samples with true label i and predicted
40
+ label j.
41
+ class_names: optional list of class names for axis labels
42
+ """
43
+
44
+ confusion_matrix: torch.Tensor
45
+ class_names: list[str] | None = None
46
+
47
+ def _expand_confusion_matrix(self) -> tuple[list[int], list[int]]:
48
+ """Expand confusion matrix to (preds, labels) pairs for wandb.
49
+
50
+ Returns:
51
+ Tuple of (preds, labels) as lists of integers.
52
+ """
53
+ cm = self.confusion_matrix.detach().cpu()
54
+
55
+ # Handle extra dimensions from distributed reduction
56
+ if cm.dim() > 2:
57
+ cm = cm.sum(dim=0)
58
+
59
+ total = cm.sum().item()
60
+ if total == 0:
61
+ return [], []
62
+
63
+ preds = []
64
+ labels = []
65
+ for true_label in range(cm.shape[0]):
66
+ for pred_label in range(cm.shape[1]):
67
+ count = cm[true_label, pred_label].item()
68
+ if count > 0:
69
+ preds.extend([pred_label] * int(count))
70
+ labels.extend([true_label] * int(count))
71
+
72
+ return preds, labels
73
+
74
+ def log_to_wandb(self, name: str) -> None:
75
+ """Log confusion matrix to wandb.
76
+
77
+ Args:
78
+ name: the metric name (e.g., "val_confusion_matrix")
79
+ """
80
+ preds, labels = self._expand_confusion_matrix()
81
+
82
+ if len(preds) == 0:
83
+ logger.warning(f"No samples to log for {name}")
84
+ return
85
+
86
+ num_classes = self.confusion_matrix.shape[0]
87
+ if self.class_names is None:
88
+ class_names = [str(i) for i in range(num_classes)]
89
+ else:
90
+ class_names = self.class_names
91
+
92
+ wandb.log(
93
+ {
94
+ name: wandb.plot.confusion_matrix(
95
+ preds=preds,
96
+ y_true=labels,
97
+ class_names=class_names,
98
+ title=name,
99
+ ),
100
+ },
101
+ )
102
+
103
+
104
+ class ConfusionMatrixMetric(Metric):
105
+ """Confusion matrix metric that works on flattened inputs.
106
+
107
+ Expects preds of shape (N, C) and labels of shape (N,).
108
+ Should be wrapped by ClassificationMetric or SegmentationMetric
109
+ which handle the task-specific preprocessing.
110
+
111
+ Args:
112
+ num_classes: number of classes
113
+ class_names: optional list of class names for labeling
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ num_classes: int,
119
+ class_names: list[str] | None = None,
120
+ ):
121
+ """Initialize a new ConfusionMatrixMetric.
122
+
123
+ Args:
124
+ num_classes: number of classes
125
+ class_names: optional list of class names for labeling
126
+ """
127
+ super().__init__()
128
+ self.num_classes = num_classes
129
+ self.class_names = class_names
130
+ self.add_state(
131
+ "confusion_matrix",
132
+ default=torch.zeros(num_classes, num_classes, dtype=torch.long),
133
+ dist_reduce_fx="sum",
134
+ )
135
+
136
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
137
+ """Update metric.
138
+
139
+ Args:
140
+ preds: predictions of shape (N, C) - probabilities
141
+ labels: ground truth of shape (N,) - class indices
142
+ """
143
+ if len(preds) == 0:
144
+ return
145
+
146
+ pred_classes = preds.argmax(dim=1) # (N,)
147
+
148
+ for true_label in range(self.num_classes):
149
+ for pred_label in range(self.num_classes):
150
+ count = ((labels == true_label) & (pred_classes == pred_label)).sum()
151
+ self.confusion_matrix[true_label, pred_label] += count
152
+
153
+ def compute(self) -> ConfusionMatrixOutput:
154
+ """Returns the confusion matrix wrapped in ConfusionMatrixOutput."""
155
+ return ConfusionMatrixOutput(
156
+ confusion_matrix=self.confusion_matrix,
157
+ class_names=self.class_names,
158
+ )
159
+
160
+ def reset(self) -> None:
161
+ """Reset metric."""
162
+ super().reset()
@@ -67,9 +67,9 @@ class SampleMetadata:
67
67
  window_group: str
68
68
  window_name: str
69
69
  window_bounds: PixelBounds
70
- patch_bounds: PixelBounds
71
- patch_idx: int
72
- num_patches_in_window: int
70
+ crop_bounds: PixelBounds
71
+ crop_idx: int
72
+ num_crops_in_window: int
73
73
  time_range: tuple[datetime, datetime] | None
74
74
  projection: Projection
75
75