rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@ import numpy.typing as npt
|
|
|
6
6
|
import torch
|
|
7
7
|
from torchmetrics import Metric, MetricCollection
|
|
8
8
|
|
|
9
|
+
from rslearn.train.model_context import SampleMetadata
|
|
9
10
|
from rslearn.utils import Feature
|
|
10
11
|
|
|
11
12
|
from .task import Task
|
|
@@ -30,7 +31,7 @@ class MultiTask(Task):
|
|
|
30
31
|
def process_inputs(
|
|
31
32
|
self,
|
|
32
33
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
33
|
-
metadata:
|
|
34
|
+
metadata: SampleMetadata,
|
|
34
35
|
load_targets: bool = True,
|
|
35
36
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
36
37
|
"""Processes the data into targets.
|
|
@@ -46,14 +47,12 @@ class MultiTask(Task):
|
|
|
46
47
|
"""
|
|
47
48
|
input_dict = {}
|
|
48
49
|
target_dict = {}
|
|
49
|
-
if metadata
|
|
50
|
+
if metadata.dataset_source is None:
|
|
50
51
|
# No multi-dataset, so always compute across all tasks
|
|
51
52
|
task_iter = list(self.tasks.items())
|
|
52
53
|
else:
|
|
53
54
|
# Multi-dataset, so only compute for the task in this dataset
|
|
54
|
-
task_iter = [
|
|
55
|
-
(metadata["dataset_source"], self.tasks[metadata["dataset_source"]])
|
|
56
|
-
]
|
|
55
|
+
task_iter = [(metadata.dataset_source, self.tasks[metadata.dataset_source])]
|
|
57
56
|
|
|
58
57
|
for task_name, task in task_iter:
|
|
59
58
|
cur_raw_inputs = {}
|
|
@@ -71,12 +70,13 @@ class MultiTask(Task):
|
|
|
71
70
|
return input_dict, target_dict
|
|
72
71
|
|
|
73
72
|
def process_output(
|
|
74
|
-
self, raw_output: Any, metadata:
|
|
73
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
75
74
|
) -> dict[str, Any]:
|
|
76
75
|
"""Processes an output into raster or vector data.
|
|
77
76
|
|
|
78
77
|
Args:
|
|
79
|
-
raw_output: the output from prediction head.
|
|
78
|
+
raw_output: the output from prediction head. It must be a dict mapping from
|
|
79
|
+
task name to per-task output for this sample.
|
|
80
80
|
metadata: metadata about the patch being read
|
|
81
81
|
|
|
82
82
|
Returns:
|
|
@@ -8,6 +8,8 @@ import torch
|
|
|
8
8
|
import torchmetrics
|
|
9
9
|
from torchmetrics import Metric, MetricCollection
|
|
10
10
|
|
|
11
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
11
13
|
from rslearn.utils.feature import Feature
|
|
12
14
|
|
|
13
15
|
from .task import BasicTask
|
|
@@ -41,7 +43,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
41
43
|
def process_inputs(
|
|
42
44
|
self,
|
|
43
45
|
raw_inputs: dict[str, torch.Tensor],
|
|
44
|
-
metadata:
|
|
46
|
+
metadata: SampleMetadata,
|
|
45
47
|
load_targets: bool = True,
|
|
46
48
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
47
49
|
"""Processes the data into targets.
|
|
@@ -72,20 +74,23 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
72
74
|
}
|
|
73
75
|
|
|
74
76
|
def process_output(
|
|
75
|
-
self, raw_output: Any, metadata:
|
|
77
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
76
78
|
) -> npt.NDArray[Any] | list[Feature]:
|
|
77
79
|
"""Processes an output into raster or vector data.
|
|
78
80
|
|
|
79
81
|
Args:
|
|
80
|
-
raw_output: the output from prediction head.
|
|
82
|
+
raw_output: the output from prediction head, which must be an HW tensor.
|
|
81
83
|
metadata: metadata about the patch being read
|
|
82
84
|
|
|
83
85
|
Returns:
|
|
84
86
|
either raster or vector data.
|
|
85
87
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
88
|
+
if not isinstance(raw_output, torch.Tensor):
|
|
89
|
+
raise ValueError("output for PerPixelRegressionTask must be a tensor")
|
|
90
|
+
if len(raw_output.shape) != 2:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
|
|
93
|
+
)
|
|
89
94
|
return (raw_output / self.scale_factor).cpu().numpy()
|
|
90
95
|
|
|
91
96
|
def visualize(
|
|
@@ -133,7 +138,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
133
138
|
return MetricCollection(metric_dict)
|
|
134
139
|
|
|
135
140
|
|
|
136
|
-
class PerPixelRegressionHead(
|
|
141
|
+
class PerPixelRegressionHead(Predictor):
|
|
137
142
|
"""Head for per-pixel regression task."""
|
|
138
143
|
|
|
139
144
|
def __init__(
|
|
@@ -156,24 +161,38 @@ class PerPixelRegressionHead(torch.nn.Module):
|
|
|
156
161
|
|
|
157
162
|
def forward(
|
|
158
163
|
self,
|
|
159
|
-
|
|
160
|
-
|
|
164
|
+
intermediates: Any,
|
|
165
|
+
context: ModelContext,
|
|
161
166
|
targets: list[dict[str, Any]] | None = None,
|
|
162
|
-
) ->
|
|
167
|
+
) -> ModelOutput:
|
|
163
168
|
"""Compute the regression outputs and loss from logits and targets.
|
|
164
169
|
|
|
165
170
|
Args:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
171
|
+
intermediates: output from previous component, which must be a FeatureMaps
|
|
172
|
+
with one feature map corresponding to the logits. The channel dimension
|
|
173
|
+
size must be 1.
|
|
174
|
+
context: the model context.
|
|
175
|
+
targets: must contain values key that stores the regression labels, and
|
|
176
|
+
valid key containing mask image indicating where the labels are valid.
|
|
169
177
|
|
|
170
178
|
Returns:
|
|
171
|
-
tuple of outputs and loss dict
|
|
179
|
+
tuple of outputs and loss dict. The output is a BHW tensor so that the
|
|
180
|
+
per-sample output is an HW tensor.
|
|
172
181
|
"""
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
182
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"the input to PerPixelRegressionHead must be a FeatureMaps"
|
|
185
|
+
)
|
|
186
|
+
if len(intermediates.feature_maps) != 1:
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"the input to PerPixelRegressionHead must have one feature map"
|
|
189
|
+
)
|
|
190
|
+
if intermediates.feature_maps[0].shape[1] != 1:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
f"the input to PerPixelRegressionHead must have channel dimension size 1, but got {intermediates.feature_maps[0].shape}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
logits = intermediates.feature_maps[0][:, 0, :, :]
|
|
177
196
|
|
|
178
197
|
if self.use_sigmoid:
|
|
179
198
|
outputs = torch.nn.functional.sigmoid(logits)
|
|
@@ -200,7 +219,10 @@ class PerPixelRegressionHead(torch.nn.Module):
|
|
|
200
219
|
else:
|
|
201
220
|
losses["regress"] = (scores * mask).sum() / mask_total
|
|
202
221
|
|
|
203
|
-
return
|
|
222
|
+
return ModelOutput(
|
|
223
|
+
outputs=outputs,
|
|
224
|
+
loss_dict=losses,
|
|
225
|
+
)
|
|
204
226
|
|
|
205
227
|
|
|
206
228
|
class PerPixelRegressionMetricWrapper(Metric):
|
|
@@ -10,6 +10,8 @@ import torchmetrics
|
|
|
10
10
|
from PIL import Image, ImageDraw
|
|
11
11
|
from torchmetrics import Metric, MetricCollection
|
|
12
12
|
|
|
13
|
+
from rslearn.models.component import FeatureVector, Predictor
|
|
14
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
13
15
|
from rslearn.utils.feature import Feature
|
|
14
16
|
from rslearn.utils.geometry import STGeometry
|
|
15
17
|
|
|
@@ -62,7 +64,7 @@ class RegressionTask(BasicTask):
|
|
|
62
64
|
def process_inputs(
|
|
63
65
|
self,
|
|
64
66
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
65
|
-
metadata:
|
|
67
|
+
metadata: SampleMetadata,
|
|
66
68
|
load_targets: bool = True,
|
|
67
69
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
68
70
|
"""Processes the data into targets.
|
|
@@ -103,22 +105,26 @@ class RegressionTask(BasicTask):
|
|
|
103
105
|
}
|
|
104
106
|
|
|
105
107
|
def process_output(
|
|
106
|
-
self, raw_output: Any, metadata:
|
|
107
|
-
) ->
|
|
108
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
109
|
+
) -> list[Feature]:
|
|
108
110
|
"""Processes an output into raster or vector data.
|
|
109
111
|
|
|
110
112
|
Args:
|
|
111
|
-
raw_output: the output from prediction head.
|
|
113
|
+
raw_output: the output from prediction head, which must be a scalar tensor.
|
|
112
114
|
metadata: metadata about the patch being read
|
|
113
115
|
|
|
114
116
|
Returns:
|
|
115
|
-
|
|
117
|
+
a list with a single Feature corresponding to the patch extent and with a
|
|
118
|
+
property containing the predicted value.
|
|
116
119
|
"""
|
|
120
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 0:
|
|
121
|
+
raise ValueError("output for RegressionTask must be a scalar Tensor")
|
|
122
|
+
|
|
117
123
|
output = raw_output.item() / self.scale_factor
|
|
118
124
|
feature = Feature(
|
|
119
125
|
STGeometry(
|
|
120
|
-
metadata
|
|
121
|
-
shapely.Point(metadata[
|
|
126
|
+
metadata.projection,
|
|
127
|
+
shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
|
|
122
128
|
None,
|
|
123
129
|
),
|
|
124
130
|
{
|
|
@@ -180,7 +186,7 @@ class RegressionTask(BasicTask):
|
|
|
180
186
|
return MetricCollection(metric_dict)
|
|
181
187
|
|
|
182
188
|
|
|
183
|
-
class RegressionHead(
|
|
189
|
+
class RegressionHead(Predictor):
|
|
184
190
|
"""Head for regression task."""
|
|
185
191
|
|
|
186
192
|
def __init__(
|
|
@@ -199,24 +205,32 @@ class RegressionHead(torch.nn.Module):
|
|
|
199
205
|
|
|
200
206
|
def forward(
|
|
201
207
|
self,
|
|
202
|
-
|
|
203
|
-
|
|
208
|
+
intermediates: Any,
|
|
209
|
+
context: ModelContext,
|
|
204
210
|
targets: list[dict[str, Any]] | None = None,
|
|
205
|
-
) ->
|
|
211
|
+
) -> ModelOutput:
|
|
206
212
|
"""Compute the regression outputs and loss from logits and targets.
|
|
207
213
|
|
|
208
214
|
Args:
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
215
|
+
intermediates: output from previous model component, which must be a
|
|
216
|
+
FeatureVector with channel dimension size 1 (Bx1).
|
|
217
|
+
context: the model context.
|
|
218
|
+
targets: target dicts, which each must contain a "value" key containing the
|
|
219
|
+
regression label, along with a "valid" key containing a flag indicating
|
|
220
|
+
whether each example is valid for this task.
|
|
212
221
|
|
|
213
222
|
Returns:
|
|
214
|
-
|
|
223
|
+
the model outputs. The output is a B tensor so that it is split up into a
|
|
224
|
+
scalar for each example.
|
|
215
225
|
"""
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
226
|
+
if not isinstance(intermediates, FeatureVector):
|
|
227
|
+
raise ValueError("the input to RegressionHead must be a FeatureVector")
|
|
228
|
+
if intermediates.feature_vector.shape[1] != 1:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"the input to RegressionHead must have channel dimension size 1, but got shape {intermediates.feature_vector.shape}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
logits = intermediates.feature_vector[:, 0]
|
|
220
234
|
|
|
221
235
|
if self.use_sigmoid:
|
|
222
236
|
outputs = torch.nn.functional.sigmoid(logits)
|
|
@@ -232,9 +246,12 @@ class RegressionHead(torch.nn.Module):
|
|
|
232
246
|
elif self.loss_mode == "l1":
|
|
233
247
|
losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
|
|
234
248
|
else:
|
|
235
|
-
|
|
249
|
+
raise ValueError(f"unknown loss mode {self.loss_mode}")
|
|
236
250
|
|
|
237
|
-
return
|
|
251
|
+
return ModelOutput(
|
|
252
|
+
outputs=outputs,
|
|
253
|
+
loss_dict=losses,
|
|
254
|
+
)
|
|
238
255
|
|
|
239
256
|
|
|
240
257
|
class RegressionMetricWrapper(Metric):
|
|
@@ -8,7 +8,8 @@ import torch
|
|
|
8
8
|
import torchmetrics.classification
|
|
9
9
|
from torchmetrics import Metric, MetricCollection
|
|
10
10
|
|
|
11
|
-
from rslearn.
|
|
11
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
12
13
|
|
|
13
14
|
from .task import BasicTask
|
|
14
15
|
|
|
@@ -108,7 +109,7 @@ class SegmentationTask(BasicTask):
|
|
|
108
109
|
def process_inputs(
|
|
109
110
|
self,
|
|
110
111
|
raw_inputs: dict[str, torch.Tensor],
|
|
111
|
-
metadata:
|
|
112
|
+
metadata: SampleMetadata,
|
|
112
113
|
load_targets: bool = True,
|
|
113
114
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
114
115
|
"""Processes the data into targets.
|
|
@@ -148,17 +149,20 @@ class SegmentationTask(BasicTask):
|
|
|
148
149
|
}
|
|
149
150
|
|
|
150
151
|
def process_output(
|
|
151
|
-
self, raw_output: Any, metadata:
|
|
152
|
-
) -> npt.NDArray[Any]
|
|
152
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
153
|
+
) -> npt.NDArray[Any]:
|
|
153
154
|
"""Processes an output into raster or vector data.
|
|
154
155
|
|
|
155
156
|
Args:
|
|
156
|
-
raw_output: the output from prediction head.
|
|
157
|
+
raw_output: the output from prediction head, which must be a CHW tensor.
|
|
157
158
|
metadata: metadata about the patch being read
|
|
158
159
|
|
|
159
160
|
Returns:
|
|
160
|
-
|
|
161
|
+
CHW numpy array with one channel, containing the predicted class IDs.
|
|
161
162
|
"""
|
|
163
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
|
|
164
|
+
raise ValueError("the output for SegmentationTask must be a CHW tensor")
|
|
165
|
+
|
|
162
166
|
if self.prob_scales is not None:
|
|
163
167
|
raw_output = (
|
|
164
168
|
raw_output
|
|
@@ -166,7 +170,7 @@ class SegmentationTask(BasicTask):
|
|
|
166
170
|
self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
|
|
167
171
|
)[:, None, None]
|
|
168
172
|
)
|
|
169
|
-
classes = raw_output.argmax(dim=0).cpu().numpy()
|
|
173
|
+
classes = raw_output.argmax(dim=0).cpu().numpy()
|
|
170
174
|
return classes[None, :, :]
|
|
171
175
|
|
|
172
176
|
def visualize(
|
|
@@ -258,25 +262,36 @@ class SegmentationTask(BasicTask):
|
|
|
258
262
|
return MetricCollection(metrics)
|
|
259
263
|
|
|
260
264
|
|
|
261
|
-
class SegmentationHead(
|
|
265
|
+
class SegmentationHead(Predictor):
|
|
262
266
|
"""Head for segmentation task."""
|
|
263
267
|
|
|
264
268
|
def forward(
|
|
265
269
|
self,
|
|
266
|
-
|
|
267
|
-
|
|
270
|
+
intermediates: Any,
|
|
271
|
+
context: ModelContext,
|
|
268
272
|
targets: list[dict[str, Any]] | None = None,
|
|
269
|
-
) ->
|
|
273
|
+
) -> ModelOutput:
|
|
270
274
|
"""Compute the segmentation outputs from logits and targets.
|
|
271
275
|
|
|
272
276
|
Args:
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
277
|
+
intermediates: a FeatureMaps with a single feature map containing the
|
|
278
|
+
segmentation logits.
|
|
279
|
+
context: the model context
|
|
280
|
+
targets: list of target dicts, where each target dict must contain a key
|
|
281
|
+
"classes" containing the per-pixel class labels, along with "valid"
|
|
282
|
+
containing a mask indicating where the example is valid.
|
|
276
283
|
|
|
277
284
|
Returns:
|
|
278
285
|
tuple of outputs and loss dict
|
|
279
286
|
"""
|
|
287
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
288
|
+
raise ValueError("input to SegmentationHead must be a FeatureMaps")
|
|
289
|
+
if len(intermediates.feature_maps) != 1:
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"input to SegmentationHead must have one feature map, but got {len(intermediates.feature_maps)}"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
logits = intermediates.feature_maps[0]
|
|
280
295
|
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
281
296
|
|
|
282
297
|
losses = {}
|
|
@@ -295,7 +310,10 @@ class SegmentationHead(torch.nn.Module):
|
|
|
295
310
|
# the summed mask loss be zero.
|
|
296
311
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
297
312
|
|
|
298
|
-
return
|
|
313
|
+
return ModelOutput(
|
|
314
|
+
outputs=outputs,
|
|
315
|
+
loss_dict=losses,
|
|
316
|
+
)
|
|
299
317
|
|
|
300
318
|
|
|
301
319
|
class SegmentationMetric(Metric):
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -7,6 +7,7 @@ import numpy.typing as npt
|
|
|
7
7
|
import torch
|
|
8
8
|
from torchmetrics import MetricCollection
|
|
9
9
|
|
|
10
|
+
from rslearn.train.model_context import SampleMetadata
|
|
10
11
|
from rslearn.utils import Feature
|
|
11
12
|
|
|
12
13
|
|
|
@@ -21,7 +22,7 @@ class Task:
|
|
|
21
22
|
def process_inputs(
|
|
22
23
|
self,
|
|
23
24
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
24
|
-
metadata:
|
|
25
|
+
metadata: SampleMetadata,
|
|
25
26
|
load_targets: bool = True,
|
|
26
27
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
27
28
|
"""Processes the data into targets.
|
|
@@ -38,7 +39,7 @@ class Task:
|
|
|
38
39
|
raise NotImplementedError
|
|
39
40
|
|
|
40
41
|
def process_output(
|
|
41
|
-
self, raw_output: Any, metadata:
|
|
42
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
42
43
|
) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
|
|
43
44
|
"""Processes an output into raster or vector data.
|
|
44
45
|
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Resize transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
from torchvision.transforms import InterpolationMode
|
|
8
|
+
|
|
9
|
+
from .transform import Transform
|
|
10
|
+
|
|
11
|
+
INTERPOLATION_MODES = {
|
|
12
|
+
"nearest": InterpolationMode.NEAREST,
|
|
13
|
+
"nearest_exact": InterpolationMode.NEAREST_EXACT,
|
|
14
|
+
"bilinear": InterpolationMode.BILINEAR,
|
|
15
|
+
"bicubic": InterpolationMode.BICUBIC,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Resize(Transform):
|
|
20
|
+
"""Resizes inputs to a target size."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
target_size: tuple[int, int],
|
|
25
|
+
selectors: list[str] = [],
|
|
26
|
+
interpolation: str = "nearest",
|
|
27
|
+
):
|
|
28
|
+
"""Initialize a resize transform.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
target_size: the (height, width) to resize to.
|
|
32
|
+
selectors: items to transform.
|
|
33
|
+
interpolation: the interpolation mode to use for resizing.
|
|
34
|
+
Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
|
|
35
|
+
"""
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.target_size = target_size
|
|
38
|
+
self.selectors = selectors
|
|
39
|
+
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
40
|
+
|
|
41
|
+
def apply_resize(self, image: torch.Tensor) -> torch.Tensor:
|
|
42
|
+
"""Apply resizing on the specified image.
|
|
43
|
+
|
|
44
|
+
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
45
|
+
back after resizing.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
image: the image to transform.
|
|
49
|
+
"""
|
|
50
|
+
if image.dim() == 2:
|
|
51
|
+
image = image.unsqueeze(0) # (H, W) -> (1, H, W)
|
|
52
|
+
result = torchvision.transforms.functional.resize(
|
|
53
|
+
image, self.target_size, self.interpolation
|
|
54
|
+
)
|
|
55
|
+
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
56
|
+
|
|
57
|
+
return torchvision.transforms.functional.resize(
|
|
58
|
+
image, self.target_size, self.interpolation
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def forward(
|
|
62
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
63
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
64
|
+
"""Apply transform over the inputs and targets.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
input_dict: the input
|
|
68
|
+
target_dict: the target
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
transformed (input_dicts, target_dicts) tuple
|
|
72
|
+
"""
|
|
73
|
+
self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
|
|
74
|
+
return input_dict, target_dict
|
rslearn/utils/geometry.py
CHANGED
|
@@ -116,6 +116,79 @@ class Projection:
|
|
|
116
116
|
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
class ResolutionFactor:
|
|
120
|
+
"""Multiplier for the resolution in a Projection.
|
|
121
|
+
|
|
122
|
+
The multiplier is either an integer x, or the inverse of an integer (1/x).
|
|
123
|
+
|
|
124
|
+
Factors greater than 1 increase the projection_units/pixel resolution, increasing
|
|
125
|
+
the resolution (more pixels per projection unit). Factors less than 1 make it coarser
|
|
126
|
+
(less pixels).
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, numerator: int = 1, denominator: int = 1):
|
|
130
|
+
"""Create a new ResolutionFactor.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
numerator: the numerator of the fraction.
|
|
134
|
+
denominator: the denominator of the fraction. If set, numerator must be 1.
|
|
135
|
+
"""
|
|
136
|
+
if numerator != 1 and denominator != 1:
|
|
137
|
+
raise ValueError("one of numerator or denominator must be 1")
|
|
138
|
+
if not isinstance(numerator, int) or not isinstance(denominator, int):
|
|
139
|
+
raise ValueError("numerator and denominator must be integers")
|
|
140
|
+
if numerator < 1 or denominator < 1:
|
|
141
|
+
raise ValueError("numerator and denominator must be >= 1")
|
|
142
|
+
self.numerator = numerator
|
|
143
|
+
self.denominator = denominator
|
|
144
|
+
|
|
145
|
+
def multiply_projection(self, projection: Projection) -> Projection:
|
|
146
|
+
"""Multiply the projection by this factor."""
|
|
147
|
+
if self.denominator > 1:
|
|
148
|
+
return Projection(
|
|
149
|
+
projection.crs,
|
|
150
|
+
projection.x_resolution * self.denominator,
|
|
151
|
+
projection.y_resolution * self.denominator,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
return Projection(
|
|
155
|
+
projection.crs,
|
|
156
|
+
projection.x_resolution // self.numerator,
|
|
157
|
+
projection.y_resolution // self.numerator,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
|
|
161
|
+
"""Multiply the bounds by this factor.
|
|
162
|
+
|
|
163
|
+
When coarsening, the width and height of the given bounds must be a multiple of
|
|
164
|
+
the denominator.
|
|
165
|
+
"""
|
|
166
|
+
if self.denominator > 1:
|
|
167
|
+
# Verify the width and height are multiples of the denominator.
|
|
168
|
+
# Otherwise the new width and height is not an integer.
|
|
169
|
+
width = bounds[2] - bounds[0]
|
|
170
|
+
height = bounds[3] - bounds[1]
|
|
171
|
+
if width % self.denominator != 0 or height % self.denominator != 0:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
|
|
174
|
+
)
|
|
175
|
+
# TODO: an offset could be introduced by bounds not being a multiple
|
|
176
|
+
# of the denominator -> will need to decide how to handle that.
|
|
177
|
+
return (
|
|
178
|
+
bounds[0] // self.denominator,
|
|
179
|
+
bounds[1] // self.denominator,
|
|
180
|
+
bounds[2] // self.denominator,
|
|
181
|
+
bounds[3] // self.denominator,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
return (
|
|
185
|
+
bounds[0] * self.numerator,
|
|
186
|
+
bounds[1] * self.numerator,
|
|
187
|
+
bounds[2] * self.numerator,
|
|
188
|
+
bounds[3] * self.numerator,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
119
192
|
class STGeometry:
|
|
120
193
|
"""A spatiotemporal geometry.
|
|
121
194
|
|
rslearn/utils/jsonargparse.py
CHANGED
|
@@ -8,6 +8,7 @@ from rasterio.crs import CRS
|
|
|
8
8
|
from upath import UPath
|
|
9
9
|
|
|
10
10
|
from rslearn.config.dataset import LayerConfig
|
|
11
|
+
from rslearn.utils.geometry import ResolutionFactor
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from rslearn.data_sources.data_source import DataSourceContext
|
|
@@ -91,6 +92,68 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
|
|
|
91
92
|
)
|
|
92
93
|
|
|
93
94
|
|
|
95
|
+
def resolution_factor_serializer(v: ResolutionFactor) -> str:
|
|
96
|
+
"""Serialize ResolutionFactor for jsonargparse.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
v: the ResolutionFactor object.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
the ResolutionFactor encoded to string
|
|
103
|
+
"""
|
|
104
|
+
if hasattr(v, "init_args"):
|
|
105
|
+
init_args = v.init_args
|
|
106
|
+
return f"{init_args.numerator}/{init_args.denominator}"
|
|
107
|
+
|
|
108
|
+
return f"{v.numerator}/{v.denominator}"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
|
|
112
|
+
"""Deserialize ResolutionFactor for jsonargparse.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
v: the encoded ResolutionFactor.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
the decoded ResolutionFactor object
|
|
119
|
+
"""
|
|
120
|
+
# Handle already-instantiated ResolutionFactor
|
|
121
|
+
if isinstance(v, ResolutionFactor):
|
|
122
|
+
return v
|
|
123
|
+
|
|
124
|
+
# Handle Namespace from class_path syntax (used during config save/validation)
|
|
125
|
+
if hasattr(v, "init_args"):
|
|
126
|
+
init_args = v.init_args
|
|
127
|
+
return ResolutionFactor(
|
|
128
|
+
numerator=init_args.numerator,
|
|
129
|
+
denominator=init_args.denominator,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Handle dict from class_path syntax in YAML config
|
|
133
|
+
if isinstance(v, dict) and "init_args" in v:
|
|
134
|
+
init_args = v["init_args"]
|
|
135
|
+
return ResolutionFactor(
|
|
136
|
+
numerator=init_args.get("numerator", 1),
|
|
137
|
+
denominator=init_args.get("denominator", 1),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if isinstance(v, int):
|
|
141
|
+
return ResolutionFactor(numerator=v)
|
|
142
|
+
elif isinstance(v, str):
|
|
143
|
+
parts = v.split("/")
|
|
144
|
+
if len(parts) == 1:
|
|
145
|
+
return ResolutionFactor(numerator=int(parts[0]))
|
|
146
|
+
elif len(parts) == 2:
|
|
147
|
+
return ResolutionFactor(
|
|
148
|
+
numerator=int(parts[0]),
|
|
149
|
+
denominator=int(parts[1]),
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError("expected resolution factor to be of the form x or 1/x")
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError("expected resolution factor to be str or int")
|
|
155
|
+
|
|
156
|
+
|
|
94
157
|
def init_jsonargparse() -> None:
|
|
95
158
|
"""Initialize custom jsonargparse serializers."""
|
|
96
159
|
global INITIALIZED
|
|
@@ -100,6 +163,9 @@ def init_jsonargparse() -> None:
|
|
|
100
163
|
jsonargparse.typing.register_type(
|
|
101
164
|
datetime, datetime_serializer, datetime_deserializer
|
|
102
165
|
)
|
|
166
|
+
jsonargparse.typing.register_type(
|
|
167
|
+
ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
|
|
168
|
+
)
|
|
103
169
|
|
|
104
170
|
from rslearn.data_sources.data_source import DataSourceContext
|
|
105
171
|
|