rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. rslearn/config/__init__.py +2 -0
  2. rslearn/config/dataset.py +49 -4
  3. rslearn/dataset/add_windows.py +1 -1
  4. rslearn/dataset/dataset.py +9 -65
  5. rslearn/dataset/materialize.py +5 -5
  6. rslearn/dataset/storage/__init__.py +1 -0
  7. rslearn/dataset/storage/file.py +202 -0
  8. rslearn/dataset/storage/storage.py +140 -0
  9. rslearn/dataset/window.py +26 -80
  10. rslearn/main.py +11 -36
  11. rslearn/models/anysat.py +11 -9
  12. rslearn/models/clay/clay.py +8 -9
  13. rslearn/models/clip.py +18 -15
  14. rslearn/models/component.py +99 -0
  15. rslearn/models/concatenate_features.py +21 -11
  16. rslearn/models/conv.py +15 -8
  17. rslearn/models/croma.py +13 -8
  18. rslearn/models/detr/detr.py +25 -14
  19. rslearn/models/dinov3.py +11 -6
  20. rslearn/models/faster_rcnn.py +19 -9
  21. rslearn/models/feature_center_crop.py +12 -9
  22. rslearn/models/fpn.py +19 -8
  23. rslearn/models/galileo/galileo.py +23 -18
  24. rslearn/models/module_wrapper.py +26 -57
  25. rslearn/models/molmo.py +16 -14
  26. rslearn/models/multitask.py +102 -73
  27. rslearn/models/olmoearth_pretrain/model.py +18 -12
  28. rslearn/models/panopticon.py +8 -7
  29. rslearn/models/pick_features.py +18 -24
  30. rslearn/models/pooling_decoder.py +22 -14
  31. rslearn/models/presto/presto.py +16 -10
  32. rslearn/models/presto/single_file_presto.py +4 -10
  33. rslearn/models/prithvi.py +12 -8
  34. rslearn/models/resize_features.py +21 -7
  35. rslearn/models/sam2_enc.py +11 -9
  36. rslearn/models/satlaspretrain.py +15 -9
  37. rslearn/models/simple_time_series.py +31 -17
  38. rslearn/models/singletask.py +24 -17
  39. rslearn/models/ssl4eo_s12.py +15 -10
  40. rslearn/models/swin.py +22 -13
  41. rslearn/models/terramind.py +24 -7
  42. rslearn/models/trunk.py +6 -3
  43. rslearn/models/unet.py +18 -9
  44. rslearn/models/upsample.py +22 -9
  45. rslearn/train/all_patches_dataset.py +22 -18
  46. rslearn/train/dataset.py +69 -54
  47. rslearn/train/lightning_module.py +51 -32
  48. rslearn/train/model_context.py +54 -0
  49. rslearn/train/prediction_writer.py +111 -41
  50. rslearn/train/tasks/classification.py +34 -15
  51. rslearn/train/tasks/detection.py +24 -31
  52. rslearn/train/tasks/embedding.py +33 -29
  53. rslearn/train/tasks/multi_task.py +7 -7
  54. rslearn/train/tasks/per_pixel_regression.py +41 -19
  55. rslearn/train/tasks/regression.py +38 -21
  56. rslearn/train/tasks/segmentation.py +33 -15
  57. rslearn/train/tasks/task.py +3 -2
  58. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
  59. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
  60. rslearn/dataset/index.py +0 -173
  61. rslearn/models/registry.py +0 -22
  62. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
  63. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
  64. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
@@ -8,8 +8,11 @@ import torch.nn.functional as F
8
8
  from einops import rearrange
9
9
  from terratorch.registry import BACKBONE_REGISTRY
10
10
 
11
+ from rslearn.train.model_context import ModelContext
11
12
  from rslearn.train.transforms.transform import Transform
12
13
 
14
+ from .component import FeatureExtractor, FeatureMaps
15
+
13
16
 
14
17
  # TerraMind v1 provides two sizes: base and large
15
18
  class TerramindSize(str, Enum):
@@ -85,7 +88,7 @@ PRETRAINED_BANDS = {
85
88
  }
86
89
 
87
90
 
88
- class Terramind(torch.nn.Module):
91
+ class Terramind(FeatureExtractor):
89
92
  """Terramind backbones."""
90
93
 
91
94
  def __init__(
@@ -123,21 +126,25 @@ class Terramind(torch.nn.Module):
123
126
  self.modalities = modalities
124
127
  self.do_resizing = do_resizing
125
128
 
126
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
129
+ def forward(self, context: ModelContext) -> FeatureMaps:
127
130
  """Forward pass for the Terramind model.
128
131
 
129
132
  Args:
130
- inputs: input dicts that must include modalities as keys which are defined in the self.modalities list
133
+ context: the model context. Input dicts must include modalities as keys
134
+ which are defined in the self.modalities list.
131
135
 
132
136
  Returns:
133
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
137
+ a FeatureMaps with one feature map from the encoder, at 1/16 of the input
138
+ resolution.
134
139
  """
135
140
  model_inputs = {}
136
141
  for modality in self.modalities:
137
142
  # We assume the all the inputs include the same modalities
138
- if modality not in inputs[0]:
143
+ if modality not in context.inputs[0]:
139
144
  continue
140
- cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
145
+ cur = torch.stack(
146
+ [inp[modality] for inp in context.inputs], dim=0
147
+ ) # (B, C, H, W)
141
148
  if self.do_resizing and (
142
149
  cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
143
150
  ):
@@ -159,7 +166,17 @@ class Terramind(torch.nn.Module):
159
166
  image_features = self.model(model_inputs)[-1]
160
167
  batch_size, num_patches, _ = image_features.shape
161
168
  height, width = int(num_patches**0.5), int(num_patches**0.5)
162
- return [rearrange(image_features, "b (h w) d -> b d h w", h=height, w=width)]
169
+ return FeatureMaps(
170
+ [
171
+ rearrange(
172
+ image_features,
173
+ "b (h w) d -> b d h w",
174
+ b=batch_size,
175
+ h=height,
176
+ w=width,
177
+ )
178
+ ]
179
+ )
163
180
 
164
181
  def get_backbone_channels(self) -> list:
165
182
  """Returns the output channels of this model when used as a backbone.
rslearn/models/trunk.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
7
 
8
8
  from rslearn.log_utils import get_logger
9
9
  from rslearn.models.task_embedding import BaseTaskEmbedding
10
+ from rslearn.train.model_context import ModelOutput
10
11
 
11
12
  logger = get_logger(__name__)
12
13
 
@@ -32,10 +33,11 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
32
33
  dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
33
34
  and optionally other keys.
34
35
  """
36
+ raise NotImplementedError
35
37
 
36
38
  @abstractmethod
37
39
  def apply_auxiliary_losses(
38
- self, trunk_out: dict[str, Any], outs: dict[str, Any]
40
+ self, trunk_out: dict[str, Any], outs: ModelOutput
39
41
  ) -> None:
40
42
  """Apply auxiliary losses in-place.
41
43
 
@@ -43,6 +45,7 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
43
45
  trunk_out: The output of the trunk.
44
46
  outs: The output of the decoders, with key "loss_dict" containing the losses.
45
47
  """
48
+ raise NotImplementedError
46
49
 
47
50
 
48
51
  class DecoderTrunk(torch.nn.Module):
@@ -122,7 +125,7 @@ class DecoderTrunk(torch.nn.Module):
122
125
  return out
123
126
 
124
127
  def apply_auxiliary_losses(
125
- self, trunk_out: dict[str, Any], outs: dict[str, Any]
128
+ self, trunk_out: dict[str, Any], outs: ModelOutput
126
129
  ) -> None:
127
130
  """Apply auxiliary losses in-place.
128
131
 
@@ -130,7 +133,7 @@ class DecoderTrunk(torch.nn.Module):
130
133
 
131
134
  Args:
132
135
  trunk_out: The output of the trunk.
133
- outs: The output of the decoders, with key "loss_dict" containing the losses.
136
+ outs: The output of the decoders.
134
137
  """
135
138
  for layer in self.layers:
136
139
  layer.apply_auxiliary_losses(trunk_out, outs)
rslearn/models/unet.py CHANGED
@@ -5,8 +5,15 @@ from typing import Any
5
5
  import torch
6
6
  import torch.nn.functional as F
7
7
 
8
+ from rslearn.train.model_context import ModelContext
8
9
 
9
- class UNetDecoder(torch.nn.Module):
10
+ from .component import (
11
+ FeatureMaps,
12
+ IntermediateComponent,
13
+ )
14
+
15
+
16
+ class UNetDecoder(IntermediateComponent):
10
17
  """UNet-style decoder.
11
18
 
12
19
  It inputs multi-scale features. Starting from last (lowest resolution) feature map,
@@ -143,23 +150,25 @@ class UNetDecoder(torch.nn.Module):
143
150
  align_corners=False,
144
151
  )
145
152
 
146
- def forward(
147
- self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]
148
- ) -> torch.Tensor:
153
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
149
154
  """Compute output from multi-scale feature map.
150
155
 
151
156
  Args:
152
- in_features: list of feature maps at different resolutions.
153
- inputs: original inputs (ignored).
157
+ intermediates: the output from the previous model component, which must be a FeatureMaps.
158
+ context: the model context.
154
159
 
155
160
  Returns:
156
- output image
161
+ output FeatureMaps consisting of one map. The embedding size is equal to the
162
+ configured out_channels.
157
163
  """
164
+ if not isinstance(intermediates, FeatureMaps):
165
+ raise ValueError("input to UNetDecoder must be a FeatureMaps")
166
+
158
167
  # Reverse the features since we will pass them in from lowest resolution to highest.
159
- in_features = list(reversed(in_features))
168
+ in_features = list(reversed(intermediates.feature_maps))
160
169
  cur_features = self.layers[0](in_features[0])
161
170
  for in_feat, layer in zip(in_features[1:], self.layers[1:]):
162
171
  cur_features = layer(torch.cat([cur_features, in_feat], dim=1))
163
172
  if self.original_size_to_interpolate is not None:
164
173
  cur_features = self._resize(cur_features)
165
- return cur_features
174
+ return FeatureMaps([cur_features])
@@ -1,9 +1,18 @@
1
1
  """An upsampling layer."""
2
2
 
3
+ from typing import Any
4
+
3
5
  import torch
4
6
 
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import (
10
+ FeatureMaps,
11
+ IntermediateComponent,
12
+ )
5
13
 
6
- class Upsample(torch.nn.Module):
14
+
15
+ class Upsample(IntermediateComponent):
7
16
  """Upsamples each input feature map by the same factor."""
8
17
 
9
18
  def __init__(
@@ -20,16 +29,20 @@ class Upsample(torch.nn.Module):
20
29
  super().__init__()
21
30
  self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
22
31
 
23
- def forward(
24
- self, features: list[torch.Tensor], inputs: list[torch.Tensor]
25
- ) -> list[torch.Tensor]:
26
- """Upsample each feature map.
32
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
33
+ """Upsample each feature map by scale_factor.
27
34
 
28
35
  Args:
29
- features: list of feature maps at different resolutions.
30
- inputs: original inputs (ignored).
36
+ intermediates: the output from the previous component, which must be a FeatureMaps.
37
+ context: the model context.
31
38
 
32
39
  Returns:
33
- upsampled feature maps
40
+ upsampled feature maps.
34
41
  """
35
- return [self.layer(feat_map) for feat_map in features]
42
+ if not isinstance(intermediates, FeatureMaps):
43
+ raise ValueError("input to Upsample must be a FeatureMaps")
44
+
45
+ upsampled_feat_maps = [
46
+ self.layer(feat_map) for feat_map in intermediates.feature_maps
47
+ ]
48
+ return FeatureMaps(upsampled_feat_maps)
@@ -2,6 +2,7 @@
2
2
 
3
3
  import itertools
4
4
  from collections.abc import Iterable, Iterator
5
+ from dataclasses import replace
5
6
  from typing import Any
6
7
 
7
8
  import shapely
@@ -9,6 +10,7 @@ import torch
9
10
 
10
11
  from rslearn.dataset import Window
11
12
  from rslearn.train.dataset import ModelDataset
13
+ from rslearn.train.model_context import SampleMetadata
12
14
  from rslearn.utils.geometry import PixelBounds, STGeometry
13
15
 
14
16
 
@@ -218,7 +220,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
218
220
 
219
221
  def __iter__(
220
222
  self,
221
- ) -> Iterator[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]]:
223
+ ) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
222
224
  """Iterate over all patches in each element of the underlying ModelDataset."""
223
225
  # Iterate over the window IDs until we have returned enough samples.
224
226
  window_ids, num_samples_needed = self._get_worker_iteration_data()
@@ -229,7 +231,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
229
231
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
230
232
  window_id
231
233
  )
232
- bounds = metadata["bounds"]
234
+ bounds = metadata.patch_bounds
233
235
 
234
236
  # For simplicity, pad tensors by patch size to ensure that any patch bounds
235
237
  # extending outside the window bounds will not have issues when we slice
@@ -244,7 +246,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
244
246
  )
245
247
  for patch_idx, patch_bounds in enumerate(patches):
246
248
  cur_geom = STGeometry(
247
- metadata["projection"], shapely.box(*patch_bounds), None
249
+ metadata.projection, shapely.box(*patch_bounds), None
248
250
  )
249
251
  start_offset = (
250
252
  patch_bounds[0] - bounds[0],
@@ -282,10 +284,12 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
282
284
  cur_passthrough_inputs = crop_input_dict(passthrough_inputs)
283
285
 
284
286
  # Adjust the metadata as well.
285
- cur_metadata = metadata.copy()
286
- cur_metadata["bounds"] = patch_bounds
287
- cur_metadata["patch_idx"] = patch_idx
288
- cur_metadata["num_patches"] = len(patches)
287
+ cur_metadata = replace(
288
+ metadata,
289
+ patch_bounds=patch_bounds,
290
+ patch_idx=patch_idx,
291
+ num_patches_in_window=len(patches),
292
+ )
289
293
 
290
294
  # Now we can compute input and target dicts via the task.
291
295
  input_dict, target_dict = self.dataset.task.process_inputs(
@@ -297,7 +301,6 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
297
301
  input_dict, target_dict = self.dataset.transforms(
298
302
  input_dict, target_dict
299
303
  )
300
- input_dict["dataset_source"] = self.dataset.name
301
304
 
302
305
  if num_samples_returned < num_samples_needed:
303
306
  yield input_dict, target_dict, cur_metadata
@@ -346,7 +349,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
346
349
  )
347
350
  self.windows = self.dataset.get_dataset_examples()
348
351
  self.window_cache: dict[
349
- int, tuple[dict[str, Any], dict[str, Any], dict[str, Any]]
352
+ int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
350
353
  ] = {}
351
354
 
352
355
  # Precompute the batch boundaries for each window
@@ -360,7 +363,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
360
363
 
361
364
  def get_raw_inputs(
362
365
  self, index: int
363
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
366
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
364
367
  """Get the raw inputs for a single patch. Retrieve from cache if possible.
365
368
 
366
369
  Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
@@ -410,13 +413,13 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
410
413
 
411
414
  def __getitem__(
412
415
  self, index: int
413
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
416
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
414
417
  """Return (input_dict, target_dict, metadata) for a single flattened patch."""
415
418
  (window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
416
419
  raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
417
- bounds = metadata["bounds"]
420
+ bounds = metadata.patch_bounds
418
421
 
419
- cur_geom = STGeometry(metadata["projection"], shapely.box(*patch_bounds), None)
422
+ cur_geom = STGeometry(metadata.projection, shapely.box(*patch_bounds), None)
420
423
  start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
421
424
  end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
422
425
 
@@ -428,10 +431,12 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
428
431
  )
429
432
 
430
433
  # Adjust the metadata as well.
431
- cur_metadata = metadata.copy()
432
- cur_metadata["bounds"] = patch_bounds
433
- cur_metadata["patch_idx"] = patch_idx
434
- cur_metadata["num_patches"] = num_patches
434
+ cur_metadata = replace(
435
+ metadata,
436
+ patch_bounds=patch_bounds,
437
+ patch_idx=patch_idx,
438
+ num_patches_in_window=num_patches,
439
+ )
435
440
 
436
441
  # Now we can compute input and target dicts via the task.
437
442
  input_dict, target_dict = self.dataset.task.process_inputs(
@@ -441,7 +446,6 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
441
446
  )
442
447
  input_dict.update(cur_passthrough_inputs)
443
448
  input_dict, target_dict = self.dataset.transforms(input_dict, target_dict)
444
- input_dict["dataset_source"] = self.dataset.name
445
449
 
446
450
  return input_dict, target_dict, cur_metadata
447
451
 
rslearn/train/dataset.py CHANGED
@@ -20,13 +20,15 @@ from rslearn.config import (
20
20
  LayerConfig,
21
21
  )
22
22
  from rslearn.dataset.dataset import Dataset
23
+ from rslearn.dataset.storage.file import FileWindowStorage
23
24
  from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
24
25
  from rslearn.log_utils import get_logger
25
- from rslearn.train.tasks import Task
26
26
  from rslearn.utils.feature import Feature
27
27
  from rslearn.utils.geometry import PixelBounds
28
28
  from rslearn.utils.mp import star_imap_unordered
29
29
 
30
+ from .model_context import SampleMetadata
31
+ from .tasks import Task
30
32
  from .transforms import Sequential
31
33
 
32
34
  logger = get_logger(__name__)
@@ -575,37 +577,7 @@ class ModelDataset(torch.utils.data.Dataset):
575
577
  else:
576
578
  self.patch_size = split_config.get_patch_size()
577
579
 
578
- if split_config.names:
579
- windows = self.dataset.load_windows(
580
- groups=split_config.groups,
581
- names=split_config.names,
582
- show_progress=True,
583
- workers=workers,
584
- )
585
- elif split_config.groups:
586
- windows = self.dataset.load_windows(
587
- groups=split_config.groups, show_progress=True, workers=workers
588
- )
589
- else:
590
- windows = self.dataset.load_windows(show_progress=True, workers=workers)
591
-
592
- if split_config.tags:
593
- # Filter the window.options.
594
- new_windows = []
595
- num_removed: dict[str, int] = {}
596
- for window in windows:
597
- for k, v in split_config.tags.items():
598
- if k not in window.options or (v and window.options[k] != v):
599
- num_removed[k] = num_removed.get(k, 0) + 1
600
- break
601
- else:
602
- new_windows.append(window)
603
- logger.info(
604
- f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
605
- )
606
- for k, v in num_removed.items():
607
- logger.info(f"Removed {v} windows due to tag {k}")
608
- windows = new_windows
580
+ windows = self._get_initial_windows(split_config, workers)
609
581
 
610
582
  # If targets are not needed, remove them from the inputs.
611
583
  if split_config.get_skip_targets():
@@ -615,17 +587,11 @@ class ModelDataset(torch.utils.data.Dataset):
615
587
 
616
588
  # Eliminate windows that are missing either a requisite input layer, or missing
617
589
  # all target layers.
618
- # We use only main thread if the index is set, since that can take a long time
619
- # to send to the worker threads, it may get serialized for each window.
620
590
  new_windows = []
621
- if workers == 0 or (len(windows) >= 1 and windows[0].index is not None):
591
+ if workers == 0:
622
592
  for window in windows:
623
593
  if check_window(self.inputs, window) is None:
624
594
  continue
625
- # The index may be set, but now that this check is done, from here on
626
- # we no longer need it. We set it None so that we don't end up passing
627
- # it later to the dataloader workers.
628
- window.index = None
629
595
  new_windows.append(window)
630
596
  else:
631
597
  p = multiprocessing.Pool(workers)
@@ -681,12 +647,62 @@ class ModelDataset(torch.utils.data.Dataset):
681
647
  with open(self.dataset_examples_fname, "w") as f:
682
648
  json.dump([self._serialize_item(example) for example in windows], f)
683
649
 
650
+ def _get_initial_windows(
651
+ self, split_config: SplitConfig, workers: int
652
+ ) -> list[Window]:
653
+ """Get the initial windows before input layer filtering.
654
+
655
+ The windows are filtered based on configured window names, groups, and tags.
656
+
657
+ This is a helper for the init function.
658
+
659
+ Args:
660
+ split_config: the split configuration.
661
+ workers: number of worker processes.
662
+
663
+ Returns:
664
+ list of windows from the dataset after applying the aforementioned filters.
665
+ """
666
+ # Load windows from dataset.
667
+ # If the window storage is FileWindowStorage, we pass the workers/show_progress arguments.
668
+ kwargs: dict[str, Any] = {}
669
+ if isinstance(self.dataset.storage, FileWindowStorage):
670
+ kwargs["workers"] = workers
671
+ kwargs["show_progress"] = True
672
+ # We also add the name/group filters to the kwargs.
673
+ if split_config.names:
674
+ kwargs["names"] = split_config.names
675
+ if split_config.groups:
676
+ kwargs["groups"] = split_config.groups
677
+
678
+ windows = self.dataset.load_windows(**kwargs)
679
+
680
+ # Filter by tags (if provided) using the window.options.
681
+ if split_config.tags:
682
+ new_windows = []
683
+ num_removed: dict[str, int] = {}
684
+ for window in windows:
685
+ for k, v in split_config.tags.items():
686
+ if k not in window.options or (v and window.options[k] != v):
687
+ num_removed[k] = num_removed.get(k, 0) + 1
688
+ break
689
+ else:
690
+ new_windows.append(window)
691
+ logger.info(
692
+ f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
693
+ )
694
+ for k, v in num_removed.items():
695
+ logger.info(f"Removed {v} windows due to tag {k}")
696
+ windows = new_windows
697
+
698
+ return windows
699
+
684
700
  def _serialize_item(self, example: Window) -> dict[str, Any]:
685
701
  return example.get_metadata()
686
702
 
687
703
  def _deserialize_item(self, d: dict[str, Any]) -> Window:
688
704
  return Window.from_metadata(
689
- Window.get_window_root(self.dataset.path, d["group"], d["name"]),
705
+ self.dataset.storage,
690
706
  d,
691
707
  )
692
708
 
@@ -713,7 +729,7 @@ class ModelDataset(torch.utils.data.Dataset):
713
729
 
714
730
  def get_raw_inputs(
715
731
  self, idx: int
716
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
732
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
717
733
  """Get the raw inputs and base metadata for this example.
718
734
 
719
735
  This is the raster or vector data before being processed by the Task. So it
@@ -775,21 +791,23 @@ class ModelDataset(torch.utils.data.Dataset):
775
791
  if data_input.passthrough:
776
792
  passthrough_inputs[name] = raw_inputs[name]
777
793
 
778
- metadata = {
779
- "group": window.group,
780
- "window_name": window.name,
781
- "window_bounds": window.bounds,
782
- "bounds": bounds,
783
- "time_range": window.time_range,
784
- "projection": window.projection,
785
- "dataset_source": self.name,
786
- }
794
+ metadata = SampleMetadata(
795
+ window_group=window.group,
796
+ window_name=window.name,
797
+ window_bounds=window.bounds,
798
+ patch_bounds=bounds,
799
+ patch_idx=0,
800
+ num_patches_in_window=1,
801
+ time_range=window.time_range,
802
+ projection=window.projection,
803
+ dataset_source=self.name,
804
+ )
787
805
 
788
806
  return raw_inputs, passthrough_inputs, metadata
789
807
 
790
808
  def __getitem__(
791
809
  self, idx: int
792
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
810
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
793
811
  """Read one training example.
794
812
 
795
813
  Args:
@@ -801,8 +819,6 @@ class ModelDataset(torch.utils.data.Dataset):
801
819
  logger.debug("__getitem__ start pid=%d item_idx=%d", os.getpid(), idx)
802
820
 
803
821
  raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(idx)
804
- metadata["patch_idx"] = 0
805
- metadata["num_patches"] = 1
806
822
 
807
823
  input_dict, target_dict = self.task.process_inputs(
808
824
  raw_inputs,
@@ -811,7 +827,6 @@ class ModelDataset(torch.utils.data.Dataset):
811
827
  )
812
828
  input_dict.update(passthrough_inputs)
813
829
  input_dict, target_dict = self.transforms(input_dict, target_dict)
814
- input_dict["dataset_source"] = self.name
815
830
 
816
831
  logger.debug("__getitem__ finish pid=%d item_idx=%d", os.getpid(), idx)
817
832