rslearn 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +1 -1
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/storage/file.py +16 -12
- rslearn/models/concatenate_features.py +6 -1
- rslearn/tile_stores/default.py +4 -2
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +36 -33
- rslearn/train/dataset.py +159 -68
- rslearn/train/lightning_module.py +60 -4
- rslearn/train/metrics.py +162 -0
- rslearn/train/model_context.py +3 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +14 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +19 -6
- rslearn/train/tasks/regression.py +19 -3
- rslearn/train/tasks/segmentation.py +17 -0
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/fsspec.py +51 -1
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/METADATA +6 -3
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/RECORD +55 -50
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py
CHANGED
|
@@ -8,6 +8,7 @@ import random
|
|
|
8
8
|
import tempfile
|
|
9
9
|
import time
|
|
10
10
|
import uuid
|
|
11
|
+
import warnings
|
|
11
12
|
from datetime import datetime
|
|
12
13
|
from enum import StrEnum
|
|
13
14
|
from typing import Any
|
|
@@ -456,11 +457,15 @@ class SplitConfig:
|
|
|
456
457
|
num_patches: int | None = None,
|
|
457
458
|
transforms: list[torch.nn.Module] | None = None,
|
|
458
459
|
sampler: SamplerFactory | None = None,
|
|
460
|
+
crop_size: int | tuple[int, int] | None = None,
|
|
461
|
+
overlap_pixels: int | None = None,
|
|
462
|
+
load_all_crops: bool | None = None,
|
|
463
|
+
skip_targets: bool | None = None,
|
|
464
|
+
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
465
|
+
# Deprecated parameters (for backwards compatibility)
|
|
459
466
|
patch_size: int | tuple[int, int] | None = None,
|
|
460
467
|
overlap_ratio: float | None = None,
|
|
461
468
|
load_all_patches: bool | None = None,
|
|
462
|
-
skip_targets: bool | None = None,
|
|
463
|
-
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
464
469
|
) -> None:
|
|
465
470
|
"""Initialize a new SplitConfig.
|
|
466
471
|
|
|
@@ -475,18 +480,21 @@ class SplitConfig:
|
|
|
475
480
|
num_patches: limit this split to this many patches
|
|
476
481
|
transforms: transforms to apply
|
|
477
482
|
sampler: SamplerFactory for this split
|
|
478
|
-
|
|
483
|
+
crop_size: an optional square size or (width, height) tuple. If set, read
|
|
479
484
|
crops of this size rather than entire windows.
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
for each window, read all
|
|
485
|
+
overlap_pixels: the number of pixels shared between adjacent crops during
|
|
486
|
+
sliding window inference.
|
|
487
|
+
load_all_crops: with crop_size set, rather than sampling a random crop
|
|
488
|
+
for each window, read all crops as separate sequential items in the
|
|
484
489
|
dataset.
|
|
485
490
|
skip_targets: whether to skip targets when loading inputs
|
|
486
491
|
output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
|
|
487
492
|
If set, windows that already
|
|
488
493
|
have this layer completed will be skipped (useful for resuming
|
|
489
494
|
partial inference runs).
|
|
495
|
+
patch_size: deprecated, use crop_size instead
|
|
496
|
+
overlap_ratio: deprecated, use overlap_pixels instead
|
|
497
|
+
load_all_patches: deprecated, use load_all_crops instead
|
|
490
498
|
"""
|
|
491
499
|
self.groups = groups
|
|
492
500
|
self.names = names
|
|
@@ -495,22 +503,27 @@ class SplitConfig:
|
|
|
495
503
|
self.num_patches = num_patches
|
|
496
504
|
self.transforms = transforms
|
|
497
505
|
self.sampler = sampler
|
|
498
|
-
self.patch_size = patch_size
|
|
499
506
|
self.skip_targets = skip_targets
|
|
500
507
|
self.output_layer_name_skip_inference_if_exists = (
|
|
501
508
|
output_layer_name_skip_inference_if_exists
|
|
502
509
|
)
|
|
503
510
|
|
|
504
|
-
#
|
|
505
|
-
# the
|
|
506
|
-
|
|
507
|
-
|
|
511
|
+
# These have deprecated equivalents -- we store both raw values since we don't
|
|
512
|
+
# have a complete picture until the final merged SplitConfig is computed. We
|
|
513
|
+
# raise deprecation warnings in merge_and_validate and we disambiguate them in
|
|
514
|
+
# get_ functions (so the variables should never be accessed directly).
|
|
515
|
+
self._crop_size = crop_size
|
|
516
|
+
self._patch_size = patch_size
|
|
517
|
+
self._overlap_pixels = overlap_pixels
|
|
518
|
+
self._overlap_ratio = overlap_ratio
|
|
519
|
+
self._load_all_crops = load_all_crops
|
|
520
|
+
self._load_all_patches = load_all_patches
|
|
508
521
|
|
|
509
|
-
|
|
510
|
-
|
|
522
|
+
def _merge(self, other: "SplitConfig") -> "SplitConfig":
|
|
523
|
+
"""Merge settings from another SplitConfig into this one.
|
|
511
524
|
|
|
512
|
-
|
|
513
|
-
|
|
525
|
+
Args:
|
|
526
|
+
other: the config to merge in (its non-None values override self's)
|
|
514
527
|
|
|
515
528
|
Returns:
|
|
516
529
|
the resulting SplitConfig combining the settings.
|
|
@@ -523,9 +536,12 @@ class SplitConfig:
|
|
|
523
536
|
num_patches=self.num_patches,
|
|
524
537
|
transforms=self.transforms,
|
|
525
538
|
sampler=self.sampler,
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
539
|
+
crop_size=self._crop_size,
|
|
540
|
+
patch_size=self._patch_size,
|
|
541
|
+
overlap_pixels=self._overlap_pixels,
|
|
542
|
+
overlap_ratio=self._overlap_ratio,
|
|
543
|
+
load_all_crops=self._load_all_crops,
|
|
544
|
+
load_all_patches=self._load_all_patches,
|
|
529
545
|
skip_targets=self.skip_targets,
|
|
530
546
|
output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
|
|
531
547
|
)
|
|
@@ -543,12 +559,18 @@ class SplitConfig:
|
|
|
543
559
|
result.transforms = other.transforms
|
|
544
560
|
if other.sampler:
|
|
545
561
|
result.sampler = other.sampler
|
|
546
|
-
if other.
|
|
547
|
-
result.
|
|
548
|
-
if other.
|
|
549
|
-
result.
|
|
550
|
-
if other.
|
|
551
|
-
result.
|
|
562
|
+
if other._crop_size is not None:
|
|
563
|
+
result._crop_size = other._crop_size
|
|
564
|
+
if other._patch_size is not None:
|
|
565
|
+
result._patch_size = other._patch_size
|
|
566
|
+
if other._overlap_pixels is not None:
|
|
567
|
+
result._overlap_pixels = other._overlap_pixels
|
|
568
|
+
if other._overlap_ratio is not None:
|
|
569
|
+
result._overlap_ratio = other._overlap_ratio
|
|
570
|
+
if other._load_all_crops is not None:
|
|
571
|
+
result._load_all_crops = other._load_all_crops
|
|
572
|
+
if other._load_all_patches is not None:
|
|
573
|
+
result._load_all_patches = other._load_all_patches
|
|
552
574
|
if other.skip_targets is not None:
|
|
553
575
|
result.skip_targets = other.skip_targets
|
|
554
576
|
if other.output_layer_name_skip_inference_if_exists is not None:
|
|
@@ -557,21 +579,90 @@ class SplitConfig:
|
|
|
557
579
|
)
|
|
558
580
|
return result
|
|
559
581
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
return None
|
|
564
|
-
if isinstance(self.patch_size, int):
|
|
565
|
-
return (self.patch_size, self.patch_size)
|
|
566
|
-
return self.patch_size
|
|
582
|
+
@staticmethod
|
|
583
|
+
def merge_and_validate(configs: list["SplitConfig"]) -> "SplitConfig":
|
|
584
|
+
"""Merge a list of SplitConfigs and validate the result.
|
|
567
585
|
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
return self.overlap_ratio if self.overlap_ratio is not None else 0.0
|
|
586
|
+
Args:
|
|
587
|
+
configs: list of SplitConfig to merge. Later configs override earlier ones.
|
|
571
588
|
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
589
|
+
Returns:
|
|
590
|
+
the merged and validated SplitConfig.
|
|
591
|
+
"""
|
|
592
|
+
if not configs:
|
|
593
|
+
return SplitConfig()
|
|
594
|
+
|
|
595
|
+
result = configs[0]
|
|
596
|
+
for config in configs[1:]:
|
|
597
|
+
result = result._merge(config)
|
|
598
|
+
|
|
599
|
+
# Emit deprecation warnings
|
|
600
|
+
if result._patch_size is not None:
|
|
601
|
+
warnings.warn(
|
|
602
|
+
"patch_size is deprecated, use crop_size instead",
|
|
603
|
+
FutureWarning,
|
|
604
|
+
stacklevel=2,
|
|
605
|
+
)
|
|
606
|
+
if result._overlap_ratio is not None:
|
|
607
|
+
warnings.warn(
|
|
608
|
+
"overlap_ratio is deprecated, use overlap_pixels instead",
|
|
609
|
+
FutureWarning,
|
|
610
|
+
stacklevel=2,
|
|
611
|
+
)
|
|
612
|
+
if result._load_all_patches is not None:
|
|
613
|
+
warnings.warn(
|
|
614
|
+
"load_all_patches is deprecated, use load_all_crops instead",
|
|
615
|
+
FutureWarning,
|
|
616
|
+
stacklevel=2,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Check for conflicting parameters
|
|
620
|
+
if result._crop_size is not None and result._patch_size is not None:
|
|
621
|
+
raise ValueError("Cannot specify both crop_size and patch_size")
|
|
622
|
+
if result._overlap_pixels is not None and result._overlap_ratio is not None:
|
|
623
|
+
raise ValueError("Cannot specify both overlap_pixels and overlap_ratio")
|
|
624
|
+
if result._load_all_crops is not None and result._load_all_patches is not None:
|
|
625
|
+
raise ValueError("Cannot specify both load_all_crops and load_all_patches")
|
|
626
|
+
|
|
627
|
+
# Validate overlap_pixels is non-negative
|
|
628
|
+
if result._overlap_pixels is not None and result._overlap_pixels < 0:
|
|
629
|
+
raise ValueError("overlap_pixels must be non-negative")
|
|
630
|
+
|
|
631
|
+
# overlap_pixels requires load_all_crops.
|
|
632
|
+
if result.get_overlap_pixels() > 0 and not result.get_load_all_crops():
|
|
633
|
+
raise ValueError(
|
|
634
|
+
"overlap_pixels requires load_all_crops to be True since (overlap is only used during sliding window inference"
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
return result
|
|
638
|
+
|
|
639
|
+
def get_crop_size(self) -> tuple[int, int] | None:
|
|
640
|
+
"""Get crop size as tuple, handling deprecated patch_size."""
|
|
641
|
+
size = self._crop_size if self._crop_size is not None else self._patch_size
|
|
642
|
+
if size is None:
|
|
643
|
+
return None
|
|
644
|
+
if isinstance(size, int):
|
|
645
|
+
return (size, size)
|
|
646
|
+
return size
|
|
647
|
+
|
|
648
|
+
def get_overlap_pixels(self) -> int:
|
|
649
|
+
"""Get the overlap pixels (default 0), handling deprecated overlap_ratio."""
|
|
650
|
+
if self._overlap_pixels is not None:
|
|
651
|
+
return self._overlap_pixels
|
|
652
|
+
if self._overlap_ratio is not None:
|
|
653
|
+
crop_size = self.get_crop_size()
|
|
654
|
+
if crop_size is None:
|
|
655
|
+
raise ValueError("overlap_ratio requires crop_size to be set")
|
|
656
|
+
return round(crop_size[0] * self._overlap_ratio)
|
|
657
|
+
return 0
|
|
658
|
+
|
|
659
|
+
def get_load_all_crops(self) -> bool:
|
|
660
|
+
"""Returns whether loading all crops is enabled (default False)."""
|
|
661
|
+
if self._load_all_crops is not None:
|
|
662
|
+
return self._load_all_crops
|
|
663
|
+
if self._load_all_patches is not None:
|
|
664
|
+
return self._load_all_patches
|
|
665
|
+
return False
|
|
575
666
|
|
|
576
667
|
def get_skip_targets(self) -> bool:
|
|
577
668
|
"""Returns whether skip_targets is enabled (default False)."""
|
|
@@ -650,7 +741,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
650
741
|
task: Task,
|
|
651
742
|
workers: int,
|
|
652
743
|
name: str | None = None,
|
|
653
|
-
|
|
744
|
+
fix_crop_pick: bool = False,
|
|
654
745
|
index_mode: IndexMode = IndexMode.OFF,
|
|
655
746
|
) -> None:
|
|
656
747
|
"""Instantiate a new ModelDataset.
|
|
@@ -662,7 +753,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
662
753
|
task: the task to train on
|
|
663
754
|
workers: number of workers to use for initializing the dataset
|
|
664
755
|
name: name of the dataset
|
|
665
|
-
|
|
756
|
+
fix_crop_pick: if True, fix the crop pick to be the same every time
|
|
666
757
|
for a given window. Useful for testing (default: False)
|
|
667
758
|
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
668
759
|
"""
|
|
@@ -671,19 +762,19 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
671
762
|
self.inputs = inputs
|
|
672
763
|
self.task = task
|
|
673
764
|
self.name = name
|
|
674
|
-
self.
|
|
765
|
+
self.fix_crop_pick = fix_crop_pick
|
|
675
766
|
if split_config.transforms:
|
|
676
767
|
self.transforms = Sequential(*split_config.transforms)
|
|
677
768
|
else:
|
|
678
769
|
self.transforms = rslearn.train.transforms.transform.Identity()
|
|
679
770
|
|
|
680
|
-
# Get normalized
|
|
681
|
-
# But if
|
|
771
|
+
# Get normalized crop size from the SplitConfig.
|
|
772
|
+
# But if load_all_crops is enabled, this is handled by AllCropsDataset, so
|
|
682
773
|
# here we instead load the entire windows.
|
|
683
|
-
if split_config.
|
|
684
|
-
self.
|
|
774
|
+
if split_config.get_load_all_crops():
|
|
775
|
+
self.crop_size = None
|
|
685
776
|
else:
|
|
686
|
-
self.
|
|
777
|
+
self.crop_size = split_config.get_crop_size()
|
|
687
778
|
|
|
688
779
|
# If targets are not needed, remove them from the inputs.
|
|
689
780
|
if split_config.get_skip_targets():
|
|
@@ -904,8 +995,8 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
904
995
|
def get_dataset_examples(self) -> list[Window]:
|
|
905
996
|
"""Get a list of examples in the dataset.
|
|
906
997
|
|
|
907
|
-
If
|
|
908
|
-
list of (window,
|
|
998
|
+
If load_all_crops is False, this is a list of Windows. Otherwise, this is a
|
|
999
|
+
list of (window, crop_bounds, (crop_idx, # crops)) tuples.
|
|
909
1000
|
"""
|
|
910
1001
|
if self.dataset_examples is None:
|
|
911
1002
|
logger.debug(
|
|
@@ -938,37 +1029,37 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
938
1029
|
"""
|
|
939
1030
|
dataset_examples = self.get_dataset_examples()
|
|
940
1031
|
example = dataset_examples[idx]
|
|
941
|
-
rng = random.Random(idx if self.
|
|
1032
|
+
rng = random.Random(idx if self.fix_crop_pick else None)
|
|
942
1033
|
|
|
943
1034
|
# Select bounds to read.
|
|
944
|
-
if self.
|
|
1035
|
+
if self.crop_size:
|
|
945
1036
|
window = example
|
|
946
1037
|
|
|
947
|
-
def
|
|
948
|
-
if
|
|
1038
|
+
def get_crop_range(n_crop: int, n_window: int) -> list[int]:
|
|
1039
|
+
if n_crop > n_window:
|
|
949
1040
|
# Select arbitrary range containing the entire window.
|
|
950
|
-
# Basically arbitrarily padding the window to get to
|
|
951
|
-
start = rng.randint(n_window -
|
|
952
|
-
return [start, start +
|
|
1041
|
+
# Basically arbitrarily padding the window to get to crop size.
|
|
1042
|
+
start = rng.randint(n_window - n_crop, 0)
|
|
1043
|
+
return [start, start + n_crop]
|
|
953
1044
|
|
|
954
1045
|
else:
|
|
955
|
-
# Select arbitrary
|
|
956
|
-
start = rng.randint(0, n_window -
|
|
957
|
-
return [start, start +
|
|
1046
|
+
# Select arbitrary crop within the window.
|
|
1047
|
+
start = rng.randint(0, n_window - n_crop)
|
|
1048
|
+
return [start, start + n_crop]
|
|
958
1049
|
|
|
959
1050
|
window_size = (
|
|
960
1051
|
window.bounds[2] - window.bounds[0],
|
|
961
1052
|
window.bounds[3] - window.bounds[1],
|
|
962
1053
|
)
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
1054
|
+
crop_ranges = [
|
|
1055
|
+
get_crop_range(self.crop_size[0], window_size[0]),
|
|
1056
|
+
get_crop_range(self.crop_size[1], window_size[1]),
|
|
966
1057
|
]
|
|
967
1058
|
bounds = (
|
|
968
|
-
window.bounds[0] +
|
|
969
|
-
window.bounds[1] +
|
|
970
|
-
window.bounds[0] +
|
|
971
|
-
window.bounds[1] +
|
|
1059
|
+
window.bounds[0] + crop_ranges[0][0],
|
|
1060
|
+
window.bounds[1] + crop_ranges[1][0],
|
|
1061
|
+
window.bounds[0] + crop_ranges[0][1],
|
|
1062
|
+
window.bounds[1] + crop_ranges[1][1],
|
|
972
1063
|
)
|
|
973
1064
|
|
|
974
1065
|
else:
|
|
@@ -990,9 +1081,9 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
990
1081
|
window_group=window.group,
|
|
991
1082
|
window_name=window.name,
|
|
992
1083
|
window_bounds=window.bounds,
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
1084
|
+
crop_bounds=bounds,
|
|
1085
|
+
crop_idx=0,
|
|
1086
|
+
num_crops_in_window=1,
|
|
996
1087
|
time_range=window.time_range,
|
|
997
1088
|
projection=window.projection,
|
|
998
1089
|
dataset_source=self.name,
|
|
@@ -6,12 +6,14 @@ from typing import Any
|
|
|
6
6
|
|
|
7
7
|
import lightning as L
|
|
8
8
|
import torch
|
|
9
|
+
import wandb
|
|
9
10
|
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
|
|
10
11
|
from PIL import Image
|
|
11
12
|
from upath import UPath
|
|
12
13
|
|
|
13
14
|
from rslearn.log_utils import get_logger
|
|
14
15
|
|
|
16
|
+
from .metrics import NonScalarMetricOutput
|
|
15
17
|
from .model_context import ModelContext, ModelOutput
|
|
16
18
|
from .optimizer import AdamW, OptimizerFactory
|
|
17
19
|
from .scheduler import PlateauScheduler, SchedulerFactory
|
|
@@ -210,15 +212,53 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
210
212
|
# Fail silently for single-dataset case, which is okay
|
|
211
213
|
pass
|
|
212
214
|
|
|
215
|
+
def _log_non_scalar_metric(self, name: str, value: NonScalarMetricOutput) -> None:
|
|
216
|
+
"""Log a non-scalar metric to wandb.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
name: the metric name (e.g., "val_confusion_matrix")
|
|
220
|
+
value: the non-scalar metric output
|
|
221
|
+
"""
|
|
222
|
+
# The non-scalar metrics are logging directly without Lightning
|
|
223
|
+
# So we need to skip logging during sanity check.
|
|
224
|
+
if self.trainer.sanity_checking:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
# Wandb is required for logging non-scalar metrics.
|
|
228
|
+
if not wandb.run:
|
|
229
|
+
logger.warning(
|
|
230
|
+
f"Weights & Biases is not initialized, skipping logging of {name}"
|
|
231
|
+
)
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
value.log_to_wandb(name)
|
|
235
|
+
|
|
213
236
|
def on_validation_epoch_end(self) -> None:
|
|
214
237
|
"""Compute and log validation metrics at epoch end.
|
|
215
238
|
|
|
216
239
|
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
217
240
|
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
218
241
|
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
242
|
+
|
|
243
|
+
Non-scalar metrics (like confusion matrices) are logged separately using
|
|
244
|
+
logger-specific APIs.
|
|
219
245
|
"""
|
|
220
246
|
metrics = self.val_metrics.compute()
|
|
221
|
-
|
|
247
|
+
|
|
248
|
+
# Separate scalar and non-scalar metrics
|
|
249
|
+
scalar_metrics = {}
|
|
250
|
+
for k, v in metrics.items():
|
|
251
|
+
if isinstance(v, NonScalarMetricOutput):
|
|
252
|
+
self._log_non_scalar_metric(k, v)
|
|
253
|
+
elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
|
|
256
|
+
"Wrap it in a NonScalarMetricOutput subclass."
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
scalar_metrics[k] = v
|
|
260
|
+
|
|
261
|
+
self.log_dict(scalar_metrics)
|
|
222
262
|
self.val_metrics.reset()
|
|
223
263
|
|
|
224
264
|
def on_test_epoch_end(self) -> None:
|
|
@@ -227,14 +267,30 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
227
267
|
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
228
268
|
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
229
269
|
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
270
|
+
|
|
271
|
+
Non-scalar metrics (like confusion matrices) are logged separately.
|
|
230
272
|
"""
|
|
231
273
|
metrics = self.test_metrics.compute()
|
|
232
|
-
|
|
274
|
+
|
|
275
|
+
# Separate scalar and non-scalar metrics
|
|
276
|
+
scalar_metrics = {}
|
|
277
|
+
for k, v in metrics.items():
|
|
278
|
+
if isinstance(v, NonScalarMetricOutput):
|
|
279
|
+
self._log_non_scalar_metric(k, v)
|
|
280
|
+
elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
|
|
281
|
+
raise ValueError(
|
|
282
|
+
f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
|
|
283
|
+
"Wrap it in a NonScalarMetricOutput subclass."
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
scalar_metrics[k] = v
|
|
287
|
+
|
|
288
|
+
self.log_dict(scalar_metrics)
|
|
233
289
|
self.test_metrics.reset()
|
|
234
290
|
|
|
235
291
|
if self.metrics_file:
|
|
236
292
|
with open(self.metrics_file, "w") as f:
|
|
237
|
-
metrics_dict = {k: v.item() for k, v in
|
|
293
|
+
metrics_dict = {k: v.item() for k, v in scalar_metrics.items()}
|
|
238
294
|
json.dump(metrics_dict, f, indent=4)
|
|
239
295
|
logger.info(f"Saved metrics to {self.metrics_file}")
|
|
240
296
|
|
|
@@ -365,7 +421,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
365
421
|
for image_suffix, image in images.items():
|
|
366
422
|
out_fname = os.path.join(
|
|
367
423
|
self.visualize_dir,
|
|
368
|
-
f"{metadata.window_name}_{metadata.
|
|
424
|
+
f"{metadata.window_name}_{metadata.crop_bounds[0]}_{metadata.crop_bounds[1]}_{image_suffix}.png",
|
|
369
425
|
)
|
|
370
426
|
Image.fromarray(image).save(out_fname)
|
|
371
427
|
|
rslearn/train/metrics.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Metric output classes for non-scalar metrics."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import wandb
|
|
8
|
+
from torchmetrics import Metric
|
|
9
|
+
|
|
10
|
+
from rslearn.log_utils import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class NonScalarMetricOutput(ABC):
|
|
17
|
+
"""Base class for non-scalar metric outputs that need special logging.
|
|
18
|
+
|
|
19
|
+
Subclasses should implement the log_to_wandb method to define how the metric
|
|
20
|
+
should be logged (only supports logging to Weights & Biases).
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def log_to_wandb(self, name: str) -> None:
|
|
25
|
+
"""Log this metric to wandb.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
name: the metric name
|
|
29
|
+
"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ConfusionMatrixOutput(NonScalarMetricOutput):
|
|
35
|
+
"""Confusion matrix metric output.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
confusion_matrix: confusion matrix of shape (num_classes, num_classes)
|
|
39
|
+
where cm[i, j] is the count of samples with true label i and predicted
|
|
40
|
+
label j.
|
|
41
|
+
class_names: optional list of class names for axis labels
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
confusion_matrix: torch.Tensor
|
|
45
|
+
class_names: list[str] | None = None
|
|
46
|
+
|
|
47
|
+
def _expand_confusion_matrix(self) -> tuple[list[int], list[int]]:
|
|
48
|
+
"""Expand confusion matrix to (preds, labels) pairs for wandb.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple of (preds, labels) as lists of integers.
|
|
52
|
+
"""
|
|
53
|
+
cm = self.confusion_matrix.detach().cpu()
|
|
54
|
+
|
|
55
|
+
# Handle extra dimensions from distributed reduction
|
|
56
|
+
if cm.dim() > 2:
|
|
57
|
+
cm = cm.sum(dim=0)
|
|
58
|
+
|
|
59
|
+
total = cm.sum().item()
|
|
60
|
+
if total == 0:
|
|
61
|
+
return [], []
|
|
62
|
+
|
|
63
|
+
preds = []
|
|
64
|
+
labels = []
|
|
65
|
+
for true_label in range(cm.shape[0]):
|
|
66
|
+
for pred_label in range(cm.shape[1]):
|
|
67
|
+
count = cm[true_label, pred_label].item()
|
|
68
|
+
if count > 0:
|
|
69
|
+
preds.extend([pred_label] * int(count))
|
|
70
|
+
labels.extend([true_label] * int(count))
|
|
71
|
+
|
|
72
|
+
return preds, labels
|
|
73
|
+
|
|
74
|
+
def log_to_wandb(self, name: str) -> None:
|
|
75
|
+
"""Log confusion matrix to wandb.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
name: the metric name (e.g., "val_confusion_matrix")
|
|
79
|
+
"""
|
|
80
|
+
preds, labels = self._expand_confusion_matrix()
|
|
81
|
+
|
|
82
|
+
if len(preds) == 0:
|
|
83
|
+
logger.warning(f"No samples to log for {name}")
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
num_classes = self.confusion_matrix.shape[0]
|
|
87
|
+
if self.class_names is None:
|
|
88
|
+
class_names = [str(i) for i in range(num_classes)]
|
|
89
|
+
else:
|
|
90
|
+
class_names = self.class_names
|
|
91
|
+
|
|
92
|
+
wandb.log(
|
|
93
|
+
{
|
|
94
|
+
name: wandb.plot.confusion_matrix(
|
|
95
|
+
preds=preds,
|
|
96
|
+
y_true=labels,
|
|
97
|
+
class_names=class_names,
|
|
98
|
+
title=name,
|
|
99
|
+
),
|
|
100
|
+
},
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ConfusionMatrixMetric(Metric):
|
|
105
|
+
"""Confusion matrix metric that works on flattened inputs.
|
|
106
|
+
|
|
107
|
+
Expects preds of shape (N, C) and labels of shape (N,).
|
|
108
|
+
Should be wrapped by ClassificationMetric or SegmentationMetric
|
|
109
|
+
which handle the task-specific preprocessing.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
num_classes: number of classes
|
|
113
|
+
class_names: optional list of class names for labeling
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
num_classes: int,
|
|
119
|
+
class_names: list[str] | None = None,
|
|
120
|
+
):
|
|
121
|
+
"""Initialize a new ConfusionMatrixMetric.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
num_classes: number of classes
|
|
125
|
+
class_names: optional list of class names for labeling
|
|
126
|
+
"""
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.num_classes = num_classes
|
|
129
|
+
self.class_names = class_names
|
|
130
|
+
self.add_state(
|
|
131
|
+
"confusion_matrix",
|
|
132
|
+
default=torch.zeros(num_classes, num_classes, dtype=torch.long),
|
|
133
|
+
dist_reduce_fx="sum",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
137
|
+
"""Update metric.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
preds: predictions of shape (N, C) - probabilities
|
|
141
|
+
labels: ground truth of shape (N,) - class indices
|
|
142
|
+
"""
|
|
143
|
+
if len(preds) == 0:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
pred_classes = preds.argmax(dim=1) # (N,)
|
|
147
|
+
|
|
148
|
+
for true_label in range(self.num_classes):
|
|
149
|
+
for pred_label in range(self.num_classes):
|
|
150
|
+
count = ((labels == true_label) & (pred_classes == pred_label)).sum()
|
|
151
|
+
self.confusion_matrix[true_label, pred_label] += count
|
|
152
|
+
|
|
153
|
+
def compute(self) -> ConfusionMatrixOutput:
|
|
154
|
+
"""Returns the confusion matrix wrapped in ConfusionMatrixOutput."""
|
|
155
|
+
return ConfusionMatrixOutput(
|
|
156
|
+
confusion_matrix=self.confusion_matrix,
|
|
157
|
+
class_names=self.class_names,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def reset(self) -> None:
|
|
161
|
+
"""Reset metric."""
|
|
162
|
+
super().reset()
|
rslearn/train/model_context.py
CHANGED
|
@@ -67,9 +67,9 @@ class SampleMetadata:
|
|
|
67
67
|
window_group: str
|
|
68
68
|
window_name: str
|
|
69
69
|
window_bounds: PixelBounds
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
70
|
+
crop_bounds: PixelBounds
|
|
71
|
+
crop_idx: int
|
|
72
|
+
num_crops_in_window: int
|
|
73
73
|
time_range: tuple[datetime, datetime] | None
|
|
74
74
|
projection: Projection
|
|
75
75
|
|