rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
import torch
|
|
7
6
|
from torchmetrics import Metric, MetricCollection
|
|
8
7
|
|
|
8
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
9
9
|
from rslearn.utils import Feature
|
|
10
10
|
|
|
11
11
|
from .task import Task
|
|
@@ -29,8 +29,8 @@ class MultiTask(Task):
|
|
|
29
29
|
|
|
30
30
|
def process_inputs(
|
|
31
31
|
self,
|
|
32
|
-
raw_inputs: dict[str,
|
|
33
|
-
metadata:
|
|
32
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
33
|
+
metadata: SampleMetadata,
|
|
34
34
|
load_targets: bool = True,
|
|
35
35
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
36
36
|
"""Processes the data into targets.
|
|
@@ -46,7 +46,14 @@ class MultiTask(Task):
|
|
|
46
46
|
"""
|
|
47
47
|
input_dict = {}
|
|
48
48
|
target_dict = {}
|
|
49
|
-
|
|
49
|
+
if metadata.dataset_source is None:
|
|
50
|
+
# No multi-dataset, so always compute across all tasks
|
|
51
|
+
task_iter = list(self.tasks.items())
|
|
52
|
+
else:
|
|
53
|
+
# Multi-dataset, so only compute for the task in this dataset
|
|
54
|
+
task_iter = [(metadata.dataset_source, self.tasks[metadata.dataset_source])]
|
|
55
|
+
|
|
56
|
+
for task_name, task in task_iter:
|
|
50
57
|
cur_raw_inputs = {}
|
|
51
58
|
for k, v in self.input_mapping[task_name].items():
|
|
52
59
|
if k not in raw_inputs:
|
|
@@ -62,12 +69,13 @@ class MultiTask(Task):
|
|
|
62
69
|
return input_dict, target_dict
|
|
63
70
|
|
|
64
71
|
def process_output(
|
|
65
|
-
self, raw_output: Any, metadata:
|
|
66
|
-
) ->
|
|
72
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
73
|
+
) -> dict[str, Any]:
|
|
67
74
|
"""Processes an output into raster or vector data.
|
|
68
75
|
|
|
69
76
|
Args:
|
|
70
|
-
raw_output: the output from prediction head.
|
|
77
|
+
raw_output: the output from prediction head. It must be a dict mapping from
|
|
78
|
+
task name to per-task output for this sample.
|
|
71
79
|
metadata: metadata about the patch being read
|
|
72
80
|
|
|
73
81
|
Returns:
|
|
@@ -75,9 +83,11 @@ class MultiTask(Task):
|
|
|
75
83
|
"""
|
|
76
84
|
processed_output = {}
|
|
77
85
|
for task_name, task in self.tasks.items():
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
86
|
+
if task_name in raw_output:
|
|
87
|
+
# In multi-dataset training, we may not have all datasets in the batch
|
|
88
|
+
processed_output[task_name] = task.process_output(
|
|
89
|
+
raw_output[task_name], metadata
|
|
90
|
+
)
|
|
81
91
|
return processed_output
|
|
82
92
|
|
|
83
93
|
def visualize(
|
|
@@ -146,10 +156,14 @@ class MetricWrapper(Metric):
|
|
|
146
156
|
preds: the predictions
|
|
147
157
|
targets: the targets
|
|
148
158
|
"""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
159
|
+
try:
|
|
160
|
+
self.metric.update(
|
|
161
|
+
[pred[self.task_name] for pred in preds],
|
|
162
|
+
[target[self.task_name] for target in targets],
|
|
163
|
+
)
|
|
164
|
+
except KeyError:
|
|
165
|
+
# In multi-dataset training, we may not have all datasets in the batch
|
|
166
|
+
pass
|
|
153
167
|
|
|
154
168
|
def compute(self) -> Any:
|
|
155
169
|
"""Returns the computed metric."""
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
"""Per-pixel regression task."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import numpy.typing as npt
|
|
7
|
+
import torch
|
|
8
|
+
import torchmetrics
|
|
9
|
+
from torchmetrics import Metric, MetricCollection
|
|
10
|
+
|
|
11
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
+
from rslearn.train.model_context import (
|
|
13
|
+
ModelContext,
|
|
14
|
+
ModelOutput,
|
|
15
|
+
RasterImage,
|
|
16
|
+
SampleMetadata,
|
|
17
|
+
)
|
|
18
|
+
from rslearn.utils.feature import Feature
|
|
19
|
+
|
|
20
|
+
from .task import BasicTask
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PerPixelRegressionTask(BasicTask):
|
|
24
|
+
"""A per-pixel regression task."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
scale_factor: float = 1,
|
|
29
|
+
metric_mode: Literal["mse", "l1"] = "mse",
|
|
30
|
+
nodata_value: float | None = None,
|
|
31
|
+
**kwargs: Any,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initialize a new PerPixelRegressionTask.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
scale_factor: multiply ground truth values by this factor before using it for
|
|
37
|
+
training.
|
|
38
|
+
metric_mode: what metric to use, either "mse" (default) or "l1"
|
|
39
|
+
nodata_value: optional value to treat as invalid. The loss will be masked
|
|
40
|
+
at pixels where the ground truth value is equal to nodata_value.
|
|
41
|
+
kwargs: other arguments to pass to BasicTask
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
self.scale_factor = scale_factor
|
|
45
|
+
self.metric_mode = metric_mode
|
|
46
|
+
self.nodata_value = nodata_value
|
|
47
|
+
|
|
48
|
+
def process_inputs(
|
|
49
|
+
self,
|
|
50
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
51
|
+
metadata: SampleMetadata,
|
|
52
|
+
load_targets: bool = True,
|
|
53
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
54
|
+
"""Processes the data into targets.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
raw_inputs: raster or vector data to process
|
|
58
|
+
metadata: metadata about the patch being read
|
|
59
|
+
load_targets: whether to load the targets or only inputs
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
tuple (input_dict, target_dict) containing the processed inputs and targets
|
|
63
|
+
that are compatible with both metrics and loss functions
|
|
64
|
+
"""
|
|
65
|
+
if not load_targets:
|
|
66
|
+
return {}, {}
|
|
67
|
+
|
|
68
|
+
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
69
|
+
assert raw_inputs["targets"].image.shape[0] == 1
|
|
70
|
+
assert raw_inputs["targets"].image.shape[1] == 1
|
|
71
|
+
labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
|
|
72
|
+
|
|
73
|
+
if self.nodata_value is not None:
|
|
74
|
+
valid = (
|
|
75
|
+
raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
|
|
76
|
+
).float()
|
|
77
|
+
else:
|
|
78
|
+
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
79
|
+
|
|
80
|
+
return {}, {
|
|
81
|
+
"values": labels,
|
|
82
|
+
"valid": valid,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
def process_output(
|
|
86
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
87
|
+
) -> npt.NDArray[Any] | list[Feature]:
|
|
88
|
+
"""Processes an output into raster or vector data.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
raw_output: the output from prediction head, which must be an HW tensor.
|
|
92
|
+
metadata: metadata about the patch being read
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
either raster or vector data.
|
|
96
|
+
"""
|
|
97
|
+
if not isinstance(raw_output, torch.Tensor):
|
|
98
|
+
raise ValueError("output for PerPixelRegressionTask must be a tensor")
|
|
99
|
+
if len(raw_output.shape) != 2:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
|
|
102
|
+
)
|
|
103
|
+
return (raw_output / self.scale_factor).cpu().numpy()
|
|
104
|
+
|
|
105
|
+
def visualize(
|
|
106
|
+
self,
|
|
107
|
+
input_dict: dict[str, Any],
|
|
108
|
+
target_dict: dict[str, Any] | None,
|
|
109
|
+
output: Any,
|
|
110
|
+
) -> dict[str, npt.NDArray[Any]]:
|
|
111
|
+
"""Visualize the outputs and targets.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
input_dict: the input dict from process_inputs
|
|
115
|
+
target_dict: the target dict from process_inputs
|
|
116
|
+
output: the prediction
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
a dictionary mapping image name to visualization image
|
|
120
|
+
"""
|
|
121
|
+
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
122
|
+
if target_dict is None:
|
|
123
|
+
raise ValueError("target_dict is required for visualization")
|
|
124
|
+
gt_values = target_dict["classes"].cpu().numpy()
|
|
125
|
+
pred_values = output.cpu().numpy()[0, :, :]
|
|
126
|
+
gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
|
|
127
|
+
pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
|
|
128
|
+
return {
|
|
129
|
+
"image": np.array(image),
|
|
130
|
+
"gt": gt_vis,
|
|
131
|
+
"pred": pred_vis,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
def get_metrics(self) -> MetricCollection:
|
|
135
|
+
"""Get the metrics for this task."""
|
|
136
|
+
metric_dict: dict[str, Metric] = {}
|
|
137
|
+
|
|
138
|
+
if self.metric_mode == "mse":
|
|
139
|
+
metric_dict["mse"] = PerPixelRegressionMetricWrapper(
|
|
140
|
+
metric=torchmetrics.MeanSquaredError(), scale_factor=self.scale_factor
|
|
141
|
+
)
|
|
142
|
+
elif self.metric_mode == "l1":
|
|
143
|
+
metric_dict["l1"] = PerPixelRegressionMetricWrapper(
|
|
144
|
+
metric=torchmetrics.MeanAbsoluteError(), scale_factor=self.scale_factor
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return MetricCollection(metric_dict)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class PerPixelRegressionHead(Predictor):
|
|
151
|
+
"""Head for per-pixel regression task."""
|
|
152
|
+
|
|
153
|
+
def __init__(
|
|
154
|
+
self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
|
|
155
|
+
):
|
|
156
|
+
"""Initialize a new RegressionHead.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
loss_mode: the loss function to use, either "mse" (default) or "l1".
|
|
160
|
+
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
161
|
+
requires targets to be between 0-1.
|
|
162
|
+
"""
|
|
163
|
+
super().__init__()
|
|
164
|
+
|
|
165
|
+
if loss_mode not in ["mse", "l1"]:
|
|
166
|
+
raise ValueError("invalid loss mode")
|
|
167
|
+
|
|
168
|
+
self.loss_mode = loss_mode
|
|
169
|
+
self.use_sigmoid = use_sigmoid
|
|
170
|
+
|
|
171
|
+
def forward(
|
|
172
|
+
self,
|
|
173
|
+
intermediates: Any,
|
|
174
|
+
context: ModelContext,
|
|
175
|
+
targets: list[dict[str, Any]] | None = None,
|
|
176
|
+
) -> ModelOutput:
|
|
177
|
+
"""Compute the regression outputs and loss from logits and targets.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
intermediates: output from previous component, which must be a FeatureMaps
|
|
181
|
+
with one feature map corresponding to the logits. The channel dimension
|
|
182
|
+
size must be 1.
|
|
183
|
+
context: the model context.
|
|
184
|
+
targets: must contain values key that stores the regression labels, and
|
|
185
|
+
valid key containing mask image indicating where the labels are valid.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
tuple of outputs and loss dict. The output is a BHW tensor so that the
|
|
189
|
+
per-sample output is an HW tensor.
|
|
190
|
+
"""
|
|
191
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
192
|
+
raise ValueError(
|
|
193
|
+
"the input to PerPixelRegressionHead must be a FeatureMaps"
|
|
194
|
+
)
|
|
195
|
+
if len(intermediates.feature_maps) != 1:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
"the input to PerPixelRegressionHead must have one feature map"
|
|
198
|
+
)
|
|
199
|
+
if intermediates.feature_maps[0].shape[1] != 1:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"the input to PerPixelRegressionHead must have channel dimension size 1, but got {intermediates.feature_maps[0].shape}"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
logits = intermediates.feature_maps[0][:, 0, :, :]
|
|
205
|
+
|
|
206
|
+
if self.use_sigmoid:
|
|
207
|
+
outputs = torch.nn.functional.sigmoid(logits)
|
|
208
|
+
else:
|
|
209
|
+
outputs = logits
|
|
210
|
+
|
|
211
|
+
losses = {}
|
|
212
|
+
if targets:
|
|
213
|
+
labels = torch.stack([target["values"] for target in targets])
|
|
214
|
+
mask = torch.stack([target["valid"] for target in targets])
|
|
215
|
+
|
|
216
|
+
if self.loss_mode == "mse":
|
|
217
|
+
scores = torch.square(outputs - labels)
|
|
218
|
+
elif self.loss_mode == "l1":
|
|
219
|
+
scores = torch.abs(outputs - labels)
|
|
220
|
+
else:
|
|
221
|
+
assert False
|
|
222
|
+
|
|
223
|
+
# Compute average but only over valid pixels.
|
|
224
|
+
mask_total = mask.sum()
|
|
225
|
+
if mask_total == 0:
|
|
226
|
+
# Just average over all pixels but it will be zero.
|
|
227
|
+
losses["regress"] = (scores * mask).mean()
|
|
228
|
+
else:
|
|
229
|
+
losses["regress"] = (scores * mask).sum() / mask_total
|
|
230
|
+
|
|
231
|
+
return ModelOutput(
|
|
232
|
+
outputs=outputs,
|
|
233
|
+
loss_dict=losses,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class PerPixelRegressionMetricWrapper(Metric):
|
|
238
|
+
"""Metric for per-pixel regression task."""
|
|
239
|
+
|
|
240
|
+
def __init__(self, metric: Metric, scale_factor: float, **kwargs: Any) -> None:
|
|
241
|
+
"""Initialize a new PerPixelRegressionMetricWrapper.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
metric: the underlying torchmetric to apply, which should accept a flat
|
|
245
|
+
tensor of predicted values followed by a flat tensor of target values
|
|
246
|
+
scale_factor: scale factor to undo so that metric is based on original
|
|
247
|
+
values
|
|
248
|
+
kwargs: other arguments to pass to super constructor
|
|
249
|
+
"""
|
|
250
|
+
super().__init__(**kwargs)
|
|
251
|
+
self.metric = metric
|
|
252
|
+
self.scale_factor = scale_factor
|
|
253
|
+
|
|
254
|
+
def update(
|
|
255
|
+
self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
|
|
256
|
+
) -> None:
|
|
257
|
+
"""Update metric.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
preds: the predictions
|
|
261
|
+
targets: the targets
|
|
262
|
+
"""
|
|
263
|
+
if not isinstance(preds, torch.Tensor):
|
|
264
|
+
preds = torch.stack(preds)
|
|
265
|
+
labels = torch.stack([target["values"] for target in targets])
|
|
266
|
+
|
|
267
|
+
# Sub-select the valid labels.
|
|
268
|
+
# We flatten the prediction and label images at valid pixels.
|
|
269
|
+
if len(preds.shape) == 4:
|
|
270
|
+
assert preds.shape[1] == 1
|
|
271
|
+
preds = preds[:, 0, :, :]
|
|
272
|
+
mask = torch.stack([target["valid"] > 0 for target in targets])
|
|
273
|
+
preds = preds[mask]
|
|
274
|
+
labels = labels[mask]
|
|
275
|
+
if len(preds) == 0:
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
self.metric.update(preds, labels)
|
|
279
|
+
|
|
280
|
+
def compute(self) -> Any:
|
|
281
|
+
"""Returns the computed metric."""
|
|
282
|
+
return self.metric.compute()
|
|
283
|
+
|
|
284
|
+
def reset(self) -> None:
|
|
285
|
+
"""Reset metric."""
|
|
286
|
+
super().reset()
|
|
287
|
+
self.metric.reset()
|
|
288
|
+
|
|
289
|
+
def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
290
|
+
"""Returns a plot of the metric."""
|
|
291
|
+
return self.metric.plot(*args, **kwargs)
|