rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +49 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +18 -12
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
|
@@ -12,6 +12,7 @@ from upath import UPath
|
|
|
12
12
|
|
|
13
13
|
from rslearn.log_utils import get_logger
|
|
14
14
|
|
|
15
|
+
from .model_context import ModelContext, ModelOutput
|
|
15
16
|
from .optimizer import AdamW, OptimizerFactory
|
|
16
17
|
from .scheduler import PlateauScheduler, SchedulerFactory
|
|
17
18
|
from .tasks import Task
|
|
@@ -231,12 +232,16 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
231
232
|
Returns:
|
|
232
233
|
The loss tensor.
|
|
233
234
|
"""
|
|
234
|
-
inputs, targets,
|
|
235
|
+
inputs, targets, metadatas = batch
|
|
236
|
+
context = ModelContext(
|
|
237
|
+
inputs=inputs,
|
|
238
|
+
metadatas=metadatas,
|
|
239
|
+
)
|
|
235
240
|
batch_size = len(inputs)
|
|
236
|
-
model_outputs = self(
|
|
237
|
-
self.on_train_forward(
|
|
241
|
+
model_outputs = self(context, targets)
|
|
242
|
+
self.on_train_forward(context, targets, model_outputs)
|
|
238
243
|
|
|
239
|
-
loss_dict = model_outputs
|
|
244
|
+
loss_dict = model_outputs.loss_dict
|
|
240
245
|
train_loss = sum(loss_dict.values())
|
|
241
246
|
self.log_dict(
|
|
242
247
|
{"train_" + k: v for k, v in loss_dict.items()},
|
|
@@ -266,13 +271,17 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
266
271
|
batch_idx: Integer displaying index of this batch.
|
|
267
272
|
dataloader_idx: Index of the current dataloader.
|
|
268
273
|
"""
|
|
269
|
-
inputs, targets,
|
|
274
|
+
inputs, targets, metadatas = batch
|
|
275
|
+
context = ModelContext(
|
|
276
|
+
inputs=inputs,
|
|
277
|
+
metadatas=metadatas,
|
|
278
|
+
)
|
|
270
279
|
batch_size = len(inputs)
|
|
271
|
-
model_outputs = self(
|
|
272
|
-
self.on_val_forward(
|
|
280
|
+
model_outputs = self(context, targets)
|
|
281
|
+
self.on_val_forward(context, targets, model_outputs)
|
|
273
282
|
|
|
274
|
-
loss_dict = model_outputs
|
|
275
|
-
outputs = model_outputs
|
|
283
|
+
loss_dict = model_outputs.loss_dict
|
|
284
|
+
outputs = model_outputs.outputs
|
|
276
285
|
val_loss = sum(loss_dict.values())
|
|
277
286
|
self.log_dict(
|
|
278
287
|
{"val_" + k: v for k, v in loss_dict.items()},
|
|
@@ -304,12 +313,16 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
304
313
|
dataloader_idx: Index of the current dataloader.
|
|
305
314
|
"""
|
|
306
315
|
inputs, targets, metadatas = batch
|
|
316
|
+
context = ModelContext(
|
|
317
|
+
inputs=inputs,
|
|
318
|
+
metadatas=metadatas,
|
|
319
|
+
)
|
|
307
320
|
batch_size = len(inputs)
|
|
308
|
-
model_outputs = self(
|
|
309
|
-
self.on_test_forward(
|
|
321
|
+
model_outputs = self(context, targets)
|
|
322
|
+
self.on_test_forward(context, targets, model_outputs)
|
|
310
323
|
|
|
311
|
-
loss_dict = model_outputs
|
|
312
|
-
outputs = model_outputs
|
|
324
|
+
loss_dict = model_outputs.loss_dict
|
|
325
|
+
outputs = model_outputs.outputs
|
|
313
326
|
test_loss = sum(loss_dict.values())
|
|
314
327
|
self.log_dict(
|
|
315
328
|
{"test_" + k: v for k, v in loss_dict.items()},
|
|
@@ -345,7 +358,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
345
358
|
|
|
346
359
|
def predict_step(
|
|
347
360
|
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
348
|
-
) ->
|
|
361
|
+
) -> ModelOutput:
|
|
349
362
|
"""Compute the predicted class probabilities.
|
|
350
363
|
|
|
351
364
|
Args:
|
|
@@ -356,63 +369,69 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
356
369
|
Returns:
|
|
357
370
|
Output predicted probabilities.
|
|
358
371
|
"""
|
|
359
|
-
inputs, _,
|
|
360
|
-
|
|
372
|
+
inputs, _, metadatas = batch
|
|
373
|
+
context = ModelContext(
|
|
374
|
+
inputs=inputs,
|
|
375
|
+
metadatas=metadatas,
|
|
376
|
+
)
|
|
377
|
+
model_outputs = self(context)
|
|
361
378
|
return model_outputs
|
|
362
379
|
|
|
363
|
-
def forward(
|
|
380
|
+
def forward(
|
|
381
|
+
self, context: ModelContext, targets: list[dict[str, Any]] | None = None
|
|
382
|
+
) -> ModelOutput:
|
|
364
383
|
"""Forward pass of the model.
|
|
365
384
|
|
|
366
385
|
Args:
|
|
367
|
-
|
|
368
|
-
|
|
386
|
+
context: the model context.
|
|
387
|
+
targets: the target dicts.
|
|
369
388
|
|
|
370
389
|
Returns:
|
|
371
390
|
Output of the model.
|
|
372
391
|
"""
|
|
373
|
-
return self.model(
|
|
392
|
+
return self.model(context, targets)
|
|
374
393
|
|
|
375
394
|
def on_train_forward(
|
|
376
395
|
self,
|
|
377
|
-
|
|
396
|
+
context: ModelContext,
|
|
378
397
|
targets: list[dict[str, Any]],
|
|
379
|
-
model_outputs:
|
|
398
|
+
model_outputs: ModelOutput,
|
|
380
399
|
) -> None:
|
|
381
400
|
"""Hook to run after the forward pass of the model during training.
|
|
382
401
|
|
|
383
402
|
Args:
|
|
384
|
-
|
|
403
|
+
context: The model context.
|
|
385
404
|
targets: The target batch.
|
|
386
|
-
model_outputs: The output of the model
|
|
405
|
+
model_outputs: The output of the model.
|
|
387
406
|
"""
|
|
388
407
|
pass
|
|
389
408
|
|
|
390
409
|
def on_val_forward(
|
|
391
410
|
self,
|
|
392
|
-
|
|
411
|
+
context: ModelContext,
|
|
393
412
|
targets: list[dict[str, Any]],
|
|
394
|
-
model_outputs:
|
|
413
|
+
model_outputs: ModelOutput,
|
|
395
414
|
) -> None:
|
|
396
415
|
"""Hook to run after the forward pass of the model during validation.
|
|
397
416
|
|
|
398
417
|
Args:
|
|
399
|
-
|
|
418
|
+
context: The model context.
|
|
400
419
|
targets: The target batch.
|
|
401
|
-
model_outputs: The output of the model
|
|
420
|
+
model_outputs: The output of the model.
|
|
402
421
|
"""
|
|
403
422
|
pass
|
|
404
423
|
|
|
405
424
|
def on_test_forward(
|
|
406
425
|
self,
|
|
407
|
-
|
|
426
|
+
context: ModelContext,
|
|
408
427
|
targets: list[dict[str, Any]],
|
|
409
|
-
model_outputs:
|
|
428
|
+
model_outputs: ModelOutput,
|
|
410
429
|
) -> None:
|
|
411
430
|
"""Hook to run after the forward pass of the model during testing.
|
|
412
431
|
|
|
413
432
|
Args:
|
|
414
|
-
|
|
433
|
+
context: The model context.
|
|
415
434
|
targets: The target batch.
|
|
416
|
-
model_outputs: The output of the model
|
|
435
|
+
model_outputs: The output of the model.
|
|
417
436
|
"""
|
|
418
437
|
pass
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Data classes to provide various context to models."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from rslearn.utils.geometry import PixelBounds, Projection
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class SampleMetadata:
|
|
15
|
+
"""Metadata pertaining to an example."""
|
|
16
|
+
|
|
17
|
+
window_group: str
|
|
18
|
+
window_name: str
|
|
19
|
+
window_bounds: PixelBounds
|
|
20
|
+
patch_bounds: PixelBounds
|
|
21
|
+
patch_idx: int
|
|
22
|
+
num_patches_in_window: int
|
|
23
|
+
time_range: tuple[datetime, datetime] | None
|
|
24
|
+
projection: Projection
|
|
25
|
+
|
|
26
|
+
# Task name to differentiate different tasks.
|
|
27
|
+
dataset_source: str | None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ModelContext:
|
|
32
|
+
"""Context to pass to all model components."""
|
|
33
|
+
|
|
34
|
+
# One input dict per example in the batch.
|
|
35
|
+
inputs: list[dict[str, torch.Tensor]]
|
|
36
|
+
# One SampleMetadata per example in the batch.
|
|
37
|
+
metadatas: list[SampleMetadata]
|
|
38
|
+
# Arbitrary dict that components can add to.
|
|
39
|
+
context_dict: dict[str, Any] = field(default_factory=lambda: {})
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class ModelOutput:
|
|
44
|
+
"""The output from the Predictor.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
outputs: output compatible with the configured Task.
|
|
48
|
+
loss_dict: map from loss names to scalar tensors.
|
|
49
|
+
metadata: arbitrary dict that can be used to store other outputs.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
outputs: Iterable[Any]
|
|
53
|
+
loss_dict: dict[str, torch.Tensor]
|
|
54
|
+
metadata: dict[str, Any] = field(default_factory=lambda: {})
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""rslearn PredictionWriter implementation."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Any
|
|
@@ -12,11 +13,15 @@ from lightning.pytorch.callbacks import BasePredictionWriter
|
|
|
12
13
|
from upath import UPath
|
|
13
14
|
|
|
14
15
|
from rslearn.config import (
|
|
16
|
+
DatasetConfig,
|
|
15
17
|
LayerConfig,
|
|
16
18
|
LayerType,
|
|
19
|
+
StorageConfig,
|
|
17
20
|
)
|
|
18
|
-
from rslearn.dataset import
|
|
21
|
+
from rslearn.dataset import Window
|
|
22
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
19
23
|
from rslearn.log_utils import get_logger
|
|
24
|
+
from rslearn.train.model_context import SampleMetadata
|
|
20
25
|
from rslearn.utils.array import copy_spatial_array
|
|
21
26
|
from rslearn.utils.feature import Feature
|
|
22
27
|
from rslearn.utils.geometry import PixelBounds
|
|
@@ -27,6 +32,7 @@ from rslearn.utils.raster_format import (
|
|
|
27
32
|
from rslearn.utils.vector_format import VectorFormat
|
|
28
33
|
|
|
29
34
|
from .lightning_module import RslearnLightningModule
|
|
35
|
+
from .model_context import ModelOutput
|
|
30
36
|
from .tasks.task import Task
|
|
31
37
|
|
|
32
38
|
logger = get_logger(__name__)
|
|
@@ -43,12 +49,18 @@ class PendingPatchOutput:
|
|
|
43
49
|
class PatchPredictionMerger:
|
|
44
50
|
"""Base class for merging predictions from multiple patches."""
|
|
45
51
|
|
|
46
|
-
def merge(
|
|
52
|
+
def merge(
|
|
53
|
+
self,
|
|
54
|
+
window: Window,
|
|
55
|
+
outputs: Sequence[PendingPatchOutput],
|
|
56
|
+
layer_config: LayerConfig,
|
|
57
|
+
) -> Any:
|
|
47
58
|
"""Merge the outputs.
|
|
48
59
|
|
|
49
60
|
Args:
|
|
50
61
|
window: the window we are merging the outputs for.
|
|
51
62
|
outputs: the outputs to process.
|
|
63
|
+
layer_config: the output layer configuration.
|
|
52
64
|
|
|
53
65
|
Returns:
|
|
54
66
|
the merged outputs.
|
|
@@ -60,7 +72,10 @@ class VectorMerger(PatchPredictionMerger):
|
|
|
60
72
|
"""Merger for vector data that simply concatenates the features."""
|
|
61
73
|
|
|
62
74
|
def merge(
|
|
63
|
-
self,
|
|
75
|
+
self,
|
|
76
|
+
window: Window,
|
|
77
|
+
outputs: Sequence[PendingPatchOutput],
|
|
78
|
+
layer_config: LayerConfig,
|
|
64
79
|
) -> list[Feature]:
|
|
65
80
|
"""Concatenate the vector features."""
|
|
66
81
|
return [feat for output in outputs for feat in output.output]
|
|
@@ -83,18 +98,20 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
83
98
|
self.downsample_factor = downsample_factor
|
|
84
99
|
|
|
85
100
|
def merge(
|
|
86
|
-
self,
|
|
101
|
+
self,
|
|
102
|
+
window: Window,
|
|
103
|
+
outputs: Sequence[PendingPatchOutput],
|
|
104
|
+
layer_config: LayerConfig,
|
|
87
105
|
) -> npt.NDArray:
|
|
88
106
|
"""Merge the raster outputs."""
|
|
89
107
|
num_channels = outputs[0].output.shape[0]
|
|
90
|
-
dtype = outputs[0].output.dtype
|
|
91
108
|
merged_image = np.zeros(
|
|
92
109
|
(
|
|
93
110
|
num_channels,
|
|
94
111
|
(window.bounds[3] - window.bounds[1]) // self.downsample_factor,
|
|
95
112
|
(window.bounds[2] - window.bounds[0]) // self.downsample_factor,
|
|
96
113
|
),
|
|
97
|
-
dtype=dtype,
|
|
114
|
+
dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
|
|
98
115
|
)
|
|
99
116
|
|
|
100
117
|
# Ensure the outputs are sorted by height then width.
|
|
@@ -148,6 +165,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
148
165
|
merger: PatchPredictionMerger | None = None,
|
|
149
166
|
output_path: str | Path | None = None,
|
|
150
167
|
layer_config: LayerConfig | None = None,
|
|
168
|
+
storage_config: StorageConfig | None = None,
|
|
151
169
|
):
|
|
152
170
|
"""Create a new RslearnWriter.
|
|
153
171
|
|
|
@@ -163,28 +181,24 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
163
181
|
layer_config: optional layer configuration. If provided, this config will be
|
|
164
182
|
used instead of reading from the dataset config, allowing usage without
|
|
165
183
|
requiring dataset config at the output path.
|
|
184
|
+
storage_config: optional storage configuration, needed similar to layer_config
|
|
185
|
+
if there is no dataset config.
|
|
166
186
|
"""
|
|
167
187
|
super().__init__(write_interval="batch")
|
|
168
188
|
self.output_layer = output_layer
|
|
169
189
|
self.selector = selector or []
|
|
170
|
-
|
|
171
|
-
|
|
190
|
+
ds_upath = UPath(path, **path_options or {})
|
|
191
|
+
output_upath = (
|
|
172
192
|
UPath(output_path, **path_options or {})
|
|
173
193
|
if output_path is not None
|
|
174
|
-
else
|
|
194
|
+
else ds_upath
|
|
175
195
|
)
|
|
176
196
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
dataset = Dataset(self.path)
|
|
183
|
-
if self.output_layer not in dataset.layers:
|
|
184
|
-
raise KeyError(
|
|
185
|
-
f"Output layer '{self.output_layer}' not found in dataset layers."
|
|
186
|
-
)
|
|
187
|
-
self.layer_config = dataset.layers[self.output_layer]
|
|
197
|
+
self.layer_config, self.dataset_storage = (
|
|
198
|
+
self._get_layer_config_and_dataset_storage(
|
|
199
|
+
ds_upath, output_upath, layer_config, storage_config
|
|
200
|
+
)
|
|
201
|
+
)
|
|
188
202
|
|
|
189
203
|
self.format: RasterFormat | VectorFormat
|
|
190
204
|
if self.layer_config.type == LayerType.RASTER:
|
|
@@ -207,11 +221,73 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
207
221
|
# patches of each window need to be reconstituted.
|
|
208
222
|
self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
|
|
209
223
|
|
|
224
|
+
def _get_layer_config_and_dataset_storage(
|
|
225
|
+
self,
|
|
226
|
+
ds_upath: UPath,
|
|
227
|
+
output_upath: UPath,
|
|
228
|
+
layer_config: LayerConfig | None,
|
|
229
|
+
storage_config: StorageConfig | None,
|
|
230
|
+
) -> tuple[LayerConfig, WindowStorage]:
|
|
231
|
+
"""Get the layer config and dataset storage to use.
|
|
232
|
+
|
|
233
|
+
This is a helper function for the init method.
|
|
234
|
+
|
|
235
|
+
If layer_config is set, we use that. If storage_config is set, we use it to
|
|
236
|
+
instantiate a WindowStorage using the output_upath.
|
|
237
|
+
|
|
238
|
+
If one of them is not set, we load the config from the ds_upath. Otherwise, we
|
|
239
|
+
avoid reading the dataset config; this way, RslearnWriter can be used with
|
|
240
|
+
output directories that do not contain the dataset config, as long as
|
|
241
|
+
layer_config and storage_config are both provided.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
ds_upath: the dataset path, where a dataset config can be loaded from if
|
|
245
|
+
layer_config or storage_config is not provided.
|
|
246
|
+
output_upath: the output directory, which could be different from the
|
|
247
|
+
dataset path.
|
|
248
|
+
layer_config: optional LayerConfig to provide.
|
|
249
|
+
storage_config: optional StorageConfig to provide.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
a tuple (layer_config, dataset_storage)
|
|
253
|
+
"""
|
|
254
|
+
dataset_storage: WindowStorage | None = None
|
|
255
|
+
|
|
256
|
+
# Instantiate the WindowStorage from the storage_config if provided.
|
|
257
|
+
if storage_config:
|
|
258
|
+
dataset_storage = (
|
|
259
|
+
storage_config.instantiate_window_storage_factory().get_storage(
|
|
260
|
+
output_upath
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if not layer_config or not dataset_storage:
|
|
265
|
+
# Need to load dataset config since one of LayerConfig/StorageConfig is missing.
|
|
266
|
+
# We use DatasetConfig.model_validate instead of initializing the Dataset
|
|
267
|
+
# because we want to get a WindowStorage that has the dataset path set to
|
|
268
|
+
# output_upath instead of ds_upath.
|
|
269
|
+
with (ds_upath / "config.json").open() as f:
|
|
270
|
+
dataset_config = DatasetConfig.model_validate(json.load(f))
|
|
271
|
+
|
|
272
|
+
if not layer_config:
|
|
273
|
+
if self.output_layer not in dataset_config.layers:
|
|
274
|
+
raise KeyError(
|
|
275
|
+
f"Output layer '{self.output_layer}' not found in dataset layers."
|
|
276
|
+
)
|
|
277
|
+
layer_config = dataset_config.layers[self.output_layer]
|
|
278
|
+
|
|
279
|
+
if not dataset_storage:
|
|
280
|
+
dataset_storage = dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
281
|
+
output_upath
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return (layer_config, dataset_storage)
|
|
285
|
+
|
|
210
286
|
def write_on_batch_end(
|
|
211
287
|
self,
|
|
212
288
|
trainer: Trainer,
|
|
213
289
|
pl_module: LightningModule,
|
|
214
|
-
prediction:
|
|
290
|
+
prediction: ModelOutput,
|
|
215
291
|
batch_indices: Sequence[int] | None,
|
|
216
292
|
batch: tuple[list, list, list],
|
|
217
293
|
batch_idx: int,
|
|
@@ -232,13 +308,13 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
232
308
|
assert isinstance(pl_module, RslearnLightningModule)
|
|
233
309
|
task = pl_module.task
|
|
234
310
|
_, _, metadatas = batch
|
|
235
|
-
self.process_output_batch(task, prediction
|
|
311
|
+
self.process_output_batch(task, prediction.outputs, metadatas)
|
|
236
312
|
|
|
237
313
|
def process_output_batch(
|
|
238
314
|
self,
|
|
239
315
|
task: Task,
|
|
240
|
-
prediction:
|
|
241
|
-
metadatas:
|
|
316
|
+
prediction: Iterable[Any],
|
|
317
|
+
metadatas: Iterable[SampleMetadata],
|
|
242
318
|
) -> None:
|
|
243
319
|
"""Write a prediction batch with simplified API.
|
|
244
320
|
|
|
@@ -263,25 +339,19 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
263
339
|
for k in self.selector:
|
|
264
340
|
output = output[k]
|
|
265
341
|
|
|
266
|
-
# Use custom output_path if provided, otherwise use dataset path
|
|
267
|
-
window_base_path = (
|
|
268
|
-
self.output_path if self.output_path is not None else self.path
|
|
269
|
-
)
|
|
270
342
|
window = Window(
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
bounds=metadata["window_bounds"],
|
|
278
|
-
time_range=metadata["time_range"],
|
|
343
|
+
storage=self.dataset_storage,
|
|
344
|
+
group=metadata.window_group,
|
|
345
|
+
name=metadata.window_name,
|
|
346
|
+
projection=metadata.projection,
|
|
347
|
+
bounds=metadata.window_bounds,
|
|
348
|
+
time_range=metadata.time_range,
|
|
279
349
|
)
|
|
280
350
|
self.process_output(
|
|
281
351
|
window,
|
|
282
|
-
metadata
|
|
283
|
-
metadata
|
|
284
|
-
metadata
|
|
352
|
+
metadata.patch_idx,
|
|
353
|
+
metadata.num_patches_in_window,
|
|
354
|
+
metadata.patch_bounds,
|
|
285
355
|
output,
|
|
286
356
|
)
|
|
287
357
|
|
|
@@ -320,7 +390,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
320
390
|
|
|
321
391
|
# Merge outputs from overlapped patches if merger is set.
|
|
322
392
|
logger.debug(f"Merging and writing for window {window.name}")
|
|
323
|
-
merged_output = self.merger.merge(window, pending_output)
|
|
393
|
+
merged_output = self.merger.merge(window, pending_output, self.layer_config)
|
|
324
394
|
|
|
325
395
|
if self.layer_config.type == LayerType.RASTER:
|
|
326
396
|
raster_dir = window.get_raster_dir(
|
|
@@ -15,6 +15,8 @@ from torchmetrics.classification import (
|
|
|
15
15
|
MulticlassRecall,
|
|
16
16
|
)
|
|
17
17
|
|
|
18
|
+
from rslearn.models.component import FeatureVector, Predictor
|
|
19
|
+
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
18
20
|
from rslearn.utils import Feature, STGeometry
|
|
19
21
|
|
|
20
22
|
from .task import BasicTask
|
|
@@ -98,7 +100,7 @@ class ClassificationTask(BasicTask):
|
|
|
98
100
|
def process_inputs(
|
|
99
101
|
self,
|
|
100
102
|
raw_inputs: dict[str, torch.Tensor | list[Feature]],
|
|
101
|
-
metadata:
|
|
103
|
+
metadata: SampleMetadata,
|
|
102
104
|
load_targets: bool = True,
|
|
103
105
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
104
106
|
"""Processes the data into targets.
|
|
@@ -154,17 +156,25 @@ class ClassificationTask(BasicTask):
|
|
|
154
156
|
}
|
|
155
157
|
|
|
156
158
|
def process_output(
|
|
157
|
-
self, raw_output: Any, metadata:
|
|
158
|
-
) ->
|
|
159
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
160
|
+
) -> list[Feature]:
|
|
159
161
|
"""Processes an output into raster or vector data.
|
|
160
162
|
|
|
161
163
|
Args:
|
|
162
|
-
raw_output: the output from prediction head
|
|
164
|
+
raw_output: the output from prediction head, which must be a tensor
|
|
165
|
+
containing output probabilities (one dimension).
|
|
163
166
|
metadata: metadata about the patch being read
|
|
164
167
|
|
|
165
168
|
Returns:
|
|
166
|
-
|
|
169
|
+
a list with one Feature corresponding to the input patch extent with a
|
|
170
|
+
property name containing the predicted class. It will have another
|
|
171
|
+
property containing the probabilities if prob_property was set.
|
|
167
172
|
"""
|
|
173
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 1:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"expected output for ClassificationTask to be a Tensor with one dimension"
|
|
176
|
+
)
|
|
177
|
+
|
|
168
178
|
probs = raw_output.cpu().numpy()
|
|
169
179
|
if len(self.classes) == 2 and self.positive_class_threshold != 0.5:
|
|
170
180
|
positive_class_prob = probs[self.positive_class_id]
|
|
@@ -184,8 +194,8 @@ class ClassificationTask(BasicTask):
|
|
|
184
194
|
|
|
185
195
|
feature = Feature(
|
|
186
196
|
STGeometry(
|
|
187
|
-
metadata
|
|
188
|
-
shapely.Point(metadata[
|
|
197
|
+
metadata.projection,
|
|
198
|
+
shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
|
|
189
199
|
None,
|
|
190
200
|
),
|
|
191
201
|
{
|
|
@@ -265,25 +275,31 @@ class ClassificationTask(BasicTask):
|
|
|
265
275
|
return MetricCollection(metrics)
|
|
266
276
|
|
|
267
277
|
|
|
268
|
-
class ClassificationHead(
|
|
278
|
+
class ClassificationHead(Predictor):
|
|
269
279
|
"""Head for classification task."""
|
|
270
280
|
|
|
271
281
|
def forward(
|
|
272
282
|
self,
|
|
273
|
-
|
|
274
|
-
|
|
283
|
+
intermediates: Any,
|
|
284
|
+
context: ModelContext,
|
|
275
285
|
targets: list[dict[str, Any]] | None = None,
|
|
276
|
-
) ->
|
|
286
|
+
) -> ModelOutput:
|
|
277
287
|
"""Compute the classification outputs and loss from logits and targets.
|
|
278
288
|
|
|
279
289
|
Args:
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
290
|
+
intermediates: output from the previous model component, it should be a
|
|
291
|
+
FeatureVector with a tensor that is (BatchSize, NumClasses) in shape.
|
|
292
|
+
context: the model context.
|
|
293
|
+
targets: must contain "class" key that stores the class label, along with
|
|
294
|
+
"valid" key indicating whether the label is valid for each example.
|
|
283
295
|
|
|
284
296
|
Returns:
|
|
285
297
|
tuple of outputs and loss dict
|
|
286
298
|
"""
|
|
299
|
+
if not isinstance(intermediates, FeatureVector):
|
|
300
|
+
raise ValueError("the input to ClassificationHead must be a FeatureVector")
|
|
301
|
+
|
|
302
|
+
logits = intermediates.feature_vector
|
|
287
303
|
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
288
304
|
|
|
289
305
|
losses = {}
|
|
@@ -298,7 +314,10 @@ class ClassificationHead(torch.nn.Module):
|
|
|
298
314
|
)
|
|
299
315
|
losses["cls"] = torch.mean(loss)
|
|
300
316
|
|
|
301
|
-
return
|
|
317
|
+
return ModelOutput(
|
|
318
|
+
outputs=outputs,
|
|
319
|
+
loss_dict=losses,
|
|
320
|
+
)
|
|
302
321
|
|
|
303
322
|
|
|
304
323
|
class ClassificationMetric(Metric):
|