rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""rslearn PredictionWriter implementation."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Any
|
|
@@ -12,11 +13,15 @@ from lightning.pytorch.callbacks import BasePredictionWriter
|
|
|
12
13
|
from upath import UPath
|
|
13
14
|
|
|
14
15
|
from rslearn.config import (
|
|
16
|
+
DatasetConfig,
|
|
15
17
|
LayerConfig,
|
|
16
18
|
LayerType,
|
|
19
|
+
StorageConfig,
|
|
17
20
|
)
|
|
18
|
-
from rslearn.dataset import
|
|
21
|
+
from rslearn.dataset import Window
|
|
22
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
19
23
|
from rslearn.log_utils import get_logger
|
|
24
|
+
from rslearn.train.model_context import SampleMetadata
|
|
20
25
|
from rslearn.utils.array import copy_spatial_array
|
|
21
26
|
from rslearn.utils.feature import Feature
|
|
22
27
|
from rslearn.utils.geometry import PixelBounds
|
|
@@ -27,6 +32,7 @@ from rslearn.utils.raster_format import (
|
|
|
27
32
|
from rslearn.utils.vector_format import VectorFormat
|
|
28
33
|
|
|
29
34
|
from .lightning_module import RslearnLightningModule
|
|
35
|
+
from .model_context import ModelOutput
|
|
30
36
|
from .tasks.task import Task
|
|
31
37
|
|
|
32
38
|
logger = get_logger(__name__)
|
|
@@ -43,12 +49,18 @@ class PendingPatchOutput:
|
|
|
43
49
|
class PatchPredictionMerger:
|
|
44
50
|
"""Base class for merging predictions from multiple patches."""
|
|
45
51
|
|
|
46
|
-
def merge(
|
|
52
|
+
def merge(
|
|
53
|
+
self,
|
|
54
|
+
window: Window,
|
|
55
|
+
outputs: Sequence[PendingPatchOutput],
|
|
56
|
+
layer_config: LayerConfig,
|
|
57
|
+
) -> Any:
|
|
47
58
|
"""Merge the outputs.
|
|
48
59
|
|
|
49
60
|
Args:
|
|
50
61
|
window: the window we are merging the outputs for.
|
|
51
62
|
outputs: the outputs to process.
|
|
63
|
+
layer_config: the output layer configuration.
|
|
52
64
|
|
|
53
65
|
Returns:
|
|
54
66
|
the merged outputs.
|
|
@@ -60,7 +72,10 @@ class VectorMerger(PatchPredictionMerger):
|
|
|
60
72
|
"""Merger for vector data that simply concatenates the features."""
|
|
61
73
|
|
|
62
74
|
def merge(
|
|
63
|
-
self,
|
|
75
|
+
self,
|
|
76
|
+
window: Window,
|
|
77
|
+
outputs: Sequence[PendingPatchOutput],
|
|
78
|
+
layer_config: LayerConfig,
|
|
64
79
|
) -> list[Feature]:
|
|
65
80
|
"""Concatenate the vector features."""
|
|
66
81
|
return [feat for output in outputs for feat in output.output]
|
|
@@ -83,18 +98,20 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
83
98
|
self.downsample_factor = downsample_factor
|
|
84
99
|
|
|
85
100
|
def merge(
|
|
86
|
-
self,
|
|
101
|
+
self,
|
|
102
|
+
window: Window,
|
|
103
|
+
outputs: Sequence[PendingPatchOutput],
|
|
104
|
+
layer_config: LayerConfig,
|
|
87
105
|
) -> npt.NDArray:
|
|
88
106
|
"""Merge the raster outputs."""
|
|
89
107
|
num_channels = outputs[0].output.shape[0]
|
|
90
|
-
dtype = outputs[0].output.dtype
|
|
91
108
|
merged_image = np.zeros(
|
|
92
109
|
(
|
|
93
110
|
num_channels,
|
|
94
111
|
(window.bounds[3] - window.bounds[1]) // self.downsample_factor,
|
|
95
112
|
(window.bounds[2] - window.bounds[0]) // self.downsample_factor,
|
|
96
113
|
),
|
|
97
|
-
dtype=dtype,
|
|
114
|
+
dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
|
|
98
115
|
)
|
|
99
116
|
|
|
100
117
|
# Ensure the outputs are sorted by height then width.
|
|
@@ -148,6 +165,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
148
165
|
merger: PatchPredictionMerger | None = None,
|
|
149
166
|
output_path: str | Path | None = None,
|
|
150
167
|
layer_config: LayerConfig | None = None,
|
|
168
|
+
storage_config: StorageConfig | None = None,
|
|
151
169
|
):
|
|
152
170
|
"""Create a new RslearnWriter.
|
|
153
171
|
|
|
@@ -163,28 +181,24 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
163
181
|
layer_config: optional layer configuration. If provided, this config will be
|
|
164
182
|
used instead of reading from the dataset config, allowing usage without
|
|
165
183
|
requiring dataset config at the output path.
|
|
184
|
+
storage_config: optional storage configuration, needed similar to layer_config
|
|
185
|
+
if there is no dataset config.
|
|
166
186
|
"""
|
|
167
187
|
super().__init__(write_interval="batch")
|
|
168
188
|
self.output_layer = output_layer
|
|
169
189
|
self.selector = selector or []
|
|
170
|
-
|
|
171
|
-
|
|
190
|
+
ds_upath = UPath(path, **path_options or {})
|
|
191
|
+
output_upath = (
|
|
172
192
|
UPath(output_path, **path_options or {})
|
|
173
193
|
if output_path is not None
|
|
174
|
-
else
|
|
194
|
+
else ds_upath
|
|
175
195
|
)
|
|
176
196
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
dataset = Dataset(self.path)
|
|
183
|
-
if self.output_layer not in dataset.layers:
|
|
184
|
-
raise KeyError(
|
|
185
|
-
f"Output layer '{self.output_layer}' not found in dataset layers."
|
|
186
|
-
)
|
|
187
|
-
self.layer_config = dataset.layers[self.output_layer]
|
|
197
|
+
self.layer_config, self.dataset_storage = (
|
|
198
|
+
self._get_layer_config_and_dataset_storage(
|
|
199
|
+
ds_upath, output_upath, layer_config, storage_config
|
|
200
|
+
)
|
|
201
|
+
)
|
|
188
202
|
|
|
189
203
|
self.format: RasterFormat | VectorFormat
|
|
190
204
|
if self.layer_config.type == LayerType.RASTER:
|
|
@@ -207,11 +221,73 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
207
221
|
# patches of each window need to be reconstituted.
|
|
208
222
|
self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
|
|
209
223
|
|
|
224
|
+
def _get_layer_config_and_dataset_storage(
|
|
225
|
+
self,
|
|
226
|
+
ds_upath: UPath,
|
|
227
|
+
output_upath: UPath,
|
|
228
|
+
layer_config: LayerConfig | None,
|
|
229
|
+
storage_config: StorageConfig | None,
|
|
230
|
+
) -> tuple[LayerConfig, WindowStorage]:
|
|
231
|
+
"""Get the layer config and dataset storage to use.
|
|
232
|
+
|
|
233
|
+
This is a helper function for the init method.
|
|
234
|
+
|
|
235
|
+
If layer_config is set, we use that. If storage_config is set, we use it to
|
|
236
|
+
instantiate a WindowStorage using the output_upath.
|
|
237
|
+
|
|
238
|
+
If one of them is not set, we load the config from the ds_upath. Otherwise, we
|
|
239
|
+
avoid reading the dataset config; this way, RslearnWriter can be used with
|
|
240
|
+
output directories that do not contain the dataset config, as long as
|
|
241
|
+
layer_config and storage_config are both provided.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
ds_upath: the dataset path, where a dataset config can be loaded from if
|
|
245
|
+
layer_config or storage_config is not provided.
|
|
246
|
+
output_upath: the output directory, which could be different from the
|
|
247
|
+
dataset path.
|
|
248
|
+
layer_config: optional LayerConfig to provide.
|
|
249
|
+
storage_config: optional StorageConfig to provide.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
a tuple (layer_config, dataset_storage)
|
|
253
|
+
"""
|
|
254
|
+
dataset_storage: WindowStorage | None = None
|
|
255
|
+
|
|
256
|
+
# Instantiate the WindowStorage from the storage_config if provided.
|
|
257
|
+
if storage_config:
|
|
258
|
+
dataset_storage = (
|
|
259
|
+
storage_config.instantiate_window_storage_factory().get_storage(
|
|
260
|
+
output_upath
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if not layer_config or not dataset_storage:
|
|
265
|
+
# Need to load dataset config since one of LayerConfig/StorageConfig is missing.
|
|
266
|
+
# We use DatasetConfig.model_validate instead of initializing the Dataset
|
|
267
|
+
# because we want to get a WindowStorage that has the dataset path set to
|
|
268
|
+
# output_upath instead of ds_upath.
|
|
269
|
+
with (ds_upath / "config.json").open() as f:
|
|
270
|
+
dataset_config = DatasetConfig.model_validate(json.load(f))
|
|
271
|
+
|
|
272
|
+
if not layer_config:
|
|
273
|
+
if self.output_layer not in dataset_config.layers:
|
|
274
|
+
raise KeyError(
|
|
275
|
+
f"Output layer '{self.output_layer}' not found in dataset layers."
|
|
276
|
+
)
|
|
277
|
+
layer_config = dataset_config.layers[self.output_layer]
|
|
278
|
+
|
|
279
|
+
if not dataset_storage:
|
|
280
|
+
dataset_storage = dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
281
|
+
output_upath
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return (layer_config, dataset_storage)
|
|
285
|
+
|
|
210
286
|
def write_on_batch_end(
|
|
211
287
|
self,
|
|
212
288
|
trainer: Trainer,
|
|
213
289
|
pl_module: LightningModule,
|
|
214
|
-
prediction:
|
|
290
|
+
prediction: ModelOutput,
|
|
215
291
|
batch_indices: Sequence[int] | None,
|
|
216
292
|
batch: tuple[list, list, list],
|
|
217
293
|
batch_idx: int,
|
|
@@ -232,13 +308,13 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
232
308
|
assert isinstance(pl_module, RslearnLightningModule)
|
|
233
309
|
task = pl_module.task
|
|
234
310
|
_, _, metadatas = batch
|
|
235
|
-
self.process_output_batch(task, prediction
|
|
311
|
+
self.process_output_batch(task, prediction.outputs, metadatas)
|
|
236
312
|
|
|
237
313
|
def process_output_batch(
|
|
238
314
|
self,
|
|
239
315
|
task: Task,
|
|
240
|
-
prediction:
|
|
241
|
-
metadatas:
|
|
316
|
+
prediction: Iterable[Any],
|
|
317
|
+
metadatas: Iterable[SampleMetadata],
|
|
242
318
|
) -> None:
|
|
243
319
|
"""Write a prediction batch with simplified API.
|
|
244
320
|
|
|
@@ -263,25 +339,19 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
263
339
|
for k in self.selector:
|
|
264
340
|
output = output[k]
|
|
265
341
|
|
|
266
|
-
# Use custom output_path if provided, otherwise use dataset path
|
|
267
|
-
window_base_path = (
|
|
268
|
-
self.output_path if self.output_path is not None else self.path
|
|
269
|
-
)
|
|
270
342
|
window = Window(
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
bounds=metadata["window_bounds"],
|
|
278
|
-
time_range=metadata["time_range"],
|
|
343
|
+
storage=self.dataset_storage,
|
|
344
|
+
group=metadata.window_group,
|
|
345
|
+
name=metadata.window_name,
|
|
346
|
+
projection=metadata.projection,
|
|
347
|
+
bounds=metadata.window_bounds,
|
|
348
|
+
time_range=metadata.time_range,
|
|
279
349
|
)
|
|
280
350
|
self.process_output(
|
|
281
351
|
window,
|
|
282
|
-
metadata
|
|
283
|
-
metadata
|
|
284
|
-
metadata
|
|
352
|
+
metadata.patch_idx,
|
|
353
|
+
metadata.num_patches_in_window,
|
|
354
|
+
metadata.patch_bounds,
|
|
285
355
|
output,
|
|
286
356
|
)
|
|
287
357
|
|
|
@@ -320,7 +390,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
320
390
|
|
|
321
391
|
# Merge outputs from overlapped patches if merger is set.
|
|
322
392
|
logger.debug(f"Merging and writing for window {window.name}")
|
|
323
|
-
merged_output = self.merger.merge(window, pending_output)
|
|
393
|
+
merged_output = self.merger.merge(window, pending_output, self.layer_config)
|
|
324
394
|
|
|
325
395
|
if self.layer_config.type == LayerType.RASTER:
|
|
326
396
|
raster_dir = window.get_raster_dir(
|
rslearn/train/scheduler.py
CHANGED
|
@@ -8,6 +8,7 @@ from torch.optim.lr_scheduler import (
|
|
|
8
8
|
CosineAnnealingLR,
|
|
9
9
|
CosineAnnealingWarmRestarts,
|
|
10
10
|
LRScheduler,
|
|
11
|
+
MultiStepLR,
|
|
11
12
|
ReduceLROnPlateau,
|
|
12
13
|
)
|
|
13
14
|
|
|
@@ -50,6 +51,20 @@ class PlateauScheduler(SchedulerFactory):
|
|
|
50
51
|
return ReduceLROnPlateau(optimizer, **self.get_kwargs())
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
@dataclass
|
|
55
|
+
class MultiStepScheduler(SchedulerFactory):
|
|
56
|
+
"""Step learning rate scheduler."""
|
|
57
|
+
|
|
58
|
+
milestones: list[int]
|
|
59
|
+
gamma: float | None = None
|
|
60
|
+
last_epoch: int | None = None
|
|
61
|
+
|
|
62
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
63
|
+
"""Build the ReduceLROnPlateau scheduler."""
|
|
64
|
+
super().build(optimizer)
|
|
65
|
+
return MultiStepLR(optimizer, **self.get_kwargs())
|
|
66
|
+
|
|
67
|
+
|
|
53
68
|
@dataclass
|
|
54
69
|
class CosineAnnealingScheduler(SchedulerFactory):
|
|
55
70
|
"""Cosine annealing learning rate scheduler."""
|
|
@@ -15,6 +15,8 @@ from torchmetrics.classification import (
|
|
|
15
15
|
MulticlassRecall,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
+
from rslearn.models.component import FeatureVector, Predictor
|
|
19
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
18
20
|
from rslearn.utils import Feature, STGeometry
|
|
19
21
|
|
|
20
22
|
from .task import BasicTask
|
|
@@ -98,7 +100,7 @@ class ClassificationTask(BasicTask):
|
|
|
98
100
|
def process_inputs(
|
|
99
101
|
self,
|
|
100
102
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
101
|
-
metadata:
|
|
103
|
+
metadata: SampleMetadata,
|
|
102
104
|
load_targets: bool = True,
|
|
103
105
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
104
106
|
"""Processes the data into targets.
|
|
@@ -154,17 +156,25 @@ class ClassificationTask(BasicTask):
|
|
|
154
156
|
}
|
|
155
157
|
|
|
156
158
|
def process_output(
|
|
157
|
-
self, raw_output: Any, metadata:
|
|
158
|
-
) ->
|
|
159
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
160
|
+
) -> list[Feature]:
|
|
159
161
|
"""Processes an output into raster or vector data.
|
|
160
162
|
|
|
161
163
|
Args:
|
|
162
|
-
raw_output: the output from prediction head
|
|
164
|
+
raw_output: the output from prediction head, which must be a tensor
|
|
165
|
+
containing output probabilities (one dimension).
|
|
163
166
|
metadata: metadata about the patch being read
|
|
164
167
|
|
|
165
168
|
Returns:
|
|
166
|
-
|
|
169
|
+
a list with one Feature corresponding to the input patch extent with a
|
|
170
|
+
property name containing the predicted class. It will have another
|
|
171
|
+
property containing the probabilities if prob_property was set.
|
|
167
172
|
"""
|
|
173
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 1:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"expected output for ClassificationTask to be a Tensor with one dimension"
|
|
176
|
+
)
|
|
177
|
+
|
|
168
178
|
probs = raw_output.cpu().numpy()
|
|
169
179
|
if len(self.classes) == 2 and self.positive_class_threshold != 0.5:
|
|
170
180
|
positive_class_prob = probs[self.positive_class_id]
|
|
@@ -184,8 +194,8 @@ class ClassificationTask(BasicTask):
|
|
|
184
194
|
|
|
185
195
|
feature = Feature(
|
|
186
196
|
STGeometry(
|
|
187
|
-
metadata
|
|
188
|
-
shapely.Point(metadata[
|
|
197
|
+
metadata.projection,
|
|
198
|
+
shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
|
|
189
199
|
None,
|
|
190
200
|
),
|
|
191
201
|
{
|
|
@@ -265,25 +275,31 @@ class ClassificationTask(BasicTask):
|
|
|
265
275
|
return MetricCollection(metrics)
|
|
266
276
|
|
|
267
277
|
|
|
268
|
-
class ClassificationHead(
|
|
278
|
+
class ClassificationHead(Predictor):
|
|
269
279
|
"""Head for classification task."""
|
|
270
280
|
|
|
271
281
|
def forward(
|
|
272
282
|
self,
|
|
273
|
-
|
|
274
|
-
|
|
283
|
+
intermediates: Any,
|
|
284
|
+
context: ModelContext,
|
|
275
285
|
targets: list[dict[str, Any]] | None = None,
|
|
276
|
-
) ->
|
|
286
|
+
) -> ModelOutput:
|
|
277
287
|
"""Compute the classification outputs and loss from logits and targets.
|
|
278
288
|
|
|
279
289
|
Args:
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
290
|
+
intermediates: output from the previous model component, it should be a
|
|
291
|
+
FeatureVector with a tensor that is (BatchSize, NumClasses) in shape.
|
|
292
|
+
context: the model context.
|
|
293
|
+
targets: must contain "class" key that stores the class label, along with
|
|
294
|
+
"valid" key indicating whether the label is valid for each example.
|
|
283
295
|
|
|
284
296
|
Returns:
|
|
285
297
|
tuple of outputs and loss dict
|
|
286
298
|
"""
|
|
299
|
+
if not isinstance(intermediates, FeatureVector):
|
|
300
|
+
raise ValueError("the input to ClassificationHead must be a FeatureVector")
|
|
301
|
+
|
|
302
|
+
logits = intermediates.feature_vector
|
|
287
303
|
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
288
304
|
|
|
289
305
|
losses = {}
|
|
@@ -298,7 +314,10 @@ class ClassificationHead(torch.nn.Module):
|
|
|
298
314
|
)
|
|
299
315
|
losses["cls"] = torch.mean(loss)
|
|
300
316
|
|
|
301
|
-
return
|
|
317
|
+
return ModelOutput(
|
|
318
|
+
outputs=outputs,
|
|
319
|
+
loss_dict=losses,
|
|
320
|
+
)
|
|
302
321
|
|
|
303
322
|
|
|
304
323
|
class ClassificationMetric(Metric):
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -12,6 +12,7 @@ import torchmetrics.classification
|
|
|
12
12
|
import torchvision
|
|
13
13
|
from torchmetrics import Metric, MetricCollection
|
|
14
14
|
|
|
15
|
+
from rslearn.train.model_context import SampleMetadata
|
|
15
16
|
from rslearn.utils import Feature, STGeometry
|
|
16
17
|
|
|
17
18
|
from .task import BasicTask
|
|
@@ -127,7 +128,7 @@ class DetectionTask(BasicTask):
|
|
|
127
128
|
def process_inputs(
|
|
128
129
|
self,
|
|
129
130
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
130
|
-
metadata:
|
|
131
|
+
metadata: SampleMetadata,
|
|
131
132
|
load_targets: bool = True,
|
|
132
133
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
133
134
|
"""Processes the data into targets.
|
|
@@ -144,6 +145,8 @@ class DetectionTask(BasicTask):
|
|
|
144
145
|
if not load_targets:
|
|
145
146
|
return {}, {}
|
|
146
147
|
|
|
148
|
+
bounds = metadata.patch_bounds
|
|
149
|
+
|
|
147
150
|
boxes = []
|
|
148
151
|
class_labels = []
|
|
149
152
|
valid = 1
|
|
@@ -186,39 +189,33 @@ class DetectionTask(BasicTask):
|
|
|
186
189
|
else:
|
|
187
190
|
box = [int(val) for val in shp.bounds]
|
|
188
191
|
|
|
189
|
-
if box[0] >=
|
|
192
|
+
if box[0] >= bounds[2] or box[2] <= bounds[0]:
|
|
190
193
|
continue
|
|
191
|
-
if box[1] >=
|
|
194
|
+
if box[1] >= bounds[3] or box[3] <= bounds[1]:
|
|
192
195
|
continue
|
|
193
196
|
|
|
194
197
|
if self.exclude_by_center:
|
|
195
198
|
center_col = (box[0] + box[2]) // 2
|
|
196
199
|
center_row = (box[1] + box[3]) // 2
|
|
197
|
-
if
|
|
198
|
-
center_col <= metadata["bounds"][0]
|
|
199
|
-
or center_col >= metadata["bounds"][2]
|
|
200
|
-
):
|
|
200
|
+
if center_col <= bounds[0] or center_col >= bounds[2]:
|
|
201
201
|
continue
|
|
202
|
-
if
|
|
203
|
-
center_row <= metadata["bounds"][1]
|
|
204
|
-
or center_row >= metadata["bounds"][3]
|
|
205
|
-
):
|
|
202
|
+
if center_row <= bounds[1] or center_row >= bounds[3]:
|
|
206
203
|
continue
|
|
207
204
|
|
|
208
205
|
if self.clip_boxes:
|
|
209
206
|
box = [
|
|
210
|
-
np.clip(box[0],
|
|
211
|
-
np.clip(box[1],
|
|
212
|
-
np.clip(box[2],
|
|
213
|
-
np.clip(box[3],
|
|
207
|
+
np.clip(box[0], bounds[0], bounds[2]),
|
|
208
|
+
np.clip(box[1], bounds[1], bounds[3]),
|
|
209
|
+
np.clip(box[2], bounds[0], bounds[2]),
|
|
210
|
+
np.clip(box[3], bounds[1], bounds[3]),
|
|
214
211
|
]
|
|
215
212
|
|
|
216
213
|
# Convert to relative coordinates.
|
|
217
214
|
box = [
|
|
218
|
-
box[0] -
|
|
219
|
-
box[1] -
|
|
220
|
-
box[2] -
|
|
221
|
-
box[3] -
|
|
215
|
+
box[0] - bounds[0],
|
|
216
|
+
box[1] - bounds[1],
|
|
217
|
+
box[2] - bounds[0],
|
|
218
|
+
box[3] - bounds[1],
|
|
222
219
|
]
|
|
223
220
|
|
|
224
221
|
boxes.append(box)
|
|
@@ -238,16 +235,12 @@ class DetectionTask(BasicTask):
|
|
|
238
235
|
"valid": torch.tensor(valid, dtype=torch.int32),
|
|
239
236
|
"boxes": boxes,
|
|
240
237
|
"labels": class_labels,
|
|
241
|
-
"width": torch.tensor(
|
|
242
|
-
|
|
243
|
-
),
|
|
244
|
-
"height": torch.tensor(
|
|
245
|
-
metadata["bounds"][3] - metadata["bounds"][1], dtype=torch.float32
|
|
246
|
-
),
|
|
238
|
+
"width": torch.tensor(bounds[2] - bounds[0], dtype=torch.float32),
|
|
239
|
+
"height": torch.tensor(bounds[3] - bounds[1], dtype=torch.float32),
|
|
247
240
|
}
|
|
248
241
|
|
|
249
242
|
def process_output(
|
|
250
|
-
self, raw_output: Any, metadata:
|
|
243
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
251
244
|
) -> npt.NDArray[Any] | list[Feature]:
|
|
252
245
|
"""Processes an output into raster or vector data.
|
|
253
246
|
|
|
@@ -267,12 +260,12 @@ class DetectionTask(BasicTask):
|
|
|
267
260
|
features = []
|
|
268
261
|
for box, class_id, score in zip(boxes, class_ids, scores):
|
|
269
262
|
shp = shapely.box(
|
|
270
|
-
metadata[
|
|
271
|
-
metadata[
|
|
272
|
-
metadata[
|
|
273
|
-
metadata[
|
|
263
|
+
metadata.patch_bounds[0] + float(box[0]),
|
|
264
|
+
metadata.patch_bounds[1] + float(box[1]),
|
|
265
|
+
metadata.patch_bounds[0] + float(box[2]),
|
|
266
|
+
metadata.patch_bounds[1] + float(box[3]),
|
|
274
267
|
)
|
|
275
|
-
geom = STGeometry(metadata
|
|
268
|
+
geom = STGeometry(metadata.projection, shp, None)
|
|
276
269
|
properties: dict[str, Any] = {
|
|
277
270
|
"score": float(score),
|
|
278
271
|
}
|
rslearn/train/tasks/embedding.py
CHANGED
|
@@ -6,6 +6,8 @@ import numpy.typing as npt
|
|
|
6
6
|
import torch
|
|
7
7
|
from torchmetrics import MetricCollection
|
|
8
8
|
|
|
9
|
+
from rslearn.models.component import FeatureMaps
|
|
10
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
9
11
|
from rslearn.utils import Feature
|
|
10
12
|
|
|
11
13
|
from .task import Task
|
|
@@ -21,7 +23,7 @@ class EmbeddingTask(Task):
|
|
|
21
23
|
def process_inputs(
|
|
22
24
|
self,
|
|
23
25
|
raw_inputs: dict[str, torch.Tensor],
|
|
24
|
-
metadata:
|
|
26
|
+
metadata: SampleMetadata,
|
|
25
27
|
load_targets: bool = True,
|
|
26
28
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
27
29
|
"""Processes the data into targets.
|
|
@@ -38,17 +40,22 @@ class EmbeddingTask(Task):
|
|
|
38
40
|
return {}, {}
|
|
39
41
|
|
|
40
42
|
def process_output(
|
|
41
|
-
self, raw_output: Any, metadata:
|
|
43
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
42
44
|
) -> npt.NDArray[Any] | list[Feature]:
|
|
43
45
|
"""Processes an output into raster or vector data.
|
|
44
46
|
|
|
45
47
|
Args:
|
|
46
|
-
raw_output: the output from prediction head.
|
|
48
|
+
raw_output: the output from prediction head, which must be a CxHxW tensor.
|
|
47
49
|
metadata: metadata about the patch being read
|
|
48
50
|
|
|
49
51
|
Returns:
|
|
50
52
|
either raster or vector data.
|
|
51
53
|
"""
|
|
54
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"output for EmbeddingTask must be a tensor with three dimensions"
|
|
57
|
+
)
|
|
58
|
+
|
|
52
59
|
# Just convert the raw output to numpy array that can be saved to GeoTIFF.
|
|
53
60
|
return raw_output.cpu().numpy()
|
|
54
61
|
|
|
@@ -76,41 +83,38 @@ class EmbeddingTask(Task):
|
|
|
76
83
|
return MetricCollection({})
|
|
77
84
|
|
|
78
85
|
|
|
79
|
-
class EmbeddingHead
|
|
86
|
+
class EmbeddingHead:
|
|
80
87
|
"""Head for embedding task.
|
|
81
88
|
|
|
82
|
-
|
|
83
|
-
returns a dummy loss.
|
|
89
|
+
It just adds a dummy loss to act as a Predictor.
|
|
84
90
|
"""
|
|
85
91
|
|
|
86
|
-
def __init__(self, feature_map_index: int | None = 0):
|
|
87
|
-
"""Create a new EmbeddingHead.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
feature_map_index: the index of the feature map to choose from the input
|
|
91
|
-
list of multi-scale feature maps (default 0). If the input is already
|
|
92
|
-
a single feature map, then set to None.
|
|
93
|
-
"""
|
|
94
|
-
super().__init__()
|
|
95
|
-
self.feature_map_index = feature_map_index
|
|
96
|
-
|
|
97
92
|
def forward(
|
|
98
93
|
self,
|
|
99
|
-
|
|
100
|
-
|
|
94
|
+
intermediates: Any,
|
|
95
|
+
context: ModelContext,
|
|
101
96
|
targets: list[dict[str, Any]] | None = None,
|
|
102
|
-
) ->
|
|
103
|
-
"""
|
|
97
|
+
) -> ModelOutput:
|
|
98
|
+
"""Return the feature map along with a dummy loss.
|
|
104
99
|
|
|
105
100
|
Args:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
101
|
+
intermediates: output from the previous model component, which must be a
|
|
102
|
+
FeatureMaps consisting of a single feature map.
|
|
103
|
+
context: the model context.
|
|
104
|
+
targets: the targets (ignored).
|
|
109
105
|
|
|
110
106
|
Returns:
|
|
111
|
-
|
|
107
|
+
model output with the feature map that was input to this component along
|
|
108
|
+
with a dummy loss.
|
|
112
109
|
"""
|
|
113
|
-
if
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
110
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
111
|
+
raise ValueError("input to EmbeddingHead must be a FeatureMaps")
|
|
112
|
+
if len(intermediates.feature_maps) != 1:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"input to EmbeddingHead must have one feature map, but got {len(intermediates.feature_maps)}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return ModelOutput(
|
|
118
|
+
outputs=intermediates.feature_maps[0],
|
|
119
|
+
loss_dict={"loss": 0},
|
|
120
|
+
)
|