rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/dataset.py +15 -16
  3. rslearn/dataset/dataset.py +28 -22
  4. rslearn/lightning_cli.py +22 -11
  5. rslearn/main.py +1 -1
  6. rslearn/models/anysat.py +35 -33
  7. rslearn/models/attention_pooling.py +177 -0
  8. rslearn/models/clip.py +5 -2
  9. rslearn/models/component.py +12 -0
  10. rslearn/models/croma.py +11 -3
  11. rslearn/models/dinov3.py +2 -1
  12. rslearn/models/faster_rcnn.py +2 -1
  13. rslearn/models/galileo/galileo.py +58 -31
  14. rslearn/models/module_wrapper.py +6 -1
  15. rslearn/models/molmo.py +4 -2
  16. rslearn/models/olmoearth_pretrain/model.py +206 -51
  17. rslearn/models/olmoearth_pretrain/norm.py +5 -3
  18. rslearn/models/panopticon.py +3 -1
  19. rslearn/models/presto/presto.py +45 -15
  20. rslearn/models/prithvi.py +9 -7
  21. rslearn/models/sam2_enc.py +3 -1
  22. rslearn/models/satlaspretrain.py +4 -1
  23. rslearn/models/simple_time_series.py +43 -17
  24. rslearn/models/ssl4eo_s12.py +19 -14
  25. rslearn/models/swin.py +3 -1
  26. rslearn/models/terramind.py +5 -4
  27. rslearn/train/all_patches_dataset.py +96 -28
  28. rslearn/train/dataset.py +102 -53
  29. rslearn/train/model_context.py +35 -1
  30. rslearn/train/scheduler.py +15 -0
  31. rslearn/train/tasks/classification.py +8 -2
  32. rslearn/train/tasks/detection.py +3 -2
  33. rslearn/train/tasks/multi_task.py +2 -3
  34. rslearn/train/tasks/per_pixel_regression.py +14 -5
  35. rslearn/train/tasks/regression.py +8 -2
  36. rslearn/train/tasks/segmentation.py +13 -4
  37. rslearn/train/tasks/task.py +2 -2
  38. rslearn/train/transforms/concatenate.py +45 -5
  39. rslearn/train/transforms/crop.py +22 -8
  40. rslearn/train/transforms/flip.py +13 -5
  41. rslearn/train/transforms/mask.py +11 -2
  42. rslearn/train/transforms/normalize.py +46 -15
  43. rslearn/train/transforms/pad.py +15 -3
  44. rslearn/train/transforms/resize.py +83 -0
  45. rslearn/train/transforms/select_bands.py +11 -2
  46. rslearn/train/transforms/sentinel1.py +18 -3
  47. rslearn/utils/geometry.py +73 -0
  48. rslearn/utils/jsonargparse.py +66 -0
  49. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
  50. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
  51. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
  52. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
  53. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
  54. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
  55. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  import logging
4
4
  import tempfile
5
+ from datetime import datetime
5
6
 
6
7
  import torch
7
8
  from einops import rearrange, repeat
@@ -118,21 +119,21 @@ class Presto(FeatureExtractor):
118
119
  of each timestep for that pixel
119
120
  """
120
121
  bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
121
- hs = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
122
- ws = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
122
+ ts = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
123
+ hs = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
124
+ ws = [x.shape[4] for x in [s1, s2, era5, srtm] if x is not None]
123
125
  devices = [x.device for x in [s1, s2, era5, srtm] if x is not None]
124
126
 
125
127
  assert len(set(bs)) == 1
126
128
  assert len(set(hs)) == 1
127
129
  assert len(set(ws)) == 1
128
130
  assert len(set(devices)) == 1
129
- b, h, w, device = bs[0], hs[0], ws[0], devices[0]
130
-
131
+ assert len(set(ts)) == 1
132
+ b, h, w, t, device = bs[0], hs[0], ws[0], ts[0], devices[0]
131
133
  # these values will be initialized as
132
134
  # we iterate through the data
133
135
  x: torch.Tensor | None = None
134
136
  mask: torch.Tensor | None = None
135
- t: int | None = None
136
137
 
137
138
  for band_group in [
138
139
  (s1, s1_bands),
@@ -146,14 +147,7 @@ class Presto(FeatureExtractor):
146
147
  else:
147
148
  continue
148
149
 
149
- m_t = data.shape[1] // len(input_bands)
150
- if t is None:
151
- t = m_t
152
- else:
153
- if t != m_t:
154
- raise ValueError("inconsistent values for t")
155
-
156
- data = rearrange(data, "b (t c) h w -> b t h w c", t=m_t)
150
+ data = rearrange(data, "b c t h w -> b t h w c")
157
151
  if x is None:
158
152
  x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
159
153
  if mask is None:
@@ -184,6 +178,23 @@ class Presto(FeatureExtractor):
184
178
  x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
185
179
  return x, mask, dynamic_world.long(), months.long()
186
180
 
181
+ @staticmethod
182
+ def time_ranges_to_timestamps(
183
+ time_ranges: list[tuple[datetime, datetime]],
184
+ device: torch.device,
185
+ ) -> torch.Tensor:
186
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by Presto.
187
+
188
+ Presto only uses the month associated with each timestamp, so we take the midpoint
189
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
190
+ time so that start_time == end_time == mid_time.
191
+ """
192
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
193
+ # months are indexed 0-11
194
+ return torch.tensor(
195
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
196
+ )
197
+
187
198
  def forward(self, context: ModelContext) -> FeatureMaps:
188
199
  """Compute feature maps from the Presto backbone.
189
200
 
@@ -194,17 +205,36 @@ class Presto(FeatureExtractor):
194
205
  a FeatureMaps with one feature map that is at the same resolution as the
195
206
  input (since Presto operates per-pixel).
196
207
  """
208
+ time_modalities = ["s1", "s2", "era5"]
197
209
  stacked_inputs = {}
198
210
  latlons: torch.Tensor | None = None
211
+ months: torch.Tensor | None = None
199
212
  for key in context.inputs[0].keys():
200
213
  # assume all the keys in an input are consistent
201
214
  if key in self.input_keys:
202
215
  if key == "latlon":
203
- latlons = torch.stack([inp[key] for inp in context.inputs], dim=0)
216
+ latlons = torch.stack(
217
+ [inp[key].image for inp in context.inputs], dim=0
218
+ )
204
219
  else:
205
220
  stacked_inputs[key] = torch.stack(
206
- [inp[key] for inp in context.inputs], dim=0
221
+ [inp[key].image for inp in context.inputs], dim=0
207
222
  )
223
+ if key in time_modalities:
224
+ if months is None:
225
+ if context.inputs[0][key].timestamps is not None:
226
+ months = torch.stack(
227
+ [
228
+ self.time_ranges_to_timestamps(
229
+ inp[key].timestamps, # type: ignore
230
+ device=stacked_inputs[key].device,
231
+ )
232
+ for inp in context.inputs
233
+ ],
234
+ dim=0,
235
+ )
236
+ if months is not None:
237
+ stacked_inputs["months"] = months
208
238
 
209
239
  (
210
240
  x,
rslearn/models/prithvi.py CHANGED
@@ -144,13 +144,15 @@ class PrithviV2(FeatureExtractor):
144
144
  """Process individual modality data.
145
145
 
146
146
  Args:
147
- data: Input tensor of shape [B, C, H, W]
147
+ data: Input tensor of shape [B, C, T, H, W]
148
148
 
149
149
  Returns:
150
- list of tensors of shape [B, C, H, W]
150
+ list of tensors of shape [B, C, T, H, W]
151
151
  """
152
152
  # Get original dimensions
153
- original_height = data.shape[2]
153
+ B, C, T, H, W = data.shape
154
+ data = rearrange(data, "b c t h w -> b (c t) h w")
155
+ original_height = H
154
156
  new_height = self.patch_size if original_height == 1 else self.image_resolution
155
157
  data = F.interpolate(
156
158
  data,
@@ -158,6 +160,7 @@ class PrithviV2(FeatureExtractor):
158
160
  mode="bilinear",
159
161
  align_corners=False,
160
162
  )
163
+ data = rearrange(data, "b (c t) h w -> b c t h w", c=C, t=T)
161
164
  return data
162
165
 
163
166
  def forward(self, context: ModelContext) -> FeatureMaps:
@@ -171,17 +174,16 @@ class PrithviV2(FeatureExtractor):
171
174
  a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
172
175
  feature maps across the 11 transformer blocks.
173
176
  """
174
- x = torch.stack([inp[self.INPUT_KEY] for inp in context.inputs], dim=0)
177
+ # x has shape BCTHW
178
+ x = torch.stack([inp[self.INPUT_KEY].image for inp in context.inputs], dim=0)
175
179
  x = self._resize_data(x)
176
- num_timesteps = x.shape[1] // len(self.bands)
177
- x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
178
180
  features = self.model.encoder.forward_features(x)
179
181
  # prepare_features_for_image_model was slightly modified since we already
180
182
  # know the number of timesteps and don't need to recompute it.
181
183
  # in addition we average along the time dimension (instead of concatenating)
182
184
  # to keep the embeddings reasonably sized.
183
185
  result = self.model.encoder.prepare_features_for_image_model(
184
- features, num_timesteps
186
+ features, x.shape[2]
185
187
  )
186
188
  return FeatureMaps([torch.cat(result, dim=1)])
187
189
 
@@ -95,7 +95,9 @@ class SAM2Encoder(FeatureExtractor):
95
95
  Returns:
96
96
  feature maps from the encoder.
97
97
  """
98
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
98
+ images = torch.stack(
99
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
100
+ )
99
101
  features = self.encoder(images)
100
102
  return FeatureMaps(features)
101
103
 
@@ -76,7 +76,10 @@ class SatlasPretrain(FeatureExtractor):
76
76
  Returns:
77
77
  multi-resolution feature maps computed by the model.
78
78
  """
79
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
79
+ # take the first (assumed to be only) timestep
80
+ images = torch.stack(
81
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
82
+ )
80
83
  feature_maps = self.model(self.maybe_resize(images))
81
84
  return FeatureMaps(feature_maps)
82
85
 
@@ -3,8 +3,9 @@
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
+ from einops import rearrange
6
7
 
7
- from rslearn.train.model_context import ModelContext
8
+ from rslearn.train.model_context import ModelContext, RasterImage
8
9
 
9
10
  from .component import FeatureExtractor, FeatureMaps
10
11
 
@@ -163,23 +164,44 @@ class SimpleTimeSeries(FeatureExtractor):
163
164
 
164
165
  def _get_batched_images(
165
166
  self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
166
- ) -> torch.Tensor:
167
+ ) -> list[RasterImage]:
167
168
  """Collect and reshape images across input dicts.
168
169
 
169
170
  The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
170
171
  the forward pass of a per-image (unitemporal) model.
171
172
  """
172
173
  images = torch.stack(
173
- [input_dict[image_key] for input_dict in input_dicts], dim=0
174
+ [input_dict[image_key].image for input_dict in input_dicts], dim=0
175
+ ) # B, C, T, H, W
176
+ timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
177
+ # if image channels is not equal to the actual number of channels, then
178
+ # then every N images should be batched together. For example, if the
179
+ # number of input channels c == 2, and image_channels == 4, then we
180
+ # want to pass 2 timesteps to the model.
181
+ # TODO is probably to make this behaviour clearer but lets leave it like
182
+ # this for now to not break things.
183
+ num_timesteps = images.shape[1] // image_channels
184
+ batched_timesteps = images.shape[2] // num_timesteps
185
+ images = rearrange(
186
+ images,
187
+ "b c (b_t k_t) h w -> (b b_t) c k_t h w",
188
+ b_t=batched_timesteps,
189
+ k_t=num_timesteps,
174
190
  )
175
- n_batch = images.shape[0]
176
- n_images = images.shape[1] // image_channels
177
- n_height = images.shape[2]
178
- n_width = images.shape[3]
179
- batched_images = images.reshape(
180
- n_batch * n_images, image_channels, n_height, n_width
181
- )
182
- return batched_images
191
+ if timestamps[0] is None:
192
+ new_timestamps = [None] * images.shape[0]
193
+ else:
194
+ # we also need to split the timestamps
195
+ new_timestamps = []
196
+ for t in timestamps:
197
+ for i in range(batched_timesteps):
198
+ new_timestamps.append(
199
+ t[i * num_timesteps : (i + 1) * num_timesteps]
200
+ )
201
+ return [
202
+ RasterImage(image=image, timestamps=timestamps)
203
+ for image, timestamps in zip(images, new_timestamps)
204
+ ] # C, T, H, W
183
205
 
184
206
  def forward(
185
207
  self,
@@ -208,8 +230,8 @@ class SimpleTimeSeries(FeatureExtractor):
208
230
 
209
231
  if batched_inputs is None:
210
232
  batched_inputs = [{} for _ in batched_images]
211
- n_images = batched_images.shape[0] // n_batch
212
- elif n_images != batched_images.shape[0] // n_batch:
233
+ n_images = len(batched_images) // n_batch
234
+ elif n_images != len(batched_images) // n_batch:
213
235
  raise ValueError(
214
236
  "expected all modalities to have the same number of timesteps"
215
237
  )
@@ -223,13 +245,18 @@ class SimpleTimeSeries(FeatureExtractor):
223
245
  context.inputs, self.image_key, self.image_channels
224
246
  )
225
247
  batched_inputs = [{self.image_key: image} for image in batched_images]
226
- n_images = batched_images.shape[0] // n_batch
248
+ n_images = len(batched_images) // n_batch
227
249
 
228
250
  assert n_images is not None
229
-
230
251
  # Now we can apply the underlying FeatureExtractor.
231
252
  # Its output must be a FeatureMaps.
232
- encoder_output = self.encoder(batched_inputs)
253
+ assert batched_inputs is not None
254
+ encoder_output = self.encoder(
255
+ ModelContext(
256
+ inputs=batched_inputs,
257
+ metadatas=context.metadatas,
258
+ )
259
+ )
233
260
  if not isinstance(encoder_output, FeatureMaps):
234
261
  raise ValueError(
235
262
  "output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
@@ -244,7 +271,6 @@ class SimpleTimeSeries(FeatureExtractor):
244
271
  )
245
272
  for feat_map in encoder_output.feature_maps
246
273
  ]
247
-
248
274
  # Groups defaults to flattening all the feature maps.
249
275
  groups = self.groups
250
276
  if not groups:
@@ -13,7 +13,7 @@ class Ssl4eoS12(FeatureExtractor):
13
13
 
14
14
  def __init__(
15
15
  self,
16
- backbone_ckpt_path: str,
16
+ backbone_ckpt_path: str | None,
17
17
  arch: str = "resnet50",
18
18
  output_layers: list[int] = [0, 1, 2, 3],
19
19
  ) -> None:
@@ -39,19 +39,22 @@ class Ssl4eoS12(FeatureExtractor):
39
39
  else:
40
40
  raise ValueError(f"unknown SSL4EO-S12 architecture {arch}")
41
41
 
42
- state_dict = torch.load(backbone_ckpt_path, weights_only=True)
43
- state_dict = state_dict["teacher"]
44
- prefix = "module.backbone."
45
- state_dict = {
46
- k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
47
- }
48
- missing_keys, unexpected_keys = self.model.load_state_dict(
49
- state_dict, strict=False
50
- )
51
- if missing_keys or unexpected_keys:
52
- print(
53
- f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
42
+ if backbone_ckpt_path is not None:
43
+ state_dict = torch.load(backbone_ckpt_path, weights_only=True)
44
+ state_dict = state_dict["teacher"]
45
+ prefix = "module.backbone."
46
+ state_dict = {
47
+ k[len(prefix) :]: v
48
+ for k, v in state_dict.items()
49
+ if k.startswith(prefix)
50
+ }
51
+ missing_keys, unexpected_keys = self.model.load_state_dict(
52
+ state_dict, strict=False
54
53
  )
54
+ if missing_keys or unexpected_keys:
55
+ print(
56
+ f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
57
+ )
55
58
 
56
59
  def get_backbone_channels(self) -> list[tuple[int, int]]:
57
60
  """Returns the output channels of this model when used as a backbone.
@@ -91,7 +94,9 @@ class Ssl4eoS12(FeatureExtractor):
91
94
  Returns:
92
95
  feature maps computed by the pre-trained model.
93
96
  """
94
- x = torch.stack([inp["image"] for inp in context.inputs], dim=0)
97
+ x = torch.stack(
98
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
99
+ )
95
100
  x = self.model.conv1(x)
96
101
  x = self.model.bn1(x)
97
102
  x = self.model.relu(x)
rslearn/models/swin.py CHANGED
@@ -151,7 +151,9 @@ class Swin(FeatureExtractor):
151
151
  a FeatureVector if the configured output_layers is None, or a FeatureMaps
152
152
  otherwise containing one feature map per configured output layer.
153
153
  """
154
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
154
+ images = torch.stack(
155
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
156
+ )
155
157
 
156
158
  if self.output_layers:
157
159
  layer_features = []
@@ -143,7 +143,8 @@ class Terramind(FeatureExtractor):
143
143
  if modality not in context.inputs[0]:
144
144
  continue
145
145
  cur = torch.stack(
146
- [inp[modality] for inp in context.inputs], dim=0
146
+ [inp[modality].single_ts_to_chw_tensor() for inp in context.inputs],
147
+ dim=0,
147
148
  ) # (B, C, H, W)
148
149
  if self.do_resizing and (
149
150
  cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
@@ -219,7 +220,7 @@ class TerramindNormalize(Transform):
219
220
  Returns:
220
221
  The normalized image.
221
222
  """
222
- images = image.float() # (C, H, W)
223
+ images = image.float() # (C, 1, H, W)
223
224
  if images.shape[0] % len(means) != 0:
224
225
  raise ValueError(
225
226
  f"the number of image channels {images.shape[0]} is not multiple of expected number of bands {len(means)}"
@@ -247,8 +248,8 @@ class TerramindNormalize(Transform):
247
248
  band_info = PRETRAINED_BANDS[modality]
248
249
  means = [band_info[band][0] for band in band_info]
249
250
  stds = [band_info[band][1] for band in band_info]
250
- input_dict[modality] = self.apply_image(
251
- input_dict[modality],
251
+ input_dict[modality].image = self.apply_image(
252
+ input_dict[modality].image,
252
253
  means,
253
254
  stds,
254
255
  )
@@ -9,8 +9,8 @@ import shapely
9
9
  import torch
10
10
 
11
11
  from rslearn.dataset import Window
12
- from rslearn.train.dataset import ModelDataset
13
- from rslearn.train.model_context import SampleMetadata
12
+ from rslearn.train.dataset import DataInput, ModelDataset
13
+ from rslearn.train.model_context import RasterImage, SampleMetadata
14
14
  from rslearn.utils.geometry import PixelBounds, STGeometry
15
15
 
16
16
 
@@ -34,22 +34,28 @@ def get_window_patch_options(
34
34
  bottommost patches may extend beyond the provided bounds.
35
35
  """
36
36
  # We stride the patches by patch_size - overlap_size until the last patch.
37
+ # We handle the first patch with a special case to ensure it is always used.
37
38
  # We handle the last patch with a special case to ensure it does not exceed the
38
39
  # window bounds. Instead, it may overlap the previous patch.
39
- cols = list(
40
+ cols = [bounds[0]] + list(
40
41
  range(
41
- bounds[0],
42
+ bounds[0] + patch_size[0],
42
43
  bounds[2] - patch_size[0],
43
44
  patch_size[0] - overlap_size[0],
44
45
  )
45
- ) + [bounds[2] - patch_size[0]]
46
- rows = list(
46
+ )
47
+ rows = [bounds[1]] + list(
47
48
  range(
48
- bounds[1],
49
+ bounds[1] + patch_size[1],
49
50
  bounds[3] - patch_size[1],
50
51
  patch_size[1] - overlap_size[1],
51
52
  )
52
- ) + [bounds[3] - patch_size[1]]
53
+ )
54
+ # Add last patches only if the input is larger than one patch.
55
+ if bounds[2] - patch_size[0] > bounds[0]:
56
+ cols.append(bounds[2] - patch_size[0])
57
+ if bounds[3] - patch_size[1] > bounds[1]:
58
+ rows.append(bounds[3] - patch_size[1])
53
59
 
54
60
  patch_bounds: list[PixelBounds] = []
55
61
  for col in cols:
@@ -62,13 +68,17 @@ def pad_slice_protect(
62
68
  raw_inputs: dict[str, Any],
63
69
  passthrough_inputs: dict[str, Any],
64
70
  patch_size: tuple[int, int],
71
+ inputs: dict[str, DataInput],
65
72
  ) -> tuple[dict[str, Any], dict[str, Any]]:
66
73
  """Pad tensors in-place by patch size to protect slicing near right/bottom edges.
67
74
 
75
+ The padding is scaled based on each input's resolution_factor.
76
+
68
77
  Args:
69
78
  raw_inputs: the raw inputs to pad.
70
79
  passthrough_inputs: the passthrough inputs to pad.
71
- patch_size: the size of the patches to extract.
80
+ patch_size: the size of the patches to extract (at window resolution).
81
+ inputs: the DataInput definitions, used to get resolution_factor per input.
72
82
 
73
83
  Returns:
74
84
  a tuple of (raw_inputs, passthrough_inputs).
@@ -77,12 +87,42 @@ def pad_slice_protect(
77
87
  for input_name, value in list(d.items()):
78
88
  if not isinstance(value, torch.Tensor):
79
89
  continue
90
+ # Get resolution scale for this input
91
+ rf = inputs[input_name].resolution_factor
92
+ scale = rf.numerator / rf.denominator
93
+ # Scale the padding amount
94
+ scaled_pad_x = int(patch_size[0] * scale)
95
+ scaled_pad_y = int(patch_size[1] * scale)
80
96
  d[input_name] = torch.nn.functional.pad(
81
- value, pad=(0, patch_size[0], 0, patch_size[1])
97
+ value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
82
98
  )
83
99
  return raw_inputs, passthrough_inputs
84
100
 
85
101
 
102
+ def crop_tensor_or_rasterimage(
103
+ x: torch.Tensor | RasterImage, start: tuple[int, int], end: tuple[int, int]
104
+ ) -> torch.Tensor | RasterImage:
105
+ """Crop a tensor or a RasterImage."""
106
+ if isinstance(x, torch.Tensor):
107
+ # Crop the CHW tensor with scaled coordinates.
108
+ return x[
109
+ :,
110
+ start[1] : end[1],
111
+ start[0] : end[0],
112
+ ].clone()
113
+ else:
114
+ # Crop the CTHW tensor with scaled coordinates.
115
+ return RasterImage(
116
+ x.image[
117
+ :,
118
+ :,
119
+ start[1] : end[1],
120
+ start[0] : end[0],
121
+ ].clone(),
122
+ x.timestamps,
123
+ )
124
+
125
+
86
126
  class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
87
127
  """This wraps a ModelDataset to iterate over all patches in that dataset.
88
128
 
@@ -123,6 +163,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
123
163
  self.rank = rank
124
164
  self.world_size = world_size
125
165
  self.windows = self.dataset.get_dataset_examples()
166
+ self.inputs = dataset.inputs
126
167
 
127
168
  def set_name(self, name: str) -> None:
128
169
  """Sets dataset name.
@@ -235,8 +276,10 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
235
276
 
236
277
  # For simplicity, pad tensors by patch size to ensure that any patch bounds
237
278
  # extending outside the window bounds will not have issues when we slice
238
- # the tensors later.
239
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
279
+ # the tensors later. Padding is scaled per-input based on resolution_factor.
280
+ pad_slice_protect(
281
+ raw_inputs, passthrough_inputs, self.patch_size, self.inputs
282
+ )
240
283
 
241
284
  # Now iterate over the patches and extract/yield the crops.
242
285
  # Note that, in case user is leveraging RslearnWriter, it is important that
@@ -258,16 +301,26 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
258
301
  )
259
302
 
260
303
  # Define a helper function to handle each input dict.
304
+ # Crop coordinates are scaled based on each input's resolution_factor.
261
305
  def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
262
306
  cropped = {}
263
307
  for input_name, value in d.items():
264
- if isinstance(value, torch.Tensor):
265
- # Crop the CHW tensor.
266
- cropped[input_name] = value[
267
- :,
268
- start_offset[1] : end_offset[1],
269
- start_offset[0] : end_offset[0],
270
- ].clone()
308
+ if isinstance(value, torch.Tensor | RasterImage):
309
+ # Get resolution scale for this input
310
+ rf = self.inputs[input_name].resolution_factor
311
+ scale = rf.numerator / rf.denominator
312
+ # Scale the crop coordinates
313
+ scaled_start = (
314
+ int(start_offset[0] * scale),
315
+ int(start_offset[1] * scale),
316
+ )
317
+ scaled_end = (
318
+ int(end_offset[0] * scale),
319
+ int(end_offset[1] * scale),
320
+ )
321
+ cropped[input_name] = crop_tensor_or_rasterimage(
322
+ value, scaled_start, scaled_end
323
+ )
271
324
  elif isinstance(value, list):
272
325
  cropped[input_name] = [
273
326
  feat
@@ -348,6 +401,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
348
401
  round(self.patch_size[1] * overlap_ratio),
349
402
  )
350
403
  self.windows = self.dataset.get_dataset_examples()
404
+ self.inputs = dataset.inputs
351
405
  self.window_cache: dict[
352
406
  int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
353
407
  ] = {}
@@ -378,27 +432,41 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
378
432
  return self.window_cache[index]
379
433
 
380
434
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
381
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
435
+ pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
382
436
 
383
437
  self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
384
438
  return self.window_cache[index]
385
439
 
386
- @staticmethod
387
440
  def _crop_input_dict(
441
+ self,
388
442
  d: dict[str, Any],
389
443
  start_offset: tuple[int, int],
390
444
  end_offset: tuple[int, int],
391
445
  cur_geom: STGeometry,
392
446
  ) -> dict[str, Any]:
393
- """Crop a dictionary of inputs to the given bounds."""
447
+ """Crop a dictionary of inputs to the given bounds.
448
+
449
+ Crop coordinates are scaled based on each input's resolution_factor.
450
+ """
394
451
  cropped = {}
395
452
  for input_name, value in d.items():
396
- if isinstance(value, torch.Tensor):
397
- cropped[input_name] = value[
398
- :,
399
- start_offset[1] : end_offset[1],
400
- start_offset[0] : end_offset[0],
401
- ].clone()
453
+ if isinstance(value, torch.Tensor | RasterImage):
454
+ # Get resolution scale for this input
455
+ rf = self.inputs[input_name].resolution_factor
456
+ scale = rf.numerator / rf.denominator
457
+ # Scale the crop coordinates
458
+ scaled_start = (
459
+ int(start_offset[0] * scale),
460
+ int(start_offset[1] * scale),
461
+ )
462
+ scaled_end = (
463
+ int(end_offset[0] * scale),
464
+ int(end_offset[1] * scale),
465
+ )
466
+ cropped[input_name] = crop_tensor_or_rasterimage(
467
+ value, scaled_start, scaled_end
468
+ )
469
+
402
470
  elif isinstance(value, list):
403
471
  cropped[input_name] = [
404
472
  feat for feat in value if cur_geom.intersects(feat.geometry)