rslearn 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/local_files.py +2 -2
  3. rslearn/data_sources/utils.py +204 -64
  4. rslearn/dataset/materialize.py +5 -1
  5. rslearn/models/clay/clay.py +3 -3
  6. rslearn/models/detr/detr.py +4 -1
  7. rslearn/models/dinov3.py +0 -1
  8. rslearn/models/olmoearth_pretrain/model.py +3 -1
  9. rslearn/models/pooling_decoder.py +1 -1
  10. rslearn/models/prithvi.py +0 -1
  11. rslearn/models/simple_time_series.py +97 -35
  12. rslearn/train/data_module.py +5 -0
  13. rslearn/train/dataset.py +151 -55
  14. rslearn/train/dataset_index.py +156 -0
  15. rslearn/train/model_context.py +16 -0
  16. rslearn/train/tasks/per_pixel_regression.py +13 -13
  17. rslearn/train/tasks/segmentation.py +26 -13
  18. rslearn/train/transforms/concatenate.py +17 -27
  19. rslearn/train/transforms/crop.py +8 -19
  20. rslearn/train/transforms/flip.py +4 -10
  21. rslearn/train/transforms/mask.py +9 -15
  22. rslearn/train/transforms/normalize.py +31 -82
  23. rslearn/train/transforms/pad.py +7 -13
  24. rslearn/train/transforms/resize.py +5 -22
  25. rslearn/train/transforms/select_bands.py +16 -36
  26. rslearn/train/transforms/sentinel1.py +4 -16
  27. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/METADATA +1 -1
  28. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/RECORD +33 -32
  29. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/WHEEL +0 -0
  30. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
  31. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
  32. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
  33. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """SimpleTimeSeries encoder."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
6
  import torch
@@ -25,13 +26,14 @@ class SimpleTimeSeries(FeatureExtractor):
25
26
  def __init__(
26
27
  self,
27
28
  encoder: FeatureExtractor,
28
- image_channels: int | None = None,
29
+ num_timesteps_per_forward_pass: int = 1,
29
30
  op: str = "max",
30
31
  groups: list[list[int]] | None = None,
31
32
  num_layers: int | None = None,
32
33
  image_key: str = "image",
33
34
  backbone_channels: list[tuple[int, int]] | None = None,
34
- image_keys: dict[str, int] | None = None,
35
+ image_keys: list[str] | dict[str, int] | None = None,
36
+ image_channels: int | None = None,
35
37
  ) -> None:
36
38
  """Create a new SimpleTimeSeries.
37
39
 
@@ -39,9 +41,11 @@ class SimpleTimeSeries(FeatureExtractor):
39
41
  encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
40
42
  function that returns the output channels, or backbone_channels must be set.
41
43
  It must output a FeatureMaps.
42
- image_channels: the number of channels per image of the time series. The
43
- input should have multiple images concatenated on the channel axis, so
44
- this parameter is used to distinguish the different images.
44
+ num_timesteps_per_forward_pass: how many timesteps to pass to the encoder
45
+ in each forward pass. Defaults to 1 (one timestep per forward pass).
46
+ Set to a higher value to batch multiple timesteps together, e.g. for
47
+ pre/post change detection where you want 4 pre and 4 post images
48
+ processed together.
45
49
  op: one of max, mean, convrnn, conv3d, or conv1d
46
50
  groups: sets of images for which to combine features. Within each set,
47
51
  features are combined using the specified operation; then, across sets,
@@ -51,28 +55,53 @@ class SimpleTimeSeries(FeatureExtractor):
51
55
  combined before features and the combined after features. groups is a
52
56
  list of sets, and each set is a list of image indices.
53
57
  num_layers: the number of layers for convrnn, conv3d, and conv1d ops.
54
- image_key: the key to access the images.
58
+ image_key: the key to access the images (used when image_keys is not set).
55
59
  backbone_channels: manually specify the backbone channels. Can be set if
56
60
  the encoder does not provide get_backbone_channels function.
57
- image_keys: as an alternative to setting image_channels, map from the key
58
- in input dict to the number of channels per timestep for that modality.
59
- This way SimpleTimeSeries can be used with multimodal inputs. One of
60
- image_channels or image_keys must be specified.
61
+ image_keys: list of keys in input dict to process as multimodal inputs.
62
+ All keys use the same num_timesteps_per_forward_pass. If not set,
63
+ only the single image_key is used. Passing a dict[str, int] is
64
+ deprecated and will be removed on 2026-04-01.
65
+ image_channels: Deprecated, use num_timesteps_per_forward_pass instead.
66
+ Will be removed on 2026-04-01.
61
67
  """
62
- if (image_channels is None and image_keys is None) or (
63
- image_channels is not None and image_keys is not None
64
- ):
65
- raise ValueError(
66
- "exactly one of image_channels and image_keys must be specified"
68
+ # Handle deprecated image_channels parameter
69
+ if image_channels is not None:
70
+ warnings.warn(
71
+ "image_channels is deprecated and will be removed on 2026-04-01. "
72
+ "Use num_timesteps_per_forward_pass instead. The new parameter directly "
73
+ "specifies the number of timesteps per forward pass rather than requiring "
74
+ "image_channels // actual_channels.",
75
+ FutureWarning,
76
+ stacklevel=2,
67
77
  )
68
78
 
79
+ # Handle deprecated dict form of image_keys
80
+ deprecated_image_keys_dict: dict[str, int] | None = None
81
+ if isinstance(image_keys, dict):
82
+ warnings.warn(
83
+ "Passing image_keys as a dict is deprecated and will be removed on "
84
+ "2026-04-01. Use image_keys as a list[str] and set "
85
+ "num_timesteps_per_forward_pass instead.",
86
+ FutureWarning,
87
+ stacklevel=2,
88
+ )
89
+ deprecated_image_keys_dict = image_keys
90
+ image_keys = None # Will use deprecated path in forward
91
+
69
92
  super().__init__()
70
93
  self.encoder = encoder
71
- self.image_channels = image_channels
94
+ self.num_timesteps_per_forward_pass = num_timesteps_per_forward_pass
95
+ # Store deprecated parameters for runtime conversion
96
+ self._deprecated_image_channels = image_channels
97
+ self._deprecated_image_keys_dict = deprecated_image_keys_dict
72
98
  self.op = op
73
99
  self.groups = groups
74
- self.image_key = image_key
75
- self.image_keys = image_keys
100
+ # Normalize image_key to image_keys list form
101
+ if image_keys is not None:
102
+ self.image_keys = image_keys
103
+ else:
104
+ self.image_keys = [image_key]
76
105
 
77
106
  if backbone_channels is not None:
78
107
  out_channels = backbone_channels
@@ -163,24 +192,25 @@ class SimpleTimeSeries(FeatureExtractor):
163
192
  return out_channels
164
193
 
165
194
  def _get_batched_images(
166
- self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
195
+ self, input_dicts: list[dict[str, Any]], image_key: str, num_timesteps: int
167
196
  ) -> list[RasterImage]:
168
197
  """Collect and reshape images across input dicts.
169
198
 
170
199
  The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
171
200
  the forward pass of a per-image (unitemporal) model.
201
+
202
+ Args:
203
+ input_dicts: list of input dictionaries containing RasterImage objects.
204
+ image_key: the key to access the RasterImage in each input dict.
205
+ num_timesteps: how many timesteps to batch together per forward pass.
172
206
  """
173
207
  images = torch.stack(
174
208
  [input_dict[image_key].image for input_dict in input_dicts], dim=0
175
209
  ) # B, C, T, H, W
176
210
  timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
177
- # if image channels is not equal to the actual number of channels, then
178
- # then every N images should be batched together. For example, if the
179
- # number of input channels c == 2, and image_channels == 4, then we
180
- # want to pass 2 timesteps to the model.
181
- # TODO is probably to make this behaviour clearer but lets leave it like
182
- # this for now to not break things.
183
- num_timesteps = image_channels // images.shape[1]
211
+ # num_timesteps specifies how many timesteps to batch together per forward pass.
212
+ # For example, if the input has 8 timesteps and num_timesteps=4, we do 2
213
+ # forward passes, each with 4 timesteps batched together.
184
214
  batched_timesteps = images.shape[2] // num_timesteps
185
215
  images = rearrange(
186
216
  images,
@@ -222,10 +252,22 @@ class SimpleTimeSeries(FeatureExtractor):
222
252
  n_batch = len(context.inputs)
223
253
  n_images: int | None = None
224
254
 
225
- if self.image_keys is not None:
226
- for image_key, image_channels in self.image_keys.items():
255
+ if self._deprecated_image_keys_dict is not None:
256
+ # Deprecated dict form: each key has its own channels_per_timestep.
257
+ # The channels_per_timestep could be used to group multiple timesteps,
258
+ # together, so we need to divide by the actual image channel count to get
259
+ # the number of timesteps to be grouped.
260
+ for (
261
+ image_key,
262
+ channels_per_timestep,
263
+ ) in self._deprecated_image_keys_dict.items():
264
+ # For deprecated image_keys dict, the value is channels per timestep,
265
+ # so we need to compute num_timesteps from the actual image channels
266
+ sample_image = context.inputs[0][image_key].image
267
+ actual_channels = sample_image.shape[0] # C in CTHW
268
+ num_timesteps = channels_per_timestep // actual_channels
227
269
  batched_images = self._get_batched_images(
228
- context.inputs, image_key, image_channels
270
+ context.inputs, image_key, num_timesteps
229
271
  )
230
272
 
231
273
  if batched_inputs is None:
@@ -240,12 +282,32 @@ class SimpleTimeSeries(FeatureExtractor):
240
282
  batched_inputs[i][image_key] = image
241
283
 
242
284
  else:
243
- assert self.image_channels is not None
244
- batched_images = self._get_batched_images(
245
- context.inputs, self.image_key, self.image_channels
246
- )
247
- batched_inputs = [{self.image_key: image} for image in batched_images]
248
- n_images = len(batched_images) // n_batch
285
+ # Determine num_timesteps - either from deprecated image_channels or
286
+ # directly from num_timesteps_per_forward_pass
287
+ if self._deprecated_image_channels is not None:
288
+ # Backwards compatibility: compute num_timesteps from image_channels
289
+ # (which should be a multiple of the actual per-timestep channels).
290
+ sample_image = context.inputs[0][self.image_keys[0]].image
291
+ actual_channels = sample_image.shape[0] # C in CTHW
292
+ num_timesteps = self._deprecated_image_channels // actual_channels
293
+ else:
294
+ num_timesteps = self.num_timesteps_per_forward_pass
295
+
296
+ for image_key in self.image_keys:
297
+ batched_images = self._get_batched_images(
298
+ context.inputs, image_key, num_timesteps
299
+ )
300
+
301
+ if batched_inputs is None:
302
+ batched_inputs = [{} for _ in batched_images]
303
+ n_images = len(batched_images) // n_batch
304
+ elif n_images != len(batched_images) // n_batch:
305
+ raise ValueError(
306
+ "expected all modalities to have the same number of timesteps"
307
+ )
308
+
309
+ for i, image in enumerate(batched_images):
310
+ batched_inputs[i][image_key] = image
249
311
 
250
312
  assert n_images is not None
251
313
  # Now we can apply the underlying FeatureExtractor.
@@ -21,6 +21,7 @@ from .all_patches_dataset import (
21
21
  )
22
22
  from .dataset import (
23
23
  DataInput,
24
+ IndexMode,
24
25
  ModelDataset,
25
26
  MultiDataset,
26
27
  RetryDataset,
@@ -69,6 +70,7 @@ class RslearnDataModule(L.LightningDataModule):
69
70
  name: str | None = None,
70
71
  retries: int = 0,
71
72
  use_in_memory_all_patches_dataset: bool = False,
73
+ index_mode: IndexMode = IndexMode.OFF,
72
74
  ) -> None:
73
75
  """Initialize a new RslearnDataModule.
74
76
 
@@ -92,6 +94,7 @@ class RslearnDataModule(L.LightningDataModule):
92
94
  retries: number of retries to attempt for getitem calls
93
95
  use_in_memory_all_patches_dataset: whether to use InMemoryAllPatchesDataset
94
96
  instead of IterableAllPatchesDataset if load_all_patches is set to true.
97
+ index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
95
98
  """
96
99
  super().__init__()
97
100
  self.inputs = inputs
@@ -103,6 +106,7 @@ class RslearnDataModule(L.LightningDataModule):
103
106
  self.name = name
104
107
  self.retries = retries
105
108
  self.use_in_memory_all_patches_dataset = use_in_memory_all_patches_dataset
109
+ self.index_mode = index_mode
106
110
  self.split_configs = {
107
111
  "train": default_config.update(train_config),
108
112
  "val": default_config.update(val_config),
@@ -138,6 +142,7 @@ class RslearnDataModule(L.LightningDataModule):
138
142
  workers=self.init_workers,
139
143
  name=self.name,
140
144
  fix_patch_pick=(split != "train"),
145
+ index_mode=self.index_mode,
141
146
  )
142
147
  logger.info(f"got {len(dataset)} examples in split {split}")
143
148
  if split_config.get_load_all_patches():
rslearn/train/dataset.py CHANGED
@@ -9,6 +9,7 @@ import tempfile
9
9
  import time
10
10
  import uuid
11
11
  from datetime import datetime
12
+ from enum import StrEnum
12
13
  from typing import Any
13
14
 
14
15
  import torch
@@ -29,6 +30,7 @@ from rslearn.dataset.window import (
29
30
  get_layer_and_group_from_dir_name,
30
31
  )
31
32
  from rslearn.log_utils import get_logger
33
+ from rslearn.train.dataset_index import DatasetIndex
32
34
  from rslearn.train.model_context import RasterImage
33
35
  from rslearn.utils.feature import Feature
34
36
  from rslearn.utils.geometry import PixelBounds, ResolutionFactor
@@ -41,6 +43,19 @@ from .transforms import Sequential
41
43
  logger = get_logger(__name__)
42
44
 
43
45
 
46
+ class IndexMode(StrEnum):
47
+ """Controls dataset index caching behavior."""
48
+
49
+ OFF = "off"
50
+ """No caching - always load windows from dataset."""
51
+
52
+ USE = "use"
53
+ """Use cached index if available, create if not."""
54
+
55
+ REFRESH = "refresh"
56
+ """Ignore existing cache and rebuild."""
57
+
58
+
44
59
  def get_torch_dtype(dtype: DType) -> torch.dtype:
45
60
  """Convert rslearn DType to torch dtype."""
46
61
  if dtype == DType.INT32:
@@ -636,6 +651,7 @@ class ModelDataset(torch.utils.data.Dataset):
636
651
  workers: int,
637
652
  name: str | None = None,
638
653
  fix_patch_pick: bool = False,
654
+ index_mode: IndexMode = IndexMode.OFF,
639
655
  ) -> None:
640
656
  """Instantiate a new ModelDataset.
641
657
 
@@ -645,9 +661,10 @@ class ModelDataset(torch.utils.data.Dataset):
645
661
  inputs: data to read from the dataset for training
646
662
  task: the task to train on
647
663
  workers: number of workers to use for initializing the dataset
648
- name: name of the dataset (default: None)
664
+ name: name of the dataset
649
665
  fix_patch_pick: if True, fix the patch pick to be the same every time
650
666
  for a given window. Useful for testing (default: False)
667
+ index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
651
668
  """
652
669
  self.dataset = dataset
653
670
  self.split_config = split_config
@@ -668,66 +685,14 @@ class ModelDataset(torch.utils.data.Dataset):
668
685
  else:
669
686
  self.patch_size = split_config.get_patch_size()
670
687
 
671
- windows = self._get_initial_windows(split_config, workers)
672
-
673
688
  # If targets are not needed, remove them from the inputs.
674
689
  if split_config.get_skip_targets():
675
690
  for k in list(self.inputs.keys()):
676
691
  if self.inputs[k].is_target:
677
692
  del self.inputs[k]
678
693
 
679
- # Eliminate windows that are missing either a requisite input layer, or missing
680
- # all target layers.
681
- new_windows = []
682
- if workers == 0:
683
- for window in windows:
684
- if (
685
- check_window(
686
- self.inputs,
687
- window,
688
- output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
689
- )
690
- is None
691
- ):
692
- continue
693
- new_windows.append(window)
694
- else:
695
- p = multiprocessing.Pool(workers)
696
- outputs = star_imap_unordered(
697
- p,
698
- check_window,
699
- [
700
- dict(
701
- inputs=self.inputs,
702
- window=window,
703
- output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
704
- )
705
- for window in windows
706
- ],
707
- )
708
- for window in tqdm.tqdm(
709
- outputs, total=len(windows), desc="Checking available layers in windows"
710
- ):
711
- if window is None:
712
- continue
713
- new_windows.append(window)
714
- p.close()
715
- windows = new_windows
716
-
717
- # Sort the windows to ensure that the dataset is consistent across GPUs.
718
- # Inconsistent ordering can lead to a subset of windows being processed during
719
- # "model test" / "model predict" when using multiple GPUs.
720
- # We use a hash so that functionality like num_samples limit gets a random
721
- # subset of windows (with respect to the hash function choice).
722
- windows.sort(
723
- key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
724
- )
725
-
726
- # Limit windows to num_samples if requested.
727
- if split_config.num_samples:
728
- # The windows are sorted by hash of window name so this distribution should
729
- # be representative of the population.
730
- windows = windows[0 : split_config.num_samples]
694
+ # Load windows (from index if available, otherwise from dataset)
695
+ windows = self._load_windows(split_config, workers, index_mode)
731
696
 
732
697
  # Write dataset_examples to a file so that we can load it lazily in the worker
733
698
  # processes. Otherwise it takes a long time to transmit it when spawning each
@@ -796,6 +761,137 @@ class ModelDataset(torch.utils.data.Dataset):
796
761
 
797
762
  return windows
798
763
 
764
+ def _load_windows(
765
+ self,
766
+ split_config: SplitConfig,
767
+ workers: int,
768
+ index_mode: IndexMode,
769
+ ) -> list[Window]:
770
+ """Load windows, using index if available.
771
+
772
+ This method handles:
773
+ 1. Loading from index if index_mode is USE and index exists
774
+ 2. Otherwise, loading from dataset, filtering, sorting, limiting
775
+ 3. Saving to index if index_mode is USE or REFRESH
776
+
777
+ Args:
778
+ split_config: the split configuration.
779
+ workers: number of worker processes.
780
+ index_mode: controls caching behavior.
781
+
782
+ Returns:
783
+ list of processed windows ready for training.
784
+ """
785
+ # Try to load from index
786
+ index: DatasetIndex | None = None
787
+
788
+ if index_mode != IndexMode.OFF:
789
+ logger.info(f"Checking index for dataset {self.dataset.path}")
790
+ index = DatasetIndex(
791
+ storage=self.dataset.storage,
792
+ dataset_path=self.dataset.path,
793
+ groups=split_config.groups,
794
+ names=split_config.names,
795
+ tags=split_config.tags,
796
+ num_samples=split_config.num_samples,
797
+ skip_targets=split_config.get_skip_targets(),
798
+ inputs=self.inputs,
799
+ )
800
+ refresh = index_mode == IndexMode.REFRESH
801
+ indexed_windows = index.load_windows(refresh)
802
+
803
+ if indexed_windows is not None:
804
+ logger.info(f"Loaded {len(indexed_windows)} windows from index")
805
+ return indexed_windows
806
+
807
+ # No index available, load and process windows from dataset
808
+ logger.debug("Loading windows from dataset...")
809
+ windows = self._get_initial_windows(split_config, workers)
810
+ windows = self._filter_windows_by_layers(windows, workers)
811
+ windows = self._sort_and_limit_windows(windows, split_config)
812
+
813
+ # Save to index if enabled
814
+ if index is not None:
815
+ index.save_windows(windows)
816
+
817
+ return windows
818
+
819
+ def _filter_windows_by_layers(
820
+ self, windows: list[Window], workers: int
821
+ ) -> list[Window]:
822
+ """Filter windows to only include those with required layers.
823
+
824
+ Args:
825
+ windows: list of windows to filter.
826
+ workers: number of worker processes for parallel filtering.
827
+
828
+ Returns:
829
+ list of windows that have all required input layers.
830
+ """
831
+ output_layer_skip = (
832
+ self.split_config.get_output_layer_name_skip_inference_if_exists()
833
+ )
834
+
835
+ if workers == 0:
836
+ return [
837
+ w
838
+ for w in windows
839
+ if check_window(
840
+ self.inputs,
841
+ w,
842
+ output_layer_name_skip_inference_if_exists=output_layer_skip,
843
+ )
844
+ is not None
845
+ ]
846
+
847
+ p = multiprocessing.Pool(workers)
848
+ outputs = star_imap_unordered(
849
+ p,
850
+ check_window,
851
+ [
852
+ dict(
853
+ inputs=self.inputs,
854
+ window=window,
855
+ output_layer_name_skip_inference_if_exists=output_layer_skip,
856
+ )
857
+ for window in windows
858
+ ],
859
+ )
860
+ filtered = []
861
+ for window in tqdm.tqdm(
862
+ outputs,
863
+ total=len(windows),
864
+ desc="Checking available layers in windows",
865
+ ):
866
+ if window is not None:
867
+ filtered.append(window)
868
+ p.close()
869
+ return filtered
870
+
871
+ def _sort_and_limit_windows(
872
+ self, windows: list[Window], split_config: SplitConfig
873
+ ) -> list[Window]:
874
+ """Sort windows by hash and apply num_samples limit.
875
+
876
+ Sorting ensures consistent ordering across GPUs. Using hash gives a
877
+ pseudo-random but deterministic order for sampling.
878
+
879
+ Args:
880
+ windows: list of windows to sort and limit.
881
+ split_config: the split configuration with num_samples.
882
+
883
+ Returns:
884
+ sorted and optionally limited list of windows.
885
+ """
886
+ windows.sort(
887
+ key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
888
+ )
889
+
890
+ if split_config.num_samples:
891
+ windows = windows[: split_config.num_samples]
892
+
893
+ return windows
894
+
799
895
  def _serialize_item(self, example: Window) -> dict[str, Any]:
800
896
  return example.get_metadata()
801
897
 
@@ -0,0 +1,156 @@
1
+ """Dataset index for caching window lists to speed up ModelDataset initialization."""
2
+
3
+ import hashlib
4
+ import json
5
+ from datetime import datetime
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from upath import UPath
9
+
10
+ from rslearn.dataset.window import Window
11
+ from rslearn.log_utils import get_logger
12
+ from rslearn.utils.fsspec import open_atomic
13
+
14
+ if TYPE_CHECKING:
15
+ from rslearn.dataset.storage.storage import WindowStorage
16
+
17
+ logger = get_logger(__name__)
18
+
19
+ # Increment this when the index format changes to force rebuild
20
+ INDEX_VERSION = 1
21
+
22
+ # Directory name for storing index files
23
+ INDEX_DIR_NAME = ".rslearn_dataset_index"
24
+
25
+
26
+ class DatasetIndex:
27
+ """Manages indexed window lists for faster ModelDataset initialization.
28
+
29
+ Note: The index does NOT automatically detect when windows are added or removed
30
+ from the dataset. Use refresh=True after modifying dataset windows.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ storage: "WindowStorage",
36
+ dataset_path: UPath,
37
+ groups: list[str] | None,
38
+ names: list[str] | None,
39
+ tags: dict[str, Any] | None,
40
+ num_samples: int | None,
41
+ skip_targets: bool,
42
+ inputs: dict[str, Any],
43
+ ) -> None:
44
+ """Initialize DatasetIndex with specific configuration.
45
+
46
+ Args:
47
+ storage: WindowStorage for deserializing windows.
48
+ dataset_path: Path to the dataset directory.
49
+ groups: list of window groups to include.
50
+ names: list of window names to include.
51
+ tags: tags to filter windows by.
52
+ num_samples: limit on number of samples.
53
+ skip_targets: whether targets are skipped.
54
+ inputs: dict mapping input names to DataInput objects.
55
+ """
56
+ self.storage = storage
57
+ self.dataset_path = dataset_path
58
+ self.index_dir = dataset_path / INDEX_DIR_NAME
59
+
60
+ # Compute index key from configuration
61
+ inputs_data = {}
62
+ for name, inp in inputs.items():
63
+ inputs_data[name] = {
64
+ "layers": inp.layers,
65
+ "required": inp.required,
66
+ "load_all_layers": inp.load_all_layers,
67
+ "is_target": inp.is_target,
68
+ }
69
+
70
+ key_data = {
71
+ "groups": groups,
72
+ "names": names,
73
+ "tags": tags,
74
+ "num_samples": num_samples,
75
+ "skip_targets": skip_targets,
76
+ "inputs": inputs_data,
77
+ }
78
+ self.index_key = hashlib.sha256(
79
+ json.dumps(key_data, sort_keys=True).encode()
80
+ ).hexdigest()
81
+
82
+ def _get_config_hash(self) -> str:
83
+ """Get hash of config.json for quick validation.
84
+
85
+ Returns:
86
+ A 16-character hex string hash of the config, or empty string if no config.
87
+ """
88
+ config_path = self.dataset_path / "config.json"
89
+ if config_path.exists():
90
+ with config_path.open() as f:
91
+ return hashlib.sha256(f.read().encode()).hexdigest()[:16]
92
+ return ""
93
+
94
+ def load_windows(self, refresh: bool = False) -> list[Window] | None:
95
+ """Load indexed window list if valid, else return None.
96
+
97
+ Args:
98
+ refresh: If True, ignore existing index and return None.
99
+
100
+ Returns:
101
+ List of Window objects if index is valid, None otherwise.
102
+ """
103
+ if refresh:
104
+ logger.info("refresh=True, rebuilding index")
105
+ return None
106
+
107
+ index_file = self.index_dir / f"{self.index_key}.json"
108
+ if not index_file.exists():
109
+ logger.info(f"No index found at {index_file}, will build")
110
+ return None
111
+
112
+ try:
113
+ with index_file.open() as f:
114
+ index_data = json.load(f)
115
+ except (OSError, json.JSONDecodeError):
116
+ logger.warning(f"Corrupted index file at {index_file}, will rebuild")
117
+ return None
118
+
119
+ # Check index version
120
+ if index_data.get("version") != INDEX_VERSION:
121
+ logger.info(
122
+ f"Index version mismatch (got {index_data.get('version')}, "
123
+ f"expected {INDEX_VERSION}), will rebuild"
124
+ )
125
+ return None
126
+
127
+ # Quick validation: check config hash
128
+ if index_data.get("config_hash") != self._get_config_hash():
129
+ logger.info("Config hash mismatch, index invalidated")
130
+ return None
131
+
132
+ # Deserialize windows
133
+ return [Window.from_metadata(self.storage, w) for w in index_data["windows"]]
134
+
135
+ def save_windows(self, windows: list[Window]) -> None:
136
+ """Save processed windows to index with atomic write.
137
+
138
+ Args:
139
+ windows: List of Window objects to index.
140
+ """
141
+ self.index_dir.mkdir(parents=True, exist_ok=True)
142
+ index_file = self.index_dir / f"{self.index_key}.json"
143
+
144
+ # Serialize windows
145
+ serialized_windows = [w.get_metadata() for w in windows]
146
+
147
+ index_data = {
148
+ "version": INDEX_VERSION,
149
+ "config_hash": self._get_config_hash(),
150
+ "created_at": datetime.now().isoformat(),
151
+ "num_windows": len(windows),
152
+ "windows": serialized_windows,
153
+ }
154
+ with open_atomic(index_file, "w") as f:
155
+ json.dump(index_data, f)
156
+ logger.info(f"Saved {len(windows)} windows to index at {index_file}")
@@ -43,6 +43,22 @@ class RasterImage:
43
43
  raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
44
44
  return self.image[:, 0]
45
45
 
46
+ def get_hw_tensor(self) -> torch.Tensor:
47
+ """Get a 2D HW tensor from a single-channel, single-timestep RasterImage.
48
+
49
+ This function checks that C=1 and T=1, then returns the HW tensor.
50
+ Useful for per-pixel labels like segmentation masks.
51
+ """
52
+ if self.image.shape[0] != 1:
53
+ raise ValueError(
54
+ f"Expected single channel (C=1), got {self.image.shape[0]}"
55
+ )
56
+ if self.image.shape[1] != 1:
57
+ raise ValueError(
58
+ f"Expected single timestep (T=1), got {self.image.shape[1]}"
59
+ )
60
+ return self.image[0, 0]
61
+
46
62
 
47
63
  @dataclass
48
64
  class SampleMetadata: