rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py
CHANGED
|
@@ -8,7 +8,9 @@ import random
|
|
|
8
8
|
import tempfile
|
|
9
9
|
import time
|
|
10
10
|
import uuid
|
|
11
|
+
import warnings
|
|
11
12
|
from datetime import datetime
|
|
13
|
+
from enum import StrEnum
|
|
12
14
|
from typing import Any
|
|
13
15
|
|
|
14
16
|
import torch
|
|
@@ -29,6 +31,7 @@ from rslearn.dataset.window import (
|
|
|
29
31
|
get_layer_and_group_from_dir_name,
|
|
30
32
|
)
|
|
31
33
|
from rslearn.log_utils import get_logger
|
|
34
|
+
from rslearn.train.dataset_index import DatasetIndex
|
|
32
35
|
from rslearn.train.model_context import RasterImage
|
|
33
36
|
from rslearn.utils.feature import Feature
|
|
34
37
|
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
@@ -41,6 +44,19 @@ from .transforms import Sequential
|
|
|
41
44
|
logger = get_logger(__name__)
|
|
42
45
|
|
|
43
46
|
|
|
47
|
+
class IndexMode(StrEnum):
|
|
48
|
+
"""Controls dataset index caching behavior."""
|
|
49
|
+
|
|
50
|
+
OFF = "off"
|
|
51
|
+
"""No caching - always load windows from dataset."""
|
|
52
|
+
|
|
53
|
+
USE = "use"
|
|
54
|
+
"""Use cached index if available, create if not."""
|
|
55
|
+
|
|
56
|
+
REFRESH = "refresh"
|
|
57
|
+
"""Ignore existing cache and rebuild."""
|
|
58
|
+
|
|
59
|
+
|
|
44
60
|
def get_torch_dtype(dtype: DType) -> torch.dtype:
|
|
45
61
|
"""Convert rslearn DType to torch dtype."""
|
|
46
62
|
if dtype == DType.INT32:
|
|
@@ -441,11 +457,15 @@ class SplitConfig:
|
|
|
441
457
|
num_patches: int | None = None,
|
|
442
458
|
transforms: list[torch.nn.Module] | None = None,
|
|
443
459
|
sampler: SamplerFactory | None = None,
|
|
460
|
+
crop_size: int | tuple[int, int] | None = None,
|
|
461
|
+
overlap_pixels: int | None = None,
|
|
462
|
+
load_all_crops: bool | None = None,
|
|
463
|
+
skip_targets: bool | None = None,
|
|
464
|
+
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
465
|
+
# Deprecated parameters (for backwards compatibility)
|
|
444
466
|
patch_size: int | tuple[int, int] | None = None,
|
|
445
467
|
overlap_ratio: float | None = None,
|
|
446
468
|
load_all_patches: bool | None = None,
|
|
447
|
-
skip_targets: bool | None = None,
|
|
448
|
-
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
449
469
|
) -> None:
|
|
450
470
|
"""Initialize a new SplitConfig.
|
|
451
471
|
|
|
@@ -460,19 +480,69 @@ class SplitConfig:
|
|
|
460
480
|
num_patches: limit this split to this many patches
|
|
461
481
|
transforms: transforms to apply
|
|
462
482
|
sampler: SamplerFactory for this split
|
|
463
|
-
|
|
483
|
+
crop_size: an optional square size or (width, height) tuple. If set, read
|
|
464
484
|
crops of this size rather than entire windows.
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
for each window, read all
|
|
485
|
+
overlap_pixels: the number of pixels shared between adjacent crops during
|
|
486
|
+
sliding window inference.
|
|
487
|
+
load_all_crops: with crop_size set, rather than sampling a random crop
|
|
488
|
+
for each window, read all crops as separate sequential items in the
|
|
469
489
|
dataset.
|
|
470
490
|
skip_targets: whether to skip targets when loading inputs
|
|
471
491
|
output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
|
|
472
492
|
If set, windows that already
|
|
473
493
|
have this layer completed will be skipped (useful for resuming
|
|
474
494
|
partial inference runs).
|
|
495
|
+
patch_size: deprecated, use crop_size instead
|
|
496
|
+
overlap_ratio: deprecated, use overlap_pixels instead
|
|
497
|
+
load_all_patches: deprecated, use load_all_crops instead
|
|
475
498
|
"""
|
|
499
|
+
# Handle deprecated load_all_patches parameter
|
|
500
|
+
if load_all_patches is not None:
|
|
501
|
+
warnings.warn(
|
|
502
|
+
"load_all_patches is deprecated, use load_all_crops instead",
|
|
503
|
+
FutureWarning,
|
|
504
|
+
stacklevel=2,
|
|
505
|
+
)
|
|
506
|
+
if load_all_crops is not None:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
"Cannot specify both load_all_patches and load_all_crops"
|
|
509
|
+
)
|
|
510
|
+
load_all_crops = load_all_patches
|
|
511
|
+
# Handle deprecated patch_size parameter
|
|
512
|
+
if patch_size is not None:
|
|
513
|
+
warnings.warn(
|
|
514
|
+
"patch_size is deprecated, use crop_size instead",
|
|
515
|
+
FutureWarning,
|
|
516
|
+
stacklevel=2,
|
|
517
|
+
)
|
|
518
|
+
if crop_size is not None:
|
|
519
|
+
raise ValueError("Cannot specify both patch_size and crop_size")
|
|
520
|
+
crop_size = patch_size
|
|
521
|
+
|
|
522
|
+
# Normalize crop_size to tuple[int, int] | None
|
|
523
|
+
self.crop_size: tuple[int, int] | None = None
|
|
524
|
+
if crop_size is not None:
|
|
525
|
+
if isinstance(crop_size, int):
|
|
526
|
+
self.crop_size = (crop_size, crop_size)
|
|
527
|
+
else:
|
|
528
|
+
self.crop_size = crop_size
|
|
529
|
+
|
|
530
|
+
# Handle deprecated overlap_ratio parameter
|
|
531
|
+
if overlap_ratio is not None:
|
|
532
|
+
warnings.warn(
|
|
533
|
+
"overlap_ratio is deprecated, use overlap_pixels instead",
|
|
534
|
+
FutureWarning,
|
|
535
|
+
stacklevel=2,
|
|
536
|
+
)
|
|
537
|
+
if overlap_pixels is not None:
|
|
538
|
+
raise ValueError("Cannot specify both overlap_ratio and overlap_pixels")
|
|
539
|
+
if self.crop_size is None:
|
|
540
|
+
raise ValueError("overlap_ratio requires crop_size to be set")
|
|
541
|
+
overlap_pixels = round(self.crop_size[0] * overlap_ratio)
|
|
542
|
+
|
|
543
|
+
if overlap_pixels is not None and overlap_pixels < 0:
|
|
544
|
+
raise ValueError("overlap_pixels must be non-negative")
|
|
545
|
+
|
|
476
546
|
self.groups = groups
|
|
477
547
|
self.names = names
|
|
478
548
|
self.tags = tags
|
|
@@ -480,19 +550,15 @@ class SplitConfig:
|
|
|
480
550
|
self.num_patches = num_patches
|
|
481
551
|
self.transforms = transforms
|
|
482
552
|
self.sampler = sampler
|
|
483
|
-
self.patch_size = patch_size
|
|
484
553
|
self.skip_targets = skip_targets
|
|
485
554
|
self.output_layer_name_skip_inference_if_exists = (
|
|
486
555
|
output_layer_name_skip_inference_if_exists
|
|
487
556
|
)
|
|
488
557
|
|
|
489
|
-
# Note that
|
|
490
|
-
#
|
|
491
|
-
self.
|
|
492
|
-
self.
|
|
493
|
-
|
|
494
|
-
if self.overlap_ratio is not None and not (0 < self.overlap_ratio < 1):
|
|
495
|
-
raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
|
|
558
|
+
# Note that load_all_crops is handled by the RslearnDataModule rather than the
|
|
559
|
+
# ModelDataset.
|
|
560
|
+
self.load_all_crops = load_all_crops
|
|
561
|
+
self.overlap_pixels = overlap_pixels
|
|
496
562
|
|
|
497
563
|
def update(self, other: "SplitConfig") -> "SplitConfig":
|
|
498
564
|
"""Override settings in this SplitConfig with those in another.
|
|
@@ -508,9 +574,9 @@ class SplitConfig:
|
|
|
508
574
|
num_patches=self.num_patches,
|
|
509
575
|
transforms=self.transforms,
|
|
510
576
|
sampler=self.sampler,
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
577
|
+
crop_size=self.crop_size,
|
|
578
|
+
overlap_pixels=self.overlap_pixels,
|
|
579
|
+
load_all_crops=self.load_all_crops,
|
|
514
580
|
skip_targets=self.skip_targets,
|
|
515
581
|
output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
|
|
516
582
|
)
|
|
@@ -528,12 +594,12 @@ class SplitConfig:
|
|
|
528
594
|
result.transforms = other.transforms
|
|
529
595
|
if other.sampler:
|
|
530
596
|
result.sampler = other.sampler
|
|
531
|
-
if other.
|
|
532
|
-
result.
|
|
533
|
-
if other.
|
|
534
|
-
result.
|
|
535
|
-
if other.
|
|
536
|
-
result.
|
|
597
|
+
if other.crop_size:
|
|
598
|
+
result.crop_size = other.crop_size
|
|
599
|
+
if other.overlap_pixels is not None:
|
|
600
|
+
result.overlap_pixels = other.overlap_pixels
|
|
601
|
+
if other.load_all_crops is not None:
|
|
602
|
+
result.load_all_crops = other.load_all_crops
|
|
537
603
|
if other.skip_targets is not None:
|
|
538
604
|
result.skip_targets = other.skip_targets
|
|
539
605
|
if other.output_layer_name_skip_inference_if_exists is not None:
|
|
@@ -542,21 +608,17 @@ class SplitConfig:
|
|
|
542
608
|
)
|
|
543
609
|
return result
|
|
544
610
|
|
|
545
|
-
def
|
|
546
|
-
"""Get
|
|
547
|
-
|
|
548
|
-
return None
|
|
549
|
-
if isinstance(self.patch_size, int):
|
|
550
|
-
return (self.patch_size, self.patch_size)
|
|
551
|
-
return self.patch_size
|
|
611
|
+
def get_crop_size(self) -> tuple[int, int] | None:
|
|
612
|
+
"""Get crop size as tuple."""
|
|
613
|
+
return self.crop_size
|
|
552
614
|
|
|
553
|
-
def
|
|
554
|
-
"""Get the overlap
|
|
555
|
-
return self.
|
|
615
|
+
def get_overlap_pixels(self) -> int:
|
|
616
|
+
"""Get the overlap pixels (default 0)."""
|
|
617
|
+
return self.overlap_pixels if self.overlap_pixels is not None else 0
|
|
556
618
|
|
|
557
|
-
def
|
|
619
|
+
def get_load_all_crops(self) -> bool:
|
|
558
620
|
"""Returns whether loading all patches is enabled (default False)."""
|
|
559
|
-
return True if self.
|
|
621
|
+
return True if self.load_all_crops is True else False
|
|
560
622
|
|
|
561
623
|
def get_skip_targets(self) -> bool:
|
|
562
624
|
"""Returns whether skip_targets is enabled (default False)."""
|
|
@@ -636,6 +698,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
636
698
|
workers: int,
|
|
637
699
|
name: str | None = None,
|
|
638
700
|
fix_patch_pick: bool = False,
|
|
701
|
+
index_mode: IndexMode = IndexMode.OFF,
|
|
639
702
|
) -> None:
|
|
640
703
|
"""Instantiate a new ModelDataset.
|
|
641
704
|
|
|
@@ -645,9 +708,10 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
645
708
|
inputs: data to read from the dataset for training
|
|
646
709
|
task: the task to train on
|
|
647
710
|
workers: number of workers to use for initializing the dataset
|
|
648
|
-
name: name of the dataset
|
|
711
|
+
name: name of the dataset
|
|
649
712
|
fix_patch_pick: if True, fix the patch pick to be the same every time
|
|
650
713
|
for a given window. Useful for testing (default: False)
|
|
714
|
+
index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
|
|
651
715
|
"""
|
|
652
716
|
self.dataset = dataset
|
|
653
717
|
self.split_config = split_config
|
|
@@ -660,15 +724,13 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
660
724
|
else:
|
|
661
725
|
self.transforms = rslearn.train.transforms.transform.Identity()
|
|
662
726
|
|
|
663
|
-
# Get normalized
|
|
664
|
-
# But if load all patches is enabled, this is handled by
|
|
727
|
+
# Get normalized crop size from the SplitConfig.
|
|
728
|
+
# But if load all patches is enabled, this is handled by AllCropsDataset, so
|
|
665
729
|
# here we instead load the entire windows.
|
|
666
|
-
if split_config.
|
|
667
|
-
self.
|
|
730
|
+
if split_config.get_load_all_crops():
|
|
731
|
+
self.crop_size = None
|
|
668
732
|
else:
|
|
669
|
-
self.
|
|
670
|
-
|
|
671
|
-
windows = self._get_initial_windows(split_config, workers)
|
|
733
|
+
self.crop_size = split_config.get_crop_size()
|
|
672
734
|
|
|
673
735
|
# If targets are not needed, remove them from the inputs.
|
|
674
736
|
if split_config.get_skip_targets():
|
|
@@ -676,58 +738,8 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
676
738
|
if self.inputs[k].is_target:
|
|
677
739
|
del self.inputs[k]
|
|
678
740
|
|
|
679
|
-
#
|
|
680
|
-
|
|
681
|
-
new_windows = []
|
|
682
|
-
if workers == 0:
|
|
683
|
-
for window in windows:
|
|
684
|
-
if (
|
|
685
|
-
check_window(
|
|
686
|
-
self.inputs,
|
|
687
|
-
window,
|
|
688
|
-
output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
|
|
689
|
-
)
|
|
690
|
-
is None
|
|
691
|
-
):
|
|
692
|
-
continue
|
|
693
|
-
new_windows.append(window)
|
|
694
|
-
else:
|
|
695
|
-
p = multiprocessing.Pool(workers)
|
|
696
|
-
outputs = star_imap_unordered(
|
|
697
|
-
p,
|
|
698
|
-
check_window,
|
|
699
|
-
[
|
|
700
|
-
dict(
|
|
701
|
-
inputs=self.inputs,
|
|
702
|
-
window=window,
|
|
703
|
-
output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
|
|
704
|
-
)
|
|
705
|
-
for window in windows
|
|
706
|
-
],
|
|
707
|
-
)
|
|
708
|
-
for window in tqdm.tqdm(
|
|
709
|
-
outputs, total=len(windows), desc="Checking available layers in windows"
|
|
710
|
-
):
|
|
711
|
-
if window is None:
|
|
712
|
-
continue
|
|
713
|
-
new_windows.append(window)
|
|
714
|
-
p.close()
|
|
715
|
-
windows = new_windows
|
|
716
|
-
|
|
717
|
-
# Sort the windows to ensure that the dataset is consistent across GPUs.
|
|
718
|
-
# Inconsistent ordering can lead to a subset of windows being processed during
|
|
719
|
-
# "model test" / "model predict" when using multiple GPUs.
|
|
720
|
-
# We use a hash so that functionality like num_samples limit gets a random
|
|
721
|
-
# subset of windows (with respect to the hash function choice).
|
|
722
|
-
windows.sort(
|
|
723
|
-
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
724
|
-
)
|
|
725
|
-
|
|
726
|
-
# Limit windows to num_samples if requested.
|
|
727
|
-
if split_config.num_samples:
|
|
728
|
-
# The windows are sorted by hash of window name so this distribution should
|
|
729
|
-
# be representative of the population.
|
|
730
|
-
windows = windows[0 : split_config.num_samples]
|
|
741
|
+
# Load windows (from index if available, otherwise from dataset)
|
|
742
|
+
windows = self._load_windows(split_config, workers, index_mode)
|
|
731
743
|
|
|
732
744
|
# Write dataset_examples to a file so that we can load it lazily in the worker
|
|
733
745
|
# processes. Otherwise it takes a long time to transmit it when spawning each
|
|
@@ -796,6 +808,137 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
796
808
|
|
|
797
809
|
return windows
|
|
798
810
|
|
|
811
|
+
def _load_windows(
|
|
812
|
+
self,
|
|
813
|
+
split_config: SplitConfig,
|
|
814
|
+
workers: int,
|
|
815
|
+
index_mode: IndexMode,
|
|
816
|
+
) -> list[Window]:
|
|
817
|
+
"""Load windows, using index if available.
|
|
818
|
+
|
|
819
|
+
This method handles:
|
|
820
|
+
1. Loading from index if index_mode is USE and index exists
|
|
821
|
+
2. Otherwise, loading from dataset, filtering, sorting, limiting
|
|
822
|
+
3. Saving to index if index_mode is USE or REFRESH
|
|
823
|
+
|
|
824
|
+
Args:
|
|
825
|
+
split_config: the split configuration.
|
|
826
|
+
workers: number of worker processes.
|
|
827
|
+
index_mode: controls caching behavior.
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
list of processed windows ready for training.
|
|
831
|
+
"""
|
|
832
|
+
# Try to load from index
|
|
833
|
+
index: DatasetIndex | None = None
|
|
834
|
+
|
|
835
|
+
if index_mode != IndexMode.OFF:
|
|
836
|
+
logger.info(f"Checking index for dataset {self.dataset.path}")
|
|
837
|
+
index = DatasetIndex(
|
|
838
|
+
storage=self.dataset.storage,
|
|
839
|
+
dataset_path=self.dataset.path,
|
|
840
|
+
groups=split_config.groups,
|
|
841
|
+
names=split_config.names,
|
|
842
|
+
tags=split_config.tags,
|
|
843
|
+
num_samples=split_config.num_samples,
|
|
844
|
+
skip_targets=split_config.get_skip_targets(),
|
|
845
|
+
inputs=self.inputs,
|
|
846
|
+
)
|
|
847
|
+
refresh = index_mode == IndexMode.REFRESH
|
|
848
|
+
indexed_windows = index.load_windows(refresh)
|
|
849
|
+
|
|
850
|
+
if indexed_windows is not None:
|
|
851
|
+
logger.info(f"Loaded {len(indexed_windows)} windows from index")
|
|
852
|
+
return indexed_windows
|
|
853
|
+
|
|
854
|
+
# No index available, load and process windows from dataset
|
|
855
|
+
logger.debug("Loading windows from dataset...")
|
|
856
|
+
windows = self._get_initial_windows(split_config, workers)
|
|
857
|
+
windows = self._filter_windows_by_layers(windows, workers)
|
|
858
|
+
windows = self._sort_and_limit_windows(windows, split_config)
|
|
859
|
+
|
|
860
|
+
# Save to index if enabled
|
|
861
|
+
if index is not None:
|
|
862
|
+
index.save_windows(windows)
|
|
863
|
+
|
|
864
|
+
return windows
|
|
865
|
+
|
|
866
|
+
def _filter_windows_by_layers(
|
|
867
|
+
self, windows: list[Window], workers: int
|
|
868
|
+
) -> list[Window]:
|
|
869
|
+
"""Filter windows to only include those with required layers.
|
|
870
|
+
|
|
871
|
+
Args:
|
|
872
|
+
windows: list of windows to filter.
|
|
873
|
+
workers: number of worker processes for parallel filtering.
|
|
874
|
+
|
|
875
|
+
Returns:
|
|
876
|
+
list of windows that have all required input layers.
|
|
877
|
+
"""
|
|
878
|
+
output_layer_skip = (
|
|
879
|
+
self.split_config.get_output_layer_name_skip_inference_if_exists()
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
if workers == 0:
|
|
883
|
+
return [
|
|
884
|
+
w
|
|
885
|
+
for w in windows
|
|
886
|
+
if check_window(
|
|
887
|
+
self.inputs,
|
|
888
|
+
w,
|
|
889
|
+
output_layer_name_skip_inference_if_exists=output_layer_skip,
|
|
890
|
+
)
|
|
891
|
+
is not None
|
|
892
|
+
]
|
|
893
|
+
|
|
894
|
+
p = multiprocessing.Pool(workers)
|
|
895
|
+
outputs = star_imap_unordered(
|
|
896
|
+
p,
|
|
897
|
+
check_window,
|
|
898
|
+
[
|
|
899
|
+
dict(
|
|
900
|
+
inputs=self.inputs,
|
|
901
|
+
window=window,
|
|
902
|
+
output_layer_name_skip_inference_if_exists=output_layer_skip,
|
|
903
|
+
)
|
|
904
|
+
for window in windows
|
|
905
|
+
],
|
|
906
|
+
)
|
|
907
|
+
filtered = []
|
|
908
|
+
for window in tqdm.tqdm(
|
|
909
|
+
outputs,
|
|
910
|
+
total=len(windows),
|
|
911
|
+
desc="Checking available layers in windows",
|
|
912
|
+
):
|
|
913
|
+
if window is not None:
|
|
914
|
+
filtered.append(window)
|
|
915
|
+
p.close()
|
|
916
|
+
return filtered
|
|
917
|
+
|
|
918
|
+
def _sort_and_limit_windows(
|
|
919
|
+
self, windows: list[Window], split_config: SplitConfig
|
|
920
|
+
) -> list[Window]:
|
|
921
|
+
"""Sort windows by hash and apply num_samples limit.
|
|
922
|
+
|
|
923
|
+
Sorting ensures consistent ordering across GPUs. Using hash gives a
|
|
924
|
+
pseudo-random but deterministic order for sampling.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
windows: list of windows to sort and limit.
|
|
928
|
+
split_config: the split configuration with num_samples.
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
sorted and optionally limited list of windows.
|
|
932
|
+
"""
|
|
933
|
+
windows.sort(
|
|
934
|
+
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
if split_config.num_samples:
|
|
938
|
+
windows = windows[: split_config.num_samples]
|
|
939
|
+
|
|
940
|
+
return windows
|
|
941
|
+
|
|
799
942
|
def _serialize_item(self, example: Window) -> dict[str, Any]:
|
|
800
943
|
return example.get_metadata()
|
|
801
944
|
|
|
@@ -808,8 +951,8 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
808
951
|
def get_dataset_examples(self) -> list[Window]:
|
|
809
952
|
"""Get a list of examples in the dataset.
|
|
810
953
|
|
|
811
|
-
If
|
|
812
|
-
list of (window,
|
|
954
|
+
If load_all_crops is False, this is a list of Windows. Otherwise, this is a
|
|
955
|
+
list of (window, crop_bounds, (crop_idx, # patches)) tuples.
|
|
813
956
|
"""
|
|
814
957
|
if self.dataset_examples is None:
|
|
815
958
|
logger.debug(
|
|
@@ -845,34 +988,34 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
845
988
|
rng = random.Random(idx if self.fix_patch_pick else None)
|
|
846
989
|
|
|
847
990
|
# Select bounds to read.
|
|
848
|
-
if self.
|
|
991
|
+
if self.crop_size:
|
|
849
992
|
window = example
|
|
850
993
|
|
|
851
|
-
def
|
|
852
|
-
if
|
|
994
|
+
def get_crop_range(n_crop: int, n_window: int) -> list[int]:
|
|
995
|
+
if n_crop > n_window:
|
|
853
996
|
# Select arbitrary range containing the entire window.
|
|
854
|
-
# Basically arbitrarily padding the window to get to
|
|
855
|
-
start = rng.randint(n_window -
|
|
856
|
-
return [start, start +
|
|
997
|
+
# Basically arbitrarily padding the window to get to crop size.
|
|
998
|
+
start = rng.randint(n_window - n_crop, 0)
|
|
999
|
+
return [start, start + n_crop]
|
|
857
1000
|
|
|
858
1001
|
else:
|
|
859
|
-
# Select arbitrary
|
|
860
|
-
start = rng.randint(0, n_window -
|
|
861
|
-
return [start, start +
|
|
1002
|
+
# Select arbitrary crop within the window.
|
|
1003
|
+
start = rng.randint(0, n_window - n_crop)
|
|
1004
|
+
return [start, start + n_crop]
|
|
862
1005
|
|
|
863
1006
|
window_size = (
|
|
864
1007
|
window.bounds[2] - window.bounds[0],
|
|
865
1008
|
window.bounds[3] - window.bounds[1],
|
|
866
1009
|
)
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
1010
|
+
crop_ranges = [
|
|
1011
|
+
get_crop_range(self.crop_size[0], window_size[0]),
|
|
1012
|
+
get_crop_range(self.crop_size[1], window_size[1]),
|
|
870
1013
|
]
|
|
871
1014
|
bounds = (
|
|
872
|
-
window.bounds[0] +
|
|
873
|
-
window.bounds[1] +
|
|
874
|
-
window.bounds[0] +
|
|
875
|
-
window.bounds[1] +
|
|
1015
|
+
window.bounds[0] + crop_ranges[0][0],
|
|
1016
|
+
window.bounds[1] + crop_ranges[1][0],
|
|
1017
|
+
window.bounds[0] + crop_ranges[0][1],
|
|
1018
|
+
window.bounds[1] + crop_ranges[1][1],
|
|
876
1019
|
)
|
|
877
1020
|
|
|
878
1021
|
else:
|
|
@@ -894,9 +1037,9 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
894
1037
|
window_group=window.group,
|
|
895
1038
|
window_name=window.name,
|
|
896
1039
|
window_bounds=window.bounds,
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
1040
|
+
crop_bounds=bounds,
|
|
1041
|
+
crop_idx=0,
|
|
1042
|
+
num_crops_in_window=1,
|
|
900
1043
|
time_range=window.time_range,
|
|
901
1044
|
projection=window.projection,
|
|
902
1045
|
dataset_source=self.name,
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Dataset index for caching window lists to speed up ModelDataset initialization."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.dataset.window import Window
|
|
11
|
+
from rslearn.log_utils import get_logger
|
|
12
|
+
from rslearn.utils.fsspec import open_atomic
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
# Increment this when the index format changes to force rebuild
|
|
20
|
+
INDEX_VERSION = 1
|
|
21
|
+
|
|
22
|
+
# Directory name for storing index files
|
|
23
|
+
INDEX_DIR_NAME = ".rslearn_dataset_index"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DatasetIndex:
|
|
27
|
+
"""Manages indexed window lists for faster ModelDataset initialization.
|
|
28
|
+
|
|
29
|
+
Note: The index does NOT automatically detect when windows are added or removed
|
|
30
|
+
from the dataset. Use refresh=True after modifying dataset windows.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
storage: "WindowStorage",
|
|
36
|
+
dataset_path: UPath,
|
|
37
|
+
groups: list[str] | None,
|
|
38
|
+
names: list[str] | None,
|
|
39
|
+
tags: dict[str, Any] | None,
|
|
40
|
+
num_samples: int | None,
|
|
41
|
+
skip_targets: bool,
|
|
42
|
+
inputs: dict[str, Any],
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Initialize DatasetIndex with specific configuration.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
storage: WindowStorage for deserializing windows.
|
|
48
|
+
dataset_path: Path to the dataset directory.
|
|
49
|
+
groups: list of window groups to include.
|
|
50
|
+
names: list of window names to include.
|
|
51
|
+
tags: tags to filter windows by.
|
|
52
|
+
num_samples: limit on number of samples.
|
|
53
|
+
skip_targets: whether targets are skipped.
|
|
54
|
+
inputs: dict mapping input names to DataInput objects.
|
|
55
|
+
"""
|
|
56
|
+
self.storage = storage
|
|
57
|
+
self.dataset_path = dataset_path
|
|
58
|
+
self.index_dir = dataset_path / INDEX_DIR_NAME
|
|
59
|
+
|
|
60
|
+
# Compute index key from configuration
|
|
61
|
+
inputs_data = {}
|
|
62
|
+
for name, inp in inputs.items():
|
|
63
|
+
inputs_data[name] = {
|
|
64
|
+
"layers": inp.layers,
|
|
65
|
+
"required": inp.required,
|
|
66
|
+
"load_all_layers": inp.load_all_layers,
|
|
67
|
+
"is_target": inp.is_target,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
key_data = {
|
|
71
|
+
"groups": groups,
|
|
72
|
+
"names": names,
|
|
73
|
+
"tags": tags,
|
|
74
|
+
"num_samples": num_samples,
|
|
75
|
+
"skip_targets": skip_targets,
|
|
76
|
+
"inputs": inputs_data,
|
|
77
|
+
}
|
|
78
|
+
self.index_key = hashlib.sha256(
|
|
79
|
+
json.dumps(key_data, sort_keys=True).encode()
|
|
80
|
+
).hexdigest()
|
|
81
|
+
|
|
82
|
+
def _get_config_hash(self) -> str:
|
|
83
|
+
"""Get hash of config.json for quick validation.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
A 16-character hex string hash of the config, or empty string if no config.
|
|
87
|
+
"""
|
|
88
|
+
config_path = self.dataset_path / "config.json"
|
|
89
|
+
if config_path.exists():
|
|
90
|
+
with config_path.open() as f:
|
|
91
|
+
return hashlib.sha256(f.read().encode()).hexdigest()[:16]
|
|
92
|
+
return ""
|
|
93
|
+
|
|
94
|
+
def load_windows(self, refresh: bool = False) -> list[Window] | None:
|
|
95
|
+
"""Load indexed window list if valid, else return None.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
refresh: If True, ignore existing index and return None.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
List of Window objects if index is valid, None otherwise.
|
|
102
|
+
"""
|
|
103
|
+
if refresh:
|
|
104
|
+
logger.info("refresh=True, rebuilding index")
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
index_file = self.index_dir / f"{self.index_key}.json"
|
|
108
|
+
if not index_file.exists():
|
|
109
|
+
logger.info(f"No index found at {index_file}, will build")
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
with index_file.open() as f:
|
|
114
|
+
index_data = json.load(f)
|
|
115
|
+
except (OSError, json.JSONDecodeError):
|
|
116
|
+
logger.warning(f"Corrupted index file at {index_file}, will rebuild")
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
# Check index version
|
|
120
|
+
if index_data.get("version") != INDEX_VERSION:
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Index version mismatch (got {index_data.get('version')}, "
|
|
123
|
+
f"expected {INDEX_VERSION}), will rebuild"
|
|
124
|
+
)
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
# Quick validation: check config hash
|
|
128
|
+
if index_data.get("config_hash") != self._get_config_hash():
|
|
129
|
+
logger.info("Config hash mismatch, index invalidated")
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
# Deserialize windows
|
|
133
|
+
return [Window.from_metadata(self.storage, w) for w in index_data["windows"]]
|
|
134
|
+
|
|
135
|
+
def save_windows(self, windows: list[Window]) -> None:
|
|
136
|
+
"""Save processed windows to index with atomic write.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
windows: List of Window objects to index.
|
|
140
|
+
"""
|
|
141
|
+
self.index_dir.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
index_file = self.index_dir / f"{self.index_key}.json"
|
|
143
|
+
|
|
144
|
+
# Serialize windows
|
|
145
|
+
serialized_windows = [w.get_metadata() for w in windows]
|
|
146
|
+
|
|
147
|
+
index_data = {
|
|
148
|
+
"version": INDEX_VERSION,
|
|
149
|
+
"config_hash": self._get_config_hash(),
|
|
150
|
+
"created_at": datetime.now().isoformat(),
|
|
151
|
+
"num_windows": len(windows),
|
|
152
|
+
"windows": serialized_windows,
|
|
153
|
+
}
|
|
154
|
+
with open_atomic(index_file, "w") as f:
|
|
155
|
+
json.dump(index_data, f)
|
|
156
|
+
logger.info(f"Saved {len(windows)} windows to index at {index_file}")
|