rslearn 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. rslearn/config/__init__.py +2 -2
  2. rslearn/config/dataset.py +164 -98
  3. rslearn/const.py +9 -15
  4. rslearn/data_sources/__init__.py +8 -0
  5. rslearn/data_sources/aws_landsat.py +235 -80
  6. rslearn/data_sources/aws_open_data.py +103 -118
  7. rslearn/data_sources/aws_sentinel1.py +142 -0
  8. rslearn/data_sources/climate_data_store.py +303 -0
  9. rslearn/data_sources/copernicus.py +943 -12
  10. rslearn/data_sources/data_source.py +17 -12
  11. rslearn/data_sources/earthdaily.py +489 -0
  12. rslearn/data_sources/earthdata_srtm.py +300 -0
  13. rslearn/data_sources/gcp_public_data.py +556 -203
  14. rslearn/data_sources/geotiff.py +1 -0
  15. rslearn/data_sources/google_earth_engine.py +454 -115
  16. rslearn/data_sources/local_files.py +153 -103
  17. rslearn/data_sources/openstreetmap.py +33 -39
  18. rslearn/data_sources/planet.py +17 -35
  19. rslearn/data_sources/planet_basemap.py +296 -0
  20. rslearn/data_sources/planetary_computer.py +764 -0
  21. rslearn/data_sources/raster_source.py +11 -297
  22. rslearn/data_sources/usda_cdl.py +206 -0
  23. rslearn/data_sources/usgs_landsat.py +130 -73
  24. rslearn/data_sources/utils.py +256 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +456 -0
  27. rslearn/data_sources/worldcover.py +142 -0
  28. rslearn/data_sources/worldpop.py +156 -0
  29. rslearn/data_sources/xyz_tiles.py +141 -79
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +1 -1
  32. rslearn/dataset/dataset.py +43 -7
  33. rslearn/dataset/index.py +173 -0
  34. rslearn/dataset/manage.py +137 -49
  35. rslearn/dataset/materialize.py +436 -95
  36. rslearn/dataset/window.py +225 -34
  37. rslearn/log_utils.py +24 -0
  38. rslearn/main.py +351 -130
  39. rslearn/models/clip.py +62 -0
  40. rslearn/models/conv.py +56 -0
  41. rslearn/models/croma.py +270 -0
  42. rslearn/models/detr/__init__.py +5 -0
  43. rslearn/models/detr/box_ops.py +103 -0
  44. rslearn/models/detr/detr.py +493 -0
  45. rslearn/models/detr/matcher.py +107 -0
  46. rslearn/models/detr/position_encoding.py +114 -0
  47. rslearn/models/detr/transformer.py +429 -0
  48. rslearn/models/detr/util.py +24 -0
  49. rslearn/models/faster_rcnn.py +10 -19
  50. rslearn/models/fpn.py +1 -1
  51. rslearn/models/module_wrapper.py +91 -0
  52. rslearn/models/moe/distributed.py +262 -0
  53. rslearn/models/moe/soft.py +676 -0
  54. rslearn/models/molmo.py +65 -0
  55. rslearn/models/multitask.py +351 -24
  56. rslearn/models/pick_features.py +15 -2
  57. rslearn/models/pooling_decoder.py +4 -2
  58. rslearn/models/satlaspretrain.py +4 -7
  59. rslearn/models/simple_time_series.py +75 -59
  60. rslearn/models/singletask.py +8 -4
  61. rslearn/models/ssl4eo_s12.py +10 -10
  62. rslearn/models/swin.py +22 -21
  63. rslearn/models/task_embedding.py +250 -0
  64. rslearn/models/terramind.py +219 -0
  65. rslearn/models/trunk.py +280 -0
  66. rslearn/models/unet.py +21 -5
  67. rslearn/models/upsample.py +35 -0
  68. rslearn/models/use_croma.py +508 -0
  69. rslearn/py.typed +0 -0
  70. rslearn/tile_stores/__init__.py +52 -18
  71. rslearn/tile_stores/default.py +382 -0
  72. rslearn/tile_stores/tile_store.py +236 -132
  73. rslearn/train/callbacks/freeze_unfreeze.py +32 -20
  74. rslearn/train/callbacks/gradients.py +109 -0
  75. rslearn/train/callbacks/peft.py +116 -0
  76. rslearn/train/data_module.py +407 -14
  77. rslearn/train/dataset.py +746 -200
  78. rslearn/train/lightning_module.py +164 -54
  79. rslearn/train/optimizer.py +31 -0
  80. rslearn/train/prediction_writer.py +235 -78
  81. rslearn/train/scheduler.py +62 -0
  82. rslearn/train/tasks/classification.py +13 -12
  83. rslearn/train/tasks/detection.py +101 -39
  84. rslearn/train/tasks/multi_task.py +24 -9
  85. rslearn/train/tasks/regression.py +113 -21
  86. rslearn/train/tasks/segmentation.py +353 -35
  87. rslearn/train/tasks/task.py +2 -2
  88. rslearn/train/transforms/__init__.py +1 -1
  89. rslearn/train/transforms/concatenate.py +9 -5
  90. rslearn/train/transforms/crop.py +8 -4
  91. rslearn/train/transforms/flip.py +5 -1
  92. rslearn/train/transforms/normalize.py +34 -10
  93. rslearn/train/transforms/pad.py +1 -1
  94. rslearn/train/transforms/transform.py +75 -73
  95. rslearn/utils/__init__.py +2 -6
  96. rslearn/utils/array.py +2 -2
  97. rslearn/utils/feature.py +2 -2
  98. rslearn/utils/fsspec.py +70 -1
  99. rslearn/utils/geometry.py +214 -7
  100. rslearn/utils/get_utm_ups_crs.py +2 -3
  101. rslearn/utils/grid_index.py +5 -5
  102. rslearn/utils/jsonargparse.py +33 -0
  103. rslearn/utils/mp.py +4 -3
  104. rslearn/utils/raster_format.py +211 -96
  105. rslearn/utils/rtree_index.py +64 -17
  106. rslearn/utils/sqlite_index.py +7 -1
  107. rslearn/utils/vector_format.py +235 -77
  108. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/METADATA +366 -284
  109. rslearn-0.0.3.dist-info/RECORD +123 -0
  110. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/WHEEL +1 -1
  111. rslearn/tile_stores/file.py +0 -242
  112. rslearn/utils/mgrs.py +0 -24
  113. rslearn/utils/utils.py +0 -22
  114. rslearn-0.0.1.dist-info/RECORD +0 -88
  115. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/entry_points.txt +0 -0
  116. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info/licenses}/LICENSE +0 -0
  117. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/top_level.txt +0 -0
@@ -12,22 +12,23 @@ from rslearn.utils import Feature
12
12
 
13
13
  from .task import BasicTask
14
14
 
15
+ # TODO: This is duplicated code fix it
15
16
  DEFAULT_COLORS = [
16
- [255, 0, 0],
17
- [0, 255, 0],
18
- [0, 0, 255],
19
- [255, 255, 0],
20
- [0, 255, 255],
21
- [255, 0, 255],
22
- [0, 128, 0],
23
- [255, 160, 122],
24
- [139, 69, 19],
25
- [128, 128, 128],
26
- [255, 255, 255],
27
- [143, 188, 143],
28
- [95, 158, 160],
29
- [255, 200, 0],
30
- [128, 0, 0],
17
+ (255, 0, 0),
18
+ (0, 255, 0),
19
+ (0, 0, 255),
20
+ (255, 255, 0),
21
+ (0, 255, 255),
22
+ (255, 0, 255),
23
+ (0, 128, 0),
24
+ (255, 160, 122),
25
+ (139, 69, 19),
26
+ (128, 128, 128),
27
+ (255, 255, 255),
28
+ (143, 188, 143),
29
+ (95, 158, 160),
30
+ (255, 200, 0),
31
+ (128, 0, 0),
31
32
  ]
32
33
 
33
34
 
@@ -39,28 +40,58 @@ class SegmentationTask(BasicTask):
39
40
  num_classes: int,
40
41
  colors: list[tuple[int, int, int]] = DEFAULT_COLORS,
41
42
  zero_is_invalid: bool = False,
43
+ enable_accuracy_metric: bool = True,
44
+ enable_miou_metric: bool = False,
45
+ enable_f1_metric: bool = False,
46
+ f1_metric_thresholds: list[list[float]] = [[0.5]],
42
47
  metric_kwargs: dict[str, Any] = {},
43
- **kwargs,
44
- ):
48
+ miou_metric_kwargs: dict[str, Any] = {},
49
+ prob_scales: list[float] | None = None,
50
+ other_metrics: dict[str, Metric] = {},
51
+ **kwargs: Any,
52
+ ) -> None:
45
53
  """Initialize a new SegmentationTask.
46
54
 
47
55
  Args:
48
56
  num_classes: the number of classes to predict
49
57
  colors: optional colors for each class
50
58
  zero_is_invalid: whether pixels labeled class 0 should be marked invalid
59
+ enable_accuracy_metric: whether to enable the accuracy metric (default
60
+ true).
61
+ enable_f1_metric: whether to enable the F1 metric (default false).
62
+ enable_miou_metric: whether to enable the mean IoU metric (default false).
63
+ f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
64
+ Each inner list is used to initialize a separate F1 metric where the
65
+ best F1 across the thresholds within the inner list is computed. If
66
+ there are multiple inner lists, then multiple F1 scores will be
67
+ reported.
51
68
  metric_kwargs: additional arguments to pass to underlying metric, see
52
69
  torchmetrics.classification.MulticlassAccuracy.
70
+ miou_metric_kwargs: additional arguments to pass to MeanIoUMetric, if
71
+ enable_miou_metric is passed.
72
+ prob_scales: during inference, scale the output probabilities by this much
73
+ before computing the argmax. There is one scale per class. Note that
74
+ this is only applied during prediction, not when computing val or test
75
+ metrics.
76
+ other_metrics: additional metrics to configure on this task.
53
77
  kwargs: additional arguments to pass to BasicTask
54
78
  """
55
79
  super().__init__(**kwargs)
56
80
  self.num_classes = num_classes
57
81
  self.colors = colors
58
82
  self.zero_is_invalid = zero_is_invalid
83
+ self.enable_accuracy_metric = enable_accuracy_metric
84
+ self.enable_f1_metric = enable_f1_metric
85
+ self.enable_miou_metric = enable_miou_metric
86
+ self.f1_metric_thresholds = f1_metric_thresholds
59
87
  self.metric_kwargs = metric_kwargs
88
+ self.miou_metric_kwargs = miou_metric_kwargs
89
+ self.prob_scales = prob_scales
90
+ self.other_metrics = other_metrics
60
91
 
61
92
  def process_inputs(
62
93
  self,
63
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
94
+ raw_inputs: dict[str, torch.Tensor],
64
95
  metadata: dict[str, Any],
65
96
  load_targets: bool = True,
66
97
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -78,6 +109,7 @@ class SegmentationTask(BasicTask):
78
109
  if not load_targets:
79
110
  return {}, {}
80
111
 
112
+ # TODO: List[Feature] is currently not supported
81
113
  assert raw_inputs["targets"].shape[0] == 1
82
114
  labels = raw_inputs["targets"][0, :, :].long()
83
115
 
@@ -103,7 +135,11 @@ class SegmentationTask(BasicTask):
103
135
  Returns:
104
136
  either raster or vector data.
105
137
  """
106
- classes = raw_output.cpu().numpy().argmax(axis=0).astype(np.uint8)
138
+ raw_output_np = raw_output.cpu().numpy()
139
+ if self.prob_scales is not None:
140
+ # Scale the channel dimension by the provided scales.
141
+ raw_output_np = raw_output_np * np.array(self.prob_scales)[:, None, None]
142
+ classes = raw_output_np.argmax(axis=0).astype(np.uint8)
107
143
  return classes[None, :, :]
108
144
 
109
145
  def visualize(
@@ -123,6 +159,8 @@ class SegmentationTask(BasicTask):
123
159
  a dictionary mapping image name to visualization image
124
160
  """
125
161
  image = super().visualize(input_dict, target_dict, output)["image"]
162
+ if target_dict is None:
163
+ raise ValueError("target_dict is required for visualization")
126
164
  gt_classes = target_dict["classes"].cpu().numpy()
127
165
  pred_classes = output.cpu().numpy().argmax(axis=0)
128
166
  gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
@@ -143,11 +181,53 @@ class SegmentationTask(BasicTask):
143
181
  def get_metrics(self) -> MetricCollection:
144
182
  """Get the metrics for this task."""
145
183
  metrics = {}
146
- metric_kwargs = dict(num_classes=self.num_classes)
147
- metric_kwargs.update(self.metric_kwargs)
148
- metrics["accuracy"] = SegmentationMetric(
149
- torchmetrics.classification.MulticlassAccuracy(**metric_kwargs)
150
- )
184
+
185
+ if self.enable_accuracy_metric:
186
+ accuracy_metric_kwargs = dict(num_classes=self.num_classes)
187
+ accuracy_metric_kwargs.update(self.metric_kwargs)
188
+ metrics["accuracy"] = SegmentationMetric(
189
+ torchmetrics.classification.MulticlassAccuracy(**accuracy_metric_kwargs)
190
+ )
191
+
192
+ if self.enable_f1_metric:
193
+ for thresholds in self.f1_metric_thresholds:
194
+ if len(self.f1_metric_thresholds) == 1:
195
+ suffix = ""
196
+ else:
197
+ # Metric name can't contain "." so change to ",".
198
+ suffix = "_" + str(thresholds[0]).replace(".", ",")
199
+
200
+ metrics["F1" + suffix] = SegmentationMetric(
201
+ F1Metric(num_classes=self.num_classes, score_thresholds=thresholds)
202
+ )
203
+ metrics["precision" + suffix] = SegmentationMetric(
204
+ F1Metric(
205
+ num_classes=self.num_classes,
206
+ score_thresholds=thresholds,
207
+ metric_mode="precision",
208
+ )
209
+ )
210
+ metrics["recall" + suffix] = SegmentationMetric(
211
+ F1Metric(
212
+ num_classes=self.num_classes,
213
+ score_thresholds=thresholds,
214
+ metric_mode="recall",
215
+ )
216
+ )
217
+
218
+ if self.enable_miou_metric:
219
+ miou_metric_kwargs: dict[str, Any] = dict(num_classes=self.num_classes)
220
+ if self.zero_is_invalid:
221
+ miou_metric_kwargs["zero_is_invalid"] = True
222
+ miou_metric_kwargs.update(self.miou_metric_kwargs)
223
+ metrics["mean_iou"] = SegmentationMetric(
224
+ MeanIoUMetric(**miou_metric_kwargs),
225
+ pass_probabilities=False,
226
+ )
227
+
228
+ if self.other_metrics:
229
+ metrics.update(self.other_metrics)
230
+
151
231
  return MetricCollection(metrics)
152
232
 
153
233
 
@@ -159,7 +239,7 @@ class SegmentationHead(torch.nn.Module):
159
239
  logits: torch.Tensor,
160
240
  inputs: list[dict[str, Any]],
161
241
  targets: list[dict[str, Any]] | None = None,
162
- ):
242
+ ) -> tuple[torch.Tensor, dict[str, Any]]:
163
243
  """Compute the segmentation outputs from logits and targets.
164
244
 
165
245
  Args:
@@ -172,28 +252,51 @@ class SegmentationHead(torch.nn.Module):
172
252
  """
173
253
  outputs = torch.nn.functional.softmax(logits, dim=1)
174
254
 
175
- loss = None
255
+ losses = {}
176
256
  if targets:
177
257
  labels = torch.stack([target["classes"] for target in targets], dim=0)
178
258
  mask = torch.stack([target["valid"] for target in targets], dim=0)
179
- loss = (
180
- torch.nn.functional.cross_entropy(logits, labels, reduction="none")
181
- * mask
259
+ per_pixel_loss = torch.nn.functional.cross_entropy(
260
+ logits, labels, reduction="none"
182
261
  )
183
- loss = torch.mean(loss)
262
+ mask_sum = torch.sum(mask)
263
+ if mask_sum > 0:
264
+ # Compute average loss over valid pixels.
265
+ losses["cls"] = torch.sum(per_pixel_loss * mask) / torch.sum(mask)
266
+ else:
267
+ # If there are no valid pixels, we avoid dividing by zero and just let
268
+ # the summed mask loss be zero.
269
+ losses["cls"] = torch.sum(per_pixel_loss * mask)
184
270
 
185
- return outputs, {"cls": loss}
271
+ return outputs, losses
186
272
 
187
273
 
188
274
  class SegmentationMetric(Metric):
189
275
  """Metric for segmentation task."""
190
276
 
191
- def __init__(self, metric: Metric):
192
- """Initialize a new SegmentationMetric."""
277
+ def __init__(
278
+ self,
279
+ metric: Metric,
280
+ pass_probabilities: bool = True,
281
+ class_idx: int | None = None,
282
+ ):
283
+ """Initialize a new SegmentationMetric.
284
+
285
+ Args:
286
+ metric: the metric to wrap. This wrapping class will handle selecting the
287
+ classes from the targets and masking out invalid pixels.
288
+ pass_probabilities: whether to pass predicted probabilities to the metric.
289
+ If False, argmax is applied to pass the predicted classes instead.
290
+ class_idx: if metric returns value for multiple classes, select this class.
291
+ """
193
292
  super().__init__()
194
293
  self.metric = metric
294
+ self.pass_probablities = pass_probabilities
295
+ self.class_idx = class_idx
195
296
 
196
- def update(self, preds: list[Any], targets: list[dict[str, Any]]) -> None:
297
+ def update(
298
+ self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
299
+ ) -> None:
197
300
  """Update metric.
198
301
 
199
302
  Args:
@@ -213,11 +316,17 @@ class SegmentationMetric(Metric):
213
316
  if len(preds) == 0:
214
317
  return
215
318
 
319
+ if not self.pass_probablities:
320
+ preds = preds.argmax(dim=1)
321
+
216
322
  self.metric.update(preds, labels)
217
323
 
218
324
  def compute(self) -> Any:
219
325
  """Returns the computed metric."""
220
- return self.metric.compute()
326
+ result = self.metric.compute()
327
+ if self.class_idx is not None:
328
+ result = result[self.class_idx]
329
+ return result
221
330
 
222
331
  def reset(self) -> None:
223
332
  """Reset metric."""
@@ -227,3 +336,212 @@ class SegmentationMetric(Metric):
227
336
  def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
228
337
  """Returns a plot of the metric."""
229
338
  return self.metric.plot(*args, **kwargs)
339
+
340
+
341
+ class F1Metric(Metric):
342
+ """F1 score for segmentation.
343
+
344
+ It treats each class as a separate prediction task, and computes the maximum F1
345
+ score under the different configured thresholds per-class.
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ num_classes: int,
351
+ score_thresholds: list[float],
352
+ metric_mode: str = "f1",
353
+ ):
354
+ """Create a new F1Metric.
355
+
356
+ Args:
357
+ num_classes: number of classes.
358
+ score_thresholds: list of score thresholds to check F1 score for. The final
359
+ metric is the best F1 across score thresholds.
360
+ metric_mode: set to "precision" or "recall" to return that instead of F1
361
+ (default "f1")
362
+ """
363
+ super().__init__()
364
+ self.num_classes = num_classes
365
+ self.score_thresholds = score_thresholds
366
+ self.metric_mode = metric_mode
367
+
368
+ assert self.metric_mode in ["f1", "precision", "recall"]
369
+
370
+ for cls_idx in range(self.num_classes):
371
+ for thr_idx in range(len(self.score_thresholds)):
372
+ cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
373
+ self.add_state(
374
+ cur_prefix + "tp", default=torch.tensor(0), dist_reduce_fx="sum"
375
+ )
376
+ self.add_state(
377
+ cur_prefix + "fp", default=torch.tensor(0), dist_reduce_fx="sum"
378
+ )
379
+ self.add_state(
380
+ cur_prefix + "fn", default=torch.tensor(0), dist_reduce_fx="sum"
381
+ )
382
+
383
+ def _get_state_prefix(self, cls_idx: int, thr_idx: int) -> str:
384
+ return f"{cls_idx}_{thr_idx}_"
385
+
386
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
387
+ """Update metric.
388
+
389
+ Args:
390
+ preds: the predictions, NxC.
391
+ labels: the targets, N, with values from 0 to C-1.
392
+ """
393
+ for cls_idx in range(self.num_classes):
394
+ for thr_idx, score_threshold in enumerate(self.score_thresholds):
395
+ pred_bin = preds[:, cls_idx] > score_threshold
396
+ gt_bin = labels == cls_idx
397
+
398
+ tp = torch.count_nonzero(pred_bin & gt_bin).item()
399
+ fp = torch.count_nonzero(pred_bin & torch.logical_not(gt_bin)).item()
400
+ fn = torch.count_nonzero(torch.logical_not(pred_bin) & gt_bin).item()
401
+
402
+ cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
403
+ setattr(self, cur_prefix + "tp", getattr(self, cur_prefix + "tp") + tp)
404
+ setattr(self, cur_prefix + "fp", getattr(self, cur_prefix + "fp") + fp)
405
+ setattr(self, cur_prefix + "fn", getattr(self, cur_prefix + "fn") + fn)
406
+
407
+ def compute(self) -> Any:
408
+ """Compute metric.
409
+
410
+ Returns:
411
+ the best F1 score across score thresholds and classes.
412
+ """
413
+ best_scores = []
414
+
415
+ for cls_idx in range(self.num_classes):
416
+ best_score = None
417
+
418
+ for thr_idx in range(len(self.score_thresholds)):
419
+ cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
420
+ tp = getattr(self, cur_prefix + "tp")
421
+ fp = getattr(self, cur_prefix + "fp")
422
+ fn = getattr(self, cur_prefix + "fn")
423
+ device = tp.device
424
+
425
+ if tp + fp == 0:
426
+ precision = torch.tensor(0, dtype=torch.float32, device=device)
427
+ else:
428
+ precision = tp / (tp + fp)
429
+
430
+ if tp + fn == 0:
431
+ recall = torch.tensor(0, dtype=torch.float32, device=device)
432
+ else:
433
+ recall = tp / (tp + fn)
434
+
435
+ if precision + recall < 0.001:
436
+ f1 = torch.tensor(0, dtype=torch.float32, device=device)
437
+ else:
438
+ f1 = 2 * precision * recall / (precision + recall)
439
+
440
+ if self.metric_mode == "f1":
441
+ score = f1
442
+ elif self.metric_mode == "precision":
443
+ score = precision
444
+ elif self.metric_mode == "recall":
445
+ score = recall
446
+
447
+ if best_score is None or score > best_score:
448
+ best_score = score
449
+
450
+ best_scores.append(best_score)
451
+
452
+ return torch.mean(torch.stack(best_scores))
453
+
454
+
455
+ class MeanIoUMetric(Metric):
456
+ """Mean IoU for segmentation.
457
+
458
+ This is the mean of the per-class intersection-over-union scores. The per-class
459
+ intersection is the number of pixels across all examples where the predicted label
460
+ and ground truth label are both that class, and the per-class union is defined
461
+ similarly.
462
+
463
+ This differs from torchmetrics.segmentation.MeanIoU, where the mean IoU is computed
464
+ per-image, and averaged across images.
465
+ """
466
+
467
+ def __init__(
468
+ self,
469
+ num_classes: int,
470
+ zero_is_invalid: bool = False,
471
+ ignore_missing_classes: bool = False,
472
+ class_idx: int | None = None,
473
+ ):
474
+ """Create a new MeanIoUMetric.
475
+
476
+ Args:
477
+ num_classes: the number of classes for the task.
478
+ zero_is_invalid: whether to ignore class 0 in computing mean IoU.
479
+ ignore_missing_classes: whether to ignore classes that don't appear in
480
+ either the predictions or the ground truth. If false, the IoU for a
481
+ missing class will be 0.
482
+ class_idx: only compute and return the IoU for this class. This option is
483
+ provided so the user can get per-class IoU results, since Lightning
484
+ only supports scalar return values from metrics.
485
+ """
486
+ super().__init__()
487
+ self.num_classes = num_classes
488
+ self.zero_is_invalid = zero_is_invalid
489
+ self.ignore_missing_classes = ignore_missing_classes
490
+ self.class_idx = class_idx
491
+
492
+ self.add_state(
493
+ "intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
494
+ )
495
+ self.add_state(
496
+ "unions", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
497
+ )
498
+
499
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
500
+ """Update metric.
501
+
502
+ Like torchmetrics.segmentation.MeanIoU with input_format="index", we expect
503
+ predictions and labels to both be class integers. This is achieved by passing
504
+ pass_probabilities=False to the SegmentationMetric wrapper.
505
+
506
+ Args:
507
+ preds: the predictions, (N,), with values from 0 to C-1.
508
+ labels: the targets, (N,), with values from 0 to C-1.
509
+ """
510
+ if preds.min() < 0 or preds.max() >= self.num_classes:
511
+ raise ValueError("predicted class outside of expected range")
512
+ if labels.min() < 0 or labels.max() >= self.num_classes:
513
+ raise ValueError("label class outside of expected range")
514
+
515
+ new_intersections = torch.zeros(
516
+ self.num_classes, device=self.intersections.device
517
+ )
518
+ new_unions = torch.zeros(self.num_classes, device=self.unions.device)
519
+ for cls_idx in range(self.num_classes):
520
+ new_intersections[cls_idx] = (
521
+ (preds == cls_idx) & (labels == cls_idx)
522
+ ).sum()
523
+ new_unions[cls_idx] = ((preds == cls_idx) | (labels == cls_idx)).sum()
524
+ self.intersections += new_intersections
525
+ self.unions += new_unions
526
+
527
+ def compute(self) -> Any:
528
+ """Compute metric.
529
+
530
+ Returns:
531
+ the mean IoU across classes.
532
+ """
533
+ per_class_scores = []
534
+
535
+ for cls_idx in range(self.num_classes):
536
+ if cls_idx == 0 and self.zero_is_invalid:
537
+ continue
538
+
539
+ intersection = self.intersections[cls_idx]
540
+ union = self.unions[cls_idx]
541
+
542
+ if union == 0 and self.ignore_missing_classes:
543
+ continue
544
+
545
+ per_class_scores.append(intersection / union)
546
+
547
+ return torch.mean(torch.stack(per_class_scores))
@@ -39,7 +39,7 @@ class Task:
39
39
 
40
40
  def process_output(
41
41
  self, raw_output: Any, metadata: dict[str, Any]
42
- ) -> npt.NDArray[Any] | list[Feature]:
42
+ ) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
43
43
  """Processes an output into raster or vector data.
44
44
 
45
45
  Args:
@@ -47,7 +47,7 @@ class Task:
47
47
  metadata: metadata about the patch being read
48
48
 
49
49
  Returns:
50
- either raster or vector data.
50
+ raster data, vector data, or multi-task dictionary output.
51
51
  """
52
52
  raise NotImplementedError
53
53
 
@@ -12,7 +12,7 @@ class Sequential(torch.nn.Module):
12
12
  tuple.
13
13
  """
14
14
 
15
- def __init__(self, *args):
15
+ def __init__(self, *args: Any) -> None:
16
16
  """Initialize a new Sequential from a list of transforms."""
17
17
  super().__init__()
18
18
  self.transforms = torch.nn.ModuleList(args)
@@ -1,8 +1,10 @@
1
- """Normalization transforms."""
1
+ """Concatenate bands across multiple image inputs."""
2
+
3
+ from typing import Any
2
4
 
3
5
  import torch
4
6
 
5
- from .transform import Transform
7
+ from .transform import Transform, read_selector, write_selector
6
8
 
7
9
 
8
10
  class Concatenate(Transform):
@@ -24,7 +26,9 @@ class Concatenate(Transform):
24
26
  self.selections = selections
25
27
  self.output_selector = output_selector
26
28
 
27
- def forward(self, input_dict, target_dict):
29
+ def forward(
30
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
31
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
28
32
  """Apply concatenation over the inputs and targets.
29
33
 
30
34
  Args:
@@ -36,10 +40,10 @@ class Concatenate(Transform):
36
40
  """
37
41
  images = []
38
42
  for selector, wanted_bands in self.selections.items():
39
- image = self.read_selector(input_dict, target_dict, selector)
43
+ image = read_selector(input_dict, target_dict, selector)
40
44
  if wanted_bands:
41
45
  image = image[wanted_bands, :, :]
42
46
  images.append(image)
43
47
  result = torch.concatenate(images, dim=0)
44
- self.write_selector(input_dict, target_dict, self.output_selector, result)
48
+ write_selector(input_dict, target_dict, self.output_selector, result)
45
49
  return input_dict, target_dict
@@ -5,7 +5,7 @@ from typing import Any
5
5
  import torch
6
6
  import torchvision
7
7
 
8
- from .transform import Transform
8
+ from .transform import Transform, read_selector
9
9
 
10
10
 
11
11
  class Crop(Transform):
@@ -69,7 +69,7 @@ class Crop(Transform):
69
69
  "remove_from_top": remove_from_top,
70
70
  }
71
71
 
72
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
72
+ def apply_image(self, image: torch.Tensor, state: dict[str, Any]) -> torch.Tensor:
73
73
  """Apply the sampled state on the specified image.
74
74
 
75
75
  Args:
@@ -97,7 +97,9 @@ class Crop(Transform):
97
97
  """
98
98
  raise NotImplementedError
99
99
 
100
- def forward(self, input_dict, target_dict):
100
+ def forward(
101
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
102
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
101
103
  """Apply transform over the inputs and targets.
102
104
 
103
105
  Args:
@@ -109,13 +111,15 @@ class Crop(Transform):
109
111
  """
110
112
  smallest_image_shape = None
111
113
  for selector in self.image_selectors:
112
- image = self.read_selector(input_dict, target_dict, selector)
114
+ image = read_selector(input_dict, target_dict, selector)
113
115
  if (
114
116
  smallest_image_shape is None
115
117
  or image.shape[-1] < smallest_image_shape[1]
116
118
  ):
117
119
  smallest_image_shape = image.shape[-2:]
118
120
 
121
+ if smallest_image_shape is None:
122
+ raise ValueError("No image found to crop")
119
123
  state = self.sample_state(smallest_image_shape)
120
124
 
121
125
  self.apply_fn(
@@ -1,5 +1,7 @@
1
1
  """Flip transform."""
2
2
 
3
+ from typing import Any
4
+
3
5
  import torch
4
6
 
5
7
  from .transform import Transform
@@ -90,7 +92,9 @@ class Flip(Transform):
90
92
  )
91
93
  return boxes
92
94
 
93
- def forward(self, input_dict, target_dict):
95
+ def forward(
96
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
97
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
94
98
  """Apply transform over the inputs and targets.
95
99
 
96
100
  Args: