rslearn 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/local_files.py +2 -2
  3. rslearn/data_sources/utils.py +204 -64
  4. rslearn/dataset/materialize.py +5 -1
  5. rslearn/models/clay/clay.py +3 -3
  6. rslearn/models/detr/detr.py +4 -1
  7. rslearn/models/dinov3.py +0 -1
  8. rslearn/models/olmoearth_pretrain/model.py +3 -1
  9. rslearn/models/pooling_decoder.py +1 -1
  10. rslearn/models/prithvi.py +0 -1
  11. rslearn/models/simple_time_series.py +97 -35
  12. rslearn/train/data_module.py +5 -0
  13. rslearn/train/dataset.py +151 -55
  14. rslearn/train/dataset_index.py +156 -0
  15. rslearn/train/model_context.py +16 -0
  16. rslearn/train/tasks/per_pixel_regression.py +13 -13
  17. rslearn/train/tasks/segmentation.py +26 -13
  18. rslearn/train/transforms/concatenate.py +17 -27
  19. rslearn/train/transforms/crop.py +8 -19
  20. rslearn/train/transforms/flip.py +4 -10
  21. rslearn/train/transforms/mask.py +9 -15
  22. rslearn/train/transforms/normalize.py +31 -82
  23. rslearn/train/transforms/pad.py +7 -13
  24. rslearn/train/transforms/resize.py +5 -22
  25. rslearn/train/transforms/select_bands.py +16 -36
  26. rslearn/train/transforms/sentinel1.py +4 -16
  27. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/METADATA +1 -1
  28. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/RECORD +33 -32
  29. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/WHEEL +0 -0
  30. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
  31. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
  32. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
  33. {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
@@ -66,20 +66,18 @@ class PerPixelRegressionTask(BasicTask):
66
66
  return {}, {}
67
67
 
68
68
  assert isinstance(raw_inputs["targets"], RasterImage)
69
- assert raw_inputs["targets"].image.shape[0] == 1
70
- assert raw_inputs["targets"].image.shape[1] == 1
71
- labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
69
+ labels = raw_inputs["targets"].get_hw_tensor().float() * self.scale_factor
72
70
 
73
71
  if self.nodata_value is not None:
74
- valid = (
75
- raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
76
- ).float()
72
+ valid = (raw_inputs["targets"].get_hw_tensor() != self.nodata_value).float()
77
73
  else:
78
74
  valid = torch.ones(labels.shape, dtype=torch.float32)
79
75
 
76
+ # Wrap in RasterImage with CTHW format (C=1, T=1) so values and valid can be
77
+ # used in image transforms.
80
78
  return {}, {
81
- "values": labels,
82
- "valid": valid,
79
+ "values": RasterImage(labels[None, None, :, :], timestamps=None),
80
+ "valid": RasterImage(valid[None, None, :, :], timestamps=None),
83
81
  }
84
82
 
85
83
  def process_output(
@@ -121,7 +119,7 @@ class PerPixelRegressionTask(BasicTask):
121
119
  image = super().visualize(input_dict, target_dict, output)["image"]
122
120
  if target_dict is None:
123
121
  raise ValueError("target_dict is required for visualization")
124
- gt_values = target_dict["classes"].cpu().numpy()
122
+ gt_values = target_dict["values"].get_hw_tensor().cpu().numpy()
125
123
  pred_values = output.cpu().numpy()[0, :, :]
126
124
  gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
127
125
  pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
@@ -210,8 +208,10 @@ class PerPixelRegressionHead(Predictor):
210
208
 
211
209
  losses = {}
212
210
  if targets:
213
- labels = torch.stack([target["values"] for target in targets])
214
- mask = torch.stack([target["valid"] for target in targets])
211
+ labels = torch.stack(
212
+ [target["values"].get_hw_tensor() for target in targets]
213
+ )
214
+ mask = torch.stack([target["valid"].get_hw_tensor() for target in targets])
215
215
 
216
216
  if self.loss_mode == "mse":
217
217
  scores = torch.square(outputs - labels)
@@ -262,14 +262,14 @@ class PerPixelRegressionMetricWrapper(Metric):
262
262
  """
263
263
  if not isinstance(preds, torch.Tensor):
264
264
  preds = torch.stack(preds)
265
- labels = torch.stack([target["values"] for target in targets])
265
+ labels = torch.stack([target["values"].get_hw_tensor() for target in targets])
266
266
 
267
267
  # Sub-select the valid labels.
268
268
  # We flatten the prediction and label images at valid pixels.
269
269
  if len(preds.shape) == 4:
270
270
  assert preds.shape[1] == 1
271
271
  preds = preds[:, 0, :, :]
272
- mask = torch.stack([target["valid"] > 0 for target in targets])
272
+ mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
273
273
  preds = preds[mask]
274
274
  labels = labels[mask]
275
275
  if len(preds) == 0:
@@ -128,9 +128,7 @@ class SegmentationTask(BasicTask):
128
128
  return {}, {}
129
129
 
130
130
  assert isinstance(raw_inputs["targets"], RasterImage)
131
- assert raw_inputs["targets"].image.shape[0] == 1
132
- assert raw_inputs["targets"].image.shape[1] == 1
133
- labels = raw_inputs["targets"].image[0, 0, :, :].long()
131
+ labels = raw_inputs["targets"].get_hw_tensor().long()
134
132
 
135
133
  if self.class_id_mapping is not None:
136
134
  new_labels = labels.clone()
@@ -146,9 +144,11 @@ class SegmentationTask(BasicTask):
146
144
  else:
147
145
  valid = torch.ones(labels.shape, dtype=torch.float32)
148
146
 
147
+ # Wrap in RasterImage with CTHW format (C=1, T=1) so classes and valid can be
148
+ # used in image transforms.
149
149
  return {}, {
150
- "classes": labels,
151
- "valid": valid,
150
+ "classes": RasterImage(labels[None, None, :, :], timestamps=None),
151
+ "valid": RasterImage(valid[None, None, :, :], timestamps=None),
152
152
  }
153
153
 
154
154
  def process_output(
@@ -206,7 +206,7 @@ class SegmentationTask(BasicTask):
206
206
  image = super().visualize(input_dict, target_dict, output)["image"]
207
207
  if target_dict is None:
208
208
  raise ValueError("target_dict is required for visualization")
209
- gt_classes = target_dict["classes"].cpu().numpy()
209
+ gt_classes = target_dict["classes"].get_hw_tensor().cpu().numpy()
210
210
  pred_classes = output.cpu().numpy().argmax(axis=0)
211
211
  gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
212
212
  pred_vis = np.zeros(
@@ -291,12 +291,19 @@ class SegmentationTask(BasicTask):
291
291
  class SegmentationHead(Predictor):
292
292
  """Head for segmentation task."""
293
293
 
294
- def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
294
+ def __init__(
295
+ self,
296
+ weights: list[float] | None = None,
297
+ dice_loss: bool = False,
298
+ temperature: float = 1.0,
299
+ ):
295
300
  """Initialize a new SegmentationTask.
296
301
 
297
302
  Args:
298
303
  weights: weights for cross entropy loss (Tensor of size C)
299
304
  dice_loss: weather to add dice loss to cross entropy
305
+ temperature: temperature scaling for softmax, does not affect the loss,
306
+ only the predictor outputs
300
307
  """
301
308
  super().__init__()
302
309
  if weights is not None:
@@ -304,6 +311,7 @@ class SegmentationHead(Predictor):
304
311
  else:
305
312
  self.weights = None
306
313
  self.dice_loss = dice_loss
314
+ self.temperature = temperature
307
315
 
308
316
  def forward(
309
317
  self,
@@ -332,12 +340,16 @@ class SegmentationHead(Predictor):
332
340
  )
333
341
 
334
342
  logits = intermediates.feature_maps[0]
335
- outputs = torch.nn.functional.softmax(logits, dim=1)
343
+ outputs = torch.nn.functional.softmax(logits / self.temperature, dim=1)
336
344
 
337
345
  losses = {}
338
346
  if targets:
339
- labels = torch.stack([target["classes"] for target in targets], dim=0)
340
- mask = torch.stack([target["valid"] for target in targets], dim=0)
347
+ labels = torch.stack(
348
+ [target["classes"].get_hw_tensor() for target in targets], dim=0
349
+ )
350
+ mask = torch.stack(
351
+ [target["valid"].get_hw_tensor() for target in targets], dim=0
352
+ )
341
353
  per_pixel_loss = torch.nn.functional.cross_entropy(
342
354
  logits, labels, weight=self.weights, reduction="none"
343
355
  )
@@ -350,7 +362,8 @@ class SegmentationHead(Predictor):
350
362
  # the summed mask loss be zero.
351
363
  losses["cls"] = torch.sum(per_pixel_loss * mask)
352
364
  if self.dice_loss:
353
- dice_loss = DiceLoss()(outputs, labels, mask)
365
+ softmax_woT = torch.nn.functional.softmax(logits, dim=1)
366
+ dice_loss = DiceLoss()(softmax_woT, labels, mask)
354
367
  losses["dice"] = dice_loss
355
368
 
356
369
  return ModelOutput(
@@ -401,12 +414,12 @@ class SegmentationMetric(Metric):
401
414
  """
402
415
  if not isinstance(preds, torch.Tensor):
403
416
  preds = torch.stack(preds)
404
- labels = torch.stack([target["classes"] for target in targets])
417
+ labels = torch.stack([target["classes"].get_hw_tensor() for target in targets])
405
418
 
406
419
  # Sub-select the valid labels.
407
420
  # We flatten the prediction and label images at valid pixels.
408
421
  # Prediction is changed from BCHW to BHWC so we can select the valid BHW mask.
409
- mask = torch.stack([target["valid"] > 0 for target in targets])
422
+ mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
410
423
  preds = preds.permute(0, 2, 3, 1)[mask]
411
424
  labels = labels[mask]
412
425
  if len(preds) == 0:
@@ -54,36 +54,26 @@ class Concatenate(Transform):
54
54
  target_dict: the target
55
55
 
56
56
  Returns:
57
- concatenated (input_dicts, target_dicts) tuple. If one of the
58
- specified inputs is a RasterImage, a RasterImage will be returned.
59
- Otherwise it will be a torch.Tensor.
57
+ (input_dicts, target_dicts) where the entry corresponding to
58
+ output_selector contains the concatenated RasterImage.
60
59
  """
61
- images = []
62
- return_raster_image: bool = False
60
+ tensors: list[torch.Tensor] = []
63
61
  timestamps: list[tuple[datetime, datetime]] | None = None
62
+
64
63
  for selector, wanted_bands in self.selections.items():
65
64
  image = read_selector(input_dict, target_dict, selector)
66
- if isinstance(image, torch.Tensor):
67
- if wanted_bands:
68
- image = image[wanted_bands, :, :]
69
- images.append(image)
70
- elif isinstance(image, RasterImage):
71
- return_raster_image = True
72
- if wanted_bands:
73
- images.append(image.image[wanted_bands, :, :])
74
- else:
75
- images.append(image.image)
76
- if timestamps is None:
77
- if image.timestamps is not None:
78
- # assume all concatenated modalities have the same
79
- # number of timestamps
80
- timestamps = image.timestamps
81
- if return_raster_image:
82
- result = RasterImage(
83
- torch.concatenate(images, dim=self.concatenate_dim),
84
- timestamps=timestamps,
85
- )
86
- else:
87
- result = torch.concatenate(images, dim=self.concatenate_dim)
65
+ if wanted_bands:
66
+ tensors.append(image.image[wanted_bands, :, :])
67
+ else:
68
+ tensors.append(image.image)
69
+ if timestamps is None and image.timestamps is not None:
70
+ # assume all concatenated modalities have the same
71
+ # number of timestamps
72
+ timestamps = image.timestamps
73
+
74
+ result = RasterImage(
75
+ torch.concatenate(tensors, dim=self.concatenate_dim),
76
+ timestamps=timestamps,
77
+ )
88
78
  write_selector(input_dict, target_dict, self.output_selector, result)
89
79
  return input_dict, target_dict
@@ -71,9 +71,7 @@ class Crop(Transform):
71
71
  "remove_from_top": remove_from_top,
72
72
  }
73
73
 
74
- def apply_image(
75
- self, image: RasterImage | torch.Tensor, state: dict[str, Any]
76
- ) -> RasterImage | torch.Tensor:
74
+ def apply_image(self, image: RasterImage, state: dict[str, Any]) -> RasterImage:
77
75
  """Apply the sampled state on the specified image.
78
76
 
79
77
  Args:
@@ -84,22 +82,13 @@ class Crop(Transform):
84
82
  crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
85
83
  remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
86
84
  remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
87
- if isinstance(image, RasterImage):
88
- image.image = torchvision.transforms.functional.crop(
89
- image.image,
90
- top=remove_from_top,
91
- left=remove_from_left,
92
- height=crop_size,
93
- width=crop_size,
94
- )
95
- else:
96
- image = torchvision.transforms.functional.crop(
97
- image,
98
- top=remove_from_top,
99
- left=remove_from_left,
100
- height=crop_size,
101
- width=crop_size,
102
- )
85
+ image.image = torchvision.transforms.functional.crop(
86
+ image.image,
87
+ top=remove_from_top,
88
+ left=remove_from_left,
89
+ height=crop_size,
90
+ width=crop_size,
91
+ )
103
92
  return image
104
93
 
105
94
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -57,16 +57,10 @@ class Flip(Transform):
57
57
  image: the image to transform.
58
58
  state: the sampled state.
59
59
  """
60
- if isinstance(image, RasterImage):
61
- if state["horizontal"]:
62
- image.image = torch.flip(image.image, dims=[-1])
63
- if state["vertical"]:
64
- image.image = torch.flip(image.image, dims=[-2])
65
- elif isinstance(image, torch.Tensor):
66
- if state["horizontal"]:
67
- image = torch.flip(image, dims=[-1])
68
- if state["vertical"]:
69
- image = torch.flip(image, dims=[-2])
60
+ if state["horizontal"]:
61
+ image.image = torch.flip(image.image, dims=[-1])
62
+ if state["vertical"]:
63
+ image.image = torch.flip(image.image, dims=[-2])
70
64
  return image
71
65
 
72
66
  def apply_boxes(
@@ -1,7 +1,5 @@
1
1
  """Mask transform."""
2
2
 
3
- import torch
4
-
5
3
  from rslearn.train.model_context import RasterImage
6
4
  from rslearn.train.transforms.transform import Transform, read_selector
7
5
 
@@ -32,9 +30,7 @@ class Mask(Transform):
32
30
  self.mask_selector = mask_selector
33
31
  self.mask_value = mask_value
34
32
 
35
- def apply_image(
36
- self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
37
- ) -> torch.Tensor | RasterImage:
33
+ def apply_image(self, image: RasterImage, mask: RasterImage) -> RasterImage:
38
34
  """Apply the mask on the image.
39
35
 
40
36
  Args:
@@ -44,21 +40,19 @@ class Mask(Transform):
44
40
  Returns:
45
41
  masked image
46
42
  """
47
- # Tile the mask to have same number of bands as the image.
48
- if isinstance(mask, RasterImage):
49
- mask = mask.image
43
+ # Extract the mask tensor (CTHW format)
44
+ mask_tensor = mask.image
50
45
 
51
- if image.shape[0] != mask.shape[0]:
52
- if mask.shape[0] != 1:
46
+ # Tile the mask to have same number of bands (C dimension) as the image.
47
+ if image.shape[0] != mask_tensor.shape[0]:
48
+ if mask_tensor.shape[0] != 1:
53
49
  raise ValueError(
54
50
  "expected mask to either have same bands as image, or one band"
55
51
  )
56
- mask = mask.repeat(image.shape[0], 1, 1)
52
+ # Repeat along C dimension, keep T, H, W the same
53
+ mask_tensor = mask_tensor.repeat(image.shape[0], 1, 1, 1)
57
54
 
58
- if isinstance(image, torch.Tensor):
59
- image[mask == 0] = self.mask_value
60
- else:
61
- image.image[mask == 0] = self.mask_value
55
+ image.image[mask_tensor == 0] = self.mask_value
62
56
  return image
63
57
 
64
58
  def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
@@ -1,5 +1,6 @@
1
1
  """Normalization transforms."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
6
  import torch
@@ -35,14 +36,17 @@ class Normalize(Transform):
35
36
  bands: optionally restrict the normalization to these band indices. If set,
36
37
  mean and std must either be one value, or have length equal to the
37
38
  number of band indices passed here.
38
- num_bands: the number of bands per image, to distinguish different images
39
- in a time series. If set, then the bands list is repeated for each
40
- image, e.g. if bands=[2] then we apply normalization on images[2],
41
- images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
42
- is not set, then we apply the mean and std on each image in the time
43
- series.
39
+ num_bands: deprecated, no longer used. Will be removed after 2026-04-01.
44
40
  """
45
41
  super().__init__()
42
+
43
+ if num_bands is not None:
44
+ warnings.warn(
45
+ "num_bands is deprecated and no longer used. "
46
+ "It will be removed after 2026-04-01.",
47
+ FutureWarning,
48
+ )
49
+
46
50
  self.mean = torch.tensor(mean)
47
51
  self.std = torch.tensor(std)
48
52
 
@@ -55,92 +59,37 @@ class Normalize(Transform):
55
59
 
56
60
  self.selectors = selectors
57
61
  self.bands = torch.tensor(bands) if bands is not None else None
58
- self.num_bands = num_bands
59
62
 
60
- def apply_image(
61
- self, image: torch.Tensor | RasterImage
62
- ) -> torch.Tensor | RasterImage:
63
+ def apply_image(self, image: RasterImage) -> RasterImage:
63
64
  """Normalize the specified image.
64
65
 
65
66
  Args:
66
67
  image: the image to transform.
67
68
  """
68
-
69
- def _repeat_mean_and_std(
70
- image_channels: int, num_bands: int | None, is_raster_image: bool
71
- ) -> tuple[torch.Tensor, torch.Tensor]:
72
- """Get mean and std tensor that are suitable for applying on the image."""
73
- # We only need to repeat the tensor if both of these are true:
74
- # - The mean/std are not just one scalar.
75
- # - self.num_bands is set, otherwise we treat the input as a single image.
76
- if len(self.mean.shape) == 0:
77
- return self.mean, self.std
78
- if num_bands is None:
79
- return self.mean, self.std
80
- num_images = image_channels // num_bands
81
- if is_raster_image:
82
- # add an extra T dimension, CTHW
83
- return self.mean.repeat(num_images)[
84
- :, None, None, None
85
- ], self.std.repeat(num_images)[:, None, None, None]
86
- else:
87
- # add an extra T dimension, CTHW
88
- return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
89
- num_images
90
- )[:, None, None]
69
+ # Get mean/std with singleton dims for broadcasting over CTHW.
70
+ if len(self.mean.shape) == 0:
71
+ # Scalar - broadcasts naturally.
72
+ mean, std = self.mean, self.std
73
+ else:
74
+ # Vector of length C - add singleton dims for T, H, W.
75
+ mean = self.mean[:, None, None, None]
76
+ std = self.std[:, None, None, None]
91
77
 
92
78
  if self.bands is not None:
93
- # User has provided band indices to normalize.
94
- # If num_bands is set, then we repeat these for each image in the input
95
- # image time series.
96
- band_indices = self.bands
97
- if self.num_bands:
98
- num_images = image.shape[0] // self.num_bands
99
- band_indices = torch.cat(
100
- [
101
- band_indices + image_idx * self.num_bands
102
- for image_idx in range(num_images)
103
- ],
104
- dim=0,
79
+ # Normalize only specific band indices.
80
+ image.image[self.bands] = (image.image[self.bands] - mean) / std
81
+ if self.valid_min is not None:
82
+ image.image[self.bands] = torch.clamp(
83
+ image.image[self.bands],
84
+ min=self.valid_min,
85
+ max=self.valid_max,
105
86
  )
106
-
107
- # We use len(self.bands) here because that is how many bands per timestep
108
- # we are actually processing with the mean/std.
109
- mean, std = _repeat_mean_and_std(
110
- image_channels=len(band_indices),
111
- num_bands=len(self.bands),
112
- is_raster_image=isinstance(image, RasterImage),
113
- )
114
- if isinstance(image, torch.Tensor):
115
- image[band_indices] = (image[band_indices] - mean) / std
116
- if self.valid_min is not None:
117
- image[band_indices] = torch.clamp(
118
- image[band_indices], min=self.valid_min, max=self.valid_max
119
- )
120
- else:
121
- image.image[band_indices] = (image.image[band_indices] - mean) / std
122
- if self.valid_min is not None:
123
- image.image[band_indices] = torch.clamp(
124
- image.image[band_indices],
125
- min=self.valid_min,
126
- max=self.valid_max,
127
- )
128
87
  else:
129
- mean, std = _repeat_mean_and_std(
130
- image_channels=image.shape[0],
131
- num_bands=self.num_bands,
132
- is_raster_image=isinstance(image, RasterImage),
133
- )
134
- if isinstance(image, torch.Tensor):
135
- image = (image - mean) / std
136
- if self.valid_min is not None:
137
- image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
138
- else:
139
- image.image = (image.image - mean) / std
140
- if self.valid_min is not None:
141
- image.image = torch.clamp(
142
- image.image, min=self.valid_min, max=self.valid_max
143
- )
88
+ image.image = (image.image - mean) / std
89
+ if self.valid_min is not None:
90
+ image.image = torch.clamp(
91
+ image.image, min=self.valid_min, max=self.valid_max
92
+ )
144
93
  return image
145
94
 
146
95
  def forward(
@@ -50,9 +50,7 @@ class Pad(Transform):
50
50
  """
51
51
  return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
52
52
 
53
- def apply_image(
54
- self, image: RasterImage | torch.Tensor, state: dict[str, bool]
55
- ) -> RasterImage | torch.Tensor:
53
+ def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
56
54
  """Apply the sampled state on the specified image.
57
55
 
58
56
  Args:
@@ -105,16 +103,12 @@ class Pad(Transform):
105
103
  horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
106
104
  vertical_pad = (vertical_half, vertical_extra - vertical_half)
107
105
 
108
- if isinstance(image, RasterImage):
109
- image.image = apply_padding(
110
- image.image, True, horizontal_pad[0], horizontal_pad[1]
111
- )
112
- image.image = apply_padding(
113
- image.image, False, vertical_pad[0], vertical_pad[1]
114
- )
115
- else:
116
- image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
117
- image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
106
+ image.image = apply_padding(
107
+ image.image, True, horizontal_pad[0], horizontal_pad[1]
108
+ )
109
+ image.image = apply_padding(
110
+ image.image, False, vertical_pad[0], vertical_pad[1]
111
+ )
118
112
  return image
119
113
 
120
114
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -2,7 +2,6 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
6
5
  import torchvision
7
6
  from torchvision.transforms import InterpolationMode
8
7
 
@@ -40,32 +39,16 @@ class Resize(Transform):
40
39
  self.selectors = selectors
41
40
  self.interpolation = INTERPOLATION_MODES[interpolation]
42
41
 
43
- def apply_resize(
44
- self, image: torch.Tensor | RasterImage
45
- ) -> torch.Tensor | RasterImage:
42
+ def apply_resize(self, image: RasterImage) -> RasterImage:
46
43
  """Apply resizing on the specified image.
47
44
 
48
- If the image is 2D, it is unsqueezed to 3D and then squeezed
49
- back after resizing.
50
-
51
45
  Args:
52
46
  image: the image to transform.
53
47
  """
54
- if isinstance(image, torch.Tensor):
55
- if image.dim() == 2:
56
- image = image.unsqueeze(0) # (H, W) -> (1, H, W)
57
- result = torchvision.transforms.functional.resize(
58
- image, self.target_size, self.interpolation
59
- )
60
- return result.squeeze(0) # (1, H, W) -> (H, W)
61
- return torchvision.transforms.functional.resize(
62
- image, self.target_size, self.interpolation
63
- )
64
- else:
65
- image.image = torchvision.transforms.functional.resize(
66
- image.image, self.target_size, self.interpolation
67
- )
68
- return image
48
+ image.image = torchvision.transforms.functional.resize(
49
+ image.image, self.target_size, self.interpolation
50
+ )
51
+ return image
69
52
 
70
53
  def forward(
71
54
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -1,9 +1,8 @@
1
1
  """The SelectBands transform."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
- from rslearn.train.model_context import RasterImage
6
-
7
6
  from .transform import Transform, read_selector, write_selector
8
7
 
9
8
 
@@ -17,60 +16,41 @@ class SelectBands(Transform):
17
16
  output_selector: str = "image",
18
17
  num_bands_per_timestep: int | None = None,
19
18
  ):
20
- """Initialize a new Concatenate.
19
+ """Initialize a new SelectBands.
21
20
 
22
21
  Args:
23
- band_indices: the bands to select.
22
+ band_indices: the bands to select from the channel dimension.
24
23
  input_selector: the selector to read the input image.
25
24
  output_selector: the output selector under which to save the output image.
26
- num_bands_per_timestep: the number of bands per image, to distinguish
27
- between stacked images in an image time series. If set, then the
28
- band_indices are selected for each image in the time series.
25
+ num_bands_per_timestep: deprecated, no longer used. Will be removed after
26
+ 2026-04-01.
29
27
  """
30
28
  super().__init__()
29
+
30
+ if num_bands_per_timestep is not None:
31
+ warnings.warn(
32
+ "num_bands_per_timestep is deprecated and no longer used. "
33
+ "It will be removed after 2026-04-01.",
34
+ FutureWarning,
35
+ )
36
+
31
37
  self.input_selector = input_selector
32
38
  self.output_selector = output_selector
33
39
  self.band_indices = band_indices
34
- self.num_bands_per_timestep = num_bands_per_timestep
35
40
 
36
41
  def forward(
37
42
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
38
43
  ) -> tuple[dict[str, Any], dict[str, Any]]:
39
- """Apply concatenation over the inputs and targets.
44
+ """Apply band selection over the inputs and targets.
40
45
 
41
46
  Args:
42
47
  input_dict: the input
43
48
  target_dict: the target
44
49
 
45
50
  Returns:
46
- normalized (input_dicts, target_dicts) tuple
51
+ (input_dicts, target_dicts) tuple with selected bands
47
52
  """
48
53
  image = read_selector(input_dict, target_dict, self.input_selector)
49
- num_bands_per_timestep = (
50
- self.num_bands_per_timestep
51
- if self.num_bands_per_timestep is not None
52
- else image.shape[0]
53
- )
54
- if isinstance(image, RasterImage):
55
- assert num_bands_per_timestep == image.shape[0], (
56
- "Expect a seperate dimension for timesteps in RasterImages."
57
- )
58
-
59
- if image.shape[0] % num_bands_per_timestep != 0:
60
- raise ValueError(
61
- f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
62
- )
63
-
64
- # Copy the band indices for each timestep in the input.
65
- wanted_bands: list[int] = []
66
- for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
67
- wanted_bands.extend(
68
- [(start_channel_idx + band_idx) for band_idx in self.band_indices]
69
- )
70
-
71
- if isinstance(image, RasterImage):
72
- image.image = image.image[wanted_bands]
73
- else:
74
- image = image[wanted_bands]
54
+ image.image = image.image[self.band_indices]
75
55
  write_selector(input_dict, target_dict, self.output_selector, image)
76
56
  return input_dict, target_dict