rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Segmentation task."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Mapping
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -8,26 +9,34 @@ import torch
|
|
|
8
9
|
import torchmetrics.classification
|
|
9
10
|
from torchmetrics import Metric, MetricCollection
|
|
10
11
|
|
|
12
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
13
|
+
from rslearn.train.model_context import (
|
|
14
|
+
ModelContext,
|
|
15
|
+
ModelOutput,
|
|
16
|
+
RasterImage,
|
|
17
|
+
SampleMetadata,
|
|
18
|
+
)
|
|
11
19
|
from rslearn.utils import Feature
|
|
12
20
|
|
|
13
21
|
from .task import BasicTask
|
|
14
22
|
|
|
23
|
+
# TODO: This is duplicated code fix it
|
|
15
24
|
DEFAULT_COLORS = [
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
25
|
+
(255, 0, 0),
|
|
26
|
+
(0, 255, 0),
|
|
27
|
+
(0, 0, 255),
|
|
28
|
+
(255, 255, 0),
|
|
29
|
+
(0, 255, 255),
|
|
30
|
+
(255, 0, 255),
|
|
31
|
+
(0, 128, 0),
|
|
32
|
+
(255, 160, 122),
|
|
33
|
+
(139, 69, 19),
|
|
34
|
+
(128, 128, 128),
|
|
35
|
+
(255, 255, 255),
|
|
36
|
+
(143, 188, 143),
|
|
37
|
+
(95, 158, 160),
|
|
38
|
+
(255, 200, 0),
|
|
39
|
+
(128, 0, 0),
|
|
31
40
|
]
|
|
32
41
|
|
|
33
42
|
|
|
@@ -37,31 +46,77 @@ class SegmentationTask(BasicTask):
|
|
|
37
46
|
def __init__(
|
|
38
47
|
self,
|
|
39
48
|
num_classes: int,
|
|
49
|
+
class_id_mapping: dict[int, int] | None = None,
|
|
40
50
|
colors: list[tuple[int, int, int]] = DEFAULT_COLORS,
|
|
41
51
|
zero_is_invalid: bool = False,
|
|
52
|
+
nodata_value: int | None = None,
|
|
53
|
+
enable_accuracy_metric: bool = True,
|
|
54
|
+
enable_miou_metric: bool = False,
|
|
55
|
+
enable_f1_metric: bool = False,
|
|
56
|
+
f1_metric_thresholds: list[list[float]] = [[0.5]],
|
|
42
57
|
metric_kwargs: dict[str, Any] = {},
|
|
43
|
-
|
|
44
|
-
|
|
58
|
+
miou_metric_kwargs: dict[str, Any] = {},
|
|
59
|
+
prob_scales: list[float] | None = None,
|
|
60
|
+
other_metrics: dict[str, Metric] = {},
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
) -> None:
|
|
45
63
|
"""Initialize a new SegmentationTask.
|
|
46
64
|
|
|
47
65
|
Args:
|
|
48
66
|
num_classes: the number of classes to predict
|
|
49
67
|
colors: optional colors for each class
|
|
50
68
|
zero_is_invalid: whether pixels labeled class 0 should be marked invalid
|
|
69
|
+
Mutually exclusive with nodata_value.
|
|
70
|
+
nodata_value: the value to use for nodata pixels. If None, all pixels are
|
|
71
|
+
considered valid. Mutually exclusive with zero_is_invalid.
|
|
72
|
+
class_id_mapping: optional mapping from original class IDs to new class IDs.
|
|
73
|
+
If provided, class labels will be remapped according to this dictionary.
|
|
74
|
+
enable_accuracy_metric: whether to enable the accuracy metric (default
|
|
75
|
+
true).
|
|
76
|
+
enable_f1_metric: whether to enable the F1 metric (default false).
|
|
77
|
+
enable_miou_metric: whether to enable the mean IoU metric (default false).
|
|
78
|
+
f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
|
|
79
|
+
Each inner list is used to initialize a separate F1 metric where the
|
|
80
|
+
best F1 across the thresholds within the inner list is computed. If
|
|
81
|
+
there are multiple inner lists, then multiple F1 scores will be
|
|
82
|
+
reported.
|
|
51
83
|
metric_kwargs: additional arguments to pass to underlying metric, see
|
|
52
84
|
torchmetrics.classification.MulticlassAccuracy.
|
|
85
|
+
miou_metric_kwargs: additional arguments to pass to MeanIoUMetric, if
|
|
86
|
+
enable_miou_metric is passed.
|
|
87
|
+
prob_scales: during inference, scale the output probabilities by this much
|
|
88
|
+
before computing the argmax. There is one scale per class. Note that
|
|
89
|
+
this is only applied during prediction, not when computing val or test
|
|
90
|
+
metrics.
|
|
91
|
+
other_metrics: additional metrics to configure on this task.
|
|
53
92
|
kwargs: additional arguments to pass to BasicTask
|
|
54
93
|
"""
|
|
55
94
|
super().__init__(**kwargs)
|
|
56
95
|
self.num_classes = num_classes
|
|
96
|
+
self.class_id_mapping = class_id_mapping
|
|
57
97
|
self.colors = colors
|
|
58
|
-
self.
|
|
98
|
+
self.nodata_value: int | None
|
|
99
|
+
|
|
100
|
+
if zero_is_invalid and nodata_value is not None:
|
|
101
|
+
raise ValueError("zero_is_invalid and nodata_value cannot both be set")
|
|
102
|
+
if zero_is_invalid:
|
|
103
|
+
self.nodata_value = 0
|
|
104
|
+
else:
|
|
105
|
+
self.nodata_value = nodata_value
|
|
106
|
+
|
|
107
|
+
self.enable_accuracy_metric = enable_accuracy_metric
|
|
108
|
+
self.enable_f1_metric = enable_f1_metric
|
|
109
|
+
self.enable_miou_metric = enable_miou_metric
|
|
110
|
+
self.f1_metric_thresholds = f1_metric_thresholds
|
|
59
111
|
self.metric_kwargs = metric_kwargs
|
|
112
|
+
self.miou_metric_kwargs = miou_metric_kwargs
|
|
113
|
+
self.prob_scales = prob_scales
|
|
114
|
+
self.other_metrics = other_metrics
|
|
60
115
|
|
|
61
116
|
def process_inputs(
|
|
62
117
|
self,
|
|
63
|
-
raw_inputs:
|
|
64
|
-
metadata:
|
|
118
|
+
raw_inputs: Mapping[str, RasterImage | list[Feature]],
|
|
119
|
+
metadata: SampleMetadata,
|
|
65
120
|
load_targets: bool = True,
|
|
66
121
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
67
122
|
"""Processes the data into targets.
|
|
@@ -78,11 +133,22 @@ class SegmentationTask(BasicTask):
|
|
|
78
133
|
if not load_targets:
|
|
79
134
|
return {}, {}
|
|
80
135
|
|
|
81
|
-
assert raw_inputs["targets"]
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
136
|
+
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
137
|
+
assert raw_inputs["targets"].image.shape[0] == 1
|
|
138
|
+
assert raw_inputs["targets"].image.shape[1] == 1
|
|
139
|
+
labels = raw_inputs["targets"].image[0, 0, :, :].long()
|
|
140
|
+
|
|
141
|
+
if self.class_id_mapping is not None:
|
|
142
|
+
new_labels = labels.clone()
|
|
143
|
+
for old_id, new_id in self.class_id_mapping.items():
|
|
144
|
+
new_labels[labels == old_id] = new_id
|
|
145
|
+
labels = new_labels
|
|
146
|
+
|
|
147
|
+
if self.nodata_value is not None:
|
|
148
|
+
valid = (labels != self.nodata_value).float()
|
|
149
|
+
# Labels, even masked ones, must be in the range 0 to num_classes-1
|
|
150
|
+
if self.nodata_value >= self.num_classes:
|
|
151
|
+
labels[labels == self.nodata_value] = 0
|
|
86
152
|
else:
|
|
87
153
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
88
154
|
|
|
@@ -92,18 +158,28 @@ class SegmentationTask(BasicTask):
|
|
|
92
158
|
}
|
|
93
159
|
|
|
94
160
|
def process_output(
|
|
95
|
-
self, raw_output: Any, metadata:
|
|
96
|
-
) -> npt.NDArray[Any]
|
|
161
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
162
|
+
) -> npt.NDArray[Any]:
|
|
97
163
|
"""Processes an output into raster or vector data.
|
|
98
164
|
|
|
99
165
|
Args:
|
|
100
|
-
raw_output: the output from prediction head.
|
|
166
|
+
raw_output: the output from prediction head, which must be a CHW tensor.
|
|
101
167
|
metadata: metadata about the patch being read
|
|
102
168
|
|
|
103
169
|
Returns:
|
|
104
|
-
|
|
170
|
+
CHW numpy array with one channel, containing the predicted class IDs.
|
|
105
171
|
"""
|
|
106
|
-
|
|
172
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
|
|
173
|
+
raise ValueError("the output for SegmentationTask must be a CHW tensor")
|
|
174
|
+
|
|
175
|
+
if self.prob_scales is not None:
|
|
176
|
+
raw_output = (
|
|
177
|
+
raw_output
|
|
178
|
+
* torch.tensor(
|
|
179
|
+
self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
|
|
180
|
+
)[:, None, None]
|
|
181
|
+
)
|
|
182
|
+
classes = raw_output.argmax(dim=0).cpu().numpy()
|
|
107
183
|
return classes[None, :, :]
|
|
108
184
|
|
|
109
185
|
def visualize(
|
|
@@ -123,6 +199,8 @@ class SegmentationTask(BasicTask):
|
|
|
123
199
|
a dictionary mapping image name to visualization image
|
|
124
200
|
"""
|
|
125
201
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
202
|
+
if target_dict is None:
|
|
203
|
+
raise ValueError("target_dict is required for visualization")
|
|
126
204
|
gt_classes = target_dict["classes"].cpu().numpy()
|
|
127
205
|
pred_classes = output.cpu().numpy().argmax(axis=0)
|
|
128
206
|
gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
|
|
@@ -143,57 +221,136 @@ class SegmentationTask(BasicTask):
|
|
|
143
221
|
def get_metrics(self) -> MetricCollection:
|
|
144
222
|
"""Get the metrics for this task."""
|
|
145
223
|
metrics = {}
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
224
|
+
|
|
225
|
+
if self.enable_accuracy_metric:
|
|
226
|
+
accuracy_metric_kwargs = dict(num_classes=self.num_classes)
|
|
227
|
+
accuracy_metric_kwargs.update(self.metric_kwargs)
|
|
228
|
+
metrics["accuracy"] = SegmentationMetric(
|
|
229
|
+
torchmetrics.classification.MulticlassAccuracy(**accuracy_metric_kwargs)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if self.enable_f1_metric:
|
|
233
|
+
for thresholds in self.f1_metric_thresholds:
|
|
234
|
+
if len(self.f1_metric_thresholds) == 1:
|
|
235
|
+
suffix = ""
|
|
236
|
+
else:
|
|
237
|
+
# Metric name can't contain "." so change to ",".
|
|
238
|
+
suffix = "_" + str(thresholds[0]).replace(".", ",")
|
|
239
|
+
|
|
240
|
+
metrics["F1" + suffix] = SegmentationMetric(
|
|
241
|
+
F1Metric(num_classes=self.num_classes, score_thresholds=thresholds)
|
|
242
|
+
)
|
|
243
|
+
metrics["precision" + suffix] = SegmentationMetric(
|
|
244
|
+
F1Metric(
|
|
245
|
+
num_classes=self.num_classes,
|
|
246
|
+
score_thresholds=thresholds,
|
|
247
|
+
metric_mode="precision",
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
metrics["recall" + suffix] = SegmentationMetric(
|
|
251
|
+
F1Metric(
|
|
252
|
+
num_classes=self.num_classes,
|
|
253
|
+
score_thresholds=thresholds,
|
|
254
|
+
metric_mode="recall",
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if self.enable_miou_metric:
|
|
259
|
+
miou_metric_kwargs: dict[str, Any] = dict(num_classes=self.num_classes)
|
|
260
|
+
if self.nodata_value is not None:
|
|
261
|
+
miou_metric_kwargs["nodata_value"] = self.nodata_value
|
|
262
|
+
miou_metric_kwargs.update(self.miou_metric_kwargs)
|
|
263
|
+
metrics["mean_iou"] = SegmentationMetric(
|
|
264
|
+
MeanIoUMetric(**miou_metric_kwargs),
|
|
265
|
+
pass_probabilities=False,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if self.other_metrics:
|
|
269
|
+
metrics.update(self.other_metrics)
|
|
270
|
+
|
|
151
271
|
return MetricCollection(metrics)
|
|
152
272
|
|
|
153
273
|
|
|
154
|
-
class SegmentationHead(
|
|
274
|
+
class SegmentationHead(Predictor):
|
|
155
275
|
"""Head for segmentation task."""
|
|
156
276
|
|
|
157
277
|
def forward(
|
|
158
278
|
self,
|
|
159
|
-
|
|
160
|
-
|
|
279
|
+
intermediates: Any,
|
|
280
|
+
context: ModelContext,
|
|
161
281
|
targets: list[dict[str, Any]] | None = None,
|
|
162
|
-
):
|
|
282
|
+
) -> ModelOutput:
|
|
163
283
|
"""Compute the segmentation outputs from logits and targets.
|
|
164
284
|
|
|
165
285
|
Args:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
286
|
+
intermediates: a FeatureMaps with a single feature map containing the
|
|
287
|
+
segmentation logits.
|
|
288
|
+
context: the model context
|
|
289
|
+
targets: list of target dicts, where each target dict must contain a key
|
|
290
|
+
"classes" containing the per-pixel class labels, along with "valid"
|
|
291
|
+
containing a mask indicating where the example is valid.
|
|
169
292
|
|
|
170
293
|
Returns:
|
|
171
294
|
tuple of outputs and loss dict
|
|
172
295
|
"""
|
|
296
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
297
|
+
raise ValueError("input to SegmentationHead must be a FeatureMaps")
|
|
298
|
+
if len(intermediates.feature_maps) != 1:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"input to SegmentationHead must have one feature map, but got {len(intermediates.feature_maps)}"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
logits = intermediates.feature_maps[0]
|
|
173
304
|
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
174
305
|
|
|
175
|
-
|
|
306
|
+
losses = {}
|
|
176
307
|
if targets:
|
|
177
308
|
labels = torch.stack([target["classes"] for target in targets], dim=0)
|
|
178
309
|
mask = torch.stack([target["valid"] for target in targets], dim=0)
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
* mask
|
|
310
|
+
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
311
|
+
logits, labels, reduction="none"
|
|
182
312
|
)
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
313
|
+
mask_sum = torch.sum(mask)
|
|
314
|
+
if mask_sum > 0:
|
|
315
|
+
# Compute average loss over valid pixels.
|
|
316
|
+
losses["cls"] = torch.sum(per_pixel_loss * mask) / torch.sum(mask)
|
|
317
|
+
else:
|
|
318
|
+
# If there are no valid pixels, we avoid dividing by zero and just let
|
|
319
|
+
# the summed mask loss be zero.
|
|
320
|
+
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
321
|
+
|
|
322
|
+
return ModelOutput(
|
|
323
|
+
outputs=outputs,
|
|
324
|
+
loss_dict=losses,
|
|
325
|
+
)
|
|
186
326
|
|
|
187
327
|
|
|
188
328
|
class SegmentationMetric(Metric):
|
|
189
329
|
"""Metric for segmentation task."""
|
|
190
330
|
|
|
191
|
-
def __init__(
|
|
192
|
-
|
|
331
|
+
def __init__(
|
|
332
|
+
self,
|
|
333
|
+
metric: Metric,
|
|
334
|
+
pass_probabilities: bool = True,
|
|
335
|
+
class_idx: int | None = None,
|
|
336
|
+
):
|
|
337
|
+
"""Initialize a new SegmentationMetric.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
metric: the metric to wrap. This wrapping class will handle selecting the
|
|
341
|
+
classes from the targets and masking out invalid pixels.
|
|
342
|
+
pass_probabilities: whether to pass predicted probabilities to the metric.
|
|
343
|
+
If False, argmax is applied to pass the predicted classes instead.
|
|
344
|
+
class_idx: if metric returns value for multiple classes, select this class.
|
|
345
|
+
"""
|
|
193
346
|
super().__init__()
|
|
194
347
|
self.metric = metric
|
|
348
|
+
self.pass_probablities = pass_probabilities
|
|
349
|
+
self.class_idx = class_idx
|
|
195
350
|
|
|
196
|
-
def update(
|
|
351
|
+
def update(
|
|
352
|
+
self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
|
|
353
|
+
) -> None:
|
|
197
354
|
"""Update metric.
|
|
198
355
|
|
|
199
356
|
Args:
|
|
@@ -213,11 +370,17 @@ class SegmentationMetric(Metric):
|
|
|
213
370
|
if len(preds) == 0:
|
|
214
371
|
return
|
|
215
372
|
|
|
373
|
+
if not self.pass_probablities:
|
|
374
|
+
preds = preds.argmax(dim=1)
|
|
375
|
+
|
|
216
376
|
self.metric.update(preds, labels)
|
|
217
377
|
|
|
218
378
|
def compute(self) -> Any:
|
|
219
379
|
"""Returns the computed metric."""
|
|
220
|
-
|
|
380
|
+
result = self.metric.compute()
|
|
381
|
+
if self.class_idx is not None:
|
|
382
|
+
result = result[self.class_idx]
|
|
383
|
+
return result
|
|
221
384
|
|
|
222
385
|
def reset(self) -> None:
|
|
223
386
|
"""Reset metric."""
|
|
@@ -227,3 +390,215 @@ class SegmentationMetric(Metric):
|
|
|
227
390
|
def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
228
391
|
"""Returns a plot of the metric."""
|
|
229
392
|
return self.metric.plot(*args, **kwargs)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class F1Metric(Metric):
|
|
396
|
+
"""F1 score for segmentation.
|
|
397
|
+
|
|
398
|
+
It treats each class as a separate prediction task, and computes the maximum F1
|
|
399
|
+
score under the different configured thresholds per-class.
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
num_classes: int,
|
|
405
|
+
score_thresholds: list[float],
|
|
406
|
+
metric_mode: str = "f1",
|
|
407
|
+
):
|
|
408
|
+
"""Create a new F1Metric.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
num_classes: number of classes.
|
|
412
|
+
score_thresholds: list of score thresholds to check F1 score for. The final
|
|
413
|
+
metric is the best F1 across score thresholds.
|
|
414
|
+
metric_mode: set to "precision" or "recall" to return that instead of F1
|
|
415
|
+
(default "f1")
|
|
416
|
+
"""
|
|
417
|
+
super().__init__()
|
|
418
|
+
self.num_classes = num_classes
|
|
419
|
+
self.score_thresholds = score_thresholds
|
|
420
|
+
self.metric_mode = metric_mode
|
|
421
|
+
|
|
422
|
+
assert self.metric_mode in ["f1", "precision", "recall"]
|
|
423
|
+
|
|
424
|
+
for cls_idx in range(self.num_classes):
|
|
425
|
+
for thr_idx in range(len(self.score_thresholds)):
|
|
426
|
+
cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
|
|
427
|
+
self.add_state(
|
|
428
|
+
cur_prefix + "tp", default=torch.tensor(0), dist_reduce_fx="sum"
|
|
429
|
+
)
|
|
430
|
+
self.add_state(
|
|
431
|
+
cur_prefix + "fp", default=torch.tensor(0), dist_reduce_fx="sum"
|
|
432
|
+
)
|
|
433
|
+
self.add_state(
|
|
434
|
+
cur_prefix + "fn", default=torch.tensor(0), dist_reduce_fx="sum"
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def _get_state_prefix(self, cls_idx: int, thr_idx: int) -> str:
|
|
438
|
+
return f"{cls_idx}_{thr_idx}_"
|
|
439
|
+
|
|
440
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
441
|
+
"""Update metric.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
preds: the predictions, NxC.
|
|
445
|
+
labels: the targets, N, with values from 0 to C-1.
|
|
446
|
+
"""
|
|
447
|
+
for cls_idx in range(self.num_classes):
|
|
448
|
+
for thr_idx, score_threshold in enumerate(self.score_thresholds):
|
|
449
|
+
pred_bin = preds[:, cls_idx] > score_threshold
|
|
450
|
+
gt_bin = labels == cls_idx
|
|
451
|
+
|
|
452
|
+
tp = torch.count_nonzero(pred_bin & gt_bin).item()
|
|
453
|
+
fp = torch.count_nonzero(pred_bin & torch.logical_not(gt_bin)).item()
|
|
454
|
+
fn = torch.count_nonzero(torch.logical_not(pred_bin) & gt_bin).item()
|
|
455
|
+
|
|
456
|
+
cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
|
|
457
|
+
setattr(self, cur_prefix + "tp", getattr(self, cur_prefix + "tp") + tp)
|
|
458
|
+
setattr(self, cur_prefix + "fp", getattr(self, cur_prefix + "fp") + fp)
|
|
459
|
+
setattr(self, cur_prefix + "fn", getattr(self, cur_prefix + "fn") + fn)
|
|
460
|
+
|
|
461
|
+
def compute(self) -> Any:
|
|
462
|
+
"""Compute metric.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
the best F1 score across score thresholds and classes.
|
|
466
|
+
"""
|
|
467
|
+
best_scores = []
|
|
468
|
+
|
|
469
|
+
for cls_idx in range(self.num_classes):
|
|
470
|
+
best_score = None
|
|
471
|
+
|
|
472
|
+
for thr_idx in range(len(self.score_thresholds)):
|
|
473
|
+
cur_prefix = self._get_state_prefix(cls_idx, thr_idx)
|
|
474
|
+
tp = getattr(self, cur_prefix + "tp")
|
|
475
|
+
fp = getattr(self, cur_prefix + "fp")
|
|
476
|
+
fn = getattr(self, cur_prefix + "fn")
|
|
477
|
+
device = tp.device
|
|
478
|
+
|
|
479
|
+
if tp + fp == 0:
|
|
480
|
+
precision = torch.tensor(0, dtype=torch.float32, device=device)
|
|
481
|
+
else:
|
|
482
|
+
precision = tp / (tp + fp)
|
|
483
|
+
|
|
484
|
+
if tp + fn == 0:
|
|
485
|
+
recall = torch.tensor(0, dtype=torch.float32, device=device)
|
|
486
|
+
else:
|
|
487
|
+
recall = tp / (tp + fn)
|
|
488
|
+
|
|
489
|
+
if precision + recall < 0.001:
|
|
490
|
+
f1 = torch.tensor(0, dtype=torch.float32, device=device)
|
|
491
|
+
else:
|
|
492
|
+
f1 = 2 * precision * recall / (precision + recall)
|
|
493
|
+
|
|
494
|
+
if self.metric_mode == "f1":
|
|
495
|
+
score = f1
|
|
496
|
+
elif self.metric_mode == "precision":
|
|
497
|
+
score = precision
|
|
498
|
+
elif self.metric_mode == "recall":
|
|
499
|
+
score = recall
|
|
500
|
+
|
|
501
|
+
if best_score is None or score > best_score:
|
|
502
|
+
best_score = score
|
|
503
|
+
|
|
504
|
+
best_scores.append(best_score)
|
|
505
|
+
|
|
506
|
+
return torch.mean(torch.stack(best_scores))
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
class MeanIoUMetric(Metric):
|
|
510
|
+
"""Mean IoU for segmentation.
|
|
511
|
+
|
|
512
|
+
This is the mean of the per-class intersection-over-union scores. The per-class
|
|
513
|
+
intersection is the number of pixels across all examples where the predicted label
|
|
514
|
+
and ground truth label are both that class, and the per-class union is defined
|
|
515
|
+
similarly.
|
|
516
|
+
|
|
517
|
+
This differs from torchmetrics.segmentation.MeanIoU, where the mean IoU is computed
|
|
518
|
+
per-image, and averaged across images.
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
def __init__(
|
|
522
|
+
self,
|
|
523
|
+
num_classes: int,
|
|
524
|
+
nodata_value: int | None = None,
|
|
525
|
+
ignore_missing_classes: bool = False,
|
|
526
|
+
class_idx: int | None = None,
|
|
527
|
+
):
|
|
528
|
+
"""Create a new MeanIoUMetric.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
num_classes: the number of classes for the task.
|
|
532
|
+
nodata_value: the value to treat as nodata/invalid. If set and is one of the
|
|
533
|
+
classes, IoU will not be calculated for it. If None, or not one of the
|
|
534
|
+
classes, IoU is calculated for all classes.
|
|
535
|
+
ignore_missing_classes: whether to ignore classes that don't appear in
|
|
536
|
+
either the predictions or the ground truth. If false, the IoU for a
|
|
537
|
+
missing class will be 0.
|
|
538
|
+
class_idx: only compute and return the IoU for this class. This option is
|
|
539
|
+
provided so the user can get per-class IoU results, since Lightning
|
|
540
|
+
only supports scalar return values from metrics.
|
|
541
|
+
"""
|
|
542
|
+
super().__init__()
|
|
543
|
+
self.num_classes = num_classes
|
|
544
|
+
self.nodata_value = nodata_value
|
|
545
|
+
self.ignore_missing_classes = ignore_missing_classes
|
|
546
|
+
self.class_idx = class_idx
|
|
547
|
+
|
|
548
|
+
self.add_state(
|
|
549
|
+
"intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
|
|
550
|
+
)
|
|
551
|
+
self.add_state(
|
|
552
|
+
"unions", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
556
|
+
"""Update metric.
|
|
557
|
+
|
|
558
|
+
Like torchmetrics.segmentation.MeanIoU with input_format="index", we expect
|
|
559
|
+
predictions and labels to both be class integers. This is achieved by passing
|
|
560
|
+
pass_probabilities=False to the SegmentationMetric wrapper.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
preds: the predictions, (N,), with values from 0 to C-1.
|
|
564
|
+
labels: the targets, (N,), with values from 0 to C-1.
|
|
565
|
+
"""
|
|
566
|
+
if preds.min() < 0 or preds.max() >= self.num_classes:
|
|
567
|
+
raise ValueError("predicted class outside of expected range")
|
|
568
|
+
if labels.min() < 0 or labels.max() >= self.num_classes:
|
|
569
|
+
raise ValueError("label class outside of expected range")
|
|
570
|
+
|
|
571
|
+
new_intersections = torch.zeros(
|
|
572
|
+
self.num_classes, device=self.intersections.device
|
|
573
|
+
)
|
|
574
|
+
new_unions = torch.zeros(self.num_classes, device=self.unions.device)
|
|
575
|
+
for cls_idx in range(self.num_classes):
|
|
576
|
+
new_intersections[cls_idx] = (
|
|
577
|
+
(preds == cls_idx) & (labels == cls_idx)
|
|
578
|
+
).sum()
|
|
579
|
+
new_unions[cls_idx] = ((preds == cls_idx) | (labels == cls_idx)).sum()
|
|
580
|
+
self.intersections += new_intersections
|
|
581
|
+
self.unions += new_unions
|
|
582
|
+
|
|
583
|
+
def compute(self) -> Any:
|
|
584
|
+
"""Compute metric.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
the mean IoU across classes.
|
|
588
|
+
"""
|
|
589
|
+
per_class_scores = []
|
|
590
|
+
|
|
591
|
+
for cls_idx in range(self.num_classes):
|
|
592
|
+
# Check if nodata_value is set and is one of the classes
|
|
593
|
+
if self.nodata_value is not None and cls_idx == self.nodata_value:
|
|
594
|
+
continue
|
|
595
|
+
|
|
596
|
+
intersection = self.intersections[cls_idx]
|
|
597
|
+
union = self.unions[cls_idx]
|
|
598
|
+
|
|
599
|
+
if union == 0 and self.ignore_missing_classes:
|
|
600
|
+
continue
|
|
601
|
+
|
|
602
|
+
per_class_scores.append(intersection / union)
|
|
603
|
+
|
|
604
|
+
return torch.mean(torch.stack(per_class_scores))
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -7,6 +7,7 @@ import numpy.typing as npt
|
|
|
7
7
|
import torch
|
|
8
8
|
from torchmetrics import MetricCollection
|
|
9
9
|
|
|
10
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
10
11
|
from rslearn.utils import Feature
|
|
11
12
|
|
|
12
13
|
|
|
@@ -20,8 +21,8 @@ class Task:
|
|
|
20
21
|
|
|
21
22
|
def process_inputs(
|
|
22
23
|
self,
|
|
23
|
-
raw_inputs: dict[str,
|
|
24
|
-
metadata:
|
|
24
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
25
|
+
metadata: SampleMetadata,
|
|
25
26
|
load_targets: bool = True,
|
|
26
27
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
27
28
|
"""Processes the data into targets.
|
|
@@ -38,8 +39,8 @@ class Task:
|
|
|
38
39
|
raise NotImplementedError
|
|
39
40
|
|
|
40
41
|
def process_output(
|
|
41
|
-
self, raw_output: Any, metadata:
|
|
42
|
-
) -> npt.NDArray[Any] | list[Feature]:
|
|
42
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
43
|
+
) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
|
|
43
44
|
"""Processes an output into raster or vector data.
|
|
44
45
|
|
|
45
46
|
Args:
|
|
@@ -47,7 +48,7 @@ class Task:
|
|
|
47
48
|
metadata: metadata about the patch being read
|
|
48
49
|
|
|
49
50
|
Returns:
|
|
50
|
-
|
|
51
|
+
raster data, vector data, or multi-task dictionary output.
|
|
51
52
|
"""
|
|
52
53
|
raise NotImplementedError
|
|
53
54
|
|
|
@@ -12,7 +12,7 @@ class Sequential(torch.nn.Module):
|
|
|
12
12
|
tuple.
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
def __init__(self, *args):
|
|
15
|
+
def __init__(self, *args: Any) -> None:
|
|
16
16
|
"""Initialize a new Sequential from a list of transforms."""
|
|
17
17
|
super().__init__()
|
|
18
18
|
self.transforms = torch.nn.ModuleList(args)
|