rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +49 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +18 -12
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/train/tasks/detection.py
CHANGED
|
@@ -12,6 +12,7 @@ import torchmetrics.classification
|
|
|
12
12
|
import torchvision
|
|
13
13
|
from torchmetrics import Metric, MetricCollection
|
|
14
14
|
|
|
15
|
+
from rslearn.train.model_context import SampleMetadata
|
|
15
16
|
from rslearn.utils import Feature, STGeometry
|
|
16
17
|
|
|
17
18
|
from .task import BasicTask
|
|
@@ -127,7 +128,7 @@ class DetectionTask(BasicTask):
|
|
|
127
128
|
def process_inputs(
|
|
128
129
|
self,
|
|
129
130
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
130
|
-
metadata:
|
|
131
|
+
metadata: SampleMetadata,
|
|
131
132
|
load_targets: bool = True,
|
|
132
133
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
133
134
|
"""Processes the data into targets.
|
|
@@ -144,6 +145,8 @@ class DetectionTask(BasicTask):
|
|
|
144
145
|
if not load_targets:
|
|
145
146
|
return {}, {}
|
|
146
147
|
|
|
148
|
+
bounds = metadata.patch_bounds
|
|
149
|
+
|
|
147
150
|
boxes = []
|
|
148
151
|
class_labels = []
|
|
149
152
|
valid = 1
|
|
@@ -186,39 +189,33 @@ class DetectionTask(BasicTask):
|
|
|
186
189
|
else:
|
|
187
190
|
box = [int(val) for val in shp.bounds]
|
|
188
191
|
|
|
189
|
-
if box[0] >=
|
|
192
|
+
if box[0] >= bounds[2] or box[2] <= bounds[0]:
|
|
190
193
|
continue
|
|
191
|
-
if box[1] >=
|
|
194
|
+
if box[1] >= bounds[3] or box[3] <= bounds[1]:
|
|
192
195
|
continue
|
|
193
196
|
|
|
194
197
|
if self.exclude_by_center:
|
|
195
198
|
center_col = (box[0] + box[2]) // 2
|
|
196
199
|
center_row = (box[1] + box[3]) // 2
|
|
197
|
-
if
|
|
198
|
-
center_col <= metadata["bounds"][0]
|
|
199
|
-
or center_col >= metadata["bounds"][2]
|
|
200
|
-
):
|
|
200
|
+
if center_col <= bounds[0] or center_col >= bounds[2]:
|
|
201
201
|
continue
|
|
202
|
-
if
|
|
203
|
-
center_row <= metadata["bounds"][1]
|
|
204
|
-
or center_row >= metadata["bounds"][3]
|
|
205
|
-
):
|
|
202
|
+
if center_row <= bounds[1] or center_row >= bounds[3]:
|
|
206
203
|
continue
|
|
207
204
|
|
|
208
205
|
if self.clip_boxes:
|
|
209
206
|
box = [
|
|
210
|
-
np.clip(box[0],
|
|
211
|
-
np.clip(box[1],
|
|
212
|
-
np.clip(box[2],
|
|
213
|
-
np.clip(box[3],
|
|
207
|
+
np.clip(box[0], bounds[0], bounds[2]),
|
|
208
|
+
np.clip(box[1], bounds[1], bounds[3]),
|
|
209
|
+
np.clip(box[2], bounds[0], bounds[2]),
|
|
210
|
+
np.clip(box[3], bounds[1], bounds[3]),
|
|
214
211
|
]
|
|
215
212
|
|
|
216
213
|
# Convert to relative coordinates.
|
|
217
214
|
box = [
|
|
218
|
-
box[0] -
|
|
219
|
-
box[1] -
|
|
220
|
-
box[2] -
|
|
221
|
-
box[3] -
|
|
215
|
+
box[0] - bounds[0],
|
|
216
|
+
box[1] - bounds[1],
|
|
217
|
+
box[2] - bounds[0],
|
|
218
|
+
box[3] - bounds[1],
|
|
222
219
|
]
|
|
223
220
|
|
|
224
221
|
boxes.append(box)
|
|
@@ -238,16 +235,12 @@ class DetectionTask(BasicTask):
|
|
|
238
235
|
"valid": torch.tensor(valid, dtype=torch.int32),
|
|
239
236
|
"boxes": boxes,
|
|
240
237
|
"labels": class_labels,
|
|
241
|
-
"width": torch.tensor(
|
|
242
|
-
|
|
243
|
-
),
|
|
244
|
-
"height": torch.tensor(
|
|
245
|
-
metadata["bounds"][3] - metadata["bounds"][1], dtype=torch.float32
|
|
246
|
-
),
|
|
238
|
+
"width": torch.tensor(bounds[2] - bounds[0], dtype=torch.float32),
|
|
239
|
+
"height": torch.tensor(bounds[3] - bounds[1], dtype=torch.float32),
|
|
247
240
|
}
|
|
248
241
|
|
|
249
242
|
def process_output(
|
|
250
|
-
self, raw_output: Any, metadata:
|
|
243
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
251
244
|
) -> npt.NDArray[Any] | list[Feature]:
|
|
252
245
|
"""Processes an output into raster or vector data.
|
|
253
246
|
|
|
@@ -267,12 +260,12 @@ class DetectionTask(BasicTask):
|
|
|
267
260
|
features = []
|
|
268
261
|
for box, class_id, score in zip(boxes, class_ids, scores):
|
|
269
262
|
shp = shapely.box(
|
|
270
|
-
metadata[
|
|
271
|
-
metadata[
|
|
272
|
-
metadata[
|
|
273
|
-
metadata[
|
|
263
|
+
metadata.patch_bounds[0] + float(box[0]),
|
|
264
|
+
metadata.patch_bounds[1] + float(box[1]),
|
|
265
|
+
metadata.patch_bounds[0] + float(box[2]),
|
|
266
|
+
metadata.patch_bounds[1] + float(box[3]),
|
|
274
267
|
)
|
|
275
|
-
geom = STGeometry(metadata
|
|
268
|
+
geom = STGeometry(metadata.projection, shp, None)
|
|
276
269
|
properties: dict[str, Any] = {
|
|
277
270
|
"score": float(score),
|
|
278
271
|
}
|
rslearn/train/tasks/embedding.py
CHANGED
|
@@ -6,6 +6,8 @@ import numpy.typing as npt
|
|
|
6
6
|
import torch
|
|
7
7
|
from torchmetrics import MetricCollection
|
|
8
8
|
|
|
9
|
+
from rslearn.models.component import FeatureMaps
|
|
10
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
9
11
|
from rslearn.utils import Feature
|
|
10
12
|
|
|
11
13
|
from .task import Task
|
|
@@ -21,7 +23,7 @@ class EmbeddingTask(Task):
|
|
|
21
23
|
def process_inputs(
|
|
22
24
|
self,
|
|
23
25
|
raw_inputs: dict[str, torch.Tensor],
|
|
24
|
-
metadata:
|
|
26
|
+
metadata: SampleMetadata,
|
|
25
27
|
load_targets: bool = True,
|
|
26
28
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
27
29
|
"""Processes the data into targets.
|
|
@@ -38,17 +40,22 @@ class EmbeddingTask(Task):
|
|
|
38
40
|
return {}, {}
|
|
39
41
|
|
|
40
42
|
def process_output(
|
|
41
|
-
self, raw_output: Any, metadata:
|
|
43
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
42
44
|
) -> npt.NDArray[Any] | list[Feature]:
|
|
43
45
|
"""Processes an output into raster or vector data.
|
|
44
46
|
|
|
45
47
|
Args:
|
|
46
|
-
raw_output: the output from prediction head.
|
|
48
|
+
raw_output: the output from prediction head, which must be a CxHxW tensor.
|
|
47
49
|
metadata: metadata about the patch being read
|
|
48
50
|
|
|
49
51
|
Returns:
|
|
50
52
|
either raster or vector data.
|
|
51
53
|
"""
|
|
54
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"output for EmbeddingTask must be a tensor with three dimensions"
|
|
57
|
+
)
|
|
58
|
+
|
|
52
59
|
# Just convert the raw output to numpy array that can be saved to GeoTIFF.
|
|
53
60
|
return raw_output.cpu().numpy()
|
|
54
61
|
|
|
@@ -76,41 +83,38 @@ class EmbeddingTask(Task):
|
|
|
76
83
|
return MetricCollection({})
|
|
77
84
|
|
|
78
85
|
|
|
79
|
-
class EmbeddingHead
|
|
86
|
+
class EmbeddingHead:
|
|
80
87
|
"""Head for embedding task.
|
|
81
88
|
|
|
82
|
-
|
|
83
|
-
returns a dummy loss.
|
|
89
|
+
It just adds a dummy loss to act as a Predictor.
|
|
84
90
|
"""
|
|
85
91
|
|
|
86
|
-
def __init__(self, feature_map_index: int | None = 0):
|
|
87
|
-
"""Create a new EmbeddingHead.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
feature_map_index: the index of the feature map to choose from the input
|
|
91
|
-
list of multi-scale feature maps (default 0). If the input is already
|
|
92
|
-
a single feature map, then set to None.
|
|
93
|
-
"""
|
|
94
|
-
super().__init__()
|
|
95
|
-
self.feature_map_index = feature_map_index
|
|
96
|
-
|
|
97
92
|
def forward(
|
|
98
93
|
self,
|
|
99
|
-
|
|
100
|
-
|
|
94
|
+
intermediates: Any,
|
|
95
|
+
context: ModelContext,
|
|
101
96
|
targets: list[dict[str, Any]] | None = None,
|
|
102
|
-
) ->
|
|
103
|
-
"""
|
|
97
|
+
) -> ModelOutput:
|
|
98
|
+
"""Return the feature map along with a dummy loss.
|
|
104
99
|
|
|
105
100
|
Args:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
101
|
+
intermediates: output from the previous model component, which must be a
|
|
102
|
+
FeatureMaps consisting of a single feature map.
|
|
103
|
+
context: the model context.
|
|
104
|
+
targets: the targets (ignored).
|
|
109
105
|
|
|
110
106
|
Returns:
|
|
111
|
-
|
|
107
|
+
model output with the feature map that was input to this component along
|
|
108
|
+
with a dummy loss.
|
|
112
109
|
"""
|
|
113
|
-
if
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
110
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
111
|
+
raise ValueError("input to EmbeddingHead must be a FeatureMaps")
|
|
112
|
+
if len(intermediates.feature_maps) != 1:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"input to EmbeddingHead must have one feature map, but got {len(intermediates.feature_maps)}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return ModelOutput(
|
|
118
|
+
outputs=intermediates.feature_maps[0],
|
|
119
|
+
loss_dict={"loss": 0},
|
|
120
|
+
)
|
|
@@ -6,6 +6,7 @@ import numpy.typing as npt
|
|
|
6
6
|
import torch
|
|
7
7
|
from torchmetrics import Metric, MetricCollection
|
|
8
8
|
|
|
9
|
+
from rslearn.train.model_context import SampleMetadata
|
|
9
10
|
from rslearn.utils import Feature
|
|
10
11
|
|
|
11
12
|
from .task import Task
|
|
@@ -30,7 +31,7 @@ class MultiTask(Task):
|
|
|
30
31
|
def process_inputs(
|
|
31
32
|
self,
|
|
32
33
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
33
|
-
metadata:
|
|
34
|
+
metadata: SampleMetadata,
|
|
34
35
|
load_targets: bool = True,
|
|
35
36
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
36
37
|
"""Processes the data into targets.
|
|
@@ -46,14 +47,12 @@ class MultiTask(Task):
|
|
|
46
47
|
"""
|
|
47
48
|
input_dict = {}
|
|
48
49
|
target_dict = {}
|
|
49
|
-
if metadata
|
|
50
|
+
if metadata.dataset_source is None:
|
|
50
51
|
# No multi-dataset, so always compute across all tasks
|
|
51
52
|
task_iter = list(self.tasks.items())
|
|
52
53
|
else:
|
|
53
54
|
# Multi-dataset, so only compute for the task in this dataset
|
|
54
|
-
task_iter = [
|
|
55
|
-
(metadata["dataset_source"], self.tasks[metadata["dataset_source"]])
|
|
56
|
-
]
|
|
55
|
+
task_iter = [(metadata.dataset_source, self.tasks[metadata.dataset_source])]
|
|
57
56
|
|
|
58
57
|
for task_name, task in task_iter:
|
|
59
58
|
cur_raw_inputs = {}
|
|
@@ -71,12 +70,13 @@ class MultiTask(Task):
|
|
|
71
70
|
return input_dict, target_dict
|
|
72
71
|
|
|
73
72
|
def process_output(
|
|
74
|
-
self, raw_output: Any, metadata:
|
|
73
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
75
74
|
) -> dict[str, Any]:
|
|
76
75
|
"""Processes an output into raster or vector data.
|
|
77
76
|
|
|
78
77
|
Args:
|
|
79
|
-
raw_output: the output from prediction head.
|
|
78
|
+
raw_output: the output from prediction head. It must be a dict mapping from
|
|
79
|
+
task name to per-task output for this sample.
|
|
80
80
|
metadata: metadata about the patch being read
|
|
81
81
|
|
|
82
82
|
Returns:
|
|
@@ -8,6 +8,8 @@ import torch
|
|
|
8
8
|
import torchmetrics
|
|
9
9
|
from torchmetrics import Metric, MetricCollection
|
|
10
10
|
|
|
11
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
11
13
|
from rslearn.utils.feature import Feature
|
|
12
14
|
|
|
13
15
|
from .task import BasicTask
|
|
@@ -41,7 +43,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
41
43
|
def process_inputs(
|
|
42
44
|
self,
|
|
43
45
|
raw_inputs: dict[str, torch.Tensor],
|
|
44
|
-
metadata:
|
|
46
|
+
metadata: SampleMetadata,
|
|
45
47
|
load_targets: bool = True,
|
|
46
48
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
47
49
|
"""Processes the data into targets.
|
|
@@ -72,20 +74,23 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
72
74
|
}
|
|
73
75
|
|
|
74
76
|
def process_output(
|
|
75
|
-
self, raw_output: Any, metadata:
|
|
77
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
76
78
|
) -> npt.NDArray[Any] | list[Feature]:
|
|
77
79
|
"""Processes an output into raster or vector data.
|
|
78
80
|
|
|
79
81
|
Args:
|
|
80
|
-
raw_output: the output from prediction head.
|
|
82
|
+
raw_output: the output from prediction head, which must be an HW tensor.
|
|
81
83
|
metadata: metadata about the patch being read
|
|
82
84
|
|
|
83
85
|
Returns:
|
|
84
86
|
either raster or vector data.
|
|
85
87
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
88
|
+
if not isinstance(raw_output, torch.Tensor):
|
|
89
|
+
raise ValueError("output for PerPixelRegressionTask must be a tensor")
|
|
90
|
+
if len(raw_output.shape) != 2:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
|
|
93
|
+
)
|
|
89
94
|
return (raw_output / self.scale_factor).cpu().numpy()
|
|
90
95
|
|
|
91
96
|
def visualize(
|
|
@@ -133,7 +138,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
133
138
|
return MetricCollection(metric_dict)
|
|
134
139
|
|
|
135
140
|
|
|
136
|
-
class PerPixelRegressionHead(
|
|
141
|
+
class PerPixelRegressionHead(Predictor):
|
|
137
142
|
"""Head for per-pixel regression task."""
|
|
138
143
|
|
|
139
144
|
def __init__(
|
|
@@ -156,24 +161,38 @@ class PerPixelRegressionHead(torch.nn.Module):
|
|
|
156
161
|
|
|
157
162
|
def forward(
|
|
158
163
|
self,
|
|
159
|
-
|
|
160
|
-
|
|
164
|
+
intermediates: Any,
|
|
165
|
+
context: ModelContext,
|
|
161
166
|
targets: list[dict[str, Any]] | None = None,
|
|
162
|
-
) ->
|
|
167
|
+
) -> ModelOutput:
|
|
163
168
|
"""Compute the regression outputs and loss from logits and targets.
|
|
164
169
|
|
|
165
170
|
Args:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
171
|
+
intermediates: output from previous component, which must be a FeatureMaps
|
|
172
|
+
with one feature map corresponding to the logits. The channel dimension
|
|
173
|
+
size must be 1.
|
|
174
|
+
context: the model context.
|
|
175
|
+
targets: must contain values key that stores the regression labels, and
|
|
176
|
+
valid key containing mask image indicating where the labels are valid.
|
|
169
177
|
|
|
170
178
|
Returns:
|
|
171
|
-
tuple of outputs and loss dict
|
|
179
|
+
tuple of outputs and loss dict. The output is a BHW tensor so that the
|
|
180
|
+
per-sample output is an HW tensor.
|
|
172
181
|
"""
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
182
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"the input to PerPixelRegressionHead must be a FeatureMaps"
|
|
185
|
+
)
|
|
186
|
+
if len(intermediates.feature_maps) != 1:
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"the input to PerPixelRegressionHead must have one feature map"
|
|
189
|
+
)
|
|
190
|
+
if intermediates.feature_maps[0].shape[1] != 1:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
f"the input to PerPixelRegressionHead must have channel dimension size 1, but got {intermediates.feature_maps[0].shape}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
logits = intermediates.feature_maps[0][:, 0, :, :]
|
|
177
196
|
|
|
178
197
|
if self.use_sigmoid:
|
|
179
198
|
outputs = torch.nn.functional.sigmoid(logits)
|
|
@@ -200,7 +219,10 @@ class PerPixelRegressionHead(torch.nn.Module):
|
|
|
200
219
|
else:
|
|
201
220
|
losses["regress"] = (scores * mask).sum() / mask_total
|
|
202
221
|
|
|
203
|
-
return
|
|
222
|
+
return ModelOutput(
|
|
223
|
+
outputs=outputs,
|
|
224
|
+
loss_dict=losses,
|
|
225
|
+
)
|
|
204
226
|
|
|
205
227
|
|
|
206
228
|
class PerPixelRegressionMetricWrapper(Metric):
|
|
@@ -10,6 +10,8 @@ import torchmetrics
|
|
|
10
10
|
from PIL import Image, ImageDraw
|
|
11
11
|
from torchmetrics import Metric, MetricCollection
|
|
12
12
|
|
|
13
|
+
from rslearn.models.component import FeatureVector, Predictor
|
|
14
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
13
15
|
from rslearn.utils.feature import Feature
|
|
14
16
|
from rslearn.utils.geometry import STGeometry
|
|
15
17
|
|
|
@@ -62,7 +64,7 @@ class RegressionTask(BasicTask):
|
|
|
62
64
|
def process_inputs(
|
|
63
65
|
self,
|
|
64
66
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
65
|
-
metadata:
|
|
67
|
+
metadata: SampleMetadata,
|
|
66
68
|
load_targets: bool = True,
|
|
67
69
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
68
70
|
"""Processes the data into targets.
|
|
@@ -103,22 +105,26 @@ class RegressionTask(BasicTask):
|
|
|
103
105
|
}
|
|
104
106
|
|
|
105
107
|
def process_output(
|
|
106
|
-
self, raw_output: Any, metadata:
|
|
107
|
-
) ->
|
|
108
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
109
|
+
) -> list[Feature]:
|
|
108
110
|
"""Processes an output into raster or vector data.
|
|
109
111
|
|
|
110
112
|
Args:
|
|
111
|
-
raw_output: the output from prediction head.
|
|
113
|
+
raw_output: the output from prediction head, which must be a scalar tensor.
|
|
112
114
|
metadata: metadata about the patch being read
|
|
113
115
|
|
|
114
116
|
Returns:
|
|
115
|
-
|
|
117
|
+
a list with a single Feature corresponding to the patch extent and with a
|
|
118
|
+
property containing the predicted value.
|
|
116
119
|
"""
|
|
120
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 0:
|
|
121
|
+
raise ValueError("output for RegressionTask must be a scalar Tensor")
|
|
122
|
+
|
|
117
123
|
output = raw_output.item() / self.scale_factor
|
|
118
124
|
feature = Feature(
|
|
119
125
|
STGeometry(
|
|
120
|
-
metadata
|
|
121
|
-
shapely.Point(metadata[
|
|
126
|
+
metadata.projection,
|
|
127
|
+
shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
|
|
122
128
|
None,
|
|
123
129
|
),
|
|
124
130
|
{
|
|
@@ -180,7 +186,7 @@ class RegressionTask(BasicTask):
|
|
|
180
186
|
return MetricCollection(metric_dict)
|
|
181
187
|
|
|
182
188
|
|
|
183
|
-
class RegressionHead(
|
|
189
|
+
class RegressionHead(Predictor):
|
|
184
190
|
"""Head for regression task."""
|
|
185
191
|
|
|
186
192
|
def __init__(
|
|
@@ -199,24 +205,32 @@ class RegressionHead(torch.nn.Module):
|
|
|
199
205
|
|
|
200
206
|
def forward(
|
|
201
207
|
self,
|
|
202
|
-
|
|
203
|
-
|
|
208
|
+
intermediates: Any,
|
|
209
|
+
context: ModelContext,
|
|
204
210
|
targets: list[dict[str, Any]] | None = None,
|
|
205
|
-
) ->
|
|
211
|
+
) -> ModelOutput:
|
|
206
212
|
"""Compute the regression outputs and loss from logits and targets.
|
|
207
213
|
|
|
208
214
|
Args:
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
215
|
+
intermediates: output from previous model component, which must be a
|
|
216
|
+
FeatureVector with channel dimension size 1 (Bx1).
|
|
217
|
+
context: the model context.
|
|
218
|
+
targets: target dicts, which each must contain a "value" key containing the
|
|
219
|
+
regression label, along with a "valid" key containing a flag indicating
|
|
220
|
+
whether each example is valid for this task.
|
|
212
221
|
|
|
213
222
|
Returns:
|
|
214
|
-
|
|
223
|
+
the model outputs. The output is a B tensor so that it is split up into a
|
|
224
|
+
scalar for each example.
|
|
215
225
|
"""
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
226
|
+
if not isinstance(intermediates, FeatureVector):
|
|
227
|
+
raise ValueError("the input to RegressionHead must be a FeatureVector")
|
|
228
|
+
if intermediates.feature_vector.shape[1] != 1:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"the input to RegressionHead must have channel dimension size 1, but got shape {intermediates.feature_vector.shape}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
logits = intermediates.feature_vector[:, 0]
|
|
220
234
|
|
|
221
235
|
if self.use_sigmoid:
|
|
222
236
|
outputs = torch.nn.functional.sigmoid(logits)
|
|
@@ -232,9 +246,12 @@ class RegressionHead(torch.nn.Module):
|
|
|
232
246
|
elif self.loss_mode == "l1":
|
|
233
247
|
losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
|
|
234
248
|
else:
|
|
235
|
-
|
|
249
|
+
raise ValueError(f"unknown loss mode {self.loss_mode}")
|
|
236
250
|
|
|
237
|
-
return
|
|
251
|
+
return ModelOutput(
|
|
252
|
+
outputs=outputs,
|
|
253
|
+
loss_dict=losses,
|
|
254
|
+
)
|
|
238
255
|
|
|
239
256
|
|
|
240
257
|
class RegressionMetricWrapper(Metric):
|
|
@@ -8,7 +8,8 @@ import torch
|
|
|
8
8
|
import torchmetrics.classification
|
|
9
9
|
from torchmetrics import Metric, MetricCollection
|
|
10
10
|
|
|
11
|
-
from rslearn.
|
|
11
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
12
13
|
|
|
13
14
|
from .task import BasicTask
|
|
14
15
|
|
|
@@ -108,7 +109,7 @@ class SegmentationTask(BasicTask):
|
|
|
108
109
|
def process_inputs(
|
|
109
110
|
self,
|
|
110
111
|
raw_inputs: dict[str, torch.Tensor],
|
|
111
|
-
metadata:
|
|
112
|
+
metadata: SampleMetadata,
|
|
112
113
|
load_targets: bool = True,
|
|
113
114
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
114
115
|
"""Processes the data into targets.
|
|
@@ -148,17 +149,20 @@ class SegmentationTask(BasicTask):
|
|
|
148
149
|
}
|
|
149
150
|
|
|
150
151
|
def process_output(
|
|
151
|
-
self, raw_output: Any, metadata:
|
|
152
|
-
) -> npt.NDArray[Any]
|
|
152
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
153
|
+
) -> npt.NDArray[Any]:
|
|
153
154
|
"""Processes an output into raster or vector data.
|
|
154
155
|
|
|
155
156
|
Args:
|
|
156
|
-
raw_output: the output from prediction head.
|
|
157
|
+
raw_output: the output from prediction head, which must be a CHW tensor.
|
|
157
158
|
metadata: metadata about the patch being read
|
|
158
159
|
|
|
159
160
|
Returns:
|
|
160
|
-
|
|
161
|
+
CHW numpy array with one channel, containing the predicted class IDs.
|
|
161
162
|
"""
|
|
163
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
|
|
164
|
+
raise ValueError("the output for SegmentationTask must be a CHW tensor")
|
|
165
|
+
|
|
162
166
|
if self.prob_scales is not None:
|
|
163
167
|
raw_output = (
|
|
164
168
|
raw_output
|
|
@@ -166,7 +170,7 @@ class SegmentationTask(BasicTask):
|
|
|
166
170
|
self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
|
|
167
171
|
)[:, None, None]
|
|
168
172
|
)
|
|
169
|
-
classes = raw_output.argmax(dim=0).cpu().numpy()
|
|
173
|
+
classes = raw_output.argmax(dim=0).cpu().numpy()
|
|
170
174
|
return classes[None, :, :]
|
|
171
175
|
|
|
172
176
|
def visualize(
|
|
@@ -258,25 +262,36 @@ class SegmentationTask(BasicTask):
|
|
|
258
262
|
return MetricCollection(metrics)
|
|
259
263
|
|
|
260
264
|
|
|
261
|
-
class SegmentationHead(
|
|
265
|
+
class SegmentationHead(Predictor):
|
|
262
266
|
"""Head for segmentation task."""
|
|
263
267
|
|
|
264
268
|
def forward(
|
|
265
269
|
self,
|
|
266
|
-
|
|
267
|
-
|
|
270
|
+
intermediates: Any,
|
|
271
|
+
context: ModelContext,
|
|
268
272
|
targets: list[dict[str, Any]] | None = None,
|
|
269
|
-
) ->
|
|
273
|
+
) -> ModelOutput:
|
|
270
274
|
"""Compute the segmentation outputs from logits and targets.
|
|
271
275
|
|
|
272
276
|
Args:
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
277
|
+
intermediates: a FeatureMaps with a single feature map containing the
|
|
278
|
+
segmentation logits.
|
|
279
|
+
context: the model context
|
|
280
|
+
targets: list of target dicts, where each target dict must contain a key
|
|
281
|
+
"classes" containing the per-pixel class labels, along with "valid"
|
|
282
|
+
containing a mask indicating where the example is valid.
|
|
276
283
|
|
|
277
284
|
Returns:
|
|
278
285
|
tuple of outputs and loss dict
|
|
279
286
|
"""
|
|
287
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
288
|
+
raise ValueError("input to SegmentationHead must be a FeatureMaps")
|
|
289
|
+
if len(intermediates.feature_maps) != 1:
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"input to SegmentationHead must have one feature map, but got {len(intermediates.feature_maps)}"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
logits = intermediates.feature_maps[0]
|
|
280
295
|
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
281
296
|
|
|
282
297
|
losses = {}
|
|
@@ -295,7 +310,10 @@ class SegmentationHead(torch.nn.Module):
|
|
|
295
310
|
# the summed mask loss be zero.
|
|
296
311
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
297
312
|
|
|
298
|
-
return
|
|
313
|
+
return ModelOutput(
|
|
314
|
+
outputs=outputs,
|
|
315
|
+
loss_dict=losses,
|
|
316
|
+
)
|
|
299
317
|
|
|
300
318
|
|
|
301
319
|
class SegmentationMetric(Metric):
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -7,6 +7,7 @@ import numpy.typing as npt
|
|
|
7
7
|
import torch
|
|
8
8
|
from torchmetrics import MetricCollection
|
|
9
9
|
|
|
10
|
+
from rslearn.train.model_context import SampleMetadata
|
|
10
11
|
from rslearn.utils import Feature
|
|
11
12
|
|
|
12
13
|
|
|
@@ -21,7 +22,7 @@ class Task:
|
|
|
21
22
|
def process_inputs(
|
|
22
23
|
self,
|
|
23
24
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
24
|
-
metadata:
|
|
25
|
+
metadata: SampleMetadata,
|
|
25
26
|
load_targets: bool = True,
|
|
26
27
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
27
28
|
"""Processes the data into targets.
|
|
@@ -38,7 +39,7 @@ class Task:
|
|
|
38
39
|
raise NotImplementedError
|
|
39
40
|
|
|
40
41
|
def process_output(
|
|
41
|
-
self, raw_output: Any, metadata:
|
|
42
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
42
43
|
) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
|
|
43
44
|
"""Processes an output into raster or vector data.
|
|
44
45
|
|