rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py CHANGED
@@ -8,7 +8,9 @@ import random
8
8
  import tempfile
9
9
  import time
10
10
  import uuid
11
+ import warnings
11
12
  from datetime import datetime
13
+ from enum import StrEnum
12
14
  from typing import Any
13
15
 
14
16
  import torch
@@ -29,6 +31,7 @@ from rslearn.dataset.window import (
29
31
  get_layer_and_group_from_dir_name,
30
32
  )
31
33
  from rslearn.log_utils import get_logger
34
+ from rslearn.train.dataset_index import DatasetIndex
32
35
  from rslearn.train.model_context import RasterImage
33
36
  from rslearn.utils.feature import Feature
34
37
  from rslearn.utils.geometry import PixelBounds, ResolutionFactor
@@ -41,6 +44,19 @@ from .transforms import Sequential
41
44
  logger = get_logger(__name__)
42
45
 
43
46
 
47
+ class IndexMode(StrEnum):
48
+ """Controls dataset index caching behavior."""
49
+
50
+ OFF = "off"
51
+ """No caching - always load windows from dataset."""
52
+
53
+ USE = "use"
54
+ """Use cached index if available, create if not."""
55
+
56
+ REFRESH = "refresh"
57
+ """Ignore existing cache and rebuild."""
58
+
59
+
44
60
  def get_torch_dtype(dtype: DType) -> torch.dtype:
45
61
  """Convert rslearn DType to torch dtype."""
46
62
  if dtype == DType.INT32:
@@ -441,11 +457,15 @@ class SplitConfig:
441
457
  num_patches: int | None = None,
442
458
  transforms: list[torch.nn.Module] | None = None,
443
459
  sampler: SamplerFactory | None = None,
460
+ crop_size: int | tuple[int, int] | None = None,
461
+ overlap_pixels: int | None = None,
462
+ load_all_crops: bool | None = None,
463
+ skip_targets: bool | None = None,
464
+ output_layer_name_skip_inference_if_exists: str | None = None,
465
+ # Deprecated parameters (for backwards compatibility)
444
466
  patch_size: int | tuple[int, int] | None = None,
445
467
  overlap_ratio: float | None = None,
446
468
  load_all_patches: bool | None = None,
447
- skip_targets: bool | None = None,
448
- output_layer_name_skip_inference_if_exists: str | None = None,
449
469
  ) -> None:
450
470
  """Initialize a new SplitConfig.
451
471
 
@@ -460,19 +480,69 @@ class SplitConfig:
460
480
  num_patches: limit this split to this many patches
461
481
  transforms: transforms to apply
462
482
  sampler: SamplerFactory for this split
463
- patch_size: an optional square size or (width, height) tuple. If set, read
483
+ crop_size: an optional square size or (width, height) tuple. If set, read
464
484
  crops of this size rather than entire windows.
465
- overlap_ratio: an optional float between 0 and 1. If set, read patches with
466
- this ratio of overlap.
467
- load_all_patches: with patch_size set, rather than sampling a random patch
468
- for each window, read all patches as separate sequential items in the
485
+ overlap_pixels: the number of pixels shared between adjacent crops during
486
+ sliding window inference.
487
+ load_all_crops: with crop_size set, rather than sampling a random crop
488
+ for each window, read all crops as separate sequential items in the
469
489
  dataset.
470
490
  skip_targets: whether to skip targets when loading inputs
471
491
  output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
472
492
  If set, windows that already
473
493
  have this layer completed will be skipped (useful for resuming
474
494
  partial inference runs).
495
+ patch_size: deprecated, use crop_size instead
496
+ overlap_ratio: deprecated, use overlap_pixels instead
497
+ load_all_patches: deprecated, use load_all_crops instead
475
498
  """
499
+ # Handle deprecated load_all_patches parameter
500
+ if load_all_patches is not None:
501
+ warnings.warn(
502
+ "load_all_patches is deprecated, use load_all_crops instead",
503
+ FutureWarning,
504
+ stacklevel=2,
505
+ )
506
+ if load_all_crops is not None:
507
+ raise ValueError(
508
+ "Cannot specify both load_all_patches and load_all_crops"
509
+ )
510
+ load_all_crops = load_all_patches
511
+ # Handle deprecated patch_size parameter
512
+ if patch_size is not None:
513
+ warnings.warn(
514
+ "patch_size is deprecated, use crop_size instead",
515
+ FutureWarning,
516
+ stacklevel=2,
517
+ )
518
+ if crop_size is not None:
519
+ raise ValueError("Cannot specify both patch_size and crop_size")
520
+ crop_size = patch_size
521
+
522
+ # Normalize crop_size to tuple[int, int] | None
523
+ self.crop_size: tuple[int, int] | None = None
524
+ if crop_size is not None:
525
+ if isinstance(crop_size, int):
526
+ self.crop_size = (crop_size, crop_size)
527
+ else:
528
+ self.crop_size = crop_size
529
+
530
+ # Handle deprecated overlap_ratio parameter
531
+ if overlap_ratio is not None:
532
+ warnings.warn(
533
+ "overlap_ratio is deprecated, use overlap_pixels instead",
534
+ FutureWarning,
535
+ stacklevel=2,
536
+ )
537
+ if overlap_pixels is not None:
538
+ raise ValueError("Cannot specify both overlap_ratio and overlap_pixels")
539
+ if self.crop_size is None:
540
+ raise ValueError("overlap_ratio requires crop_size to be set")
541
+ overlap_pixels = round(self.crop_size[0] * overlap_ratio)
542
+
543
+ if overlap_pixels is not None and overlap_pixels < 0:
544
+ raise ValueError("overlap_pixels must be non-negative")
545
+
476
546
  self.groups = groups
477
547
  self.names = names
478
548
  self.tags = tags
@@ -480,19 +550,15 @@ class SplitConfig:
480
550
  self.num_patches = num_patches
481
551
  self.transforms = transforms
482
552
  self.sampler = sampler
483
- self.patch_size = patch_size
484
553
  self.skip_targets = skip_targets
485
554
  self.output_layer_name_skip_inference_if_exists = (
486
555
  output_layer_name_skip_inference_if_exists
487
556
  )
488
557
 
489
- # Note that load_all_patches are handled by the RslearnDataModule rather than
490
- # the ModelDataset.
491
- self.load_all_patches = load_all_patches
492
- self.overlap_ratio = overlap_ratio
493
-
494
- if self.overlap_ratio is not None and not (0 < self.overlap_ratio < 1):
495
- raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
558
+ # Note that load_all_crops is handled by the RslearnDataModule rather than the
559
+ # ModelDataset.
560
+ self.load_all_crops = load_all_crops
561
+ self.overlap_pixels = overlap_pixels
496
562
 
497
563
  def update(self, other: "SplitConfig") -> "SplitConfig":
498
564
  """Override settings in this SplitConfig with those in another.
@@ -508,9 +574,9 @@ class SplitConfig:
508
574
  num_patches=self.num_patches,
509
575
  transforms=self.transforms,
510
576
  sampler=self.sampler,
511
- patch_size=self.patch_size,
512
- overlap_ratio=self.overlap_ratio,
513
- load_all_patches=self.load_all_patches,
577
+ crop_size=self.crop_size,
578
+ overlap_pixels=self.overlap_pixels,
579
+ load_all_crops=self.load_all_crops,
514
580
  skip_targets=self.skip_targets,
515
581
  output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
516
582
  )
@@ -528,12 +594,12 @@ class SplitConfig:
528
594
  result.transforms = other.transforms
529
595
  if other.sampler:
530
596
  result.sampler = other.sampler
531
- if other.patch_size:
532
- result.patch_size = other.patch_size
533
- if other.overlap_ratio is not None:
534
- result.overlap_ratio = other.overlap_ratio
535
- if other.load_all_patches is not None:
536
- result.load_all_patches = other.load_all_patches
597
+ if other.crop_size:
598
+ result.crop_size = other.crop_size
599
+ if other.overlap_pixels is not None:
600
+ result.overlap_pixels = other.overlap_pixels
601
+ if other.load_all_crops is not None:
602
+ result.load_all_crops = other.load_all_crops
537
603
  if other.skip_targets is not None:
538
604
  result.skip_targets = other.skip_targets
539
605
  if other.output_layer_name_skip_inference_if_exists is not None:
@@ -542,21 +608,17 @@ class SplitConfig:
542
608
  )
543
609
  return result
544
610
 
545
- def get_patch_size(self) -> tuple[int, int] | None:
546
- """Get patch size normalized to int tuple."""
547
- if self.patch_size is None:
548
- return None
549
- if isinstance(self.patch_size, int):
550
- return (self.patch_size, self.patch_size)
551
- return self.patch_size
611
+ def get_crop_size(self) -> tuple[int, int] | None:
612
+ """Get crop size as tuple."""
613
+ return self.crop_size
552
614
 
553
- def get_overlap_ratio(self) -> float:
554
- """Get the overlap ratio (default 0)."""
555
- return self.overlap_ratio if self.overlap_ratio is not None else 0.0
615
+ def get_overlap_pixels(self) -> int:
616
+ """Get the overlap pixels (default 0)."""
617
+ return self.overlap_pixels if self.overlap_pixels is not None else 0
556
618
 
557
- def get_load_all_patches(self) -> bool:
619
+ def get_load_all_crops(self) -> bool:
558
620
  """Returns whether loading all patches is enabled (default False)."""
559
- return True if self.load_all_patches is True else False
621
+ return True if self.load_all_crops is True else False
560
622
 
561
623
  def get_skip_targets(self) -> bool:
562
624
  """Returns whether skip_targets is enabled (default False)."""
@@ -636,6 +698,7 @@ class ModelDataset(torch.utils.data.Dataset):
636
698
  workers: int,
637
699
  name: str | None = None,
638
700
  fix_patch_pick: bool = False,
701
+ index_mode: IndexMode = IndexMode.OFF,
639
702
  ) -> None:
640
703
  """Instantiate a new ModelDataset.
641
704
 
@@ -645,9 +708,10 @@ class ModelDataset(torch.utils.data.Dataset):
645
708
  inputs: data to read from the dataset for training
646
709
  task: the task to train on
647
710
  workers: number of workers to use for initializing the dataset
648
- name: name of the dataset (default: None)
711
+ name: name of the dataset
649
712
  fix_patch_pick: if True, fix the patch pick to be the same every time
650
713
  for a given window. Useful for testing (default: False)
714
+ index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
651
715
  """
652
716
  self.dataset = dataset
653
717
  self.split_config = split_config
@@ -660,15 +724,13 @@ class ModelDataset(torch.utils.data.Dataset):
660
724
  else:
661
725
  self.transforms = rslearn.train.transforms.transform.Identity()
662
726
 
663
- # Get normalized patch size from the SplitConfig.
664
- # But if load all patches is enabled, this is handled by AllPatchesDataset, so
727
+ # Get normalized crop size from the SplitConfig.
728
+ # But if load all patches is enabled, this is handled by AllCropsDataset, so
665
729
  # here we instead load the entire windows.
666
- if split_config.get_load_all_patches():
667
- self.patch_size = None
730
+ if split_config.get_load_all_crops():
731
+ self.crop_size = None
668
732
  else:
669
- self.patch_size = split_config.get_patch_size()
670
-
671
- windows = self._get_initial_windows(split_config, workers)
733
+ self.crop_size = split_config.get_crop_size()
672
734
 
673
735
  # If targets are not needed, remove them from the inputs.
674
736
  if split_config.get_skip_targets():
@@ -676,58 +738,8 @@ class ModelDataset(torch.utils.data.Dataset):
676
738
  if self.inputs[k].is_target:
677
739
  del self.inputs[k]
678
740
 
679
- # Eliminate windows that are missing either a requisite input layer, or missing
680
- # all target layers.
681
- new_windows = []
682
- if workers == 0:
683
- for window in windows:
684
- if (
685
- check_window(
686
- self.inputs,
687
- window,
688
- output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
689
- )
690
- is None
691
- ):
692
- continue
693
- new_windows.append(window)
694
- else:
695
- p = multiprocessing.Pool(workers)
696
- outputs = star_imap_unordered(
697
- p,
698
- check_window,
699
- [
700
- dict(
701
- inputs=self.inputs,
702
- window=window,
703
- output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
704
- )
705
- for window in windows
706
- ],
707
- )
708
- for window in tqdm.tqdm(
709
- outputs, total=len(windows), desc="Checking available layers in windows"
710
- ):
711
- if window is None:
712
- continue
713
- new_windows.append(window)
714
- p.close()
715
- windows = new_windows
716
-
717
- # Sort the windows to ensure that the dataset is consistent across GPUs.
718
- # Inconsistent ordering can lead to a subset of windows being processed during
719
- # "model test" / "model predict" when using multiple GPUs.
720
- # We use a hash so that functionality like num_samples limit gets a random
721
- # subset of windows (with respect to the hash function choice).
722
- windows.sort(
723
- key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
724
- )
725
-
726
- # Limit windows to num_samples if requested.
727
- if split_config.num_samples:
728
- # The windows are sorted by hash of window name so this distribution should
729
- # be representative of the population.
730
- windows = windows[0 : split_config.num_samples]
741
+ # Load windows (from index if available, otherwise from dataset)
742
+ windows = self._load_windows(split_config, workers, index_mode)
731
743
 
732
744
  # Write dataset_examples to a file so that we can load it lazily in the worker
733
745
  # processes. Otherwise it takes a long time to transmit it when spawning each
@@ -796,6 +808,137 @@ class ModelDataset(torch.utils.data.Dataset):
796
808
 
797
809
  return windows
798
810
 
811
+ def _load_windows(
812
+ self,
813
+ split_config: SplitConfig,
814
+ workers: int,
815
+ index_mode: IndexMode,
816
+ ) -> list[Window]:
817
+ """Load windows, using index if available.
818
+
819
+ This method handles:
820
+ 1. Loading from index if index_mode is USE and index exists
821
+ 2. Otherwise, loading from dataset, filtering, sorting, limiting
822
+ 3. Saving to index if index_mode is USE or REFRESH
823
+
824
+ Args:
825
+ split_config: the split configuration.
826
+ workers: number of worker processes.
827
+ index_mode: controls caching behavior.
828
+
829
+ Returns:
830
+ list of processed windows ready for training.
831
+ """
832
+ # Try to load from index
833
+ index: DatasetIndex | None = None
834
+
835
+ if index_mode != IndexMode.OFF:
836
+ logger.info(f"Checking index for dataset {self.dataset.path}")
837
+ index = DatasetIndex(
838
+ storage=self.dataset.storage,
839
+ dataset_path=self.dataset.path,
840
+ groups=split_config.groups,
841
+ names=split_config.names,
842
+ tags=split_config.tags,
843
+ num_samples=split_config.num_samples,
844
+ skip_targets=split_config.get_skip_targets(),
845
+ inputs=self.inputs,
846
+ )
847
+ refresh = index_mode == IndexMode.REFRESH
848
+ indexed_windows = index.load_windows(refresh)
849
+
850
+ if indexed_windows is not None:
851
+ logger.info(f"Loaded {len(indexed_windows)} windows from index")
852
+ return indexed_windows
853
+
854
+ # No index available, load and process windows from dataset
855
+ logger.debug("Loading windows from dataset...")
856
+ windows = self._get_initial_windows(split_config, workers)
857
+ windows = self._filter_windows_by_layers(windows, workers)
858
+ windows = self._sort_and_limit_windows(windows, split_config)
859
+
860
+ # Save to index if enabled
861
+ if index is not None:
862
+ index.save_windows(windows)
863
+
864
+ return windows
865
+
866
+ def _filter_windows_by_layers(
867
+ self, windows: list[Window], workers: int
868
+ ) -> list[Window]:
869
+ """Filter windows to only include those with required layers.
870
+
871
+ Args:
872
+ windows: list of windows to filter.
873
+ workers: number of worker processes for parallel filtering.
874
+
875
+ Returns:
876
+ list of windows that have all required input layers.
877
+ """
878
+ output_layer_skip = (
879
+ self.split_config.get_output_layer_name_skip_inference_if_exists()
880
+ )
881
+
882
+ if workers == 0:
883
+ return [
884
+ w
885
+ for w in windows
886
+ if check_window(
887
+ self.inputs,
888
+ w,
889
+ output_layer_name_skip_inference_if_exists=output_layer_skip,
890
+ )
891
+ is not None
892
+ ]
893
+
894
+ p = multiprocessing.Pool(workers)
895
+ outputs = star_imap_unordered(
896
+ p,
897
+ check_window,
898
+ [
899
+ dict(
900
+ inputs=self.inputs,
901
+ window=window,
902
+ output_layer_name_skip_inference_if_exists=output_layer_skip,
903
+ )
904
+ for window in windows
905
+ ],
906
+ )
907
+ filtered = []
908
+ for window in tqdm.tqdm(
909
+ outputs,
910
+ total=len(windows),
911
+ desc="Checking available layers in windows",
912
+ ):
913
+ if window is not None:
914
+ filtered.append(window)
915
+ p.close()
916
+ return filtered
917
+
918
+ def _sort_and_limit_windows(
919
+ self, windows: list[Window], split_config: SplitConfig
920
+ ) -> list[Window]:
921
+ """Sort windows by hash and apply num_samples limit.
922
+
923
+ Sorting ensures consistent ordering across GPUs. Using hash gives a
924
+ pseudo-random but deterministic order for sampling.
925
+
926
+ Args:
927
+ windows: list of windows to sort and limit.
928
+ split_config: the split configuration with num_samples.
929
+
930
+ Returns:
931
+ sorted and optionally limited list of windows.
932
+ """
933
+ windows.sort(
934
+ key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
935
+ )
936
+
937
+ if split_config.num_samples:
938
+ windows = windows[: split_config.num_samples]
939
+
940
+ return windows
941
+
799
942
  def _serialize_item(self, example: Window) -> dict[str, Any]:
800
943
  return example.get_metadata()
801
944
 
@@ -808,8 +951,8 @@ class ModelDataset(torch.utils.data.Dataset):
808
951
  def get_dataset_examples(self) -> list[Window]:
809
952
  """Get a list of examples in the dataset.
810
953
 
811
- If load_all_patches is False, this is a list of Windows. Otherwise, this is a
812
- list of (window, patch_bounds, (patch_idx, # patches)) tuples.
954
+ If load_all_crops is False, this is a list of Windows. Otherwise, this is a
955
+ list of (window, crop_bounds, (crop_idx, # patches)) tuples.
813
956
  """
814
957
  if self.dataset_examples is None:
815
958
  logger.debug(
@@ -845,34 +988,34 @@ class ModelDataset(torch.utils.data.Dataset):
845
988
  rng = random.Random(idx if self.fix_patch_pick else None)
846
989
 
847
990
  # Select bounds to read.
848
- if self.patch_size:
991
+ if self.crop_size:
849
992
  window = example
850
993
 
851
- def get_patch_range(n_patch: int, n_window: int) -> list[int]:
852
- if n_patch > n_window:
994
+ def get_crop_range(n_crop: int, n_window: int) -> list[int]:
995
+ if n_crop > n_window:
853
996
  # Select arbitrary range containing the entire window.
854
- # Basically arbitrarily padding the window to get to patch size.
855
- start = rng.randint(n_window - n_patch, 0)
856
- return [start, start + n_patch]
997
+ # Basically arbitrarily padding the window to get to crop size.
998
+ start = rng.randint(n_window - n_crop, 0)
999
+ return [start, start + n_crop]
857
1000
 
858
1001
  else:
859
- # Select arbitrary patch within the window.
860
- start = rng.randint(0, n_window - n_patch)
861
- return [start, start + n_patch]
1002
+ # Select arbitrary crop within the window.
1003
+ start = rng.randint(0, n_window - n_crop)
1004
+ return [start, start + n_crop]
862
1005
 
863
1006
  window_size = (
864
1007
  window.bounds[2] - window.bounds[0],
865
1008
  window.bounds[3] - window.bounds[1],
866
1009
  )
867
- patch_ranges = [
868
- get_patch_range(self.patch_size[0], window_size[0]),
869
- get_patch_range(self.patch_size[1], window_size[1]),
1010
+ crop_ranges = [
1011
+ get_crop_range(self.crop_size[0], window_size[0]),
1012
+ get_crop_range(self.crop_size[1], window_size[1]),
870
1013
  ]
871
1014
  bounds = (
872
- window.bounds[0] + patch_ranges[0][0],
873
- window.bounds[1] + patch_ranges[1][0],
874
- window.bounds[0] + patch_ranges[0][1],
875
- window.bounds[1] + patch_ranges[1][1],
1015
+ window.bounds[0] + crop_ranges[0][0],
1016
+ window.bounds[1] + crop_ranges[1][0],
1017
+ window.bounds[0] + crop_ranges[0][1],
1018
+ window.bounds[1] + crop_ranges[1][1],
876
1019
  )
877
1020
 
878
1021
  else:
@@ -894,9 +1037,9 @@ class ModelDataset(torch.utils.data.Dataset):
894
1037
  window_group=window.group,
895
1038
  window_name=window.name,
896
1039
  window_bounds=window.bounds,
897
- patch_bounds=bounds,
898
- patch_idx=0,
899
- num_patches_in_window=1,
1040
+ crop_bounds=bounds,
1041
+ crop_idx=0,
1042
+ num_crops_in_window=1,
900
1043
  time_range=window.time_range,
901
1044
  projection=window.projection,
902
1045
  dataset_source=self.name,
@@ -0,0 +1,156 @@
1
+ """Dataset index for caching window lists to speed up ModelDataset initialization."""
2
+
3
+ import hashlib
4
+ import json
5
+ from datetime import datetime
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from upath import UPath
9
+
10
+ from rslearn.dataset.window import Window
11
+ from rslearn.log_utils import get_logger
12
+ from rslearn.utils.fsspec import open_atomic
13
+
14
+ if TYPE_CHECKING:
15
+ from rslearn.dataset.storage.storage import WindowStorage
16
+
17
+ logger = get_logger(__name__)
18
+
19
+ # Increment this when the index format changes to force rebuild
20
+ INDEX_VERSION = 1
21
+
22
+ # Directory name for storing index files
23
+ INDEX_DIR_NAME = ".rslearn_dataset_index"
24
+
25
+
26
+ class DatasetIndex:
27
+ """Manages indexed window lists for faster ModelDataset initialization.
28
+
29
+ Note: The index does NOT automatically detect when windows are added or removed
30
+ from the dataset. Use refresh=True after modifying dataset windows.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ storage: "WindowStorage",
36
+ dataset_path: UPath,
37
+ groups: list[str] | None,
38
+ names: list[str] | None,
39
+ tags: dict[str, Any] | None,
40
+ num_samples: int | None,
41
+ skip_targets: bool,
42
+ inputs: dict[str, Any],
43
+ ) -> None:
44
+ """Initialize DatasetIndex with specific configuration.
45
+
46
+ Args:
47
+ storage: WindowStorage for deserializing windows.
48
+ dataset_path: Path to the dataset directory.
49
+ groups: list of window groups to include.
50
+ names: list of window names to include.
51
+ tags: tags to filter windows by.
52
+ num_samples: limit on number of samples.
53
+ skip_targets: whether targets are skipped.
54
+ inputs: dict mapping input names to DataInput objects.
55
+ """
56
+ self.storage = storage
57
+ self.dataset_path = dataset_path
58
+ self.index_dir = dataset_path / INDEX_DIR_NAME
59
+
60
+ # Compute index key from configuration
61
+ inputs_data = {}
62
+ for name, inp in inputs.items():
63
+ inputs_data[name] = {
64
+ "layers": inp.layers,
65
+ "required": inp.required,
66
+ "load_all_layers": inp.load_all_layers,
67
+ "is_target": inp.is_target,
68
+ }
69
+
70
+ key_data = {
71
+ "groups": groups,
72
+ "names": names,
73
+ "tags": tags,
74
+ "num_samples": num_samples,
75
+ "skip_targets": skip_targets,
76
+ "inputs": inputs_data,
77
+ }
78
+ self.index_key = hashlib.sha256(
79
+ json.dumps(key_data, sort_keys=True).encode()
80
+ ).hexdigest()
81
+
82
+ def _get_config_hash(self) -> str:
83
+ """Get hash of config.json for quick validation.
84
+
85
+ Returns:
86
+ A 16-character hex string hash of the config, or empty string if no config.
87
+ """
88
+ config_path = self.dataset_path / "config.json"
89
+ if config_path.exists():
90
+ with config_path.open() as f:
91
+ return hashlib.sha256(f.read().encode()).hexdigest()[:16]
92
+ return ""
93
+
94
+ def load_windows(self, refresh: bool = False) -> list[Window] | None:
95
+ """Load indexed window list if valid, else return None.
96
+
97
+ Args:
98
+ refresh: If True, ignore existing index and return None.
99
+
100
+ Returns:
101
+ List of Window objects if index is valid, None otherwise.
102
+ """
103
+ if refresh:
104
+ logger.info("refresh=True, rebuilding index")
105
+ return None
106
+
107
+ index_file = self.index_dir / f"{self.index_key}.json"
108
+ if not index_file.exists():
109
+ logger.info(f"No index found at {index_file}, will build")
110
+ return None
111
+
112
+ try:
113
+ with index_file.open() as f:
114
+ index_data = json.load(f)
115
+ except (OSError, json.JSONDecodeError):
116
+ logger.warning(f"Corrupted index file at {index_file}, will rebuild")
117
+ return None
118
+
119
+ # Check index version
120
+ if index_data.get("version") != INDEX_VERSION:
121
+ logger.info(
122
+ f"Index version mismatch (got {index_data.get('version')}, "
123
+ f"expected {INDEX_VERSION}), will rebuild"
124
+ )
125
+ return None
126
+
127
+ # Quick validation: check config hash
128
+ if index_data.get("config_hash") != self._get_config_hash():
129
+ logger.info("Config hash mismatch, index invalidated")
130
+ return None
131
+
132
+ # Deserialize windows
133
+ return [Window.from_metadata(self.storage, w) for w in index_data["windows"]]
134
+
135
+ def save_windows(self, windows: list[Window]) -> None:
136
+ """Save processed windows to index with atomic write.
137
+
138
+ Args:
139
+ windows: List of Window objects to index.
140
+ """
141
+ self.index_dir.mkdir(parents=True, exist_ok=True)
142
+ index_file = self.index_dir / f"{self.index_key}.json"
143
+
144
+ # Serialize windows
145
+ serialized_windows = [w.get_metadata() for w in windows]
146
+
147
+ index_data = {
148
+ "version": INDEX_VERSION,
149
+ "config_hash": self._get_config_hash(),
150
+ "created_at": datetime.now().isoformat(),
151
+ "num_windows": len(windows),
152
+ "windows": serialized_windows,
153
+ }
154
+ with open_atomic(index_file, "w") as f:
155
+ json.dump(index_data, f)
156
+ logger.info(f"Saved {len(windows)} windows to index at {index_file}")