rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """Segmentation task."""
2
2
 
3
+ from collections.abc import Mapping
3
4
  from typing import Any
4
5
 
5
6
  import numpy as np
@@ -8,26 +9,34 @@ import torch
8
9
  import torchmetrics.classification
9
10
  from torchmetrics import Metric, MetricCollection
10
11
 
12
+ from rslearn.models.component import FeatureMaps, Predictor
13
+ from rslearn.train.model_context import (
14
+ ModelContext,
15
+ ModelOutput,
16
+ RasterImage,
17
+ SampleMetadata,
18
+ )
11
19
  from rslearn.utils import Feature
12
20
 
13
21
  from .task import BasicTask
14
22
 
23
+ # TODO: This is duplicated code fix it
15
24
  DEFAULT_COLORS = [
16
- [255, 0, 0],
17
- [0, 255, 0],
18
- [0, 0, 255],
19
- [255, 255, 0],
20
- [0, 255, 255],
21
- [255, 0, 255],
22
- [0, 128, 0],
23
- [255, 160, 122],
24
- [139, 69, 19],
25
- [128, 128, 128],
26
- [255, 255, 255],
27
- [143, 188, 143],
28
- [95, 158, 160],
29
- [255, 200, 0],
30
- [128, 0, 0],
25
+ (255, 0, 0),
26
+ (0, 255, 0),
27
+ (0, 0, 255),
28
+ (255, 255, 0),
29
+ (0, 255, 255),
30
+ (255, 0, 255),
31
+ (0, 128, 0),
32
+ (255, 160, 122),
33
+ (139, 69, 19),
34
+ (128, 128, 128),
35
+ (255, 255, 255),
36
+ (143, 188, 143),
37
+ (95, 158, 160),
38
+ (255, 200, 0),
39
+ (128, 0, 0),
31
40
  ]
32
41
 
33
42
 
@@ -37,31 +46,77 @@ class SegmentationTask(BasicTask):
37
46
  def __init__(
38
47
  self,
39
48
  num_classes: int,
49
+ class_id_mapping: dict[int, int] | None = None,
40
50
  colors: list[tuple[int, int, int]] = DEFAULT_COLORS,
41
51
  zero_is_invalid: bool = False,
52
+ nodata_value: int | None = None,
53
+ enable_accuracy_metric: bool = True,
54
+ enable_miou_metric: bool = False,
55
+ enable_f1_metric: bool = False,
56
+ f1_metric_thresholds: list[list[float]] = [[0.5]],
42
57
  metric_kwargs: dict[str, Any] = {},
43
- **kwargs,
44
- ):
58
+ miou_metric_kwargs: dict[str, Any] = {},
59
+ prob_scales: list[float] | None = None,
60
+ other_metrics: dict[str, Metric] = {},
61
+ **kwargs: Any,
62
+ ) -> None:
45
63
  """Initialize a new SegmentationTask.
46
64
 
47
65
  Args:
48
66
  num_classes: the number of classes to predict
49
67
  colors: optional colors for each class
50
68
  zero_is_invalid: whether pixels labeled class 0 should be marked invalid
69
+ Mutually exclusive with nodata_value.
70
+ nodata_value: the value to use for nodata pixels. If None, all pixels are
71
+ considered valid. Mutually exclusive with zero_is_invalid.
72
+ class_id_mapping: optional mapping from original class IDs to new class IDs.
73
+ If provided, class labels will be remapped according to this dictionary.
74
+ enable_accuracy_metric: whether to enable the accuracy metric (default
75
+ true).
76
+ enable_f1_metric: whether to enable the F1 metric (default false).
77
+ enable_miou_metric: whether to enable the mean IoU metric (default false).
78
+ f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
79
+ Each inner list is used to initialize a separate F1 metric where the
80
+ best F1 across the thresholds within the inner list is computed. If
81
+ there are multiple inner lists, then multiple F1 scores will be
82
+ reported.
51
83
  metric_kwargs: additional arguments to pass to underlying metric, see
52
84
  torchmetrics.classification.MulticlassAccuracy.
85
+ miou_metric_kwargs: additional arguments to pass to MeanIoUMetric, if
86
+ enable_miou_metric is passed.
87
+ prob_scales: during inference, scale the output probabilities by this much
88
+ before computing the argmax. There is one scale per class. Note that
89
+ this is only applied during prediction, not when computing val or test
90
+ metrics.
91
+ other_metrics: additional metrics to configure on this task.
53
92
  kwargs: additional arguments to pass to BasicTask
54
93
  """
55
94
  super().__init__(**kwargs)
56
95
  self.num_classes = num_classes
96
+ self.class_id_mapping = class_id_mapping
57
97
  self.colors = colors
58
- self.zero_is_invalid = zero_is_invalid
98
+ self.nodata_value: int | None
99
+
100
+ if zero_is_invalid and nodata_value is not None:
101
+ raise ValueError("zero_is_invalid and nodata_value cannot both be set")
102
+ if zero_is_invalid:
103
+ self.nodata_value = 0
104
+ else:
105
+ self.nodata_value = nodata_value
106
+
107
+ self.enable_accuracy_metric = enable_accuracy_metric
108
+ self.enable_f1_metric = enable_f1_metric
109
+ self.enable_miou_metric = enable_miou_metric
110
+ self.f1_metric_thresholds = f1_metric_thresholds
59
111
  self.metric_kwargs = metric_kwargs
112
+ self.miou_metric_kwargs = miou_metric_kwargs
113
+ self.prob_scales = prob_scales
114
+ self.other_metrics = other_metrics
60
115
 
61
116
  def process_inputs(
62
117
  self,
63
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
64
- metadata: dict[str, Any],
118
+ raw_inputs: Mapping[str, RasterImage | list[Feature]],
119
+ metadata: SampleMetadata,
65
120
  load_targets: bool = True,
66
121
  ) -> tuple[dict[str, Any], dict[str, Any]]:
67
122
  """Processes the data into targets.
@@ -78,11 +133,22 @@ class SegmentationTask(BasicTask):
78
133
  if not load_targets:
79
134
  return {}, {}
80
135
 
81
- assert raw_inputs["targets"].shape[0] == 1
82
- labels = raw_inputs["targets"][0, :, :].long()
83
-
84
- if self.zero_is_invalid:
85
- valid = (labels > 0).float()
136
+ assert isinstance(raw_inputs["targets"], RasterImage)
137
+ assert raw_inputs["targets"].image.shape[0] == 1
138
+ assert raw_inputs["targets"].image.shape[1] == 1
139
+ labels = raw_inputs["targets"].image[0, 0, :, :].long()
140
+
141
+ if self.class_id_mapping is not None:
142
+ new_labels = labels.clone()
143
+ for old_id, new_id in self.class_id_mapping.items():
144
+ new_labels[labels == old_id] = new_id
145
+ labels = new_labels
146
+
147
+ if self.nodata_value is not None:
148
+ valid = (labels != self.nodata_value).float()
149
+ # Labels, even masked ones, must be in the range 0 to num_classes-1
150
+ if self.nodata_value >= self.num_classes:
151
+ labels[labels == self.nodata_value] = 0
86
152
  else:
87
153
  valid = torch.ones(labels.shape, dtype=torch.float32)
88
154
 
@@ -92,18 +158,28 @@ class SegmentationTask(BasicTask):
92
158
  }
93
159
 
94
160
  def process_output(
95
- self, raw_output: Any, metadata: dict[str, Any]
96
- ) -> npt.NDArray[Any] | list[Feature]:
161
+ self, raw_output: Any, metadata: SampleMetadata
162
+ ) -> npt.NDArray[Any]:
97
163
  """Processes an output into raster or vector data.
98
164
 
99
165
  Args:
100
- raw_output: the output from prediction head.
166
+ raw_output: the output from prediction head, which must be a CHW tensor.
101
167
  metadata: metadata about the patch being read
102
168
 
103
169
  Returns:
104
- either raster or vector data.
170
+ CHW numpy array with one channel, containing the predicted class IDs.
105
171
  """
106
- classes = raw_output.cpu().numpy().argmax(axis=0).astype(np.uint8)
172
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
173
+ raise ValueError("the output for SegmentationTask must be a CHW tensor")
174
+
175
+ if self.prob_scales is not None:
176
+ raw_output = (
177
+ raw_output
178
+ * torch.tensor(
179
+ self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
180
+ )[:, None, None]
181
+ )
182
+ classes = raw_output.argmax(dim=0).cpu().numpy()
107
183
  return classes[None, :, :]
108
184
 
109
185
  def visualize(
@@ -123,6 +199,8 @@ class SegmentationTask(BasicTask):
123
199
  a dictionary mapping image name to visualization image
124
200
  """
125
201
  image = super().visualize(input_dict, target_dict, output)["image"]
202
+ if target_dict is None:
203
+ raise ValueError("target_dict is required for visualization")
126
204
  gt_classes = target_dict["classes"].cpu().numpy()
127
205
  pred_classes = output.cpu().numpy().argmax(axis=0)
128
206
  gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
@@ -143,57 +221,136 @@ class SegmentationTask(BasicTask):
143
221
  def get_metrics(self) -> MetricCollection:
144
222
  """Get the metrics for this task."""
145
223
  metrics = {}
146
- metric_kwargs = dict(num_classes=self.num_classes)
147
- metric_kwargs.update(self.metric_kwargs)
148
- metrics["accuracy"] = SegmentationMetric(
149
- torchmetrics.classification.MulticlassAccuracy(**metric_kwargs)
150
- )
224
+
225
+ if self.enable_accuracy_metric:
226
+ accuracy_metric_kwargs = dict(num_classes=self.num_classes)
227
+ accuracy_metric_kwargs.update(self.metric_kwargs)
228
+ metrics["accuracy"] = SegmentationMetric(
229
+ torchmetrics.classification.MulticlassAccuracy(**accuracy_metric_kwargs)
230
+ )
231
+
232
+ if self.enable_f1_metric:
233
+ for thresholds in self.f1_metric_thresholds:
234
+ if len(self.f1_metric_thresholds) == 1:
235
+ suffix = ""
236
+ else:
237
+ # Metric name can't contain "." so change to ",".
238
+ suffix = "_" + str(thresholds[0]).replace(".", ",")
239
+
240
+ metrics["F1" + suffix] = SegmentationMetric(
241
+ F1Metric(num_classes=self.num_classes, score_thresholds=thresholds)
242
+ )
243
+ metrics["precision" + suffix] = SegmentationMetric(
244
+ F1Metric(
245
+ num_classes=self.num_classes,
246
+ score_thresholds=thresholds,
247
+ metric_mode="precision",
248
+ )
249
+ )
250
+ metrics["recall" + suffix] = SegmentationMetric(
251
+ F1Metric(
252
+ num_classes=self.num_classes,
253
+ score_thresholds=thresholds,
254
+ metric_mode="recall",
255
+ )
256
+ )
257
+
258
+ if self.enable_miou_metric:
259
+ miou_metric_kwargs: dict[str, Any] = dict(num_classes=self.num_classes)
260
+ if self.nodata_value is not None:
261
+ miou_metric_kwargs["nodata_value"] = self.nodata_value
262
+ miou_metric_kwargs.update(self.miou_metric_kwargs)
263
+ metrics["mean_iou"] = SegmentationMetric(
264
+ MeanIoUMetric(**miou_metric_kwargs),
265
+ pass_probabilities=False,
266
+ )
267
+
268
+ if self.other_metrics:
269
+ metrics.update(self.other_metrics)
270
+
151
271
  return MetricCollection(metrics)
152
272
 
153
273
 
154
- class SegmentationHead(torch.nn.Module):
274
+ class SegmentationHead(Predictor):
155
275
  """Head for segmentation task."""
156
276
 
157
277
  def forward(
158
278
  self,
159
- logits: torch.Tensor,
160
- inputs: list[dict[str, Any]],
279
+ intermediates: Any,
280
+ context: ModelContext,
161
281
  targets: list[dict[str, Any]] | None = None,
162
- ):
282
+ ) -> ModelOutput:
163
283
  """Compute the segmentation outputs from logits and targets.
164
284
 
165
285
  Args:
166
- logits: tensor that is (BatchSize, NumClasses, Height, Width) in shape.
167
- inputs: original inputs (ignored).
168
- targets: should contain classes key that stores the per-pixel class labels.
286
+ intermediates: a FeatureMaps with a single feature map containing the
287
+ segmentation logits.
288
+ context: the model context
289
+ targets: list of target dicts, where each target dict must contain a key
290
+ "classes" containing the per-pixel class labels, along with "valid"
291
+ containing a mask indicating where the example is valid.
169
292
 
170
293
  Returns:
171
294
  tuple of outputs and loss dict
172
295
  """
296
+ if not isinstance(intermediates, FeatureMaps):
297
+ raise ValueError("input to SegmentationHead must be a FeatureMaps")
298
+ if len(intermediates.feature_maps) != 1:
299
+ raise ValueError(
300
+ f"input to SegmentationHead must have one feature map, but got {len(intermediates.feature_maps)}"
301
+ )
302
+
303
+ logits = intermediates.feature_maps[0]
173
304
  outputs = torch.nn.functional.softmax(logits, dim=1)
174
305
 
175
- loss = None
306
+ losses = {}
176
307
  if targets:
177
308
  labels = torch.stack([target["classes"] for target in targets], dim=0)
178
309
  mask = torch.stack([target["valid"] for target in targets], dim=0)
179
- loss = (
180
- torch.nn.functional.cross_entropy(logits, labels, reduction="none")
181
- * mask
310
+ per_pixel_loss = torch.nn.functional.cross_entropy(
311
+ logits, labels, reduction="none"
182
312
  )
183
- loss = torch.mean(loss)
184
-
185
- return outputs, {"cls": loss}
313
+ mask_sum = torch.sum(mask)
314
+ if mask_sum > 0:
315
+ # Compute average loss over valid pixels.
316
+ losses["cls"] = torch.sum(per_pixel_loss * mask) / torch.sum(mask)
317
+ else:
318
+ # If there are no valid pixels, we avoid dividing by zero and just let
319
+ # the summed mask loss be zero.
320
+ losses["cls"] = torch.sum(per_pixel_loss * mask)
321
+
322
+ return ModelOutput(
323
+ outputs=outputs,
324
+ loss_dict=losses,
325
+ )
186
326
 
187
327
 
188
328
  class SegmentationMetric(Metric):
189
329
  """Metric for segmentation task."""
190
330
 
191
- def __init__(self, metric: Metric):
192
- """Initialize a new SegmentationMetric."""
331
+ def __init__(
332
+ self,
333
+ metric: Metric,
334
+ pass_probabilities: bool = True,
335
+ class_idx: int | None = None,
336
+ ):
337
+ """Initialize a new SegmentationMetric.
338
+
339
+ Args:
340
+ metric: the metric to wrap. This wrapping class will handle selecting the
341
+ classes from the targets and masking out invalid pixels.
342
+ pass_probabilities: whether to pass predicted probabilities to the metric.
343
+ If False, argmax is applied to pass the predicted classes instead.
344
+ class_idx: if metric returns value for multiple classes, select this class.
345
+ """
193
346
  super().__init__()
194
347
  self.metric = metric
348
+ self.pass_probablities = pass_probabilities
349
+ self.class_idx = class_idx
195
350
 
196
- def update(self, preds: list[Any], targets: list[dict[str, Any]]) -> None:
351
+ def update(
352
+ self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
353
+ ) -> None:
197
354
  """Update metric.
198
355
 
199
356
  Args:
@@ -213,11 +370,17 @@ class SegmentationMetric(Metric):
213
370
  if len(preds) == 0:
214
371
  return
215
372
 
373
+ if not self.pass_probablities:
374
+ preds = preds.argmax(dim=1)
375
+
216
376
  self.metric.update(preds, labels)
217
377
 
218
378
  def compute(self) -> Any:
219
379
  """Returns the computed metric."""
220
- return self.metric.compute()
380
+ result = self.metric.compute()
381
+ if self.class_idx is not None:
382
+ result = result[self.class_idx]
383
+ return result
221
384
 
222
385
  def reset(self) -> None:
223
386
  """Reset metric."""
@@ -227,3 +390,215 @@ class SegmentationMetric(Metric):
227
390
  def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
228
391
  """Returns a plot of the metric."""
229
392
  return self.metric.plot(*args, **kwargs)
393
+
394
+
395
+ class F1Metric(Metric):
396
+ """F1 score for segmentation.
397
+
398
+ It treats each class as a separate prediction task, and computes the maximum F1
399
+ score under the different configured thresholds per-class.
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ num_classes: int,
405
+ score_thresholds: list[float],
406
+ metric_mode: str = "f1",
407
+ ):
408
+ """Create a new F1Metric.
409
+
410
+ Args:
411
+ num_classes: number of classes.
412
+ score_thresholds: list of score thresholds to check F1 score for. The final
413
+ metric is the best F1 across score thresholds.
414
+ metric_mode: set to "precision" or "recall" to return that instead of F1
415
+ (default "f1")
416
+ """
417
+ super().__init__()
418
+ self.num_classes = num_classes
419
+ self.score_thresholds = score_thresholds
420
+ self.metric_mode = metric_mode
421
+
422
+ assert self.metric_mode in ["f1", "precision", "recall"]
423
+
424
+ for cls_idx in range(self.num_classes):
425
+ for thr_idx in range(len(self.score_thresholds)):
426
+ cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
427
+ self.add_state(
428
+ cur_prefix + "tp", default=torch.tensor(0), dist_reduce_fx="sum"
429
+ )
430
+ self.add_state(
431
+ cur_prefix + "fp", default=torch.tensor(0), dist_reduce_fx="sum"
432
+ )
433
+ self.add_state(
434
+ cur_prefix + "fn", default=torch.tensor(0), dist_reduce_fx="sum"
435
+ )
436
+
437
+ def _get_state_prefix(self, cls_idx: int, thr_idx: int) -> str:
438
+ return f"{cls_idx}_{thr_idx}_"
439
+
440
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
441
+ """Update metric.
442
+
443
+ Args:
444
+ preds: the predictions, NxC.
445
+ labels: the targets, N, with values from 0 to C-1.
446
+ """
447
+ for cls_idx in range(self.num_classes):
448
+ for thr_idx, score_threshold in enumerate(self.score_thresholds):
449
+ pred_bin = preds[:, cls_idx] > score_threshold
450
+ gt_bin = labels == cls_idx
451
+
452
+ tp = torch.count_nonzero(pred_bin & gt_bin).item()
453
+ fp = torch.count_nonzero(pred_bin & torch.logical_not(gt_bin)).item()
454
+ fn = torch.count_nonzero(torch.logical_not(pred_bin) & gt_bin).item()
455
+
456
+ cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
457
+ setattr(self, cur_prefix + "tp", getattr(self, cur_prefix + "tp") + tp)
458
+ setattr(self, cur_prefix + "fp", getattr(self, cur_prefix + "fp") + fp)
459
+ setattr(self, cur_prefix + "fn", getattr(self, cur_prefix + "fn") + fn)
460
+
461
+ def compute(self) -> Any:
462
+ """Compute metric.
463
+
464
+ Returns:
465
+ the best F1 score across score thresholds and classes.
466
+ """
467
+ best_scores = []
468
+
469
+ for cls_idx in range(self.num_classes):
470
+ best_score = None
471
+
472
+ for thr_idx in range(len(self.score_thresholds)):
473
+ cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
474
+ tp = getattr(self, cur_prefix + "tp")
475
+ fp = getattr(self, cur_prefix + "fp")
476
+ fn = getattr(self, cur_prefix + "fn")
477
+ device = tp.device
478
+
479
+ if tp + fp == 0:
480
+ precision = torch.tensor(0, dtype=torch.float32, device=device)
481
+ else:
482
+ precision = tp / (tp + fp)
483
+
484
+ if tp + fn == 0:
485
+ recall = torch.tensor(0, dtype=torch.float32, device=device)
486
+ else:
487
+ recall = tp / (tp + fn)
488
+
489
+ if precision + recall < 0.001:
490
+ f1 = torch.tensor(0, dtype=torch.float32, device=device)
491
+ else:
492
+ f1 = 2 * precision * recall / (precision + recall)
493
+
494
+ if self.metric_mode == "f1":
495
+ score = f1
496
+ elif self.metric_mode == "precision":
497
+ score = precision
498
+ elif self.metric_mode == "recall":
499
+ score = recall
500
+
501
+ if best_score is None or score > best_score:
502
+ best_score = score
503
+
504
+ best_scores.append(best_score)
505
+
506
+ return torch.mean(torch.stack(best_scores))
507
+
508
+
509
+ class MeanIoUMetric(Metric):
510
+ """Mean IoU for segmentation.
511
+
512
+ This is the mean of the per-class intersection-over-union scores. The per-class
513
+ intersection is the number of pixels across all examples where the predicted label
514
+ and ground truth label are both that class, and the per-class union is defined
515
+ similarly.
516
+
517
+ This differs from torchmetrics.segmentation.MeanIoU, where the mean IoU is computed
518
+ per-image, and averaged across images.
519
+ """
520
+
521
+ def __init__(
522
+ self,
523
+ num_classes: int,
524
+ nodata_value: int | None = None,
525
+ ignore_missing_classes: bool = False,
526
+ class_idx: int | None = None,
527
+ ):
528
+ """Create a new MeanIoUMetric.
529
+
530
+ Args:
531
+ num_classes: the number of classes for the task.
532
+ nodata_value: the value to treat as nodata/invalid. If set and is one of the
533
+ classes, IoU will not be calculated for it. If None, or not one of the
534
+ classes, IoU is calculated for all classes.
535
+ ignore_missing_classes: whether to ignore classes that don't appear in
536
+ either the predictions or the ground truth. If false, the IoU for a
537
+ missing class will be 0.
538
+ class_idx: only compute and return the IoU for this class. This option is
539
+ provided so the user can get per-class IoU results, since Lightning
540
+ only supports scalar return values from metrics.
541
+ """
542
+ super().__init__()
543
+ self.num_classes = num_classes
544
+ self.nodata_value = nodata_value
545
+ self.ignore_missing_classes = ignore_missing_classes
546
+ self.class_idx = class_idx
547
+
548
+ self.add_state(
549
+ "intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
550
+ )
551
+ self.add_state(
552
+ "unions", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
553
+ )
554
+
555
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
556
+ """Update metric.
557
+
558
+ Like torchmetrics.segmentation.MeanIoU with input_format="index", we expect
559
+ predictions and labels to both be class integers. This is achieved by passing
560
+ pass_probabilities=False to the SegmentationMetric wrapper.
561
+
562
+ Args:
563
+ preds: the predictions, (N,), with values from 0 to C-1.
564
+ labels: the targets, (N,), with values from 0 to C-1.
565
+ """
566
+ if preds.min() < 0 or preds.max() >= self.num_classes:
567
+ raise ValueError("predicted class outside of expected range")
568
+ if labels.min() < 0 or labels.max() >= self.num_classes:
569
+ raise ValueError("label class outside of expected range")
570
+
571
+ new_intersections = torch.zeros(
572
+ self.num_classes, device=self.intersections.device
573
+ )
574
+ new_unions = torch.zeros(self.num_classes, device=self.unions.device)
575
+ for cls_idx in range(self.num_classes):
576
+ new_intersections[cls_idx] = (
577
+ (preds == cls_idx) & (labels == cls_idx)
578
+ ).sum()
579
+ new_unions[cls_idx] = ((preds == cls_idx) | (labels == cls_idx)).sum()
580
+ self.intersections += new_intersections
581
+ self.unions += new_unions
582
+
583
+ def compute(self) -> Any:
584
+ """Compute metric.
585
+
586
+ Returns:
587
+ the mean IoU across classes.
588
+ """
589
+ per_class_scores = []
590
+
591
+ for cls_idx in range(self.num_classes):
592
+ # Check if nodata_value is set and is one of the classes
593
+ if self.nodata_value is not None and cls_idx == self.nodata_value:
594
+ continue
595
+
596
+ intersection = self.intersections[cls_idx]
597
+ union = self.unions[cls_idx]
598
+
599
+ if union == 0 and self.ignore_missing_classes:
600
+ continue
601
+
602
+ per_class_scores.append(intersection / union)
603
+
604
+ return torch.mean(torch.stack(per_class_scores))
@@ -7,6 +7,7 @@ import numpy.typing as npt
7
7
  import torch
8
8
  from torchmetrics import MetricCollection
9
9
 
10
+ from rslearn.train.model_context import RasterImage, SampleMetadata
10
11
  from rslearn.utils import Feature
11
12
 
12
13
 
@@ -20,8 +21,8 @@ class Task:
20
21
 
21
22
  def process_inputs(
22
23
  self,
23
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
24
- metadata: dict[str, Any],
24
+ raw_inputs: dict[str, RasterImage | list[Feature]],
25
+ metadata: SampleMetadata,
25
26
  load_targets: bool = True,
26
27
  ) -> tuple[dict[str, Any], dict[str, Any]]:
27
28
  """Processes the data into targets.
@@ -38,8 +39,8 @@ class Task:
38
39
  raise NotImplementedError
39
40
 
40
41
  def process_output(
41
- self, raw_output: Any, metadata: dict[str, Any]
42
- ) -> npt.NDArray[Any] | list[Feature]:
42
+ self, raw_output: Any, metadata: SampleMetadata
43
+ ) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
43
44
  """Processes an output into raster or vector data.
44
45
 
45
46
  Args:
@@ -47,7 +48,7 @@ class Task:
47
48
  metadata: metadata about the patch being read
48
49
 
49
50
  Returns:
50
- either raster or vector data.
51
+ raster data, vector data, or multi-task dictionary output.
51
52
  """
52
53
  raise NotImplementedError
53
54
 
@@ -12,7 +12,7 @@ class Sequential(torch.nn.Module):
12
12
  tuple.
13
13
  """
14
14
 
15
- def __init__(self, *args):
15
+ def __init__(self, *args: Any) -> None:
16
16
  """Initialize a new Sequential from a list of transforms."""
17
17
  super().__init__()
18
18
  self.transforms = torch.nn.ModuleList(args)