rslearn 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/local_files.py +2 -2
  3. rslearn/data_sources/utils.py +204 -64
  4. rslearn/dataset/materialize.py +5 -1
  5. rslearn/models/clay/clay.py +3 -3
  6. rslearn/models/detr/detr.py +4 -1
  7. rslearn/models/dinov3.py +0 -1
  8. rslearn/models/olmoearth_pretrain/model.py +3 -1
  9. rslearn/models/pooling_decoder.py +1 -1
  10. rslearn/models/prithvi.py +0 -1
  11. rslearn/models/simple_time_series.py +97 -35
  12. rslearn/train/data_module.py +5 -0
  13. rslearn/train/dataset.py +186 -49
  14. rslearn/train/dataset_index.py +156 -0
  15. rslearn/train/model_context.py +16 -0
  16. rslearn/train/tasks/detection.py +1 -18
  17. rslearn/train/tasks/per_pixel_regression.py +13 -13
  18. rslearn/train/tasks/segmentation.py +27 -32
  19. rslearn/train/transforms/concatenate.py +17 -27
  20. rslearn/train/transforms/crop.py +8 -19
  21. rslearn/train/transforms/flip.py +4 -10
  22. rslearn/train/transforms/mask.py +9 -15
  23. rslearn/train/transforms/normalize.py +31 -82
  24. rslearn/train/transforms/pad.py +7 -13
  25. rslearn/train/transforms/resize.py +5 -22
  26. rslearn/train/transforms/select_bands.py +16 -36
  27. rslearn/train/transforms/sentinel1.py +4 -16
  28. rslearn/utils/colors.py +20 -0
  29. rslearn/vis/__init__.py +1 -0
  30. rslearn/vis/normalization.py +127 -0
  31. rslearn/vis/render_raster_label.py +96 -0
  32. rslearn/vis/render_sensor_image.py +27 -0
  33. rslearn/vis/render_vector_label.py +439 -0
  34. rslearn/vis/utils.py +99 -0
  35. rslearn/vis/vis_server.py +574 -0
  36. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/METADATA +14 -1
  37. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/RECORD +42 -33
  38. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/WHEEL +1 -1
  39. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
  40. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
  41. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
  42. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,156 @@
1
+ """Dataset index for caching window lists to speed up ModelDataset initialization."""
2
+
3
+ import hashlib
4
+ import json
5
+ from datetime import datetime
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from upath import UPath
9
+
10
+ from rslearn.dataset.window import Window
11
+ from rslearn.log_utils import get_logger
12
+ from rslearn.utils.fsspec import open_atomic
13
+
14
+ if TYPE_CHECKING:
15
+ from rslearn.dataset.storage.storage import WindowStorage
16
+
17
+ logger = get_logger(__name__)
18
+
19
+ # Increment this when the index format changes to force rebuild
20
+ INDEX_VERSION = 1
21
+
22
+ # Directory name for storing index files
23
+ INDEX_DIR_NAME = ".rslearn_dataset_index"
24
+
25
+
26
+ class DatasetIndex:
27
+ """Manages indexed window lists for faster ModelDataset initialization.
28
+
29
+ Note: The index does NOT automatically detect when windows are added or removed
30
+ from the dataset. Use refresh=True after modifying dataset windows.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ storage: "WindowStorage",
36
+ dataset_path: UPath,
37
+ groups: list[str] | None,
38
+ names: list[str] | None,
39
+ tags: dict[str, Any] | None,
40
+ num_samples: int | None,
41
+ skip_targets: bool,
42
+ inputs: dict[str, Any],
43
+ ) -> None:
44
+ """Initialize DatasetIndex with specific configuration.
45
+
46
+ Args:
47
+ storage: WindowStorage for deserializing windows.
48
+ dataset_path: Path to the dataset directory.
49
+ groups: list of window groups to include.
50
+ names: list of window names to include.
51
+ tags: tags to filter windows by.
52
+ num_samples: limit on number of samples.
53
+ skip_targets: whether targets are skipped.
54
+ inputs: dict mapping input names to DataInput objects.
55
+ """
56
+ self.storage = storage
57
+ self.dataset_path = dataset_path
58
+ self.index_dir = dataset_path / INDEX_DIR_NAME
59
+
60
+ # Compute index key from configuration
61
+ inputs_data = {}
62
+ for name, inp in inputs.items():
63
+ inputs_data[name] = {
64
+ "layers": inp.layers,
65
+ "required": inp.required,
66
+ "load_all_layers": inp.load_all_layers,
67
+ "is_target": inp.is_target,
68
+ }
69
+
70
+ key_data = {
71
+ "groups": groups,
72
+ "names": names,
73
+ "tags": tags,
74
+ "num_samples": num_samples,
75
+ "skip_targets": skip_targets,
76
+ "inputs": inputs_data,
77
+ }
78
+ self.index_key = hashlib.sha256(
79
+ json.dumps(key_data, sort_keys=True).encode()
80
+ ).hexdigest()
81
+
82
+ def _get_config_hash(self) -> str:
83
+ """Get hash of config.json for quick validation.
84
+
85
+ Returns:
86
+ A 16-character hex string hash of the config, or empty string if no config.
87
+ """
88
+ config_path = self.dataset_path / "config.json"
89
+ if config_path.exists():
90
+ with config_path.open() as f:
91
+ return hashlib.sha256(f.read().encode()).hexdigest()[:16]
92
+ return ""
93
+
94
+ def load_windows(self, refresh: bool = False) -> list[Window] | None:
95
+ """Load indexed window list if valid, else return None.
96
+
97
+ Args:
98
+ refresh: If True, ignore existing index and return None.
99
+
100
+ Returns:
101
+ List of Window objects if index is valid, None otherwise.
102
+ """
103
+ if refresh:
104
+ logger.info("refresh=True, rebuilding index")
105
+ return None
106
+
107
+ index_file = self.index_dir / f"{self.index_key}.json"
108
+ if not index_file.exists():
109
+ logger.info(f"No index found at {index_file}, will build")
110
+ return None
111
+
112
+ try:
113
+ with index_file.open() as f:
114
+ index_data = json.load(f)
115
+ except (OSError, json.JSONDecodeError):
116
+ logger.warning(f"Corrupted index file at {index_file}, will rebuild")
117
+ return None
118
+
119
+ # Check index version
120
+ if index_data.get("version") != INDEX_VERSION:
121
+ logger.info(
122
+ f"Index version mismatch (got {index_data.get('version')}, "
123
+ f"expected {INDEX_VERSION}), will rebuild"
124
+ )
125
+ return None
126
+
127
+ # Quick validation: check config hash
128
+ if index_data.get("config_hash") != self._get_config_hash():
129
+ logger.info("Config hash mismatch, index invalidated")
130
+ return None
131
+
132
+ # Deserialize windows
133
+ return [Window.from_metadata(self.storage, w) for w in index_data["windows"]]
134
+
135
+ def save_windows(self, windows: list[Window]) -> None:
136
+ """Save processed windows to index with atomic write.
137
+
138
+ Args:
139
+ windows: List of Window objects to index.
140
+ """
141
+ self.index_dir.mkdir(parents=True, exist_ok=True)
142
+ index_file = self.index_dir / f"{self.index_key}.json"
143
+
144
+ # Serialize windows
145
+ serialized_windows = [w.get_metadata() for w in windows]
146
+
147
+ index_data = {
148
+ "version": INDEX_VERSION,
149
+ "config_hash": self._get_config_hash(),
150
+ "created_at": datetime.now().isoformat(),
151
+ "num_windows": len(windows),
152
+ "windows": serialized_windows,
153
+ }
154
+ with open_atomic(index_file, "w") as f:
155
+ json.dump(index_data, f)
156
+ logger.info(f"Saved {len(windows)} windows to index at {index_file}")
@@ -43,6 +43,22 @@ class RasterImage:
43
43
  raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
44
44
  return self.image[:, 0]
45
45
 
46
+ def get_hw_tensor(self) -> torch.Tensor:
47
+ """Get a 2D HW tensor from a single-channel, single-timestep RasterImage.
48
+
49
+ This function checks that C=1 and T=1, then returns the HW tensor.
50
+ Useful for per-pixel labels like segmentation masks.
51
+ """
52
+ if self.image.shape[0] != 1:
53
+ raise ValueError(
54
+ f"Expected single channel (C=1), got {self.image.shape[0]}"
55
+ )
56
+ if self.image.shape[1] != 1:
57
+ raise ValueError(
58
+ f"Expected single timestep (T=1), got {self.image.shape[1]}"
59
+ )
60
+ return self.image[0, 0]
61
+
46
62
 
47
63
  @dataclass
48
64
  class SampleMetadata:
@@ -14,27 +14,10 @@ from torchmetrics import Metric, MetricCollection
14
14
 
15
15
  from rslearn.train.model_context import RasterImage, SampleMetadata
16
16
  from rslearn.utils import Feature, STGeometry
17
+ from rslearn.utils.colors import DEFAULT_COLORS
17
18
 
18
19
  from .task import BasicTask
19
20
 
20
- DEFAULT_COLORS = [
21
- (255, 0, 0),
22
- (0, 255, 0),
23
- (0, 0, 255),
24
- (255, 255, 0),
25
- (0, 255, 255),
26
- (255, 0, 255),
27
- (0, 128, 0),
28
- (255, 160, 122),
29
- (139, 69, 19),
30
- (128, 128, 128),
31
- (255, 255, 255),
32
- (143, 188, 143),
33
- (95, 158, 160),
34
- (255, 200, 0),
35
- (128, 0, 0),
36
- ]
37
-
38
21
 
39
22
  class DetectionTask(BasicTask):
40
23
  """A point or bounding box detection task."""
@@ -66,20 +66,18 @@ class PerPixelRegressionTask(BasicTask):
66
66
  return {}, {}
67
67
 
68
68
  assert isinstance(raw_inputs["targets"], RasterImage)
69
- assert raw_inputs["targets"].image.shape[0] == 1
70
- assert raw_inputs["targets"].image.shape[1] == 1
71
- labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
69
+ labels = raw_inputs["targets"].get_hw_tensor().float() * self.scale_factor
72
70
 
73
71
  if self.nodata_value is not None:
74
- valid = (
75
- raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
76
- ).float()
72
+ valid = (raw_inputs["targets"].get_hw_tensor() != self.nodata_value).float()
77
73
  else:
78
74
  valid = torch.ones(labels.shape, dtype=torch.float32)
79
75
 
76
+ # Wrap in RasterImage with CTHW format (C=1, T=1) so values and valid can be
77
+ # used in image transforms.
80
78
  return {}, {
81
- "values": labels,
82
- "valid": valid,
79
+ "values": RasterImage(labels[None, None, :, :], timestamps=None),
80
+ "valid": RasterImage(valid[None, None, :, :], timestamps=None),
83
81
  }
84
82
 
85
83
  def process_output(
@@ -121,7 +119,7 @@ class PerPixelRegressionTask(BasicTask):
121
119
  image = super().visualize(input_dict, target_dict, output)["image"]
122
120
  if target_dict is None:
123
121
  raise ValueError("target_dict is required for visualization")
124
- gt_values = target_dict["classes"].cpu().numpy()
122
+ gt_values = target_dict["values"].get_hw_tensor().cpu().numpy()
125
123
  pred_values = output.cpu().numpy()[0, :, :]
126
124
  gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
127
125
  pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
@@ -210,8 +208,10 @@ class PerPixelRegressionHead(Predictor):
210
208
 
211
209
  losses = {}
212
210
  if targets:
213
- labels = torch.stack([target["values"] for target in targets])
214
- mask = torch.stack([target["valid"] for target in targets])
211
+ labels = torch.stack(
212
+ [target["values"].get_hw_tensor() for target in targets]
213
+ )
214
+ mask = torch.stack([target["valid"].get_hw_tensor() for target in targets])
215
215
 
216
216
  if self.loss_mode == "mse":
217
217
  scores = torch.square(outputs - labels)
@@ -262,14 +262,14 @@ class PerPixelRegressionMetricWrapper(Metric):
262
262
  """
263
263
  if not isinstance(preds, torch.Tensor):
264
264
  preds = torch.stack(preds)
265
- labels = torch.stack([target["values"] for target in targets])
265
+ labels = torch.stack([target["values"].get_hw_tensor() for target in targets])
266
266
 
267
267
  # Sub-select the valid labels.
268
268
  # We flatten the prediction and label images at valid pixels.
269
269
  if len(preds.shape) == 4:
270
270
  assert preds.shape[1] == 1
271
271
  preds = preds[:, 0, :, :]
272
- mask = torch.stack([target["valid"] > 0 for target in targets])
272
+ mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
273
273
  preds = preds[mask]
274
274
  labels = labels[mask]
275
275
  if len(preds) == 0:
@@ -17,28 +17,10 @@ from rslearn.train.model_context import (
17
17
  SampleMetadata,
18
18
  )
19
19
  from rslearn.utils import Feature
20
+ from rslearn.utils.colors import DEFAULT_COLORS
20
21
 
21
22
  from .task import BasicTask
22
23
 
23
- # TODO: This is duplicated code fix it
24
- DEFAULT_COLORS = [
25
- (255, 0, 0),
26
- (0, 255, 0),
27
- (0, 0, 255),
28
- (255, 255, 0),
29
- (0, 255, 255),
30
- (255, 0, 255),
31
- (0, 128, 0),
32
- (255, 160, 122),
33
- (139, 69, 19),
34
- (128, 128, 128),
35
- (255, 255, 255),
36
- (143, 188, 143),
37
- (95, 158, 160),
38
- (255, 200, 0),
39
- (128, 0, 0),
40
- ]
41
-
42
24
 
43
25
  class SegmentationTask(BasicTask):
44
26
  """A segmentation (per-pixel classification) task."""
@@ -146,9 +128,7 @@ class SegmentationTask(BasicTask):
146
128
  return {}, {}
147
129
 
148
130
  assert isinstance(raw_inputs["targets"], RasterImage)
149
- assert raw_inputs["targets"].image.shape[0] == 1
150
- assert raw_inputs["targets"].image.shape[1] == 1
151
- labels = raw_inputs["targets"].image[0, 0, :, :].long()
131
+ labels = raw_inputs["targets"].get_hw_tensor().long()
152
132
 
153
133
  if self.class_id_mapping is not None:
154
134
  new_labels = labels.clone()
@@ -164,9 +144,11 @@ class SegmentationTask(BasicTask):
164
144
  else:
165
145
  valid = torch.ones(labels.shape, dtype=torch.float32)
166
146
 
147
+ # Wrap in RasterImage with CTHW format (C=1, T=1) so classes and valid can be
148
+ # used in image transforms.
167
149
  return {}, {
168
- "classes": labels,
169
- "valid": valid,
150
+ "classes": RasterImage(labels[None, None, :, :], timestamps=None),
151
+ "valid": RasterImage(valid[None, None, :, :], timestamps=None),
170
152
  }
171
153
 
172
154
  def process_output(
@@ -224,7 +206,7 @@ class SegmentationTask(BasicTask):
224
206
  image = super().visualize(input_dict, target_dict, output)["image"]
225
207
  if target_dict is None:
226
208
  raise ValueError("target_dict is required for visualization")
227
- gt_classes = target_dict["classes"].cpu().numpy()
209
+ gt_classes = target_dict["classes"].get_hw_tensor().cpu().numpy()
228
210
  pred_classes = output.cpu().numpy().argmax(axis=0)
229
211
  gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
230
212
  pred_vis = np.zeros(
@@ -309,12 +291,19 @@ class SegmentationTask(BasicTask):
309
291
  class SegmentationHead(Predictor):
310
292
  """Head for segmentation task."""
311
293
 
312
- def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
294
+ def __init__(
295
+ self,
296
+ weights: list[float] | None = None,
297
+ dice_loss: bool = False,
298
+ temperature: float = 1.0,
299
+ ):
313
300
  """Initialize a new SegmentationTask.
314
301
 
315
302
  Args:
316
303
  weights: weights for cross entropy loss (Tensor of size C)
317
304
  dice_loss: weather to add dice loss to cross entropy
305
+ temperature: temperature scaling for softmax, does not affect the loss,
306
+ only the predictor outputs
318
307
  """
319
308
  super().__init__()
320
309
  if weights is not None:
@@ -322,6 +311,7 @@ class SegmentationHead(Predictor):
322
311
  else:
323
312
  self.weights = None
324
313
  self.dice_loss = dice_loss
314
+ self.temperature = temperature
325
315
 
326
316
  def forward(
327
317
  self,
@@ -350,12 +340,16 @@ class SegmentationHead(Predictor):
350
340
  )
351
341
 
352
342
  logits = intermediates.feature_maps[0]
353
- outputs = torch.nn.functional.softmax(logits, dim=1)
343
+ outputs = torch.nn.functional.softmax(logits / self.temperature, dim=1)
354
344
 
355
345
  losses = {}
356
346
  if targets:
357
- labels = torch.stack([target["classes"] for target in targets], dim=0)
358
- mask = torch.stack([target["valid"] for target in targets], dim=0)
347
+ labels = torch.stack(
348
+ [target["classes"].get_hw_tensor() for target in targets], dim=0
349
+ )
350
+ mask = torch.stack(
351
+ [target["valid"].get_hw_tensor() for target in targets], dim=0
352
+ )
359
353
  per_pixel_loss = torch.nn.functional.cross_entropy(
360
354
  logits, labels, weight=self.weights, reduction="none"
361
355
  )
@@ -368,7 +362,8 @@ class SegmentationHead(Predictor):
368
362
  # the summed mask loss be zero.
369
363
  losses["cls"] = torch.sum(per_pixel_loss * mask)
370
364
  if self.dice_loss:
371
- dice_loss = DiceLoss()(outputs, labels, mask)
365
+ softmax_woT = torch.nn.functional.softmax(logits, dim=1)
366
+ dice_loss = DiceLoss()(softmax_woT, labels, mask)
372
367
  losses["dice"] = dice_loss
373
368
 
374
369
  return ModelOutput(
@@ -419,12 +414,12 @@ class SegmentationMetric(Metric):
419
414
  """
420
415
  if not isinstance(preds, torch.Tensor):
421
416
  preds = torch.stack(preds)
422
- labels = torch.stack([target["classes"] for target in targets])
417
+ labels = torch.stack([target["classes"].get_hw_tensor() for target in targets])
423
418
 
424
419
  # Sub-select the valid labels.
425
420
  # We flatten the prediction and label images at valid pixels.
426
421
  # Prediction is changed from BCHW to BHWC so we can select the valid BHW mask.
427
- mask = torch.stack([target["valid"] > 0 for target in targets])
422
+ mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
428
423
  preds = preds.permute(0, 2, 3, 1)[mask]
429
424
  labels = labels[mask]
430
425
  if len(preds) == 0:
@@ -54,36 +54,26 @@ class Concatenate(Transform):
54
54
  target_dict: the target
55
55
 
56
56
  Returns:
57
- concatenated (input_dicts, target_dicts) tuple. If one of the
58
- specified inputs is a RasterImage, a RasterImage will be returned.
59
- Otherwise it will be a torch.Tensor.
57
+ (input_dicts, target_dicts) where the entry corresponding to
58
+ output_selector contains the concatenated RasterImage.
60
59
  """
61
- images = []
62
- return_raster_image: bool = False
60
+ tensors: list[torch.Tensor] = []
63
61
  timestamps: list[tuple[datetime, datetime]] | None = None
62
+
64
63
  for selector, wanted_bands in self.selections.items():
65
64
  image = read_selector(input_dict, target_dict, selector)
66
- if isinstance(image, torch.Tensor):
67
- if wanted_bands:
68
- image = image[wanted_bands, :, :]
69
- images.append(image)
70
- elif isinstance(image, RasterImage):
71
- return_raster_image = True
72
- if wanted_bands:
73
- images.append(image.image[wanted_bands, :, :])
74
- else:
75
- images.append(image.image)
76
- if timestamps is None:
77
- if image.timestamps is not None:
78
- # assume all concatenated modalities have the same
79
- # number of timestamps
80
- timestamps = image.timestamps
81
- if return_raster_image:
82
- result = RasterImage(
83
- torch.concatenate(images, dim=self.concatenate_dim),
84
- timestamps=timestamps,
85
- )
86
- else:
87
- result = torch.concatenate(images, dim=self.concatenate_dim)
65
+ if wanted_bands:
66
+ tensors.append(image.image[wanted_bands, :, :])
67
+ else:
68
+ tensors.append(image.image)
69
+ if timestamps is None and image.timestamps is not None:
70
+ # assume all concatenated modalities have the same
71
+ # number of timestamps
72
+ timestamps = image.timestamps
73
+
74
+ result = RasterImage(
75
+ torch.concatenate(tensors, dim=self.concatenate_dim),
76
+ timestamps=timestamps,
77
+ )
88
78
  write_selector(input_dict, target_dict, self.output_selector, result)
89
79
  return input_dict, target_dict
@@ -71,9 +71,7 @@ class Crop(Transform):
71
71
  "remove_from_top": remove_from_top,
72
72
  }
73
73
 
74
- def apply_image(
75
- self, image: RasterImage | torch.Tensor, state: dict[str, Any]
76
- ) -> RasterImage | torch.Tensor:
74
+ def apply_image(self, image: RasterImage, state: dict[str, Any]) -> RasterImage:
77
75
  """Apply the sampled state on the specified image.
78
76
 
79
77
  Args:
@@ -84,22 +82,13 @@ class Crop(Transform):
84
82
  crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
85
83
  remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
86
84
  remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
87
- if isinstance(image, RasterImage):
88
- image.image = torchvision.transforms.functional.crop(
89
- image.image,
90
- top=remove_from_top,
91
- left=remove_from_left,
92
- height=crop_size,
93
- width=crop_size,
94
- )
95
- else:
96
- image = torchvision.transforms.functional.crop(
97
- image,
98
- top=remove_from_top,
99
- left=remove_from_left,
100
- height=crop_size,
101
- width=crop_size,
102
- )
85
+ image.image = torchvision.transforms.functional.crop(
86
+ image.image,
87
+ top=remove_from_top,
88
+ left=remove_from_left,
89
+ height=crop_size,
90
+ width=crop_size,
91
+ )
103
92
  return image
104
93
 
105
94
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -57,16 +57,10 @@ class Flip(Transform):
57
57
  image: the image to transform.
58
58
  state: the sampled state.
59
59
  """
60
- if isinstance(image, RasterImage):
61
- if state["horizontal"]:
62
- image.image = torch.flip(image.image, dims=[-1])
63
- if state["vertical"]:
64
- image.image = torch.flip(image.image, dims=[-2])
65
- elif isinstance(image, torch.Tensor):
66
- if state["horizontal"]:
67
- image = torch.flip(image, dims=[-1])
68
- if state["vertical"]:
69
- image = torch.flip(image, dims=[-2])
60
+ if state["horizontal"]:
61
+ image.image = torch.flip(image.image, dims=[-1])
62
+ if state["vertical"]:
63
+ image.image = torch.flip(image.image, dims=[-2])
70
64
  return image
71
65
 
72
66
  def apply_boxes(
@@ -1,7 +1,5 @@
1
1
  """Mask transform."""
2
2
 
3
- import torch
4
-
5
3
  from rslearn.train.model_context import RasterImage
6
4
  from rslearn.train.transforms.transform import Transform, read_selector
7
5
 
@@ -32,9 +30,7 @@ class Mask(Transform):
32
30
  self.mask_selector = mask_selector
33
31
  self.mask_value = mask_value
34
32
 
35
- def apply_image(
36
- self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
37
- ) -> torch.Tensor | RasterImage:
33
+ def apply_image(self, image: RasterImage, mask: RasterImage) -> RasterImage:
38
34
  """Apply the mask on the image.
39
35
 
40
36
  Args:
@@ -44,21 +40,19 @@ class Mask(Transform):
44
40
  Returns:
45
41
  masked image
46
42
  """
47
- # Tile the mask to have same number of bands as the image.
48
- if isinstance(mask, RasterImage):
49
- mask = mask.image
43
+ # Extract the mask tensor (CTHW format)
44
+ mask_tensor = mask.image
50
45
 
51
- if image.shape[0] != mask.shape[0]:
52
- if mask.shape[0] != 1:
46
+ # Tile the mask to have same number of bands (C dimension) as the image.
47
+ if image.shape[0] != mask_tensor.shape[0]:
48
+ if mask_tensor.shape[0] != 1:
53
49
  raise ValueError(
54
50
  "expected mask to either have same bands as image, or one band"
55
51
  )
56
- mask = mask.repeat(image.shape[0], 1, 1)
52
+ # Repeat along C dimension, keep T, H, W the same
53
+ mask_tensor = mask_tensor.repeat(image.shape[0], 1, 1, 1)
57
54
 
58
- if isinstance(image, torch.Tensor):
59
- image[mask == 0] = self.mask_value
60
- else:
61
- image.image[mask == 0] = self.mask_value
55
+ image.image[mask_tensor == 0] = self.mask_value
62
56
  return image
63
57
 
64
58
  def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]: