rslearn 0.0.27__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/dataset/storage/file.py +16 -12
- rslearn/tile_stores/default.py +4 -2
- rslearn/train/data_module.py +10 -7
- rslearn/train/dataset.py +118 -74
- rslearn/train/lightning_module.py +59 -3
- rslearn/train/metrics.py +162 -0
- rslearn/train/tasks/classification.py +13 -0
- rslearn/train/tasks/per_pixel_regression.py +19 -6
- rslearn/train/tasks/regression.py +18 -2
- rslearn/train/tasks/segmentation.py +17 -0
- rslearn/utils/fsspec.py +51 -1
- {rslearn-0.0.27.dist-info → rslearn-0.0.28.dist-info}/METADATA +1 -1
- {rslearn-0.0.27.dist-info → rslearn-0.0.28.dist-info}/RECORD +18 -17
- {rslearn-0.0.27.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
- {rslearn-0.0.27.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.27.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.27.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.27.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
rslearn/dataset/storage/file.py
CHANGED
|
@@ -15,7 +15,7 @@ from rslearn.dataset.window import (
|
|
|
15
15
|
get_window_layer_dir,
|
|
16
16
|
)
|
|
17
17
|
from rslearn.log_utils import get_logger
|
|
18
|
-
from rslearn.utils.fsspec import open_atomic
|
|
18
|
+
from rslearn.utils.fsspec import iter_nonhidden_subdirs, open_atomic
|
|
19
19
|
from rslearn.utils.mp import star_imap_unordered
|
|
20
20
|
|
|
21
21
|
from .storage import WindowStorage, WindowStorageFactory
|
|
@@ -77,8 +77,8 @@ class FileWindowStorage(WindowStorage):
|
|
|
77
77
|
window_dirs = []
|
|
78
78
|
if not groups:
|
|
79
79
|
groups = []
|
|
80
|
-
for
|
|
81
|
-
groups.append(
|
|
80
|
+
for group_dir in iter_nonhidden_subdirs(self.path / "windows"):
|
|
81
|
+
groups.append(group_dir.name)
|
|
82
82
|
for group in groups:
|
|
83
83
|
group_dir = self.path / "windows" / group
|
|
84
84
|
if not group_dir.exists():
|
|
@@ -86,16 +86,20 @@ class FileWindowStorage(WindowStorage):
|
|
|
86
86
|
f"Skipping group directory {group_dir} since it does not exist"
|
|
87
87
|
)
|
|
88
88
|
continue
|
|
89
|
+
if not group_dir.is_dir():
|
|
90
|
+
logger.warning(
|
|
91
|
+
f"Skipping group path {group_dir} since it is not a directory"
|
|
92
|
+
)
|
|
93
|
+
continue
|
|
89
94
|
if names:
|
|
90
|
-
|
|
95
|
+
for window_name in names:
|
|
96
|
+
window_dir = group_dir / window_name
|
|
97
|
+
if not window_dir.is_dir():
|
|
98
|
+
continue
|
|
99
|
+
window_dirs.append(window_dir)
|
|
91
100
|
else:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
cur_names.append(p.name)
|
|
95
|
-
|
|
96
|
-
for window_name in cur_names:
|
|
97
|
-
window_dir = group_dir / window_name
|
|
98
|
-
window_dirs.append(window_dir)
|
|
101
|
+
for window_dir in iter_nonhidden_subdirs(group_dir):
|
|
102
|
+
window_dirs.append(window_dir)
|
|
99
103
|
|
|
100
104
|
if workers == 0:
|
|
101
105
|
windows = [load_window(self, window_dir) for window_dir in window_dirs]
|
|
@@ -162,7 +166,7 @@ class FileWindowStorage(WindowStorage):
|
|
|
162
166
|
return []
|
|
163
167
|
|
|
164
168
|
completed_layers = []
|
|
165
|
-
for layer_dir in layers_directory
|
|
169
|
+
for layer_dir in iter_nonhidden_subdirs(layers_directory):
|
|
166
170
|
layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
|
|
167
171
|
if not self.is_layer_completed(group, name, layer_name, group_idx):
|
|
168
172
|
continue
|
rslearn/tile_stores/default.py
CHANGED
|
@@ -15,6 +15,8 @@ from upath import UPath
|
|
|
15
15
|
from rslearn.const import WGS84_PROJECTION
|
|
16
16
|
from rslearn.utils.feature import Feature
|
|
17
17
|
from rslearn.utils.fsspec import (
|
|
18
|
+
iter_nonhidden_files,
|
|
19
|
+
iter_nonhidden_subdirs,
|
|
18
20
|
join_upath,
|
|
19
21
|
open_atomic,
|
|
20
22
|
open_rasterio_upath_reader,
|
|
@@ -129,7 +131,7 @@ class DefaultTileStore(TileStore):
|
|
|
129
131
|
ValueError: if no file is found.
|
|
130
132
|
"""
|
|
131
133
|
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
|
|
132
|
-
for fname in raster_dir
|
|
134
|
+
for fname in iter_nonhidden_files(raster_dir):
|
|
133
135
|
# Ignore completed sentinel files, bands files, as well as temporary files created by
|
|
134
136
|
# open_atomic (in case this tile store is on local filesystem).
|
|
135
137
|
if fname.name == COMPLETED_FNAME:
|
|
@@ -175,7 +177,7 @@ class DefaultTileStore(TileStore):
|
|
|
175
177
|
return []
|
|
176
178
|
|
|
177
179
|
bands: list[list[str]] = []
|
|
178
|
-
for raster_dir in item_dir
|
|
180
|
+
for raster_dir in iter_nonhidden_subdirs(item_dir):
|
|
179
181
|
if not (raster_dir / BANDS_FNAME).exists():
|
|
180
182
|
# This is likely a legacy directory where the bands are only encoded in
|
|
181
183
|
# the directory name, so we have to rely on that.
|
rslearn/train/data_module.py
CHANGED
|
@@ -108,10 +108,10 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
108
108
|
self.use_in_memory_all_crops_dataset = use_in_memory_all_crops_dataset
|
|
109
109
|
self.index_mode = index_mode
|
|
110
110
|
self.split_configs = {
|
|
111
|
-
"train":
|
|
112
|
-
"val":
|
|
113
|
-
"test":
|
|
114
|
-
"predict":
|
|
111
|
+
"train": SplitConfig.merge_and_validate([default_config, train_config]),
|
|
112
|
+
"val": SplitConfig.merge_and_validate([default_config, val_config]),
|
|
113
|
+
"test": SplitConfig.merge_and_validate([default_config, test_config]),
|
|
114
|
+
"predict": SplitConfig.merge_and_validate([default_config, predict_config]),
|
|
115
115
|
}
|
|
116
116
|
|
|
117
117
|
def setup(
|
|
@@ -141,7 +141,7 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
141
141
|
task=self.task,
|
|
142
142
|
workers=self.init_workers,
|
|
143
143
|
name=self.name,
|
|
144
|
-
|
|
144
|
+
fix_crop_pick=(split != "train"),
|
|
145
145
|
index_mode=self.index_mode,
|
|
146
146
|
)
|
|
147
147
|
logger.info(f"got {len(dataset)} examples in split {split}")
|
|
@@ -203,13 +203,16 @@ class RslearnDataModule(L.LightningDataModule):
|
|
|
203
203
|
# Enable persistent workers unless we are using main process.
|
|
204
204
|
persistent_workers = self.num_workers > 0
|
|
205
205
|
|
|
206
|
-
# If using all
|
|
206
|
+
# If using all crops, limit number of workers to the number of windows.
|
|
207
207
|
# Otherwise it has to distribute the same window to different workers which can
|
|
208
208
|
# cause issues for RslearnWriter.
|
|
209
209
|
# If the number of windows is 0, then we can set positive number of workers
|
|
210
210
|
# since they won't yield anything anyway.
|
|
211
211
|
num_workers = self.num_workers
|
|
212
|
-
if
|
|
212
|
+
if (
|
|
213
|
+
split_config.get_load_all_crops()
|
|
214
|
+
and len(dataset.get_dataset_examples()) > 0
|
|
215
|
+
):
|
|
213
216
|
num_workers = min(num_workers, len(dataset.get_dataset_examples()))
|
|
214
217
|
|
|
215
218
|
kwargs: dict[str, Any] = dict(
|
rslearn/train/dataset.py
CHANGED
|
@@ -496,53 +496,6 @@ class SplitConfig:
|
|
|
496
496
|
overlap_ratio: deprecated, use overlap_pixels instead
|
|
497
497
|
load_all_patches: deprecated, use load_all_crops instead
|
|
498
498
|
"""
|
|
499
|
-
# Handle deprecated load_all_patches parameter
|
|
500
|
-
if load_all_patches is not None:
|
|
501
|
-
warnings.warn(
|
|
502
|
-
"load_all_patches is deprecated, use load_all_crops instead",
|
|
503
|
-
FutureWarning,
|
|
504
|
-
stacklevel=2,
|
|
505
|
-
)
|
|
506
|
-
if load_all_crops is not None:
|
|
507
|
-
raise ValueError(
|
|
508
|
-
"Cannot specify both load_all_patches and load_all_crops"
|
|
509
|
-
)
|
|
510
|
-
load_all_crops = load_all_patches
|
|
511
|
-
# Handle deprecated patch_size parameter
|
|
512
|
-
if patch_size is not None:
|
|
513
|
-
warnings.warn(
|
|
514
|
-
"patch_size is deprecated, use crop_size instead",
|
|
515
|
-
FutureWarning,
|
|
516
|
-
stacklevel=2,
|
|
517
|
-
)
|
|
518
|
-
if crop_size is not None:
|
|
519
|
-
raise ValueError("Cannot specify both patch_size and crop_size")
|
|
520
|
-
crop_size = patch_size
|
|
521
|
-
|
|
522
|
-
# Normalize crop_size to tuple[int, int] | None
|
|
523
|
-
self.crop_size: tuple[int, int] | None = None
|
|
524
|
-
if crop_size is not None:
|
|
525
|
-
if isinstance(crop_size, int):
|
|
526
|
-
self.crop_size = (crop_size, crop_size)
|
|
527
|
-
else:
|
|
528
|
-
self.crop_size = crop_size
|
|
529
|
-
|
|
530
|
-
# Handle deprecated overlap_ratio parameter
|
|
531
|
-
if overlap_ratio is not None:
|
|
532
|
-
warnings.warn(
|
|
533
|
-
"overlap_ratio is deprecated, use overlap_pixels instead",
|
|
534
|
-
FutureWarning,
|
|
535
|
-
stacklevel=2,
|
|
536
|
-
)
|
|
537
|
-
if overlap_pixels is not None:
|
|
538
|
-
raise ValueError("Cannot specify both overlap_ratio and overlap_pixels")
|
|
539
|
-
if self.crop_size is None:
|
|
540
|
-
raise ValueError("overlap_ratio requires crop_size to be set")
|
|
541
|
-
overlap_pixels = round(self.crop_size[0] * overlap_ratio)
|
|
542
|
-
|
|
543
|
-
if overlap_pixels is not None and overlap_pixels < 0:
|
|
544
|
-
raise ValueError("overlap_pixels must be non-negative")
|
|
545
|
-
|
|
546
499
|
self.groups = groups
|
|
547
500
|
self.names = names
|
|
548
501
|
self.tags = tags
|
|
@@ -555,13 +508,22 @@ class SplitConfig:
|
|
|
555
508
|
output_layer_name_skip_inference_if_exists
|
|
556
509
|
)
|
|
557
510
|
|
|
558
|
-
#
|
|
559
|
-
#
|
|
560
|
-
|
|
561
|
-
|
|
511
|
+
# These have deprecated equivalents -- we store both raw values since we don't
|
|
512
|
+
# have a complete picture until the final merged SplitConfig is computed. We
|
|
513
|
+
# raise deprecation warnings in merge_and_validate and we disambiguate them in
|
|
514
|
+
# get_ functions (so the variables should never be accessed directly).
|
|
515
|
+
self._crop_size = crop_size
|
|
516
|
+
self._patch_size = patch_size
|
|
517
|
+
self._overlap_pixels = overlap_pixels
|
|
518
|
+
self._overlap_ratio = overlap_ratio
|
|
519
|
+
self._load_all_crops = load_all_crops
|
|
520
|
+
self._load_all_patches = load_all_patches
|
|
562
521
|
|
|
563
|
-
def
|
|
564
|
-
"""
|
|
522
|
+
def _merge(self, other: "SplitConfig") -> "SplitConfig":
|
|
523
|
+
"""Merge settings from another SplitConfig into this one.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
other: the config to merge in (its non-None values override self's)
|
|
565
527
|
|
|
566
528
|
Returns:
|
|
567
529
|
the resulting SplitConfig combining the settings.
|
|
@@ -574,9 +536,12 @@ class SplitConfig:
|
|
|
574
536
|
num_patches=self.num_patches,
|
|
575
537
|
transforms=self.transforms,
|
|
576
538
|
sampler=self.sampler,
|
|
577
|
-
crop_size=self.
|
|
578
|
-
|
|
579
|
-
|
|
539
|
+
crop_size=self._crop_size,
|
|
540
|
+
patch_size=self._patch_size,
|
|
541
|
+
overlap_pixels=self._overlap_pixels,
|
|
542
|
+
overlap_ratio=self._overlap_ratio,
|
|
543
|
+
load_all_crops=self._load_all_crops,
|
|
544
|
+
load_all_patches=self._load_all_patches,
|
|
580
545
|
skip_targets=self.skip_targets,
|
|
581
546
|
output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
|
|
582
547
|
)
|
|
@@ -594,12 +559,18 @@ class SplitConfig:
|
|
|
594
559
|
result.transforms = other.transforms
|
|
595
560
|
if other.sampler:
|
|
596
561
|
result.sampler = other.sampler
|
|
597
|
-
if other.
|
|
598
|
-
result.
|
|
599
|
-
if other.
|
|
600
|
-
result.
|
|
601
|
-
if other.
|
|
602
|
-
result.
|
|
562
|
+
if other._crop_size is not None:
|
|
563
|
+
result._crop_size = other._crop_size
|
|
564
|
+
if other._patch_size is not None:
|
|
565
|
+
result._patch_size = other._patch_size
|
|
566
|
+
if other._overlap_pixels is not None:
|
|
567
|
+
result._overlap_pixels = other._overlap_pixels
|
|
568
|
+
if other._overlap_ratio is not None:
|
|
569
|
+
result._overlap_ratio = other._overlap_ratio
|
|
570
|
+
if other._load_all_crops is not None:
|
|
571
|
+
result._load_all_crops = other._load_all_crops
|
|
572
|
+
if other._load_all_patches is not None:
|
|
573
|
+
result._load_all_patches = other._load_all_patches
|
|
603
574
|
if other.skip_targets is not None:
|
|
604
575
|
result.skip_targets = other.skip_targets
|
|
605
576
|
if other.output_layer_name_skip_inference_if_exists is not None:
|
|
@@ -608,17 +579,90 @@ class SplitConfig:
|
|
|
608
579
|
)
|
|
609
580
|
return result
|
|
610
581
|
|
|
582
|
+
@staticmethod
|
|
583
|
+
def merge_and_validate(configs: list["SplitConfig"]) -> "SplitConfig":
|
|
584
|
+
"""Merge a list of SplitConfigs and validate the result.
|
|
585
|
+
|
|
586
|
+
Args:
|
|
587
|
+
configs: list of SplitConfig to merge. Later configs override earlier ones.
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
the merged and validated SplitConfig.
|
|
591
|
+
"""
|
|
592
|
+
if not configs:
|
|
593
|
+
return SplitConfig()
|
|
594
|
+
|
|
595
|
+
result = configs[0]
|
|
596
|
+
for config in configs[1:]:
|
|
597
|
+
result = result._merge(config)
|
|
598
|
+
|
|
599
|
+
# Emit deprecation warnings
|
|
600
|
+
if result._patch_size is not None:
|
|
601
|
+
warnings.warn(
|
|
602
|
+
"patch_size is deprecated, use crop_size instead",
|
|
603
|
+
FutureWarning,
|
|
604
|
+
stacklevel=2,
|
|
605
|
+
)
|
|
606
|
+
if result._overlap_ratio is not None:
|
|
607
|
+
warnings.warn(
|
|
608
|
+
"overlap_ratio is deprecated, use overlap_pixels instead",
|
|
609
|
+
FutureWarning,
|
|
610
|
+
stacklevel=2,
|
|
611
|
+
)
|
|
612
|
+
if result._load_all_patches is not None:
|
|
613
|
+
warnings.warn(
|
|
614
|
+
"load_all_patches is deprecated, use load_all_crops instead",
|
|
615
|
+
FutureWarning,
|
|
616
|
+
stacklevel=2,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Check for conflicting parameters
|
|
620
|
+
if result._crop_size is not None and result._patch_size is not None:
|
|
621
|
+
raise ValueError("Cannot specify both crop_size and patch_size")
|
|
622
|
+
if result._overlap_pixels is not None and result._overlap_ratio is not None:
|
|
623
|
+
raise ValueError("Cannot specify both overlap_pixels and overlap_ratio")
|
|
624
|
+
if result._load_all_crops is not None and result._load_all_patches is not None:
|
|
625
|
+
raise ValueError("Cannot specify both load_all_crops and load_all_patches")
|
|
626
|
+
|
|
627
|
+
# Validate overlap_pixels is non-negative
|
|
628
|
+
if result._overlap_pixels is not None and result._overlap_pixels < 0:
|
|
629
|
+
raise ValueError("overlap_pixels must be non-negative")
|
|
630
|
+
|
|
631
|
+
# overlap_pixels requires load_all_crops.
|
|
632
|
+
if result.get_overlap_pixels() > 0 and not result.get_load_all_crops():
|
|
633
|
+
raise ValueError(
|
|
634
|
+
"overlap_pixels requires load_all_crops to be True since (overlap is only used during sliding window inference"
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
return result
|
|
638
|
+
|
|
611
639
|
def get_crop_size(self) -> tuple[int, int] | None:
|
|
612
|
-
"""Get crop size as tuple."""
|
|
613
|
-
|
|
640
|
+
"""Get crop size as tuple, handling deprecated patch_size."""
|
|
641
|
+
size = self._crop_size if self._crop_size is not None else self._patch_size
|
|
642
|
+
if size is None:
|
|
643
|
+
return None
|
|
644
|
+
if isinstance(size, int):
|
|
645
|
+
return (size, size)
|
|
646
|
+
return size
|
|
614
647
|
|
|
615
648
|
def get_overlap_pixels(self) -> int:
|
|
616
|
-
"""Get the overlap pixels (default 0)."""
|
|
617
|
-
|
|
649
|
+
"""Get the overlap pixels (default 0), handling deprecated overlap_ratio."""
|
|
650
|
+
if self._overlap_pixels is not None:
|
|
651
|
+
return self._overlap_pixels
|
|
652
|
+
if self._overlap_ratio is not None:
|
|
653
|
+
crop_size = self.get_crop_size()
|
|
654
|
+
if crop_size is None:
|
|
655
|
+
raise ValueError("overlap_ratio requires crop_size to be set")
|
|
656
|
+
return round(crop_size[0] * self._overlap_ratio)
|
|
657
|
+
return 0
|
|
618
658
|
|
|
619
659
|
def get_load_all_crops(self) -> bool:
|
|
620
|
-
"""Returns whether loading all
|
|
621
|
-
|
|
660
|
+
"""Returns whether loading all crops is enabled (default False)."""
|
|
661
|
+
if self._load_all_crops is not None:
|
|
662
|
+
return self._load_all_crops
|
|
663
|
+
if self._load_all_patches is not None:
|
|
664
|
+
return self._load_all_patches
|
|
665
|
+
return False
|
|
622
666
|
|
|
623
667
|
def get_skip_targets(self) -> bool:
|
|
624
668
|
"""Returns whether skip_targets is enabled (default False)."""
|
|
@@ -697,7 +741,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
697
741
|
task: Task,
|
|
698
742
|
workers: int,
|
|
699
743
|
name: str | None = None,
|
|
700
|
-
|
|
744
|
+
fix_crop_pick: bool = False,
|
|
701
745
|
index_mode: IndexMode = IndexMode.OFF,
|
|
702
746
|
) -> None:
|
|
703
747
|
"""Instantiate a new ModelDataset.
|
|
@@ -709,7 +753,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
709
753
|
task: the task to train on
|
|
710
754
|
workers: number of workers to use for initializing the dataset
|
|
711
755
|
name: name of the dataset
|
|
712
|
-
|
|
756
|
+
fix_crop_pick: if True, fix the crop pick to be the same every time
|
|
713
757
|
for a given window. Useful for testing (default: False)
|
|
714
758
|
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
715
759
|
"""
|
|
@@ -718,14 +762,14 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
718
762
|
self.inputs = inputs
|
|
719
763
|
self.task = task
|
|
720
764
|
self.name = name
|
|
721
|
-
self.
|
|
765
|
+
self.fix_crop_pick = fix_crop_pick
|
|
722
766
|
if split_config.transforms:
|
|
723
767
|
self.transforms = Sequential(*split_config.transforms)
|
|
724
768
|
else:
|
|
725
769
|
self.transforms = rslearn.train.transforms.transform.Identity()
|
|
726
770
|
|
|
727
771
|
# Get normalized crop size from the SplitConfig.
|
|
728
|
-
# But if
|
|
772
|
+
# But if load_all_crops is enabled, this is handled by AllCropsDataset, so
|
|
729
773
|
# here we instead load the entire windows.
|
|
730
774
|
if split_config.get_load_all_crops():
|
|
731
775
|
self.crop_size = None
|
|
@@ -952,7 +996,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
952
996
|
"""Get a list of examples in the dataset.
|
|
953
997
|
|
|
954
998
|
If load_all_crops is False, this is a list of Windows. Otherwise, this is a
|
|
955
|
-
list of (window, crop_bounds, (crop_idx, #
|
|
999
|
+
list of (window, crop_bounds, (crop_idx, # crops)) tuples.
|
|
956
1000
|
"""
|
|
957
1001
|
if self.dataset_examples is None:
|
|
958
1002
|
logger.debug(
|
|
@@ -985,7 +1029,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
985
1029
|
"""
|
|
986
1030
|
dataset_examples = self.get_dataset_examples()
|
|
987
1031
|
example = dataset_examples[idx]
|
|
988
|
-
rng = random.Random(idx if self.
|
|
1032
|
+
rng = random.Random(idx if self.fix_crop_pick else None)
|
|
989
1033
|
|
|
990
1034
|
# Select bounds to read.
|
|
991
1035
|
if self.crop_size:
|
|
@@ -6,12 +6,14 @@ from typing import Any
|
|
|
6
6
|
|
|
7
7
|
import lightning as L
|
|
8
8
|
import torch
|
|
9
|
+
import wandb
|
|
9
10
|
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
|
|
10
11
|
from PIL import Image
|
|
11
12
|
from upath import UPath
|
|
12
13
|
|
|
13
14
|
from rslearn.log_utils import get_logger
|
|
14
15
|
|
|
16
|
+
from .metrics import NonScalarMetricOutput
|
|
15
17
|
from .model_context import ModelContext, ModelOutput
|
|
16
18
|
from .optimizer import AdamW, OptimizerFactory
|
|
17
19
|
from .scheduler import PlateauScheduler, SchedulerFactory
|
|
@@ -210,15 +212,53 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
210
212
|
# Fail silently for single-dataset case, which is okay
|
|
211
213
|
pass
|
|
212
214
|
|
|
215
|
+
def _log_non_scalar_metric(self, name: str, value: NonScalarMetricOutput) -> None:
|
|
216
|
+
"""Log a non-scalar metric to wandb.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
name: the metric name (e.g., "val_confusion_matrix")
|
|
220
|
+
value: the non-scalar metric output
|
|
221
|
+
"""
|
|
222
|
+
# The non-scalar metrics are logging directly without Lightning
|
|
223
|
+
# So we need to skip logging during sanity check.
|
|
224
|
+
if self.trainer.sanity_checking:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
# Wandb is required for logging non-scalar metrics.
|
|
228
|
+
if not wandb.run:
|
|
229
|
+
logger.warning(
|
|
230
|
+
f"Weights & Biases is not initialized, skipping logging of {name}"
|
|
231
|
+
)
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
value.log_to_wandb(name)
|
|
235
|
+
|
|
213
236
|
def on_validation_epoch_end(self) -> None:
|
|
214
237
|
"""Compute and log validation metrics at epoch end.
|
|
215
238
|
|
|
216
239
|
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
217
240
|
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
218
241
|
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
242
|
+
|
|
243
|
+
Non-scalar metrics (like confusion matrices) are logged separately using
|
|
244
|
+
logger-specific APIs.
|
|
219
245
|
"""
|
|
220
246
|
metrics = self.val_metrics.compute()
|
|
221
|
-
|
|
247
|
+
|
|
248
|
+
# Separate scalar and non-scalar metrics
|
|
249
|
+
scalar_metrics = {}
|
|
250
|
+
for k, v in metrics.items():
|
|
251
|
+
if isinstance(v, NonScalarMetricOutput):
|
|
252
|
+
self._log_non_scalar_metric(k, v)
|
|
253
|
+
elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
|
|
256
|
+
"Wrap it in a NonScalarMetricOutput subclass."
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
scalar_metrics[k] = v
|
|
260
|
+
|
|
261
|
+
self.log_dict(scalar_metrics)
|
|
222
262
|
self.val_metrics.reset()
|
|
223
263
|
|
|
224
264
|
def on_test_epoch_end(self) -> None:
|
|
@@ -227,14 +267,30 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
227
267
|
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
228
268
|
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
229
269
|
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
270
|
+
|
|
271
|
+
Non-scalar metrics (like confusion matrices) are logged separately.
|
|
230
272
|
"""
|
|
231
273
|
metrics = self.test_metrics.compute()
|
|
232
|
-
|
|
274
|
+
|
|
275
|
+
# Separate scalar and non-scalar metrics
|
|
276
|
+
scalar_metrics = {}
|
|
277
|
+
for k, v in metrics.items():
|
|
278
|
+
if isinstance(v, NonScalarMetricOutput):
|
|
279
|
+
self._log_non_scalar_metric(k, v)
|
|
280
|
+
elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
|
|
281
|
+
raise ValueError(
|
|
282
|
+
f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
|
|
283
|
+
"Wrap it in a NonScalarMetricOutput subclass."
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
scalar_metrics[k] = v
|
|
287
|
+
|
|
288
|
+
self.log_dict(scalar_metrics)
|
|
233
289
|
self.test_metrics.reset()
|
|
234
290
|
|
|
235
291
|
if self.metrics_file:
|
|
236
292
|
with open(self.metrics_file, "w") as f:
|
|
237
|
-
metrics_dict = {k: v.item() for k, v in
|
|
293
|
+
metrics_dict = {k: v.item() for k, v in scalar_metrics.items()}
|
|
238
294
|
json.dump(metrics_dict, f, indent=4)
|
|
239
295
|
logger.info(f"Saved metrics to {self.metrics_file}")
|
|
240
296
|
|
rslearn/train/metrics.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Metric output classes for non-scalar metrics."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import wandb
|
|
8
|
+
from torchmetrics import Metric
|
|
9
|
+
|
|
10
|
+
from rslearn.log_utils import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class NonScalarMetricOutput(ABC):
|
|
17
|
+
"""Base class for non-scalar metric outputs that need special logging.
|
|
18
|
+
|
|
19
|
+
Subclasses should implement the log_to_wandb method to define how the metric
|
|
20
|
+
should be logged (only supports logging to Weights & Biases).
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def log_to_wandb(self, name: str) -> None:
|
|
25
|
+
"""Log this metric to wandb.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
name: the metric name
|
|
29
|
+
"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ConfusionMatrixOutput(NonScalarMetricOutput):
|
|
35
|
+
"""Confusion matrix metric output.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
confusion_matrix: confusion matrix of shape (num_classes, num_classes)
|
|
39
|
+
where cm[i, j] is the count of samples with true label i and predicted
|
|
40
|
+
label j.
|
|
41
|
+
class_names: optional list of class names for axis labels
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
confusion_matrix: torch.Tensor
|
|
45
|
+
class_names: list[str] | None = None
|
|
46
|
+
|
|
47
|
+
def _expand_confusion_matrix(self) -> tuple[list[int], list[int]]:
|
|
48
|
+
"""Expand confusion matrix to (preds, labels) pairs for wandb.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple of (preds, labels) as lists of integers.
|
|
52
|
+
"""
|
|
53
|
+
cm = self.confusion_matrix.detach().cpu()
|
|
54
|
+
|
|
55
|
+
# Handle extra dimensions from distributed reduction
|
|
56
|
+
if cm.dim() > 2:
|
|
57
|
+
cm = cm.sum(dim=0)
|
|
58
|
+
|
|
59
|
+
total = cm.sum().item()
|
|
60
|
+
if total == 0:
|
|
61
|
+
return [], []
|
|
62
|
+
|
|
63
|
+
preds = []
|
|
64
|
+
labels = []
|
|
65
|
+
for true_label in range(cm.shape[0]):
|
|
66
|
+
for pred_label in range(cm.shape[1]):
|
|
67
|
+
count = cm[true_label, pred_label].item()
|
|
68
|
+
if count > 0:
|
|
69
|
+
preds.extend([pred_label] * int(count))
|
|
70
|
+
labels.extend([true_label] * int(count))
|
|
71
|
+
|
|
72
|
+
return preds, labels
|
|
73
|
+
|
|
74
|
+
def log_to_wandb(self, name: str) -> None:
|
|
75
|
+
"""Log confusion matrix to wandb.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
name: the metric name (e.g., "val_confusion_matrix")
|
|
79
|
+
"""
|
|
80
|
+
preds, labels = self._expand_confusion_matrix()
|
|
81
|
+
|
|
82
|
+
if len(preds) == 0:
|
|
83
|
+
logger.warning(f"No samples to log for {name}")
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
num_classes = self.confusion_matrix.shape[0]
|
|
87
|
+
if self.class_names is None:
|
|
88
|
+
class_names = [str(i) for i in range(num_classes)]
|
|
89
|
+
else:
|
|
90
|
+
class_names = self.class_names
|
|
91
|
+
|
|
92
|
+
wandb.log(
|
|
93
|
+
{
|
|
94
|
+
name: wandb.plot.confusion_matrix(
|
|
95
|
+
preds=preds,
|
|
96
|
+
y_true=labels,
|
|
97
|
+
class_names=class_names,
|
|
98
|
+
title=name,
|
|
99
|
+
),
|
|
100
|
+
},
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ConfusionMatrixMetric(Metric):
|
|
105
|
+
"""Confusion matrix metric that works on flattened inputs.
|
|
106
|
+
|
|
107
|
+
Expects preds of shape (N, C) and labels of shape (N,).
|
|
108
|
+
Should be wrapped by ClassificationMetric or SegmentationMetric
|
|
109
|
+
which handle the task-specific preprocessing.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
num_classes: number of classes
|
|
113
|
+
class_names: optional list of class names for labeling
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
num_classes: int,
|
|
119
|
+
class_names: list[str] | None = None,
|
|
120
|
+
):
|
|
121
|
+
"""Initialize a new ConfusionMatrixMetric.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
num_classes: number of classes
|
|
125
|
+
class_names: optional list of class names for labeling
|
|
126
|
+
"""
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.num_classes = num_classes
|
|
129
|
+
self.class_names = class_names
|
|
130
|
+
self.add_state(
|
|
131
|
+
"confusion_matrix",
|
|
132
|
+
default=torch.zeros(num_classes, num_classes, dtype=torch.long),
|
|
133
|
+
dist_reduce_fx="sum",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
137
|
+
"""Update metric.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
preds: predictions of shape (N, C) - probabilities
|
|
141
|
+
labels: ground truth of shape (N,) - class indices
|
|
142
|
+
"""
|
|
143
|
+
if len(preds) == 0:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
pred_classes = preds.argmax(dim=1) # (N,)
|
|
147
|
+
|
|
148
|
+
for true_label in range(self.num_classes):
|
|
149
|
+
for pred_label in range(self.num_classes):
|
|
150
|
+
count = ((labels == true_label) & (pred_classes == pred_label)).sum()
|
|
151
|
+
self.confusion_matrix[true_label, pred_label] += count
|
|
152
|
+
|
|
153
|
+
def compute(self) -> ConfusionMatrixOutput:
|
|
154
|
+
"""Returns the confusion matrix wrapped in ConfusionMatrixOutput."""
|
|
155
|
+
return ConfusionMatrixOutput(
|
|
156
|
+
confusion_matrix=self.confusion_matrix,
|
|
157
|
+
class_names=self.class_names,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def reset(self) -> None:
|
|
161
|
+
"""Reset metric."""
|
|
162
|
+
super().reset()
|
|
@@ -16,6 +16,7 @@ from torchmetrics.classification import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from rslearn.models.component import FeatureVector, Predictor
|
|
19
|
+
from rslearn.train.metrics import ConfusionMatrixMetric
|
|
19
20
|
from rslearn.train.model_context import (
|
|
20
21
|
ModelContext,
|
|
21
22
|
ModelOutput,
|
|
@@ -44,6 +45,7 @@ class ClassificationTask(BasicTask):
|
|
|
44
45
|
f1_metric_kwargs: dict[str, Any] = {},
|
|
45
46
|
positive_class: str | None = None,
|
|
46
47
|
positive_class_threshold: float = 0.5,
|
|
48
|
+
enable_confusion_matrix: bool = False,
|
|
47
49
|
**kwargs: Any,
|
|
48
50
|
):
|
|
49
51
|
"""Initialize a new ClassificationTask.
|
|
@@ -69,6 +71,8 @@ class ClassificationTask(BasicTask):
|
|
|
69
71
|
positive_class: positive class name.
|
|
70
72
|
positive_class_threshold: threshold for classifying the positive class in
|
|
71
73
|
binary classification (default 0.5).
|
|
74
|
+
enable_confusion_matrix: whether to compute confusion matrix (default false).
|
|
75
|
+
If true, it requires wandb to be initialized for logging.
|
|
72
76
|
kwargs: other arguments to pass to BasicTask
|
|
73
77
|
"""
|
|
74
78
|
super().__init__(**kwargs)
|
|
@@ -84,6 +88,7 @@ class ClassificationTask(BasicTask):
|
|
|
84
88
|
self.f1_metric_kwargs = f1_metric_kwargs
|
|
85
89
|
self.positive_class = positive_class
|
|
86
90
|
self.positive_class_threshold = positive_class_threshold
|
|
91
|
+
self.enable_confusion_matrix = enable_confusion_matrix
|
|
87
92
|
|
|
88
93
|
if self.positive_class_threshold != 0.5:
|
|
89
94
|
# Must be binary classification
|
|
@@ -278,6 +283,14 @@ class ClassificationTask(BasicTask):
|
|
|
278
283
|
)
|
|
279
284
|
metrics["f1"] = ClassificationMetric(MulticlassF1Score(**kwargs))
|
|
280
285
|
|
|
286
|
+
if self.enable_confusion_matrix:
|
|
287
|
+
metrics["confusion_matrix"] = ClassificationMetric(
|
|
288
|
+
ConfusionMatrixMetric(
|
|
289
|
+
num_classes=len(self.classes),
|
|
290
|
+
class_names=self.classes,
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
|
|
281
294
|
return MetricCollection(metrics)
|
|
282
295
|
|
|
283
296
|
|
|
@@ -149,22 +149,28 @@ class PerPixelRegressionHead(Predictor):
|
|
|
149
149
|
"""Head for per-pixel regression task."""
|
|
150
150
|
|
|
151
151
|
def __init__(
|
|
152
|
-
self,
|
|
152
|
+
self,
|
|
153
|
+
loss_mode: Literal["mse", "l1", "huber"] = "mse",
|
|
154
|
+
use_sigmoid: bool = False,
|
|
155
|
+
huber_delta: float = 1.0,
|
|
153
156
|
):
|
|
154
|
-
"""Initialize a new
|
|
157
|
+
"""Initialize a new PerPixelRegressionHead.
|
|
155
158
|
|
|
156
159
|
Args:
|
|
157
|
-
loss_mode: the loss function to use
|
|
160
|
+
loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
|
|
158
161
|
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
159
162
|
requires targets to be between 0-1.
|
|
163
|
+
huber_delta: delta parameter for Huber loss (only used when
|
|
164
|
+
loss_mode="huber").
|
|
160
165
|
"""
|
|
161
166
|
super().__init__()
|
|
162
167
|
|
|
163
|
-
if loss_mode not in ["mse", "l1"]:
|
|
164
|
-
raise ValueError("invalid loss mode")
|
|
168
|
+
if loss_mode not in ["mse", "l1", "huber"]:
|
|
169
|
+
raise ValueError(f"invalid loss mode {loss_mode}")
|
|
165
170
|
|
|
166
171
|
self.loss_mode = loss_mode
|
|
167
172
|
self.use_sigmoid = use_sigmoid
|
|
173
|
+
self.huber_delta = huber_delta
|
|
168
174
|
|
|
169
175
|
def forward(
|
|
170
176
|
self,
|
|
@@ -217,8 +223,15 @@ class PerPixelRegressionHead(Predictor):
|
|
|
217
223
|
scores = torch.square(outputs - labels)
|
|
218
224
|
elif self.loss_mode == "l1":
|
|
219
225
|
scores = torch.abs(outputs - labels)
|
|
226
|
+
elif self.loss_mode == "huber":
|
|
227
|
+
scores = torch.nn.functional.huber_loss(
|
|
228
|
+
outputs,
|
|
229
|
+
labels,
|
|
230
|
+
reduction="none",
|
|
231
|
+
delta=self.huber_delta,
|
|
232
|
+
)
|
|
220
233
|
else:
|
|
221
|
-
|
|
234
|
+
raise ValueError(f"unknown loss mode {self.loss_mode}")
|
|
222
235
|
|
|
223
236
|
# Compute average but only over valid pixels.
|
|
224
237
|
mask_total = mask.sum()
|
|
@@ -196,18 +196,24 @@ class RegressionHead(Predictor):
|
|
|
196
196
|
"""Head for regression task."""
|
|
197
197
|
|
|
198
198
|
def __init__(
|
|
199
|
-
self,
|
|
199
|
+
self,
|
|
200
|
+
loss_mode: Literal["mse", "l1", "huber"] = "mse",
|
|
201
|
+
use_sigmoid: bool = False,
|
|
202
|
+
huber_delta: float = 1.0,
|
|
200
203
|
):
|
|
201
204
|
"""Initialize a new RegressionHead.
|
|
202
205
|
|
|
203
206
|
Args:
|
|
204
|
-
loss_mode: the loss function to use
|
|
207
|
+
loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
|
|
205
208
|
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
206
209
|
requires targets to be between 0-1.
|
|
210
|
+
huber_delta: delta parameter for Huber loss (only used when
|
|
211
|
+
loss_mode="huber").
|
|
207
212
|
"""
|
|
208
213
|
super().__init__()
|
|
209
214
|
self.loss_mode = loss_mode
|
|
210
215
|
self.use_sigmoid = use_sigmoid
|
|
216
|
+
self.huber_delta = huber_delta
|
|
211
217
|
|
|
212
218
|
def forward(
|
|
213
219
|
self,
|
|
@@ -251,6 +257,16 @@ class RegressionHead(Predictor):
|
|
|
251
257
|
losses["regress"] = torch.mean(torch.square(outputs - labels) * mask)
|
|
252
258
|
elif self.loss_mode == "l1":
|
|
253
259
|
losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
|
|
260
|
+
elif self.loss_mode == "huber":
|
|
261
|
+
losses["regress"] = torch.mean(
|
|
262
|
+
torch.nn.functional.huber_loss(
|
|
263
|
+
outputs,
|
|
264
|
+
labels,
|
|
265
|
+
reduction="none",
|
|
266
|
+
delta=self.huber_delta,
|
|
267
|
+
)
|
|
268
|
+
* mask
|
|
269
|
+
)
|
|
254
270
|
else:
|
|
255
271
|
raise ValueError(f"unknown loss mode {self.loss_mode}")
|
|
256
272
|
|
|
@@ -10,6 +10,7 @@ import torchmetrics.classification
|
|
|
10
10
|
from torchmetrics import Metric, MetricCollection
|
|
11
11
|
|
|
12
12
|
from rslearn.models.component import FeatureMaps, Predictor
|
|
13
|
+
from rslearn.train.metrics import ConfusionMatrixMetric
|
|
13
14
|
from rslearn.train.model_context import (
|
|
14
15
|
ModelContext,
|
|
15
16
|
ModelOutput,
|
|
@@ -43,6 +44,8 @@ class SegmentationTask(BasicTask):
|
|
|
43
44
|
other_metrics: dict[str, Metric] = {},
|
|
44
45
|
output_probs: bool = False,
|
|
45
46
|
output_class_idx: int | None = None,
|
|
47
|
+
enable_confusion_matrix: bool = False,
|
|
48
|
+
class_names: list[str] | None = None,
|
|
46
49
|
**kwargs: Any,
|
|
47
50
|
) -> None:
|
|
48
51
|
"""Initialize a new SegmentationTask.
|
|
@@ -80,6 +83,10 @@ class SegmentationTask(BasicTask):
|
|
|
80
83
|
during prediction.
|
|
81
84
|
output_class_idx: if set along with output_probs, only output the probability
|
|
82
85
|
for this specific class index (single-channel output).
|
|
86
|
+
enable_confusion_matrix: whether to compute confusion matrix (default false).
|
|
87
|
+
If true, it requires wandb to be initialized for logging.
|
|
88
|
+
class_names: optional list of class names for labeling confusion matrix axes.
|
|
89
|
+
If not provided, classes will be labeled as "class_0", "class_1", etc.
|
|
83
90
|
kwargs: additional arguments to pass to BasicTask
|
|
84
91
|
"""
|
|
85
92
|
super().__init__(**kwargs)
|
|
@@ -106,6 +113,8 @@ class SegmentationTask(BasicTask):
|
|
|
106
113
|
self.other_metrics = other_metrics
|
|
107
114
|
self.output_probs = output_probs
|
|
108
115
|
self.output_class_idx = output_class_idx
|
|
116
|
+
self.enable_confusion_matrix = enable_confusion_matrix
|
|
117
|
+
self.class_names = class_names
|
|
109
118
|
|
|
110
119
|
def process_inputs(
|
|
111
120
|
self,
|
|
@@ -285,6 +294,14 @@ class SegmentationTask(BasicTask):
|
|
|
285
294
|
if self.other_metrics:
|
|
286
295
|
metrics.update(self.other_metrics)
|
|
287
296
|
|
|
297
|
+
if self.enable_confusion_matrix:
|
|
298
|
+
metrics["confusion_matrix"] = SegmentationMetric(
|
|
299
|
+
ConfusionMatrixMetric(
|
|
300
|
+
num_classes=self.num_classes,
|
|
301
|
+
class_names=self.class_names,
|
|
302
|
+
),
|
|
303
|
+
)
|
|
304
|
+
|
|
288
305
|
return MetricCollection(metrics)
|
|
289
306
|
|
|
290
307
|
|
rslearn/utils/fsspec.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
@@ -16,6 +16,56 @@ from rslearn.log_utils import get_logger
|
|
|
16
16
|
logger = get_logger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
def iter_nonhidden(path: UPath) -> Iterator[UPath]:
|
|
20
|
+
"""Iterate over non-hidden entries in a directory.
|
|
21
|
+
|
|
22
|
+
Hidden entries are those whose basename starts with "." (e.g. ".DS_Store").
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
path: the directory to iterate.
|
|
26
|
+
|
|
27
|
+
Yields:
|
|
28
|
+
non-hidden UPath entries in the directory.
|
|
29
|
+
"""
|
|
30
|
+
try:
|
|
31
|
+
it = path.iterdir()
|
|
32
|
+
except (FileNotFoundError, NotADirectoryError):
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
for p in it:
|
|
36
|
+
if p.name.startswith("."):
|
|
37
|
+
continue
|
|
38
|
+
yield p
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def iter_nonhidden_subdirs(path: UPath) -> Iterator[UPath]:
|
|
42
|
+
"""Iterate over non-hidden subdirectories in a directory.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
path: the directory to iterate.
|
|
46
|
+
|
|
47
|
+
Yields:
|
|
48
|
+
non-hidden subdirectories in the directory.
|
|
49
|
+
"""
|
|
50
|
+
for p in iter_nonhidden(path):
|
|
51
|
+
if p.is_dir():
|
|
52
|
+
yield p
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def iter_nonhidden_files(path: UPath) -> Iterator[UPath]:
|
|
56
|
+
"""Iterate over non-hidden files in a directory.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
path: the directory to iterate.
|
|
60
|
+
|
|
61
|
+
Yields:
|
|
62
|
+
non-hidden files in the directory.
|
|
63
|
+
"""
|
|
64
|
+
for p in iter_nonhidden(path):
|
|
65
|
+
if p.is_file():
|
|
66
|
+
yield p
|
|
67
|
+
|
|
68
|
+
|
|
19
69
|
@contextmanager
|
|
20
70
|
def get_upath_local(
|
|
21
71
|
path: UPath, extra_paths: list[UPath] = []
|
|
@@ -47,7 +47,7 @@ rslearn/dataset/materialize.py,sha256=VoL5Qf5pGcQV4QMlO5vrcu7w0Sl1NdIRLUVk0kSCMO
|
|
|
47
47
|
rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
|
|
48
48
|
rslearn/dataset/window.py,sha256=X4q8YzcSOTtwKxCPf71QLMoyKUtYMSnZu0kPnmVSUx4,10644
|
|
49
49
|
rslearn/dataset/storage/__init__.py,sha256=R50AVV5LH2g7ol0-jyvGcB390VsclXGbJXz4fmkn9as,52
|
|
50
|
-
rslearn/dataset/storage/file.py,sha256=
|
|
50
|
+
rslearn/dataset/storage/file.py,sha256=GJJgH_eHknLbMQwoC3mOoXJKl6Ha3oNXzz62FIEMWlg,7130
|
|
51
51
|
rslearn/dataset/storage/storage.py,sha256=DxZ7iwV938PiLwdQzb5EXSb4Mj8bRGmOTmA9fzq_Ge8,4840
|
|
52
52
|
rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
|
|
53
53
|
rslearn/models/anysat.py,sha256=nzk6hB83ltNFNXYRNA1rTvq2AQcAhwyvgBaZui1M37o,8107
|
|
@@ -112,14 +112,15 @@ rslearn/models/presto/__init__.py,sha256=eZrB-XKi_vYqZhpyAOwppJi4dRuMtYVAdbq7KRy
|
|
|
112
112
|
rslearn/models/presto/presto.py,sha256=fkyHB85Hfx5L-4yejSFAFv83gk9VFqAR1GTgggtq0EA,11049
|
|
113
113
|
rslearn/models/presto/single_file_presto.py,sha256=-P00xjhj9dx3O6HqWpQmG9dPk_i6bT_t8vhX4uQm5tA,30242
|
|
114
114
|
rslearn/tile_stores/__init__.py,sha256=-cW1J7So60SEP5ZLHCPdaFBV5CxvV3QlOhaFnUkhTJ0,1675
|
|
115
|
-
rslearn/tile_stores/default.py,sha256=
|
|
115
|
+
rslearn/tile_stores/default.py,sha256=AG2j0FCNi_4cnXqLjRIef5wMqMJ5_YtSkTIhk7qJQVQ,15134
|
|
116
116
|
rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
|
|
117
117
|
rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
|
|
118
118
|
rslearn/train/all_crops_dataset.py,sha256=CWnqbSjRXJZQsudljvpA07oldiP4fZTmjwrT0sjVnq4,21399
|
|
119
|
-
rslearn/train/data_module.py,sha256=
|
|
120
|
-
rslearn/train/dataset.py,sha256=
|
|
119
|
+
rslearn/train/data_module.py,sha256=G1TRhXg8SPewYy0BTZN5KpeLPK72qIaH15ePfUwrxgM,23865
|
|
120
|
+
rslearn/train/dataset.py,sha256=vCmm6yrW2bAc5A94aBwQe-SOIGdVcZYMM2oBYRq2_sw,45253
|
|
121
121
|
rslearn/train/dataset_index.py,sha256=S5iXhQga5gnnkDqThXXlyjIwkJBPVWiUfDPx3iVs-pw,5306
|
|
122
|
-
rslearn/train/lightning_module.py,sha256=
|
|
122
|
+
rslearn/train/lightning_module.py,sha256=n4hasJBVlAmMhvf2yaFo0gy1vGz5haQkJpZdCSKlJ8A,17482
|
|
123
|
+
rslearn/train/metrics.py,sha256=RknMf2n09D5XBCf0YM4Zmm0XI-pFbRtsmbY51ipVMPk,4799
|
|
123
124
|
rslearn/train/model_context.py,sha256=8DMWGj5xCRmRDo_38lkhkUMHfK_yg3XZrUJQIz5a1vA,3200
|
|
124
125
|
rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
|
|
125
126
|
rslearn/train/prediction_writer.py,sha256=cRFehEtr0iBuVqzE69a0B4Lvb8ywxLeyon34KWI86H0,16961
|
|
@@ -130,13 +131,13 @@ rslearn/train/callbacks/freeze_unfreeze.py,sha256=8fIzBMhCKKjpTffIeAdhdSjsBd8NjT
|
|
|
130
131
|
rslearn/train/callbacks/gradients.py,sha256=4YqCf0tBb6E5FnyFYbveXfQFlgNPyxIXb2FCWX4-6qs,5075
|
|
131
132
|
rslearn/train/callbacks/peft.py,sha256=wEOKsS3RhsRaZTXn_Kz2wdsZdIiIaZPdCJWtdJBurT8,4156
|
|
132
133
|
rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFNMs,75
|
|
133
|
-
rslearn/train/tasks/classification.py,sha256=
|
|
134
|
+
rslearn/train/tasks/classification.py,sha256=_3cRa8ojd9sG2ELRW_BvByZh2YFdCBaklR8Kv9LAgOY,14864
|
|
134
135
|
rslearn/train/tasks/detection.py,sha256=uDMGtsCMSk9OGXn-vpFKBAyHyVN0ji2NCfqBgg1BQyw,21725
|
|
135
136
|
rslearn/train/tasks/embedding.py,sha256=NdJEAaDWlWYzvOBVf7eIHfFOzqTgavfFH1J1gMbAMVo,3891
|
|
136
137
|
rslearn/train/tasks/multi_task.py,sha256=32hvwyVsHqt7N_M3zXsTErK1K7-0-BPHzt7iGNehyaI,6314
|
|
137
|
-
rslearn/train/tasks/per_pixel_regression.py,sha256=
|
|
138
|
-
rslearn/train/tasks/regression.py,sha256=
|
|
139
|
-
rslearn/train/tasks/segmentation.py,sha256=
|
|
138
|
+
rslearn/train/tasks/per_pixel_regression.py,sha256=3m_BTP2akadYe3IuAlCG2bd_alfNyom55-pFrI2q4PE,10928
|
|
139
|
+
rslearn/train/tasks/regression.py,sha256=TzHL42gm3aIdev0R7_uz_TSYbAwSvQPjCD42y1p9_7Y,13269
|
|
140
|
+
rslearn/train/tasks/segmentation.py,sha256=b9XS09EQvum89eoW3vWqFMKuCRtznODteKIr1hFnIz4,30531
|
|
140
141
|
rslearn/train/tasks/task.py,sha256=nMPunl9OlnOimr48saeTnwKMQ7Du4syGrwNKVQq4FL4,4110
|
|
141
142
|
rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
|
|
142
143
|
rslearn/train/transforms/concatenate.py,sha256=S8f1svzwb5UmeAgzXe4Af_hFvt5o0tQctIE6t3QYuPI,2625
|
|
@@ -153,7 +154,7 @@ rslearn/utils/__init__.py,sha256=GZc1erpEfXTc32yjEDbt5rnMrnXEBY7WVm3v4NlwwWY,620
|
|
|
153
154
|
rslearn/utils/array.py,sha256=RC7ygtPnQwU6Lb9kwORvNxatJcaJ76JPsykQvndAfes,2444
|
|
154
155
|
rslearn/utils/colors.py,sha256=ELY9_buH06TOVPLrDAyf2S0G--ZiOxnnP8Ujim6_3ig,369
|
|
155
156
|
rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
|
|
156
|
-
rslearn/utils/fsspec.py,sha256=
|
|
157
|
+
rslearn/utils/fsspec.py,sha256=TcEUgXKvsmtKHv5JVOI2Vp4WNfVNeTok0x4JgZaD1iw,7052
|
|
157
158
|
rslearn/utils/geometry.py,sha256=VzLoxtwdV3uC3szowT-bGuCFF6ge8eK0m01lq8q-01Q,22423
|
|
158
159
|
rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF9Ps,3599
|
|
159
160
|
rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
|
|
@@ -175,10 +176,10 @@ rslearn/vis/render_sensor_image.py,sha256=D0ynK6ABPV046970lIKwF98klpSCtrsUvZTwtZ
|
|
|
175
176
|
rslearn/vis/render_vector_label.py,sha256=ncwgRKCYCJCK1-wTpjgksOiDDebku37LpAyq6wsg4jg,14939
|
|
176
177
|
rslearn/vis/utils.py,sha256=Zop3dEmyaXUYhPiGdYzrTO8BRXWscP2dEZy2myQUnNk,2765
|
|
177
178
|
rslearn/vis/vis_server.py,sha256=kIGnhTy-yfu5lBOVCoo8VVG259i974JPszudCePbzfI,20157
|
|
178
|
-
rslearn-0.0.
|
|
179
|
-
rslearn-0.0.
|
|
180
|
-
rslearn-0.0.
|
|
181
|
-
rslearn-0.0.
|
|
182
|
-
rslearn-0.0.
|
|
183
|
-
rslearn-0.0.
|
|
184
|
-
rslearn-0.0.
|
|
179
|
+
rslearn-0.0.28.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
|
|
180
|
+
rslearn-0.0.28.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
|
|
181
|
+
rslearn-0.0.28.dist-info/METADATA,sha256=OSGXg3yVyndUAZYL9EwvCl095O_Vfknpxlhx8O5dLmQ,38714
|
|
182
|
+
rslearn-0.0.28.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
183
|
+
rslearn-0.0.28.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
|
|
184
|
+
rslearn-0.0.28.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
|
|
185
|
+
rslearn-0.0.28.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|