rslearn 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -2
- rslearn/config/dataset.py +164 -98
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +235 -80
- rslearn/data_sources/aws_open_data.py +103 -118
- rslearn/data_sources/aws_sentinel1.py +142 -0
- rslearn/data_sources/climate_data_store.py +303 -0
- rslearn/data_sources/copernicus.py +943 -12
- rslearn/data_sources/data_source.py +17 -12
- rslearn/data_sources/earthdaily.py +489 -0
- rslearn/data_sources/earthdata_srtm.py +300 -0
- rslearn/data_sources/gcp_public_data.py +556 -203
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +454 -115
- rslearn/data_sources/local_files.py +153 -103
- rslearn/data_sources/openstreetmap.py +33 -39
- rslearn/data_sources/planet.py +17 -35
- rslearn/data_sources/planet_basemap.py +296 -0
- rslearn/data_sources/planetary_computer.py +764 -0
- rslearn/data_sources/raster_source.py +11 -297
- rslearn/data_sources/usda_cdl.py +206 -0
- rslearn/data_sources/usgs_landsat.py +130 -73
- rslearn/data_sources/utils.py +256 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +456 -0
- rslearn/data_sources/worldcover.py +142 -0
- rslearn/data_sources/worldpop.py +156 -0
- rslearn/data_sources/xyz_tiles.py +141 -79
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +43 -7
- rslearn/dataset/index.py +173 -0
- rslearn/dataset/manage.py +137 -49
- rslearn/dataset/materialize.py +436 -95
- rslearn/dataset/window.py +225 -34
- rslearn/log_utils.py +24 -0
- rslearn/main.py +351 -130
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/croma.py +270 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +493 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/faster_rcnn.py +10 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +91 -0
- rslearn/models/moe/distributed.py +262 -0
- rslearn/models/moe/soft.py +676 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +351 -24
- rslearn/models/pick_features.py +15 -2
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +75 -59
- rslearn/models/singletask.py +8 -4
- rslearn/models/ssl4eo_s12.py +10 -10
- rslearn/models/swin.py +22 -21
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +219 -0
- rslearn/models/trunk.py +280 -0
- rslearn/models/unet.py +21 -5
- rslearn/models/upsample.py +35 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/py.typed +0 -0
- rslearn/tile_stores/__init__.py +52 -18
- rslearn/tile_stores/default.py +382 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/callbacks/freeze_unfreeze.py +32 -20
- rslearn/train/callbacks/gradients.py +109 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +407 -14
- rslearn/train/dataset.py +746 -200
- rslearn/train/lightning_module.py +164 -54
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +235 -78
- rslearn/train/scheduler.py +62 -0
- rslearn/train/tasks/classification.py +13 -12
- rslearn/train/tasks/detection.py +101 -39
- rslearn/train/tasks/multi_task.py +24 -9
- rslearn/train/tasks/regression.py +113 -21
- rslearn/train/tasks/segmentation.py +353 -35
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +9 -5
- rslearn/train/transforms/crop.py +8 -4
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +34 -10
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +75 -73
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +214 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +33 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +211 -96
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +235 -77
- {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/METADATA +366 -284
- rslearn-0.0.3.dist-info/RECORD +123 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/WHEEL +1 -1
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -12,22 +12,23 @@ from rslearn.utils import Feature
|
|
|
12
12
|
|
|
13
13
|
from .task import BasicTask
|
|
14
14
|
|
|
15
|
+
# TODO: This is duplicated code fix it
|
|
15
16
|
DEFAULT_COLORS = [
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
17
|
+
(255, 0, 0),
|
|
18
|
+
(0, 255, 0),
|
|
19
|
+
(0, 0, 255),
|
|
20
|
+
(255, 255, 0),
|
|
21
|
+
(0, 255, 255),
|
|
22
|
+
(255, 0, 255),
|
|
23
|
+
(0, 128, 0),
|
|
24
|
+
(255, 160, 122),
|
|
25
|
+
(139, 69, 19),
|
|
26
|
+
(128, 128, 128),
|
|
27
|
+
(255, 255, 255),
|
|
28
|
+
(143, 188, 143),
|
|
29
|
+
(95, 158, 160),
|
|
30
|
+
(255, 200, 0),
|
|
31
|
+
(128, 0, 0),
|
|
31
32
|
]
|
|
32
33
|
|
|
33
34
|
|
|
@@ -39,28 +40,58 @@ class SegmentationTask(BasicTask):
|
|
|
39
40
|
num_classes: int,
|
|
40
41
|
colors: list[tuple[int, int, int]] = DEFAULT_COLORS,
|
|
41
42
|
zero_is_invalid: bool = False,
|
|
43
|
+
enable_accuracy_metric: bool = True,
|
|
44
|
+
enable_miou_metric: bool = False,
|
|
45
|
+
enable_f1_metric: bool = False,
|
|
46
|
+
f1_metric_thresholds: list[list[float]] = [[0.5]],
|
|
42
47
|
metric_kwargs: dict[str, Any] = {},
|
|
43
|
-
|
|
44
|
-
|
|
48
|
+
miou_metric_kwargs: dict[str, Any] = {},
|
|
49
|
+
prob_scales: list[float] | None = None,
|
|
50
|
+
other_metrics: dict[str, Metric] = {},
|
|
51
|
+
**kwargs: Any,
|
|
52
|
+
) -> None:
|
|
45
53
|
"""Initialize a new SegmentationTask.
|
|
46
54
|
|
|
47
55
|
Args:
|
|
48
56
|
num_classes: the number of classes to predict
|
|
49
57
|
colors: optional colors for each class
|
|
50
58
|
zero_is_invalid: whether pixels labeled class 0 should be marked invalid
|
|
59
|
+
enable_accuracy_metric: whether to enable the accuracy metric (default
|
|
60
|
+
true).
|
|
61
|
+
enable_f1_metric: whether to enable the F1 metric (default false).
|
|
62
|
+
enable_miou_metric: whether to enable the mean IoU metric (default false).
|
|
63
|
+
f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
|
|
64
|
+
Each inner list is used to initialize a separate F1 metric where the
|
|
65
|
+
best F1 across the thresholds within the inner list is computed. If
|
|
66
|
+
there are multiple inner lists, then multiple F1 scores will be
|
|
67
|
+
reported.
|
|
51
68
|
metric_kwargs: additional arguments to pass to underlying metric, see
|
|
52
69
|
torchmetrics.classification.MulticlassAccuracy.
|
|
70
|
+
miou_metric_kwargs: additional arguments to pass to MeanIoUMetric, if
|
|
71
|
+
enable_miou_metric is passed.
|
|
72
|
+
prob_scales: during inference, scale the output probabilities by this much
|
|
73
|
+
before computing the argmax. There is one scale per class. Note that
|
|
74
|
+
this is only applied during prediction, not when computing val or test
|
|
75
|
+
metrics.
|
|
76
|
+
other_metrics: additional metrics to configure on this task.
|
|
53
77
|
kwargs: additional arguments to pass to BasicTask
|
|
54
78
|
"""
|
|
55
79
|
super().__init__(**kwargs)
|
|
56
80
|
self.num_classes = num_classes
|
|
57
81
|
self.colors = colors
|
|
58
82
|
self.zero_is_invalid = zero_is_invalid
|
|
83
|
+
self.enable_accuracy_metric = enable_accuracy_metric
|
|
84
|
+
self.enable_f1_metric = enable_f1_metric
|
|
85
|
+
self.enable_miou_metric = enable_miou_metric
|
|
86
|
+
self.f1_metric_thresholds = f1_metric_thresholds
|
|
59
87
|
self.metric_kwargs = metric_kwargs
|
|
88
|
+
self.miou_metric_kwargs = miou_metric_kwargs
|
|
89
|
+
self.prob_scales = prob_scales
|
|
90
|
+
self.other_metrics = other_metrics
|
|
60
91
|
|
|
61
92
|
def process_inputs(
|
|
62
93
|
self,
|
|
63
|
-
raw_inputs: dict[str, torch.Tensor
|
|
94
|
+
raw_inputs: dict[str, torch.Tensor],
|
|
64
95
|
metadata: dict[str, Any],
|
|
65
96
|
load_targets: bool = True,
|
|
66
97
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -78,6 +109,7 @@ class SegmentationTask(BasicTask):
|
|
|
78
109
|
if not load_targets:
|
|
79
110
|
return {}, {}
|
|
80
111
|
|
|
112
|
+
# TODO: List[Feature] is currently not supported
|
|
81
113
|
assert raw_inputs["targets"].shape[0] == 1
|
|
82
114
|
labels = raw_inputs["targets"][0, :, :].long()
|
|
83
115
|
|
|
@@ -103,7 +135,11 @@ class SegmentationTask(BasicTask):
|
|
|
103
135
|
Returns:
|
|
104
136
|
either raster or vector data.
|
|
105
137
|
"""
|
|
106
|
-
|
|
138
|
+
raw_output_np = raw_output.cpu().numpy()
|
|
139
|
+
if self.prob_scales is not None:
|
|
140
|
+
# Scale the channel dimension by the provided scales.
|
|
141
|
+
raw_output_np = raw_output_np * np.array(self.prob_scales)[:, None, None]
|
|
142
|
+
classes = raw_output_np.argmax(axis=0).astype(np.uint8)
|
|
107
143
|
return classes[None, :, :]
|
|
108
144
|
|
|
109
145
|
def visualize(
|
|
@@ -123,6 +159,8 @@ class SegmentationTask(BasicTask):
|
|
|
123
159
|
a dictionary mapping image name to visualization image
|
|
124
160
|
"""
|
|
125
161
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
162
|
+
if target_dict is None:
|
|
163
|
+
raise ValueError("target_dict is required for visualization")
|
|
126
164
|
gt_classes = target_dict["classes"].cpu().numpy()
|
|
127
165
|
pred_classes = output.cpu().numpy().argmax(axis=0)
|
|
128
166
|
gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
|
|
@@ -143,11 +181,53 @@ class SegmentationTask(BasicTask):
|
|
|
143
181
|
def get_metrics(self) -> MetricCollection:
|
|
144
182
|
"""Get the metrics for this task."""
|
|
145
183
|
metrics = {}
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
184
|
+
|
|
185
|
+
if self.enable_accuracy_metric:
|
|
186
|
+
accuracy_metric_kwargs = dict(num_classes=self.num_classes)
|
|
187
|
+
accuracy_metric_kwargs.update(self.metric_kwargs)
|
|
188
|
+
metrics["accuracy"] = SegmentationMetric(
|
|
189
|
+
torchmetrics.classification.MulticlassAccuracy(**accuracy_metric_kwargs)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if self.enable_f1_metric:
|
|
193
|
+
for thresholds in self.f1_metric_thresholds:
|
|
194
|
+
if len(self.f1_metric_thresholds) == 1:
|
|
195
|
+
suffix = ""
|
|
196
|
+
else:
|
|
197
|
+
# Metric name can't contain "." so change to ",".
|
|
198
|
+
suffix = "_" + str(thresholds[0]).replace(".", ",")
|
|
199
|
+
|
|
200
|
+
metrics["F1" + suffix] = SegmentationMetric(
|
|
201
|
+
F1Metric(num_classes=self.num_classes, score_thresholds=thresholds)
|
|
202
|
+
)
|
|
203
|
+
metrics["precision" + suffix] = SegmentationMetric(
|
|
204
|
+
F1Metric(
|
|
205
|
+
num_classes=self.num_classes,
|
|
206
|
+
score_thresholds=thresholds,
|
|
207
|
+
metric_mode="precision",
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
metrics["recall" + suffix] = SegmentationMetric(
|
|
211
|
+
F1Metric(
|
|
212
|
+
num_classes=self.num_classes,
|
|
213
|
+
score_thresholds=thresholds,
|
|
214
|
+
metric_mode="recall",
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if self.enable_miou_metric:
|
|
219
|
+
miou_metric_kwargs: dict[str, Any] = dict(num_classes=self.num_classes)
|
|
220
|
+
if self.zero_is_invalid:
|
|
221
|
+
miou_metric_kwargs["zero_is_invalid"] = True
|
|
222
|
+
miou_metric_kwargs.update(self.miou_metric_kwargs)
|
|
223
|
+
metrics["mean_iou"] = SegmentationMetric(
|
|
224
|
+
MeanIoUMetric(**miou_metric_kwargs),
|
|
225
|
+
pass_probabilities=False,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if self.other_metrics:
|
|
229
|
+
metrics.update(self.other_metrics)
|
|
230
|
+
|
|
151
231
|
return MetricCollection(metrics)
|
|
152
232
|
|
|
153
233
|
|
|
@@ -159,7 +239,7 @@ class SegmentationHead(torch.nn.Module):
|
|
|
159
239
|
logits: torch.Tensor,
|
|
160
240
|
inputs: list[dict[str, Any]],
|
|
161
241
|
targets: list[dict[str, Any]] | None = None,
|
|
162
|
-
):
|
|
242
|
+
) -> tuple[torch.Tensor, dict[str, Any]]:
|
|
163
243
|
"""Compute the segmentation outputs from logits and targets.
|
|
164
244
|
|
|
165
245
|
Args:
|
|
@@ -172,28 +252,51 @@ class SegmentationHead(torch.nn.Module):
|
|
|
172
252
|
"""
|
|
173
253
|
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
174
254
|
|
|
175
|
-
|
|
255
|
+
losses = {}
|
|
176
256
|
if targets:
|
|
177
257
|
labels = torch.stack([target["classes"] for target in targets], dim=0)
|
|
178
258
|
mask = torch.stack([target["valid"] for target in targets], dim=0)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
* mask
|
|
259
|
+
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
260
|
+
logits, labels, reduction="none"
|
|
182
261
|
)
|
|
183
|
-
|
|
262
|
+
mask_sum = torch.sum(mask)
|
|
263
|
+
if mask_sum > 0:
|
|
264
|
+
# Compute average loss over valid pixels.
|
|
265
|
+
losses["cls"] = torch.sum(per_pixel_loss * mask) / torch.sum(mask)
|
|
266
|
+
else:
|
|
267
|
+
# If there are no valid pixels, we avoid dividing by zero and just let
|
|
268
|
+
# the summed mask loss be zero.
|
|
269
|
+
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
184
270
|
|
|
185
|
-
return outputs,
|
|
271
|
+
return outputs, losses
|
|
186
272
|
|
|
187
273
|
|
|
188
274
|
class SegmentationMetric(Metric):
|
|
189
275
|
"""Metric for segmentation task."""
|
|
190
276
|
|
|
191
|
-
def __init__(
|
|
192
|
-
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
metric: Metric,
|
|
280
|
+
pass_probabilities: bool = True,
|
|
281
|
+
class_idx: int | None = None,
|
|
282
|
+
):
|
|
283
|
+
"""Initialize a new SegmentationMetric.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
metric: the metric to wrap. This wrapping class will handle selecting the
|
|
287
|
+
classes from the targets and masking out invalid pixels.
|
|
288
|
+
pass_probabilities: whether to pass predicted probabilities to the metric.
|
|
289
|
+
If False, argmax is applied to pass the predicted classes instead.
|
|
290
|
+
class_idx: if metric returns value for multiple classes, select this class.
|
|
291
|
+
"""
|
|
193
292
|
super().__init__()
|
|
194
293
|
self.metric = metric
|
|
294
|
+
self.pass_probablities = pass_probabilities
|
|
295
|
+
self.class_idx = class_idx
|
|
195
296
|
|
|
196
|
-
def update(
|
|
297
|
+
def update(
|
|
298
|
+
self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
|
|
299
|
+
) -> None:
|
|
197
300
|
"""Update metric.
|
|
198
301
|
|
|
199
302
|
Args:
|
|
@@ -213,11 +316,17 @@ class SegmentationMetric(Metric):
|
|
|
213
316
|
if len(preds) == 0:
|
|
214
317
|
return
|
|
215
318
|
|
|
319
|
+
if not self.pass_probablities:
|
|
320
|
+
preds = preds.argmax(dim=1)
|
|
321
|
+
|
|
216
322
|
self.metric.update(preds, labels)
|
|
217
323
|
|
|
218
324
|
def compute(self) -> Any:
|
|
219
325
|
"""Returns the computed metric."""
|
|
220
|
-
|
|
326
|
+
result = self.metric.compute()
|
|
327
|
+
if self.class_idx is not None:
|
|
328
|
+
result = result[self.class_idx]
|
|
329
|
+
return result
|
|
221
330
|
|
|
222
331
|
def reset(self) -> None:
|
|
223
332
|
"""Reset metric."""
|
|
@@ -227,3 +336,212 @@ class SegmentationMetric(Metric):
|
|
|
227
336
|
def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
228
337
|
"""Returns a plot of the metric."""
|
|
229
338
|
return self.metric.plot(*args, **kwargs)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class F1Metric(Metric):
|
|
342
|
+
"""F1 score for segmentation.
|
|
343
|
+
|
|
344
|
+
It treats each class as a separate prediction task, and computes the maximum F1
|
|
345
|
+
score under the different configured thresholds per-class.
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
def __init__(
|
|
349
|
+
self,
|
|
350
|
+
num_classes: int,
|
|
351
|
+
score_thresholds: list[float],
|
|
352
|
+
metric_mode: str = "f1",
|
|
353
|
+
):
|
|
354
|
+
"""Create a new F1Metric.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
num_classes: number of classes.
|
|
358
|
+
score_thresholds: list of score thresholds to check F1 score for. The final
|
|
359
|
+
metric is the best F1 across score thresholds.
|
|
360
|
+
metric_mode: set to "precision" or "recall" to return that instead of F1
|
|
361
|
+
(default "f1")
|
|
362
|
+
"""
|
|
363
|
+
super().__init__()
|
|
364
|
+
self.num_classes = num_classes
|
|
365
|
+
self.score_thresholds = score_thresholds
|
|
366
|
+
self.metric_mode = metric_mode
|
|
367
|
+
|
|
368
|
+
assert self.metric_mode in ["f1", "precision", "recall"]
|
|
369
|
+
|
|
370
|
+
for cls_idx in range(self.num_classes):
|
|
371
|
+
for thr_idx in range(len(self.score_thresholds)):
|
|
372
|
+
cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
|
|
373
|
+
self.add_state(
|
|
374
|
+
cur_prefix + "tp", default=torch.tensor(0), dist_reduce_fx="sum"
|
|
375
|
+
)
|
|
376
|
+
self.add_state(
|
|
377
|
+
cur_prefix + "fp", default=torch.tensor(0), dist_reduce_fx="sum"
|
|
378
|
+
)
|
|
379
|
+
self.add_state(
|
|
380
|
+
cur_prefix + "fn", default=torch.tensor(0), dist_reduce_fx="sum"
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
def _get_state_prefix(self, cls_idx: int, thr_idx: int) -> str:
|
|
384
|
+
return f"{cls_idx}_{thr_idx}_"
|
|
385
|
+
|
|
386
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
387
|
+
"""Update metric.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
preds: the predictions, NxC.
|
|
391
|
+
labels: the targets, N, with values from 0 to C-1.
|
|
392
|
+
"""
|
|
393
|
+
for cls_idx in range(self.num_classes):
|
|
394
|
+
for thr_idx, score_threshold in enumerate(self.score_thresholds):
|
|
395
|
+
pred_bin = preds[:, cls_idx] > score_threshold
|
|
396
|
+
gt_bin = labels == cls_idx
|
|
397
|
+
|
|
398
|
+
tp = torch.count_nonzero(pred_bin & gt_bin).item()
|
|
399
|
+
fp = torch.count_nonzero(pred_bin & torch.logical_not(gt_bin)).item()
|
|
400
|
+
fn = torch.count_nonzero(torch.logical_not(pred_bin) & gt_bin).item()
|
|
401
|
+
|
|
402
|
+
cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
|
|
403
|
+
setattr(self, cur_prefix + "tp", getattr(self, cur_prefix + "tp") + tp)
|
|
404
|
+
setattr(self, cur_prefix + "fp", getattr(self, cur_prefix + "fp") + fp)
|
|
405
|
+
setattr(self, cur_prefix + "fn", getattr(self, cur_prefix + "fn") + fn)
|
|
406
|
+
|
|
407
|
+
def compute(self) -> Any:
|
|
408
|
+
"""Compute metric.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
the best F1 score across score thresholds and classes.
|
|
412
|
+
"""
|
|
413
|
+
best_scores = []
|
|
414
|
+
|
|
415
|
+
for cls_idx in range(self.num_classes):
|
|
416
|
+
best_score = None
|
|
417
|
+
|
|
418
|
+
for thr_idx in range(len(self.score_thresholds)):
|
|
419
|
+
cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
|
|
420
|
+
tp = getattr(self, cur_prefix + "tp")
|
|
421
|
+
fp = getattr(self, cur_prefix + "fp")
|
|
422
|
+
fn = getattr(self, cur_prefix + "fn")
|
|
423
|
+
device = tp.device
|
|
424
|
+
|
|
425
|
+
if tp + fp == 0:
|
|
426
|
+
precision = torch.tensor(0, dtype=torch.float32, device=device)
|
|
427
|
+
else:
|
|
428
|
+
precision = tp / (tp + fp)
|
|
429
|
+
|
|
430
|
+
if tp + fn == 0:
|
|
431
|
+
recall = torch.tensor(0, dtype=torch.float32, device=device)
|
|
432
|
+
else:
|
|
433
|
+
recall = tp / (tp + fn)
|
|
434
|
+
|
|
435
|
+
if precision + recall < 0.001:
|
|
436
|
+
f1 = torch.tensor(0, dtype=torch.float32, device=device)
|
|
437
|
+
else:
|
|
438
|
+
f1 = 2 * precision * recall / (precision + recall)
|
|
439
|
+
|
|
440
|
+
if self.metric_mode == "f1":
|
|
441
|
+
score = f1
|
|
442
|
+
elif self.metric_mode == "precision":
|
|
443
|
+
score = precision
|
|
444
|
+
elif self.metric_mode == "recall":
|
|
445
|
+
score = recall
|
|
446
|
+
|
|
447
|
+
if best_score is None or score > best_score:
|
|
448
|
+
best_score = score
|
|
449
|
+
|
|
450
|
+
best_scores.append(best_score)
|
|
451
|
+
|
|
452
|
+
return torch.mean(torch.stack(best_scores))
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class MeanIoUMetric(Metric):
|
|
456
|
+
"""Mean IoU for segmentation.
|
|
457
|
+
|
|
458
|
+
This is the mean of the per-class intersection-over-union scores. The per-class
|
|
459
|
+
intersection is the number of pixels across all examples where the predicted label
|
|
460
|
+
and ground truth label are both that class, and the per-class union is defined
|
|
461
|
+
similarly.
|
|
462
|
+
|
|
463
|
+
This differs from torchmetrics.segmentation.MeanIoU, where the mean IoU is computed
|
|
464
|
+
per-image, and averaged across images.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def __init__(
|
|
468
|
+
self,
|
|
469
|
+
num_classes: int,
|
|
470
|
+
zero_is_invalid: bool = False,
|
|
471
|
+
ignore_missing_classes: bool = False,
|
|
472
|
+
class_idx: int | None = None,
|
|
473
|
+
):
|
|
474
|
+
"""Create a new MeanIoUMetric.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
num_classes: the number of classes for the task.
|
|
478
|
+
zero_is_invalid: whether to ignore class 0 in computing mean IoU.
|
|
479
|
+
ignore_missing_classes: whether to ignore classes that don't appear in
|
|
480
|
+
either the predictions or the ground truth. If false, the IoU for a
|
|
481
|
+
missing class will be 0.
|
|
482
|
+
class_idx: only compute and return the IoU for this class. This option is
|
|
483
|
+
provided so the user can get per-class IoU results, since Lightning
|
|
484
|
+
only supports scalar return values from metrics.
|
|
485
|
+
"""
|
|
486
|
+
super().__init__()
|
|
487
|
+
self.num_classes = num_classes
|
|
488
|
+
self.zero_is_invalid = zero_is_invalid
|
|
489
|
+
self.ignore_missing_classes = ignore_missing_classes
|
|
490
|
+
self.class_idx = class_idx
|
|
491
|
+
|
|
492
|
+
self.add_state(
|
|
493
|
+
"intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
|
|
494
|
+
)
|
|
495
|
+
self.add_state(
|
|
496
|
+
"unions", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
500
|
+
"""Update metric.
|
|
501
|
+
|
|
502
|
+
Like torchmetrics.segmentation.MeanIoU with input_format="index", we expect
|
|
503
|
+
predictions and labels to both be class integers. This is achieved by passing
|
|
504
|
+
pass_probabilities=False to the SegmentationMetric wrapper.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
preds: the predictions, (N,), with values from 0 to C-1.
|
|
508
|
+
labels: the targets, (N,), with values from 0 to C-1.
|
|
509
|
+
"""
|
|
510
|
+
if preds.min() < 0 or preds.max() >= self.num_classes:
|
|
511
|
+
raise ValueError("predicted class outside of expected range")
|
|
512
|
+
if labels.min() < 0 or labels.max() >= self.num_classes:
|
|
513
|
+
raise ValueError("label class outside of expected range")
|
|
514
|
+
|
|
515
|
+
new_intersections = torch.zeros(
|
|
516
|
+
self.num_classes, device=self.intersections.device
|
|
517
|
+
)
|
|
518
|
+
new_unions = torch.zeros(self.num_classes, device=self.unions.device)
|
|
519
|
+
for cls_idx in range(self.num_classes):
|
|
520
|
+
new_intersections[cls_idx] = (
|
|
521
|
+
(preds == cls_idx) & (labels == cls_idx)
|
|
522
|
+
).sum()
|
|
523
|
+
new_unions[cls_idx] = ((preds == cls_idx) | (labels == cls_idx)).sum()
|
|
524
|
+
self.intersections += new_intersections
|
|
525
|
+
self.unions += new_unions
|
|
526
|
+
|
|
527
|
+
def compute(self) -> Any:
|
|
528
|
+
"""Compute metric.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
the mean IoU across classes.
|
|
532
|
+
"""
|
|
533
|
+
per_class_scores = []
|
|
534
|
+
|
|
535
|
+
for cls_idx in range(self.num_classes):
|
|
536
|
+
if cls_idx == 0 and self.zero_is_invalid:
|
|
537
|
+
continue
|
|
538
|
+
|
|
539
|
+
intersection = self.intersections[cls_idx]
|
|
540
|
+
union = self.unions[cls_idx]
|
|
541
|
+
|
|
542
|
+
if union == 0 and self.ignore_missing_classes:
|
|
543
|
+
continue
|
|
544
|
+
|
|
545
|
+
per_class_scores.append(intersection / union)
|
|
546
|
+
|
|
547
|
+
return torch.mean(torch.stack(per_class_scores))
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -39,7 +39,7 @@ class Task:
|
|
|
39
39
|
|
|
40
40
|
def process_output(
|
|
41
41
|
self, raw_output: Any, metadata: dict[str, Any]
|
|
42
|
-
) -> npt.NDArray[Any] | list[Feature]:
|
|
42
|
+
) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
|
|
43
43
|
"""Processes an output into raster or vector data.
|
|
44
44
|
|
|
45
45
|
Args:
|
|
@@ -47,7 +47,7 @@ class Task:
|
|
|
47
47
|
metadata: metadata about the patch being read
|
|
48
48
|
|
|
49
49
|
Returns:
|
|
50
|
-
|
|
50
|
+
raster data, vector data, or multi-task dictionary output.
|
|
51
51
|
"""
|
|
52
52
|
raise NotImplementedError
|
|
53
53
|
|
|
@@ -12,7 +12,7 @@ class Sequential(torch.nn.Module):
|
|
|
12
12
|
tuple.
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
def __init__(self, *args):
|
|
15
|
+
def __init__(self, *args: Any) -> None:
|
|
16
16
|
"""Initialize a new Sequential from a list of transforms."""
|
|
17
17
|
super().__init__()
|
|
18
18
|
self.transforms = torch.nn.ModuleList(args)
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Concatenate bands across multiple image inputs."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
|
|
5
|
-
from .transform import Transform
|
|
7
|
+
from .transform import Transform, read_selector, write_selector
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class Concatenate(Transform):
|
|
@@ -24,7 +26,9 @@ class Concatenate(Transform):
|
|
|
24
26
|
self.selections = selections
|
|
25
27
|
self.output_selector = output_selector
|
|
26
28
|
|
|
27
|
-
def forward(
|
|
29
|
+
def forward(
|
|
30
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
31
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
28
32
|
"""Apply concatenation over the inputs and targets.
|
|
29
33
|
|
|
30
34
|
Args:
|
|
@@ -36,10 +40,10 @@ class Concatenate(Transform):
|
|
|
36
40
|
"""
|
|
37
41
|
images = []
|
|
38
42
|
for selector, wanted_bands in self.selections.items():
|
|
39
|
-
image =
|
|
43
|
+
image = read_selector(input_dict, target_dict, selector)
|
|
40
44
|
if wanted_bands:
|
|
41
45
|
image = image[wanted_bands, :, :]
|
|
42
46
|
images.append(image)
|
|
43
47
|
result = torch.concatenate(images, dim=0)
|
|
44
|
-
|
|
48
|
+
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
45
49
|
return input_dict, target_dict
|
rslearn/train/transforms/crop.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
-
from .transform import Transform
|
|
8
|
+
from .transform import Transform, read_selector
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class Crop(Transform):
|
|
@@ -69,7 +69,7 @@ class Crop(Transform):
|
|
|
69
69
|
"remove_from_top": remove_from_top,
|
|
70
70
|
}
|
|
71
71
|
|
|
72
|
-
def apply_image(self, image: torch.Tensor, state: dict[str,
|
|
72
|
+
def apply_image(self, image: torch.Tensor, state: dict[str, Any]) -> torch.Tensor:
|
|
73
73
|
"""Apply the sampled state on the specified image.
|
|
74
74
|
|
|
75
75
|
Args:
|
|
@@ -97,7 +97,9 @@ class Crop(Transform):
|
|
|
97
97
|
"""
|
|
98
98
|
raise NotImplementedError
|
|
99
99
|
|
|
100
|
-
def forward(
|
|
100
|
+
def forward(
|
|
101
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
102
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
101
103
|
"""Apply transform over the inputs and targets.
|
|
102
104
|
|
|
103
105
|
Args:
|
|
@@ -109,13 +111,15 @@ class Crop(Transform):
|
|
|
109
111
|
"""
|
|
110
112
|
smallest_image_shape = None
|
|
111
113
|
for selector in self.image_selectors:
|
|
112
|
-
image =
|
|
114
|
+
image = read_selector(input_dict, target_dict, selector)
|
|
113
115
|
if (
|
|
114
116
|
smallest_image_shape is None
|
|
115
117
|
or image.shape[-1] < smallest_image_shape[1]
|
|
116
118
|
):
|
|
117
119
|
smallest_image_shape = image.shape[-2:]
|
|
118
120
|
|
|
121
|
+
if smallest_image_shape is None:
|
|
122
|
+
raise ValueError("No image found to crop")
|
|
119
123
|
state = self.sample_state(smallest_image_shape)
|
|
120
124
|
|
|
121
125
|
self.apply_fn(
|
rslearn/train/transforms/flip.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Flip transform."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
5
7
|
from .transform import Transform
|
|
@@ -90,7 +92,9 @@ class Flip(Transform):
|
|
|
90
92
|
)
|
|
91
93
|
return boxes
|
|
92
94
|
|
|
93
|
-
def forward(
|
|
95
|
+
def forward(
|
|
96
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
97
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
94
98
|
"""Apply transform over the inputs and targets.
|
|
95
99
|
|
|
96
100
|
Args:
|