rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -365,7 +365,7 @@ class RslearnLightningModule(L.LightningModule):
365
365
  for image_suffix, image in images.items():
366
366
  out_fname = os.path.join(
367
367
  self.visualize_dir,
368
- f"{metadata.window_name}_{metadata.patch_bounds[0]}_{metadata.patch_bounds[1]}_{image_suffix}.png",
368
+ f"{metadata.window_name}_{metadata.crop_bounds[0]}_{metadata.crop_bounds[1]}_{image_suffix}.png",
369
369
  )
370
370
  Image.fromarray(image).save(out_fname)
371
371
 
@@ -43,6 +43,22 @@ class RasterImage:
43
43
  raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
44
44
  return self.image[:, 0]
45
45
 
46
+ def get_hw_tensor(self) -> torch.Tensor:
47
+ """Get a 2D HW tensor from a single-channel, single-timestep RasterImage.
48
+
49
+ This function checks that C=1 and T=1, then returns the HW tensor.
50
+ Useful for per-pixel labels like segmentation masks.
51
+ """
52
+ if self.image.shape[0] != 1:
53
+ raise ValueError(
54
+ f"Expected single channel (C=1), got {self.image.shape[0]}"
55
+ )
56
+ if self.image.shape[1] != 1:
57
+ raise ValueError(
58
+ f"Expected single timestep (T=1), got {self.image.shape[1]}"
59
+ )
60
+ return self.image[0, 0]
61
+
46
62
 
47
63
  @dataclass
48
64
  class SampleMetadata:
@@ -51,9 +67,9 @@ class SampleMetadata:
51
67
  window_group: str
52
68
  window_name: str
53
69
  window_bounds: PixelBounds
54
- patch_bounds: PixelBounds
55
- patch_idx: int
56
- num_patches_in_window: int
70
+ crop_bounds: PixelBounds
71
+ crop_idx: int
72
+ num_crops_in_window: int
57
73
  time_range: tuple[datetime, datetime] | None
58
74
  projection: Projection
59
75
 
@@ -1,6 +1,7 @@
1
1
  """rslearn PredictionWriter implementation."""
2
2
 
3
3
  import json
4
+ import warnings
4
5
  from collections.abc import Iterable, Sequence
5
6
  from dataclasses import dataclass
6
7
  from pathlib import Path
@@ -39,20 +40,20 @@ logger = get_logger(__name__)
39
40
 
40
41
 
41
42
  @dataclass
42
- class PendingPatchOutput:
43
- """A patch output that hasn't been merged yet."""
43
+ class PendingCropOutput:
44
+ """A crop output that hasn't been merged yet."""
44
45
 
45
46
  bounds: PixelBounds
46
47
  output: Any
47
48
 
48
49
 
49
- class PatchPredictionMerger:
50
- """Base class for merging predictions from multiple patches."""
50
+ class CropPredictionMerger:
51
+ """Base class for merging predictions from multiple crops."""
51
52
 
52
53
  def merge(
53
54
  self,
54
55
  window: Window,
55
- outputs: Sequence[PendingPatchOutput],
56
+ outputs: Sequence[PendingCropOutput],
56
57
  layer_config: LayerConfig,
57
58
  ) -> Any:
58
59
  """Merge the outputs.
@@ -68,39 +69,60 @@ class PatchPredictionMerger:
68
69
  raise NotImplementedError
69
70
 
70
71
 
71
- class VectorMerger(PatchPredictionMerger):
72
+ class VectorMerger(CropPredictionMerger):
72
73
  """Merger for vector data that simply concatenates the features."""
73
74
 
74
75
  def merge(
75
76
  self,
76
77
  window: Window,
77
- outputs: Sequence[PendingPatchOutput],
78
+ outputs: Sequence[PendingCropOutput],
78
79
  layer_config: LayerConfig,
79
80
  ) -> list[Feature]:
80
81
  """Concatenate the vector features."""
81
82
  return [feat for output in outputs for feat in output.output]
82
83
 
83
84
 
84
- class RasterMerger(PatchPredictionMerger):
85
+ class RasterMerger(CropPredictionMerger):
85
86
  """Merger for raster data that copies the rasters to the output."""
86
87
 
87
- def __init__(self, padding: int | None = None, downsample_factor: int = 1):
88
+ def __init__(
89
+ self,
90
+ overlap_pixels: int | None = None,
91
+ downsample_factor: int = 1,
92
+ # Deprecated parameter (for backwards compatibility)
93
+ padding: int | None = None,
94
+ ):
88
95
  """Create a new RasterMerger.
89
96
 
90
97
  Args:
91
- padding: the padding around the individual patch outputs to remove. This is
92
- typically used when leveraging overlapping patches. Portions of outputs
93
- at the border of the window will still be retained.
98
+ overlap_pixels: the number of pixels shared between adjacent crops during
99
+ sliding window inference. Half of this overlap is removed from each
100
+ crop during merging (except at window boundaries where the full crop
101
+ is retained).
94
102
  downsample_factor: the factor by which the rasters output by the task are
95
103
  lower in resolution relative to the window resolution.
104
+ padding: deprecated, use overlap_pixels instead. The old padding value
105
+ equals overlap_pixels // 2.
96
106
  """
97
- self.padding = padding
107
+ # Handle deprecated padding parameter
108
+ if padding is not None:
109
+ warnings.warn(
110
+ "padding is deprecated, use overlap_pixels instead. "
111
+ "Note: overlap_pixels = padding * 2",
112
+ FutureWarning,
113
+ stacklevel=2,
114
+ )
115
+ if overlap_pixels is not None:
116
+ raise ValueError("Cannot specify both padding and overlap_pixels")
117
+ overlap_pixels = padding * 2
118
+
119
+ self.overlap_pixels = overlap_pixels
98
120
  self.downsample_factor = downsample_factor
99
121
 
100
122
  def merge(
101
123
  self,
102
124
  window: Window,
103
- outputs: Sequence[PendingPatchOutput],
125
+ outputs: Sequence[PendingCropOutput],
104
126
  layer_config: LayerConfig,
105
127
  ) -> npt.NDArray:
106
128
  """Merge the raster outputs."""
@@ -114,6 +136,12 @@ class RasterMerger(PatchPredictionMerger):
114
136
  dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
115
137
  )
116
138
 
139
+ # Compute how many pixels to trim from each side.
140
+ # We remove half of the overlap from each side (not at window boundaries).
141
+ trim_pixels = (
142
+ self.overlap_pixels // 2 if self.overlap_pixels is not None else None
143
+ )
144
+
117
145
  # Ensure the outputs are sorted by height then width.
118
146
  # This way when we merge we can be sure that outputs that are lower or further
119
147
  # to the right will overwrite earlier outputs.
@@ -123,18 +151,18 @@ class RasterMerger(PatchPredictionMerger):
123
151
  for output in sorted_outputs:
124
152
  # So now we just need to compute the src_offset to copy.
125
153
  # If the output is not on the left or top boundary, then we should apply
126
- # the padding (if set).
154
+ # the trim (if set).
127
155
  src = output.output
128
156
  src_offset = (
129
157
  output.bounds[0] // self.downsample_factor,
130
158
  output.bounds[1] // self.downsample_factor,
131
159
  )
132
- if self.padding is not None and output.bounds[0] != window.bounds[0]:
133
- src = src[:, :, self.padding :]
134
- src_offset = (src_offset[0] + self.padding, src_offset[1])
135
- if self.padding is not None and output.bounds[1] != window.bounds[1]:
136
- src = src[:, self.padding :, :]
137
- src_offset = (src_offset[0], src_offset[1] + self.padding)
160
+ if trim_pixels is not None and output.bounds[0] != window.bounds[0]:
161
+ src = src[:, :, trim_pixels:]
162
+ src_offset = (src_offset[0] + trim_pixels, src_offset[1])
163
+ if trim_pixels is not None and output.bounds[1] != window.bounds[1]:
164
+ src = src[:, trim_pixels:, :]
165
+ src_offset = (src_offset[0], src_offset[1] + trim_pixels)
138
166
 
139
167
  copy_spatial_array(
140
168
  src=src,
@@ -162,7 +190,7 @@ class RslearnWriter(BasePredictionWriter):
162
190
  output_layer: str,
163
191
  path_options: dict[str, Any] | None = None,
164
192
  selector: list[str] | None = None,
165
- merger: PatchPredictionMerger | None = None,
193
+ merger: CropPredictionMerger | None = None,
166
194
  output_path: str | Path | None = None,
167
195
  layer_config: LayerConfig | None = None,
168
196
  storage_config: StorageConfig | None = None,
@@ -175,7 +203,7 @@ class RslearnWriter(BasePredictionWriter):
175
203
  path_options: additional options for path to pass to fsspec
176
204
  selector: keys to access the desired output in the output dict if needed.
177
205
  e.g ["key1", "key2"] gets output["key1"]["key2"]
178
- merger: merger to use to merge outputs from overlapped patches.
206
+ merger: merger to use to merge outputs from overlapped crops.
179
207
  output_path: optional custom path for writing predictions. If provided,
180
208
  predictions will be written to this path instead of deriving from dataset path.
181
209
  layer_config: optional layer configuration. If provided, this config will be
@@ -217,9 +245,9 @@ class RslearnWriter(BasePredictionWriter):
217
245
  self.merger = VectorMerger()
218
246
 
219
247
  # Map from window name to pending data to write.
220
- # This is used when windows are split up into patches, so the data from all the
221
- # patches of each window need to be reconstituted.
222
- self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
248
+ # This is used when windows are split up into crops, so the data from all the
249
+ # crops of each window need to be reconstituted.
250
+ self.pending_outputs: dict[str, list[PendingCropOutput]] = {}
223
251
 
224
252
  def _get_layer_config_and_dataset_storage(
225
253
  self,
@@ -327,7 +355,7 @@ class RslearnWriter(BasePredictionWriter):
327
355
  will be processed by the task to obtain a vector (list[Feature]) or
328
356
  raster (npt.NDArray) output.
329
357
  metadatas: corresponding list of metadatas from the batch describing the
330
- patches that were processed.
358
+ crops that were processed.
331
359
  """
332
360
  # Process the predictions into outputs that can be written.
333
361
  outputs: list = [
@@ -349,17 +377,17 @@ class RslearnWriter(BasePredictionWriter):
349
377
  )
350
378
  self.process_output(
351
379
  window,
352
- metadata.patch_idx,
353
- metadata.num_patches_in_window,
354
- metadata.patch_bounds,
380
+ metadata.crop_idx,
381
+ metadata.num_crops_in_window,
382
+ metadata.crop_bounds,
355
383
  output,
356
384
  )
357
385
 
358
386
  def process_output(
359
387
  self,
360
388
  window: Window,
361
- patch_idx: int,
362
- num_patches: int,
389
+ crop_idx: int,
390
+ num_crops: int,
363
391
  cur_bounds: PixelBounds,
364
392
  output: npt.NDArray | list[Feature],
365
393
  ) -> None:
@@ -367,28 +395,28 @@ class RslearnWriter(BasePredictionWriter):
367
395
 
368
396
  Args:
369
397
  window: the window that the output pertains to.
370
- patch_idx: the index of this patch for the window.
371
- num_patches: the total number of patches to be processed for the window.
372
- cur_bounds: the bounds of the current patch.
398
+ crop_idx: the index of this crop for the window.
399
+ num_crops: the total number of crops to be processed for the window.
400
+ cur_bounds: the bounds of the current crop.
373
401
  output: the output data.
374
402
  """
375
- # Incorporate the output into our list of pending patch outputs.
403
+ # Incorporate the output into our list of pending crop outputs.
376
404
  if window.name not in self.pending_outputs:
377
405
  self.pending_outputs[window.name] = []
378
- self.pending_outputs[window.name].append(PendingPatchOutput(cur_bounds, output))
406
+ self.pending_outputs[window.name].append(PendingCropOutput(cur_bounds, output))
379
407
  logger.debug(
380
- f"Stored PendingPatchOutput for patch #{patch_idx}/{num_patches} at window {window.name}"
408
+ f"Stored PendingCropOutput for crop #{crop_idx}/{num_crops} at window {window.name}"
381
409
  )
382
410
 
383
- if patch_idx < num_patches - 1:
411
+ if crop_idx < num_crops - 1:
384
412
  return
385
413
 
386
- # This is the last patch so it's time to write it.
414
+ # This is the last crop so it's time to write it.
387
415
  # First get the pending output and clear it.
388
416
  pending_output = self.pending_outputs[window.name]
389
417
  del self.pending_outputs[window.name]
390
418
 
391
- # Merge outputs from overlapped patches if merger is set.
419
+ # Merge outputs from overlapped crops if merger is set.
392
420
  logger.debug(f"Merging and writing for window {window.name}")
393
421
  merged_output = self.merger.merge(window, pending_output, self.layer_config)
394
422
 
@@ -201,7 +201,7 @@ class ClassificationTask(BasicTask):
201
201
  feature = Feature(
202
202
  STGeometry(
203
203
  metadata.projection,
204
- shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
204
+ shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
205
205
  None,
206
206
  ),
207
207
  {
@@ -128,7 +128,7 @@ class DetectionTask(BasicTask):
128
128
  if not load_targets:
129
129
  return {}, {}
130
130
 
131
- bounds = metadata.patch_bounds
131
+ bounds = metadata.crop_bounds
132
132
 
133
133
  boxes = []
134
134
  class_labels = []
@@ -244,10 +244,10 @@ class DetectionTask(BasicTask):
244
244
  features = []
245
245
  for box, class_id, score in zip(boxes, class_ids, scores):
246
246
  shp = shapely.box(
247
- metadata.patch_bounds[0] + float(box[0]),
248
- metadata.patch_bounds[1] + float(box[1]),
249
- metadata.patch_bounds[0] + float(box[2]),
250
- metadata.patch_bounds[1] + float(box[3]),
247
+ metadata.crop_bounds[0] + float(box[0]),
248
+ metadata.crop_bounds[1] + float(box[1]),
249
+ metadata.crop_bounds[0] + float(box[2]),
250
+ metadata.crop_bounds[1] + float(box[3]),
251
251
  )
252
252
  geom = STGeometry(metadata.projection, shp, None)
253
253
  properties: dict[str, Any] = {
@@ -66,20 +66,18 @@ class PerPixelRegressionTask(BasicTask):
66
66
  return {}, {}
67
67
 
68
68
  assert isinstance(raw_inputs["targets"], RasterImage)
69
- assert raw_inputs["targets"].image.shape[0] == 1
70
- assert raw_inputs["targets"].image.shape[1] == 1
71
- labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
69
+ labels = raw_inputs["targets"].get_hw_tensor().float() * self.scale_factor
72
70
 
73
71
  if self.nodata_value is not None:
74
- valid = (
75
- raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
76
- ).float()
72
+ valid = (raw_inputs["targets"].get_hw_tensor() != self.nodata_value).float()
77
73
  else:
78
74
  valid = torch.ones(labels.shape, dtype=torch.float32)
79
75
 
76
+ # Wrap in RasterImage with CTHW format (C=1, T=1) so values and valid can be
77
+ # used in image transforms.
80
78
  return {}, {
81
- "values": labels,
82
- "valid": valid,
79
+ "values": RasterImage(labels[None, None, :, :], timestamps=None),
80
+ "valid": RasterImage(valid[None, None, :, :], timestamps=None),
83
81
  }
84
82
 
85
83
  def process_output(
@@ -121,7 +119,7 @@ class PerPixelRegressionTask(BasicTask):
121
119
  image = super().visualize(input_dict, target_dict, output)["image"]
122
120
  if target_dict is None:
123
121
  raise ValueError("target_dict is required for visualization")
124
- gt_values = target_dict["classes"].cpu().numpy()
122
+ gt_values = target_dict["values"].get_hw_tensor().cpu().numpy()
125
123
  pred_values = output.cpu().numpy()[0, :, :]
126
124
  gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
127
125
  pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
@@ -210,8 +208,10 @@ class PerPixelRegressionHead(Predictor):
210
208
 
211
209
  losses = {}
212
210
  if targets:
213
- labels = torch.stack([target["values"] for target in targets])
214
- mask = torch.stack([target["valid"] for target in targets])
211
+ labels = torch.stack(
212
+ [target["values"].get_hw_tensor() for target in targets]
213
+ )
214
+ mask = torch.stack([target["valid"].get_hw_tensor() for target in targets])
215
215
 
216
216
  if self.loss_mode == "mse":
217
217
  scores = torch.square(outputs - labels)
@@ -262,14 +262,14 @@ class PerPixelRegressionMetricWrapper(Metric):
262
262
  """
263
263
  if not isinstance(preds, torch.Tensor):
264
264
  preds = torch.stack(preds)
265
- labels = torch.stack([target["values"] for target in targets])
265
+ labels = torch.stack([target["values"].get_hw_tensor() for target in targets])
266
266
 
267
267
  # Sub-select the valid labels.
268
268
  # We flatten the prediction and label images at valid pixels.
269
269
  if len(preds.shape) == 4:
270
270
  assert preds.shape[1] == 1
271
271
  preds = preds[:, 0, :, :]
272
- mask = torch.stack([target["valid"] > 0 for target in targets])
272
+ mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
273
273
  preds = preds[mask]
274
274
  labels = labels[mask]
275
275
  if len(preds) == 0:
@@ -130,7 +130,7 @@ class RegressionTask(BasicTask):
130
130
  feature = Feature(
131
131
  STGeometry(
132
132
  metadata.projection,
133
- shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
133
+ shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
134
134
  None,
135
135
  ),
136
136
  {
@@ -128,9 +128,7 @@ class SegmentationTask(BasicTask):
128
128
  return {}, {}
129
129
 
130
130
  assert isinstance(raw_inputs["targets"], RasterImage)
131
- assert raw_inputs["targets"].image.shape[0] == 1
132
- assert raw_inputs["targets"].image.shape[1] == 1
133
- labels = raw_inputs["targets"].image[0, 0, :, :].long()
131
+ labels = raw_inputs["targets"].get_hw_tensor().long()
134
132
 
135
133
  if self.class_id_mapping is not None:
136
134
  new_labels = labels.clone()
@@ -146,9 +144,11 @@ class SegmentationTask(BasicTask):
146
144
  else:
147
145
  valid = torch.ones(labels.shape, dtype=torch.float32)
148
146
 
147
+ # Wrap in RasterImage with CTHW format (C=1, T=1) so classes and valid can be
148
+ # used in image transforms.
149
149
  return {}, {
150
- "classes": labels,
151
- "valid": valid,
150
+ "classes": RasterImage(labels[None, None, :, :], timestamps=None),
151
+ "valid": RasterImage(valid[None, None, :, :], timestamps=None),
152
152
  }
153
153
 
154
154
  def process_output(
@@ -206,7 +206,7 @@ class SegmentationTask(BasicTask):
206
206
  image = super().visualize(input_dict, target_dict, output)["image"]
207
207
  if target_dict is None:
208
208
  raise ValueError("target_dict is required for visualization")
209
- gt_classes = target_dict["classes"].cpu().numpy()
209
+ gt_classes = target_dict["classes"].get_hw_tensor().cpu().numpy()
210
210
  pred_classes = output.cpu().numpy().argmax(axis=0)
211
211
  gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
212
212
  pred_vis = np.zeros(
@@ -291,12 +291,19 @@ class SegmentationTask(BasicTask):
291
291
  class SegmentationHead(Predictor):
292
292
  """Head for segmentation task."""
293
293
 
294
- def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
294
+ def __init__(
295
+ self,
296
+ weights: list[float] | None = None,
297
+ dice_loss: bool = False,
298
+ temperature: float = 1.0,
299
+ ):
295
300
  """Initialize a new SegmentationTask.
296
301
 
297
302
  Args:
298
303
  weights: weights for cross entropy loss (Tensor of size C)
299
304
  dice_loss: weather to add dice loss to cross entropy
305
+ temperature: temperature scaling for softmax, does not affect the loss,
306
+ only the predictor outputs
300
307
  """
301
308
  super().__init__()
302
309
  if weights is not None:
@@ -304,6 +311,7 @@ class SegmentationHead(Predictor):
304
311
  else:
305
312
  self.weights = None
306
313
  self.dice_loss = dice_loss
314
+ self.temperature = temperature
307
315
 
308
316
  def forward(
309
317
  self,
@@ -332,12 +340,16 @@ class SegmentationHead(Predictor):
332
340
  )
333
341
 
334
342
  logits = intermediates.feature_maps[0]
335
- outputs = torch.nn.functional.softmax(logits, dim=1)
343
+ outputs = torch.nn.functional.softmax(logits / self.temperature, dim=1)
336
344
 
337
345
  losses = {}
338
346
  if targets:
339
- labels = torch.stack([target["classes"] for target in targets], dim=0)
340
- mask = torch.stack([target["valid"] for target in targets], dim=0)
347
+ labels = torch.stack(
348
+ [target["classes"].get_hw_tensor() for target in targets], dim=0
349
+ )
350
+ mask = torch.stack(
351
+ [target["valid"].get_hw_tensor() for target in targets], dim=0
352
+ )
341
353
  per_pixel_loss = torch.nn.functional.cross_entropy(
342
354
  logits, labels, weight=self.weights, reduction="none"
343
355
  )
@@ -350,7 +362,8 @@ class SegmentationHead(Predictor):
350
362
  # the summed mask loss be zero.
351
363
  losses["cls"] = torch.sum(per_pixel_loss * mask)
352
364
  if self.dice_loss:
353
- dice_loss = DiceLoss()(outputs, labels, mask)
365
+ softmax_woT = torch.nn.functional.softmax(logits, dim=1)
366
+ dice_loss = DiceLoss()(softmax_woT, labels, mask)
354
367
  losses["dice"] = dice_loss
355
368
 
356
369
  return ModelOutput(
@@ -401,12 +414,12 @@ class SegmentationMetric(Metric):
401
414
  """
402
415
  if not isinstance(preds, torch.Tensor):
403
416
  preds = torch.stack(preds)
404
- labels = torch.stack([target["classes"] for target in targets])
417
+ labels = torch.stack([target["classes"].get_hw_tensor() for target in targets])
405
418
 
406
419
  # Sub-select the valid labels.
407
420
  # We flatten the prediction and label images at valid pixels.
408
421
  # Prediction is changed from BCHW to BHWC so we can select the valid BHW mask.
409
- mask = torch.stack([target["valid"] > 0 for target in targets])
422
+ mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
410
423
  preds = preds.permute(0, 2, 3, 1)[mask]
411
424
  labels = labels[mask]
412
425
  if len(preds) == 0:
@@ -54,36 +54,26 @@ class Concatenate(Transform):
54
54
  target_dict: the target
55
55
 
56
56
  Returns:
57
- concatenated (input_dicts, target_dicts) tuple. If one of the
58
- specified inputs is a RasterImage, a RasterImage will be returned.
59
- Otherwise it will be a torch.Tensor.
57
+ (input_dicts, target_dicts) where the entry corresponding to
58
+ output_selector contains the concatenated RasterImage.
60
59
  """
61
- images = []
62
- return_raster_image: bool = False
60
+ tensors: list[torch.Tensor] = []
63
61
  timestamps: list[tuple[datetime, datetime]] | None = None
62
+
64
63
  for selector, wanted_bands in self.selections.items():
65
64
  image = read_selector(input_dict, target_dict, selector)
66
- if isinstance(image, torch.Tensor):
67
- if wanted_bands:
68
- image = image[wanted_bands, :, :]
69
- images.append(image)
70
- elif isinstance(image, RasterImage):
71
- return_raster_image = True
72
- if wanted_bands:
73
- images.append(image.image[wanted_bands, :, :])
74
- else:
75
- images.append(image.image)
76
- if timestamps is None:
77
- if image.timestamps is not None:
78
- # assume all concatenated modalities have the same
79
- # number of timestamps
80
- timestamps = image.timestamps
81
- if return_raster_image:
82
- result = RasterImage(
83
- torch.concatenate(images, dim=self.concatenate_dim),
84
- timestamps=timestamps,
85
- )
86
- else:
87
- result = torch.concatenate(images, dim=self.concatenate_dim)
65
+ if wanted_bands:
66
+ tensors.append(image.image[wanted_bands, :, :])
67
+ else:
68
+ tensors.append(image.image)
69
+ if timestamps is None and image.timestamps is not None:
70
+ # assume all concatenated modalities have the same
71
+ # number of timestamps
72
+ timestamps = image.timestamps
73
+
74
+ result = RasterImage(
75
+ torch.concatenate(tensors, dim=self.concatenate_dim),
76
+ timestamps=timestamps,
77
+ )
88
78
  write_selector(input_dict, target_dict, self.output_selector, result)
89
79
  return input_dict, target_dict
@@ -71,9 +71,7 @@ class Crop(Transform):
71
71
  "remove_from_top": remove_from_top,
72
72
  }
73
73
 
74
- def apply_image(
75
- self, image: RasterImage | torch.Tensor, state: dict[str, Any]
76
- ) -> RasterImage | torch.Tensor:
74
+ def apply_image(self, image: RasterImage, state: dict[str, Any]) -> RasterImage:
77
75
  """Apply the sampled state on the specified image.
78
76
 
79
77
  Args:
@@ -84,22 +82,13 @@ class Crop(Transform):
84
82
  crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
85
83
  remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
86
84
  remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
87
- if isinstance(image, RasterImage):
88
- image.image = torchvision.transforms.functional.crop(
89
- image.image,
90
- top=remove_from_top,
91
- left=remove_from_left,
92
- height=crop_size,
93
- width=crop_size,
94
- )
95
- else:
96
- image = torchvision.transforms.functional.crop(
97
- image,
98
- top=remove_from_top,
99
- left=remove_from_left,
100
- height=crop_size,
101
- width=crop_size,
102
- )
85
+ image.image = torchvision.transforms.functional.crop(
86
+ image.image,
87
+ top=remove_from_top,
88
+ left=remove_from_left,
89
+ height=crop_size,
90
+ width=crop_size,
91
+ )
103
92
  return image
104
93
 
105
94
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -57,16 +57,10 @@ class Flip(Transform):
57
57
  image: the image to transform.
58
58
  state: the sampled state.
59
59
  """
60
- if isinstance(image, RasterImage):
61
- if state["horizontal"]:
62
- image.image = torch.flip(image.image, dims=[-1])
63
- if state["vertical"]:
64
- image.image = torch.flip(image.image, dims=[-2])
65
- elif isinstance(image, torch.Tensor):
66
- if state["horizontal"]:
67
- image = torch.flip(image, dims=[-1])
68
- if state["vertical"]:
69
- image = torch.flip(image, dims=[-2])
60
+ if state["horizontal"]:
61
+ image.image = torch.flip(image.image, dims=[-1])
62
+ if state["vertical"]:
63
+ image.image = torch.flip(image.image, dims=[-2])
70
64
  return image
71
65
 
72
66
  def apply_boxes(