rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  """SSL4EO-S12 models."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
4
  import torchvision
7
5
 
6
+ from rslearn.train.model_context import ModelContext
7
+
8
+ from .component import FeatureExtractor, FeatureMaps
9
+
8
10
 
9
- class Ssl4eoS12(torch.nn.Module):
11
+ class Ssl4eoS12(FeatureExtractor):
10
12
  """The SSL4EO-S12 family of pretrained models."""
11
13
 
12
14
  def __init__(
@@ -74,19 +76,22 @@ class Ssl4eoS12(torch.nn.Module):
74
76
 
75
77
  def forward(
76
78
  self,
77
- inputs: list[dict[str, Any]],
78
- ) -> list[torch.Tensor]:
79
+ context: ModelContext,
80
+ ) -> FeatureMaps:
79
81
  """Compute outputs from the backbone.
80
82
 
81
83
  If output_layers is set, then the outputs are multi-scale feature maps;
82
84
  otherwise, the model is being used for classification so the outputs are class
83
85
  probabilities and the loss.
84
86
 
85
- Inputs:
86
- inputs: input dicts that must include "image" key containing the image to
87
- process.
87
+ Args:
88
+ context: the model context. Input dicts must include "image" key containing
89
+ the images to process.
90
+
91
+ Returns:
92
+ feature maps computed by the pre-trained model.
88
93
  """
89
- x = torch.stack([inp["image"] for inp in inputs], dim=0)
94
+ x = torch.stack([inp["image"] for inp in context.inputs], dim=0)
90
95
  x = self.model.conv1(x)
91
96
  x = self.model.bn1(x)
92
97
  x = self.model.relu(x)
@@ -97,4 +102,4 @@ class Ssl4eoS12(torch.nn.Module):
97
102
  layer3 = self.model.layer3(layer2)
98
103
  layer4 = self.model.layer4(layer3)
99
104
  all_features = [layer1, layer2, layer3, layer4]
100
- return [all_features[idx] for idx in self.output_layers]
105
+ return FeatureMaps([all_features[idx] for idx in self.output_layers])
rslearn/models/swin.py CHANGED
@@ -1,7 +1,5 @@
1
1
  """Swin Transformer."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
4
  import torchvision
7
5
  from torchvision.models.swin_transformer import (
@@ -13,8 +11,12 @@ from torchvision.models.swin_transformer import (
13
11
  Swin_V2_T_Weights,
14
12
  )
15
13
 
14
+ from rslearn.train.model_context import ModelContext
15
+
16
+ from .component import FeatureExtractor, FeatureMaps, FeatureVector
17
+
16
18
 
17
- class Swin(torch.nn.Module):
19
+ class Swin(FeatureExtractor):
18
20
  """A Swin Transformer model.
19
21
 
20
22
  It can either be used stand-alone for classification, or as a feature extractor in
@@ -34,9 +36,12 @@ class Swin(torch.nn.Module):
34
36
  Args:
35
37
  arch: the architecture, e.g. "swin_v2_b" (default) or "swin_t"
36
38
  pretrained: set True to use ImageNet pre-trained weights
37
- input_channels: number of input channels (default 3)
39
+ input_channels: number of input channels (default 3). If not 3, the first
40
+ layer is updated and will be randomly initialized even if pretrained is
41
+ set.
38
42
  output_layers: list of layers to output, default use as classification
39
- model. For feature extraction, [1, 3, 5, 7] is recommended.
43
+ model (output FeatureVector). For feature extraction, [1, 3, 5, 7] is
44
+ recommended.
40
45
  num_outputs: number of output logits, defaults to 1000 which matches the
41
46
  pretrained models.
42
47
  """
@@ -130,19 +135,23 @@ class Swin(torch.nn.Module):
130
135
 
131
136
  def forward(
132
137
  self,
133
- inputs: list[dict[str, Any]],
134
- ) -> list[torch.Tensor]:
138
+ context: ModelContext,
139
+ ) -> FeatureVector | FeatureMaps:
135
140
  """Compute outputs from the backbone.
136
141
 
137
142
  If output_layers is set, then the outputs are multi-scale feature maps;
138
143
  otherwise, the model is being used for classification so the outputs are class
139
144
  probabilities and the loss.
140
145
 
141
- Inputs:
142
- inputs: input dicts that must include "image" key containing the image to
143
- process.
146
+ Args:
147
+ context: the model context. Input dicts must include "image" key containing
148
+ the image to process.
149
+
150
+ Returns:
151
+ a FeatureVector if the configured output_layers is None, or a FeatureMaps
152
+ otherwise containing one feature map per configured output layer.
144
153
  """
145
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
154
+ images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
146
155
 
147
156
  if self.output_layers:
148
157
  layer_features = []
@@ -150,7 +159,7 @@ class Swin(torch.nn.Module):
150
159
  for layer in self.model.features:
151
160
  x = layer(x)
152
161
  layer_features.append(x.permute(0, 3, 1, 2))
153
- return [layer_features[idx] for idx in self.output_layers]
162
+ return FeatureMaps([layer_features[idx] for idx in self.output_layers])
154
163
 
155
164
  else:
156
- return self.model(images)
165
+ return FeatureVector(self.model(images))
@@ -8,8 +8,11 @@ import torch.nn.functional as F
8
8
  from einops import rearrange
9
9
  from terratorch.registry import BACKBONE_REGISTRY
10
10
 
11
+ from rslearn.train.model_context import ModelContext
11
12
  from rslearn.train.transforms.transform import Transform
12
13
 
14
+ from .component import FeatureExtractor, FeatureMaps
15
+
13
16
 
14
17
  # TerraMind v1 provides two sizes: base and large
15
18
  class TerramindSize(str, Enum):
@@ -85,7 +88,7 @@ PRETRAINED_BANDS = {
85
88
  }
86
89
 
87
90
 
88
- class Terramind(torch.nn.Module):
91
+ class Terramind(FeatureExtractor):
89
92
  """Terramind backbones."""
90
93
 
91
94
  def __init__(
@@ -123,21 +126,25 @@ class Terramind(torch.nn.Module):
123
126
  self.modalities = modalities
124
127
  self.do_resizing = do_resizing
125
128
 
126
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
129
+ def forward(self, context: ModelContext) -> FeatureMaps:
127
130
  """Forward pass for the Terramind model.
128
131
 
129
132
  Args:
130
- inputs: input dicts that must include modalities as keys which are defined in the self.modalities list
133
+ context: the model context. Input dicts must include modalities as keys
134
+ which are defined in the self.modalities list.
131
135
 
132
136
  Returns:
133
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
137
+ a FeatureMaps with one feature map from the encoder, at 1/16 of the input
138
+ resolution.
134
139
  """
135
140
  model_inputs = {}
136
141
  for modality in self.modalities:
137
142
  # We assume the all the inputs include the same modalities
138
- if modality not in inputs[0]:
143
+ if modality not in context.inputs[0]:
139
144
  continue
140
- cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
145
+ cur = torch.stack(
146
+ [inp[modality] for inp in context.inputs], dim=0
147
+ ) # (B, C, H, W)
141
148
  if self.do_resizing and (
142
149
  cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
143
150
  ):
@@ -159,7 +166,17 @@ class Terramind(torch.nn.Module):
159
166
  image_features = self.model(model_inputs)[-1]
160
167
  batch_size, num_patches, _ = image_features.shape
161
168
  height, width = int(num_patches**0.5), int(num_patches**0.5)
162
- return [rearrange(image_features, "b (h w) d -> b d h w", h=height, w=width)]
169
+ return FeatureMaps(
170
+ [
171
+ rearrange(
172
+ image_features,
173
+ "b (h w) d -> b d h w",
174
+ b=batch_size,
175
+ h=height,
176
+ w=width,
177
+ )
178
+ ]
179
+ )
163
180
 
164
181
  def get_backbone_channels(self) -> list:
165
182
  """Returns the output channels of this model when used as a backbone.
rslearn/models/trunk.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
7
 
8
8
  from rslearn.log_utils import get_logger
9
9
  from rslearn.models.task_embedding import BaseTaskEmbedding
10
+ from rslearn.train.model_context import ModelOutput
10
11
 
11
12
  logger = get_logger(__name__)
12
13
 
@@ -32,10 +33,11 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
32
33
  dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
33
34
  and optionally other keys.
34
35
  """
36
+ raise NotImplementedError
35
37
 
36
38
  @abstractmethod
37
39
  def apply_auxiliary_losses(
38
- self, trunk_out: dict[str, Any], outs: dict[str, Any]
40
+ self, trunk_out: dict[str, Any], outs: ModelOutput
39
41
  ) -> None:
40
42
  """Apply auxiliary losses in-place.
41
43
 
@@ -43,6 +45,7 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
43
45
  trunk_out: The output of the trunk.
44
46
  outs: The output of the decoders, with key "loss_dict" containing the losses.
45
47
  """
48
+ raise NotImplementedError
46
49
 
47
50
 
48
51
  class DecoderTrunk(torch.nn.Module):
@@ -122,7 +125,7 @@ class DecoderTrunk(torch.nn.Module):
122
125
  return out
123
126
 
124
127
  def apply_auxiliary_losses(
125
- self, trunk_out: dict[str, Any], outs: dict[str, Any]
128
+ self, trunk_out: dict[str, Any], outs: ModelOutput
126
129
  ) -> None:
127
130
  """Apply auxiliary losses in-place.
128
131
 
@@ -130,7 +133,7 @@ class DecoderTrunk(torch.nn.Module):
130
133
 
131
134
  Args:
132
135
  trunk_out: The output of the trunk.
133
- outs: The output of the decoders, with key "loss_dict" containing the losses.
136
+ outs: The output of the decoders.
134
137
  """
135
138
  for layer in self.layers:
136
139
  layer.apply_auxiliary_losses(trunk_out, outs)
rslearn/models/unet.py CHANGED
@@ -5,8 +5,15 @@ from typing import Any
5
5
  import torch
6
6
  import torch.nn.functional as F
7
7
 
8
+ from rslearn.train.model_context import ModelContext
8
9
 
9
- class UNetDecoder(torch.nn.Module):
10
+ from .component import (
11
+ FeatureMaps,
12
+ IntermediateComponent,
13
+ )
14
+
15
+
16
+ class UNetDecoder(IntermediateComponent):
10
17
  """UNet-style decoder.
11
18
 
12
19
  It inputs multi-scale features. Starting from last (lowest resolution) feature map,
@@ -143,23 +150,25 @@ class UNetDecoder(torch.nn.Module):
143
150
  align_corners=False,
144
151
  )
145
152
 
146
- def forward(
147
- self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]
148
- ) -> torch.Tensor:
153
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
149
154
  """Compute output from multi-scale feature map.
150
155
 
151
156
  Args:
152
- in_features: list of feature maps at different resolutions.
153
- inputs: original inputs (ignored).
157
+ intermediates: the output from the previous model component, which must be a FeatureMaps.
158
+ context: the model context.
154
159
 
155
160
  Returns:
156
- output image
161
+ output FeatureMaps consisting of one map. The embedding size is equal to the
162
+ configured out_channels.
157
163
  """
164
+ if not isinstance(intermediates, FeatureMaps):
165
+ raise ValueError("input to UNetDecoder must be a FeatureMaps")
166
+
158
167
  # Reverse the features since we will pass them in from lowest resolution to highest.
159
- in_features = list(reversed(in_features))
168
+ in_features = list(reversed(intermediates.feature_maps))
160
169
  cur_features = self.layers[0](in_features[0])
161
170
  for in_feat, layer in zip(in_features[1:], self.layers[1:]):
162
171
  cur_features = layer(torch.cat([cur_features, in_feat], dim=1))
163
172
  if self.original_size_to_interpolate is not None:
164
173
  cur_features = self._resize(cur_features)
165
- return cur_features
174
+ return FeatureMaps([cur_features])
@@ -1,9 +1,18 @@
1
1
  """An upsampling layer."""
2
2
 
3
+ from typing import Any
4
+
3
5
  import torch
4
6
 
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import (
10
+ FeatureMaps,
11
+ IntermediateComponent,
12
+ )
5
13
 
6
- class Upsample(torch.nn.Module):
14
+
15
+ class Upsample(IntermediateComponent):
7
16
  """Upsamples each input feature map by the same factor."""
8
17
 
9
18
  def __init__(
@@ -20,16 +29,20 @@ class Upsample(torch.nn.Module):
20
29
  super().__init__()
21
30
  self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
22
31
 
23
- def forward(
24
- self, features: list[torch.Tensor], inputs: list[torch.Tensor]
25
- ) -> list[torch.Tensor]:
26
- """Upsample each feature map.
32
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
33
+ """Upsample each feature map by scale_factor.
27
34
 
28
35
  Args:
29
- features: list of feature maps at different resolutions.
30
- inputs: original inputs (ignored).
36
+ intermediates: the output from the previous component, which must be a FeatureMaps.
37
+ context: the model context.
31
38
 
32
39
  Returns:
33
- upsampled feature maps
40
+ upsampled feature maps.
34
41
  """
35
- return [self.layer(feat_map) for feat_map in features]
42
+ if not isinstance(intermediates, FeatureMaps):
43
+ raise ValueError("input to Upsample must be a FeatureMaps")
44
+
45
+ upsampled_feat_maps = [
46
+ self.layer(feat_map) for feat_map in intermediates.feature_maps
47
+ ]
48
+ return FeatureMaps(upsampled_feat_maps)
@@ -2,13 +2,15 @@
2
2
 
3
3
  import itertools
4
4
  from collections.abc import Iterable, Iterator
5
+ from dataclasses import replace
5
6
  from typing import Any
6
7
 
7
8
  import shapely
8
9
  import torch
9
10
 
10
11
  from rslearn.dataset import Window
11
- from rslearn.train.dataset import ModelDataset
12
+ from rslearn.train.dataset import DataInput, ModelDataset
13
+ from rslearn.train.model_context import SampleMetadata
12
14
  from rslearn.utils.geometry import PixelBounds, STGeometry
13
15
 
14
16
 
@@ -32,22 +34,28 @@ def get_window_patch_options(
32
34
  bottommost patches may extend beyond the provided bounds.
33
35
  """
34
36
  # We stride the patches by patch_size - overlap_size until the last patch.
37
+ # We handle the first patch with a special case to ensure it is always used.
35
38
  # We handle the last patch with a special case to ensure it does not exceed the
36
39
  # window bounds. Instead, it may overlap the previous patch.
37
- cols = list(
40
+ cols = [bounds[0]] + list(
38
41
  range(
39
- bounds[0],
42
+ bounds[0] + patch_size[0],
40
43
  bounds[2] - patch_size[0],
41
44
  patch_size[0] - overlap_size[0],
42
45
  )
43
- ) + [bounds[2] - patch_size[0]]
44
- rows = list(
46
+ )
47
+ rows = [bounds[1]] + list(
45
48
  range(
46
- bounds[1],
49
+ bounds[1] + patch_size[1],
47
50
  bounds[3] - patch_size[1],
48
51
  patch_size[1] - overlap_size[1],
49
52
  )
50
- ) + [bounds[3] - patch_size[1]]
53
+ )
54
+ # Add last patches only if the input is larger than one patch.
55
+ if bounds[2] - patch_size[0] > bounds[0]:
56
+ cols.append(bounds[2] - patch_size[0])
57
+ if bounds[3] - patch_size[1] > bounds[1]:
58
+ rows.append(bounds[3] - patch_size[1])
51
59
 
52
60
  patch_bounds: list[PixelBounds] = []
53
61
  for col in cols:
@@ -60,13 +68,17 @@ def pad_slice_protect(
60
68
  raw_inputs: dict[str, Any],
61
69
  passthrough_inputs: dict[str, Any],
62
70
  patch_size: tuple[int, int],
71
+ inputs: dict[str, DataInput],
63
72
  ) -> tuple[dict[str, Any], dict[str, Any]]:
64
73
  """Pad tensors in-place by patch size to protect slicing near right/bottom edges.
65
74
 
75
+ The padding is scaled based on each input's resolution_factor.
76
+
66
77
  Args:
67
78
  raw_inputs: the raw inputs to pad.
68
79
  passthrough_inputs: the passthrough inputs to pad.
69
- patch_size: the size of the patches to extract.
80
+ patch_size: the size of the patches to extract (at window resolution).
81
+ inputs: the DataInput definitions, used to get resolution_factor per input.
70
82
 
71
83
  Returns:
72
84
  a tuple of (raw_inputs, passthrough_inputs).
@@ -75,8 +87,14 @@ def pad_slice_protect(
75
87
  for input_name, value in list(d.items()):
76
88
  if not isinstance(value, torch.Tensor):
77
89
  continue
90
+ # Get resolution scale for this input
91
+ rf = inputs[input_name].resolution_factor
92
+ scale = rf.numerator / rf.denominator
93
+ # Scale the padding amount
94
+ scaled_pad_x = int(patch_size[0] * scale)
95
+ scaled_pad_y = int(patch_size[1] * scale)
78
96
  d[input_name] = torch.nn.functional.pad(
79
- value, pad=(0, patch_size[0], 0, patch_size[1])
97
+ value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
80
98
  )
81
99
  return raw_inputs, passthrough_inputs
82
100
 
@@ -121,6 +139,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
121
139
  self.rank = rank
122
140
  self.world_size = world_size
123
141
  self.windows = self.dataset.get_dataset_examples()
142
+ self.inputs = dataset.inputs
124
143
 
125
144
  def set_name(self, name: str) -> None:
126
145
  """Sets dataset name.
@@ -218,7 +237,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
218
237
 
219
238
  def __iter__(
220
239
  self,
221
- ) -> Iterator[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]]:
240
+ ) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
222
241
  """Iterate over all patches in each element of the underlying ModelDataset."""
223
242
  # Iterate over the window IDs until we have returned enough samples.
224
243
  window_ids, num_samples_needed = self._get_worker_iteration_data()
@@ -229,12 +248,14 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
229
248
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
230
249
  window_id
231
250
  )
232
- bounds = metadata["bounds"]
251
+ bounds = metadata.patch_bounds
233
252
 
234
253
  # For simplicity, pad tensors by patch size to ensure that any patch bounds
235
254
  # extending outside the window bounds will not have issues when we slice
236
- # the tensors later.
237
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
255
+ # the tensors later. Padding is scaled per-input based on resolution_factor.
256
+ pad_slice_protect(
257
+ raw_inputs, passthrough_inputs, self.patch_size, self.inputs
258
+ )
238
259
 
239
260
  # Now iterate over the patches and extract/yield the crops.
240
261
  # Note that, in case user is leveraging RslearnWriter, it is important that
@@ -244,7 +265,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
244
265
  )
245
266
  for patch_idx, patch_bounds in enumerate(patches):
246
267
  cur_geom = STGeometry(
247
- metadata["projection"], shapely.box(*patch_bounds), None
268
+ metadata.projection, shapely.box(*patch_bounds), None
248
269
  )
249
270
  start_offset = (
250
271
  patch_bounds[0] - bounds[0],
@@ -256,15 +277,28 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
256
277
  )
257
278
 
258
279
  # Define a helper function to handle each input dict.
280
+ # Crop coordinates are scaled based on each input's resolution_factor.
259
281
  def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
260
282
  cropped = {}
261
283
  for input_name, value in d.items():
262
284
  if isinstance(value, torch.Tensor):
263
- # Crop the CHW tensor.
285
+ # Get resolution scale for this input
286
+ rf = self.inputs[input_name].resolution_factor
287
+ scale = rf.numerator / rf.denominator
288
+ # Scale the crop coordinates
289
+ scaled_start = (
290
+ int(start_offset[0] * scale),
291
+ int(start_offset[1] * scale),
292
+ )
293
+ scaled_end = (
294
+ int(end_offset[0] * scale),
295
+ int(end_offset[1] * scale),
296
+ )
297
+ # Crop the CHW tensor with scaled coordinates.
264
298
  cropped[input_name] = value[
265
299
  :,
266
- start_offset[1] : end_offset[1],
267
- start_offset[0] : end_offset[0],
300
+ scaled_start[1] : scaled_end[1],
301
+ scaled_start[0] : scaled_end[0],
268
302
  ].clone()
269
303
  elif isinstance(value, list):
270
304
  cropped[input_name] = [
@@ -282,10 +316,12 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
282
316
  cur_passthrough_inputs = crop_input_dict(passthrough_inputs)
283
317
 
284
318
  # Adjust the metadata as well.
285
- cur_metadata = metadata.copy()
286
- cur_metadata["bounds"] = patch_bounds
287
- cur_metadata["patch_idx"] = patch_idx
288
- cur_metadata["num_patches"] = len(patches)
319
+ cur_metadata = replace(
320
+ metadata,
321
+ patch_bounds=patch_bounds,
322
+ patch_idx=patch_idx,
323
+ num_patches_in_window=len(patches),
324
+ )
289
325
 
290
326
  # Now we can compute input and target dicts via the task.
291
327
  input_dict, target_dict = self.dataset.task.process_inputs(
@@ -297,7 +333,6 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
297
333
  input_dict, target_dict = self.dataset.transforms(
298
334
  input_dict, target_dict
299
335
  )
300
- input_dict["dataset_source"] = self.dataset.name
301
336
 
302
337
  if num_samples_returned < num_samples_needed:
303
338
  yield input_dict, target_dict, cur_metadata
@@ -345,8 +380,9 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
345
380
  round(self.patch_size[1] * overlap_ratio),
346
381
  )
347
382
  self.windows = self.dataset.get_dataset_examples()
383
+ self.inputs = dataset.inputs
348
384
  self.window_cache: dict[
349
- int, tuple[dict[str, Any], dict[str, Any], dict[str, Any]]
385
+ int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
350
386
  ] = {}
351
387
 
352
388
  # Precompute the batch boundaries for each window
@@ -360,7 +396,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
360
396
 
361
397
  def get_raw_inputs(
362
398
  self, index: int
363
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
399
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
364
400
  """Get the raw inputs for a single patch. Retrieve from cache if possible.
365
401
 
366
402
  Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
@@ -375,26 +411,41 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
375
411
  return self.window_cache[index]
376
412
 
377
413
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
378
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
414
+ pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
379
415
 
380
416
  self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
381
417
  return self.window_cache[index]
382
418
 
383
- @staticmethod
384
419
  def _crop_input_dict(
420
+ self,
385
421
  d: dict[str, Any],
386
422
  start_offset: tuple[int, int],
387
423
  end_offset: tuple[int, int],
388
424
  cur_geom: STGeometry,
389
425
  ) -> dict[str, Any]:
390
- """Crop a dictionary of inputs to the given bounds."""
426
+ """Crop a dictionary of inputs to the given bounds.
427
+
428
+ Crop coordinates are scaled based on each input's resolution_factor.
429
+ """
391
430
  cropped = {}
392
431
  for input_name, value in d.items():
393
432
  if isinstance(value, torch.Tensor):
433
+ # Get resolution scale for this input
434
+ rf = self.inputs[input_name].resolution_factor
435
+ scale = rf.numerator / rf.denominator
436
+ # Scale the crop coordinates
437
+ scaled_start = (
438
+ int(start_offset[0] * scale),
439
+ int(start_offset[1] * scale),
440
+ )
441
+ scaled_end = (
442
+ int(end_offset[0] * scale),
443
+ int(end_offset[1] * scale),
444
+ )
394
445
  cropped[input_name] = value[
395
446
  :,
396
- start_offset[1] : end_offset[1],
397
- start_offset[0] : end_offset[0],
447
+ scaled_start[1] : scaled_end[1],
448
+ scaled_start[0] : scaled_end[0],
398
449
  ].clone()
399
450
  elif isinstance(value, list):
400
451
  cropped[input_name] = [
@@ -410,13 +461,13 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
410
461
 
411
462
  def __getitem__(
412
463
  self, index: int
413
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
464
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
414
465
  """Return (input_dict, target_dict, metadata) for a single flattened patch."""
415
466
  (window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
416
467
  raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
417
- bounds = metadata["bounds"]
468
+ bounds = metadata.patch_bounds
418
469
 
419
- cur_geom = STGeometry(metadata["projection"], shapely.box(*patch_bounds), None)
470
+ cur_geom = STGeometry(metadata.projection, shapely.box(*patch_bounds), None)
420
471
  start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
421
472
  end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
422
473
 
@@ -428,10 +479,12 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
428
479
  )
429
480
 
430
481
  # Adjust the metadata as well.
431
- cur_metadata = metadata.copy()
432
- cur_metadata["bounds"] = patch_bounds
433
- cur_metadata["patch_idx"] = patch_idx
434
- cur_metadata["num_patches"] = num_patches
482
+ cur_metadata = replace(
483
+ metadata,
484
+ patch_bounds=patch_bounds,
485
+ patch_idx=patch_idx,
486
+ num_patches_in_window=num_patches,
487
+ )
435
488
 
436
489
  # Now we can compute input and target dicts via the task.
437
490
  input_dict, target_dict = self.dataset.task.process_inputs(
@@ -441,7 +494,6 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
441
494
  )
442
495
  input_dict.update(cur_passthrough_inputs)
443
496
  input_dict, target_dict = self.dataset.transforms(input_dict, target_dict)
444
- input_dict["dataset_source"] = self.dataset.name
445
497
 
446
498
  return input_dict, target_dict, cur_metadata
447
499