rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/dataset.py +15 -16
- rslearn/dataset/dataset.py +28 -22
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +1 -1
- rslearn/models/anysat.py +35 -33
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clip.py +5 -2
- rslearn/models/component.py +12 -0
- rslearn/models/croma.py +11 -3
- rslearn/models/dinov3.py +2 -1
- rslearn/models/faster_rcnn.py +2 -1
- rslearn/models/galileo/galileo.py +58 -31
- rslearn/models/module_wrapper.py +6 -1
- rslearn/models/molmo.py +4 -2
- rslearn/models/olmoearth_pretrain/model.py +206 -51
- rslearn/models/olmoearth_pretrain/norm.py +5 -3
- rslearn/models/panopticon.py +3 -1
- rslearn/models/presto/presto.py +45 -15
- rslearn/models/prithvi.py +9 -7
- rslearn/models/sam2_enc.py +3 -1
- rslearn/models/satlaspretrain.py +4 -1
- rslearn/models/simple_time_series.py +43 -17
- rslearn/models/ssl4eo_s12.py +19 -14
- rslearn/models/swin.py +3 -1
- rslearn/models/terramind.py +5 -4
- rslearn/train/all_patches_dataset.py +96 -28
- rslearn/train/dataset.py +102 -53
- rslearn/train/model_context.py +35 -1
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +8 -2
- rslearn/train/tasks/detection.py +3 -2
- rslearn/train/tasks/multi_task.py +2 -3
- rslearn/train/tasks/per_pixel_regression.py +14 -5
- rslearn/train/tasks/regression.py +8 -2
- rslearn/train/tasks/segmentation.py +13 -4
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +45 -5
- rslearn/train/transforms/crop.py +22 -8
- rslearn/train/transforms/flip.py +13 -5
- rslearn/train/transforms/mask.py +11 -2
- rslearn/train/transforms/normalize.py +46 -15
- rslearn/train/transforms/pad.py +15 -3
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +11 -2
- rslearn/train/transforms/sentinel1.py +18 -3
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py
CHANGED
|
@@ -8,6 +8,7 @@ import random
|
|
|
8
8
|
import tempfile
|
|
9
9
|
import time
|
|
10
10
|
import uuid
|
|
11
|
+
from datetime import datetime
|
|
11
12
|
from typing import Any
|
|
12
13
|
|
|
13
14
|
import torch
|
|
@@ -19,12 +20,18 @@ from rslearn.config import (
|
|
|
19
20
|
DType,
|
|
20
21
|
LayerConfig,
|
|
21
22
|
)
|
|
23
|
+
from rslearn.data_sources.data_source import Item
|
|
22
24
|
from rslearn.dataset.dataset import Dataset
|
|
23
25
|
from rslearn.dataset.storage.file import FileWindowStorage
|
|
24
|
-
from rslearn.dataset.window import
|
|
26
|
+
from rslearn.dataset.window import (
|
|
27
|
+
Window,
|
|
28
|
+
WindowLayerData,
|
|
29
|
+
get_layer_and_group_from_dir_name,
|
|
30
|
+
)
|
|
25
31
|
from rslearn.log_utils import get_logger
|
|
32
|
+
from rslearn.train.model_context import RasterImage
|
|
26
33
|
from rslearn.utils.feature import Feature
|
|
27
|
-
from rslearn.utils.geometry import PixelBounds
|
|
34
|
+
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
28
35
|
from rslearn.utils.mp import star_imap_unordered
|
|
29
36
|
|
|
30
37
|
from .model_context import SampleMetadata
|
|
@@ -130,6 +137,10 @@ class DataInput:
|
|
|
130
137
|
"""Specification of a piece of data from a window that is needed for training.
|
|
131
138
|
|
|
132
139
|
The DataInput includes which layer(s) the data can be obtained from for each window.
|
|
140
|
+
|
|
141
|
+
Note that this class is not a dataclass because jsonargparse does not play well
|
|
142
|
+
with dataclasses without enabling specialized options which we have not validated
|
|
143
|
+
will work with the rest of our code.
|
|
133
144
|
"""
|
|
134
145
|
|
|
135
146
|
def __init__(
|
|
@@ -143,7 +154,9 @@ class DataInput:
|
|
|
143
154
|
dtype: DType = DType.FLOAT32,
|
|
144
155
|
load_all_layers: bool = False,
|
|
145
156
|
load_all_item_groups: bool = False,
|
|
146
|
-
|
|
157
|
+
resolution_factor: ResolutionFactor = ResolutionFactor(),
|
|
158
|
+
resampling: Resampling = Resampling.nearest,
|
|
159
|
+
):
|
|
147
160
|
"""Initialize a new DataInput.
|
|
148
161
|
|
|
149
162
|
Args:
|
|
@@ -166,6 +179,11 @@ class DataInput:
|
|
|
166
179
|
are reading from. By default, we assume the specified layer name is of
|
|
167
180
|
the form "{layer_name}.{group_idx}" and read that item group only. With
|
|
168
181
|
this option enabled, we ignore the group_idx and read all item groups.
|
|
182
|
+
resolution_factor: controls the resolution at which raster data is loaded for training.
|
|
183
|
+
By default (factor=1), data is loaded at the window resolution.
|
|
184
|
+
E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
|
|
185
|
+
the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
|
|
186
|
+
resampling: resampling method (default nearest neighbor).
|
|
169
187
|
"""
|
|
170
188
|
self.data_type = data_type
|
|
171
189
|
self.layers = layers
|
|
@@ -176,6 +194,8 @@ class DataInput:
|
|
|
176
194
|
self.dtype = dtype
|
|
177
195
|
self.load_all_layers = load_all_layers
|
|
178
196
|
self.load_all_item_groups = load_all_item_groups
|
|
197
|
+
self.resolution_factor = resolution_factor
|
|
198
|
+
self.resampling = resampling
|
|
179
199
|
|
|
180
200
|
|
|
181
201
|
def read_raster_layer_for_data_input(
|
|
@@ -185,7 +205,8 @@ def read_raster_layer_for_data_input(
|
|
|
185
205
|
group_idx: int,
|
|
186
206
|
layer_config: LayerConfig,
|
|
187
207
|
data_input: DataInput,
|
|
188
|
-
|
|
208
|
+
layer_data: WindowLayerData | None,
|
|
209
|
+
) -> tuple[torch.Tensor, tuple[datetime, datetime] | None]:
|
|
189
210
|
"""Read a raster layer for a DataInput.
|
|
190
211
|
|
|
191
212
|
This scans the available rasters for the layer at the window to determine which
|
|
@@ -198,9 +219,11 @@ def read_raster_layer_for_data_input(
|
|
|
198
219
|
group_idx: the item group.
|
|
199
220
|
layer_config: the layer configuration.
|
|
200
221
|
data_input: the DataInput that specifies the bands and dtype.
|
|
222
|
+
layer_data: the WindowLayerData associated with this layer and window.
|
|
201
223
|
|
|
202
224
|
Returns:
|
|
203
|
-
|
|
225
|
+
RasterImage containing raster data and the timestamp associated
|
|
226
|
+
with that data.
|
|
204
227
|
"""
|
|
205
228
|
# See what different sets of bands we need to read to get all the
|
|
206
229
|
# configured bands.
|
|
@@ -233,15 +256,23 @@ def read_raster_layer_for_data_input(
|
|
|
233
256
|
+ f"window {window.name} layer {layer_name} group {group_idx}"
|
|
234
257
|
)
|
|
235
258
|
|
|
259
|
+
# Get the projection and bounds to read under (multiply window resolution # by
|
|
260
|
+
# the specified resolution factor).
|
|
261
|
+
final_projection = data_input.resolution_factor.multiply_projection(
|
|
262
|
+
window.projection
|
|
263
|
+
)
|
|
264
|
+
final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
|
|
265
|
+
|
|
236
266
|
image = torch.zeros(
|
|
237
|
-
(
|
|
267
|
+
(
|
|
268
|
+
len(needed_bands),
|
|
269
|
+
final_bounds[3] - final_bounds[1],
|
|
270
|
+
final_bounds[2] - final_bounds[0],
|
|
271
|
+
),
|
|
238
272
|
dtype=get_torch_dtype(data_input.dtype),
|
|
239
273
|
)
|
|
240
274
|
|
|
241
275
|
for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
|
|
242
|
-
final_projection, final_bounds = band_set.get_final_projection_and_bounds(
|
|
243
|
-
window.projection, bounds
|
|
244
|
-
)
|
|
245
276
|
if band_set.format is None:
|
|
246
277
|
raise ValueError(f"No format specified for {layer_name}")
|
|
247
278
|
raster_format = band_set.instantiate_raster_format()
|
|
@@ -249,49 +280,48 @@ def read_raster_layer_for_data_input(
|
|
|
249
280
|
layer_name, band_set.bands, group_idx=group_idx
|
|
250
281
|
)
|
|
251
282
|
|
|
252
|
-
#
|
|
253
|
-
#
|
|
254
|
-
#
|
|
255
|
-
#
|
|
256
|
-
#
|
|
257
|
-
#
|
|
258
|
-
is_bounds_zoomable = True
|
|
259
|
-
if band_set.zoom_offset < 0:
|
|
260
|
-
zoom_factor = 2 ** (-band_set.zoom_offset)
|
|
261
|
-
is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == (
|
|
262
|
-
bounds[2] - bounds[0]
|
|
263
|
-
) and (final_bounds[3] - final_bounds[1]) * zoom_factor == (
|
|
264
|
-
bounds[3] - bounds[1]
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
if is_bounds_zoomable:
|
|
268
|
-
src = raster_format.decode_raster(
|
|
269
|
-
raster_dir, final_projection, final_bounds
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
# Resize to patch size if needed.
|
|
273
|
-
# This is for band sets that are stored at a lower resolution.
|
|
274
|
-
# Here we assume that it is a multiple.
|
|
275
|
-
if src.shape[1:3] != image.shape[1:3]:
|
|
276
|
-
if src.shape[1] < image.shape[1]:
|
|
277
|
-
factor = image.shape[1] // src.shape[1]
|
|
278
|
-
src = src.repeat(repeats=factor, axis=1).repeat(
|
|
279
|
-
repeats=factor, axis=2
|
|
280
|
-
)
|
|
281
|
-
else:
|
|
282
|
-
factor = src.shape[1] // image.shape[1]
|
|
283
|
-
src = src[:, ::factor, ::factor]
|
|
284
|
-
|
|
285
|
-
else:
|
|
286
|
-
src = raster_format.decode_raster(
|
|
287
|
-
raster_dir, window.projection, bounds, resampling=Resampling.nearest
|
|
288
|
-
)
|
|
283
|
+
# TODO: previously we try to read based on band_set.zoom_offset when possible,
|
|
284
|
+
# and handle zooming in with torch.repeat (if resampling method is nearest
|
|
285
|
+
# neighbor). However, we have not benchmarked whether this actually improves
|
|
286
|
+
# data loading speed, so for simplicity, for now we let rasterio handle the
|
|
287
|
+
# resampling. If it really is much faster to handle it via torch, then it may
|
|
288
|
+
# make sense to bring back that functionality.
|
|
289
289
|
|
|
290
|
+
src = raster_format.decode_raster(
|
|
291
|
+
raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
|
|
292
|
+
)
|
|
290
293
|
image[dst_indexes, :, :] = torch.as_tensor(
|
|
291
294
|
src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
|
|
292
295
|
)
|
|
293
296
|
|
|
294
|
-
|
|
297
|
+
# add the timestamp. this is a tuple defining the start and end of the time range.
|
|
298
|
+
time_range = None
|
|
299
|
+
if layer_data is not None:
|
|
300
|
+
item = Item.deserialize(layer_data.serialized_item_groups[group_idx][0])
|
|
301
|
+
if item.geometry.time_range is not None:
|
|
302
|
+
# we assume if one layer data has a geometry & time range, all of them do
|
|
303
|
+
time_ranges = [
|
|
304
|
+
(
|
|
305
|
+
datetime.fromisoformat(
|
|
306
|
+
Item.deserialize(
|
|
307
|
+
layer_data.serialized_item_groups[group_idx][idx]
|
|
308
|
+
).geometry.time_range[0] # type: ignore
|
|
309
|
+
),
|
|
310
|
+
datetime.fromisoformat(
|
|
311
|
+
Item.deserialize(
|
|
312
|
+
layer_data.serialized_item_groups[group_idx][idx]
|
|
313
|
+
).geometry.time_range[1] # type: ignore
|
|
314
|
+
),
|
|
315
|
+
)
|
|
316
|
+
for idx in range(len(layer_data.serialized_item_groups[group_idx]))
|
|
317
|
+
]
|
|
318
|
+
# take the min and max
|
|
319
|
+
time_range = (
|
|
320
|
+
min([t[0] for t in time_ranges]),
|
|
321
|
+
max([t[1] for t in time_ranges]),
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return image, time_range
|
|
295
325
|
|
|
296
326
|
|
|
297
327
|
def read_data_input(
|
|
@@ -300,7 +330,7 @@ def read_data_input(
|
|
|
300
330
|
bounds: PixelBounds,
|
|
301
331
|
data_input: DataInput,
|
|
302
332
|
rng: random.Random,
|
|
303
|
-
) ->
|
|
333
|
+
) -> RasterImage | list[Feature]:
|
|
304
334
|
"""Read the data specified by the DataInput from the window.
|
|
305
335
|
|
|
306
336
|
Args:
|
|
@@ -342,15 +372,34 @@ def read_data_input(
|
|
|
342
372
|
layers_to_read = [rng.choice(layer_options)]
|
|
343
373
|
|
|
344
374
|
if data_input.data_type == "raster":
|
|
375
|
+
# load it once here
|
|
376
|
+
layer_datas = window.load_layer_datas()
|
|
345
377
|
images: list[torch.Tensor] = []
|
|
378
|
+
time_ranges: list[tuple[datetime, datetime] | None] = []
|
|
346
379
|
for layer_name, group_idx in layers_to_read:
|
|
347
380
|
layer_config = dataset.layers[layer_name]
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
381
|
+
image, time_range = read_raster_layer_for_data_input(
|
|
382
|
+
window,
|
|
383
|
+
bounds,
|
|
384
|
+
layer_name,
|
|
385
|
+
group_idx,
|
|
386
|
+
layer_config,
|
|
387
|
+
data_input,
|
|
388
|
+
# some layers (e.g. "label_raster") won't have associated
|
|
389
|
+
# layer datas
|
|
390
|
+
layer_datas[layer_name] if layer_name in layer_datas else None,
|
|
352
391
|
)
|
|
353
|
-
|
|
392
|
+
if len(time_ranges) > 0:
|
|
393
|
+
if type(time_ranges[-1]) is not type(time_range):
|
|
394
|
+
raise ValueError(
|
|
395
|
+
f"All time ranges should be datetime tuples or None. Got {type(time_range)} amd {type(time_ranges[-1])}"
|
|
396
|
+
)
|
|
397
|
+
images.append(image)
|
|
398
|
+
time_ranges.append(time_range)
|
|
399
|
+
return RasterImage(
|
|
400
|
+
torch.stack(images, dim=1),
|
|
401
|
+
time_ranges if time_ranges[0] is not None else None, # type: ignore
|
|
402
|
+
)
|
|
354
403
|
|
|
355
404
|
elif data_input.data_type == "vector":
|
|
356
405
|
# We don't really support time series for vector data currently, we just
|
rslearn/train/model_context.py
CHANGED
|
@@ -10,6 +10,40 @@ import torch
|
|
|
10
10
|
from rslearn.utils.geometry import PixelBounds, Projection
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
@dataclass
|
|
14
|
+
class RasterImage:
|
|
15
|
+
"""A raster image is a torch.tensor containing the images and their associated timestamps."""
|
|
16
|
+
|
|
17
|
+
# image is a 4D CTHW tensor
|
|
18
|
+
image: torch.Tensor
|
|
19
|
+
# if timestamps is not None, len(timestamps) must match the T dimension of the tensor
|
|
20
|
+
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def shape(self) -> torch.Size:
|
|
24
|
+
"""The shape of the image."""
|
|
25
|
+
return self.image.shape
|
|
26
|
+
|
|
27
|
+
def dim(self) -> int:
|
|
28
|
+
"""The dim of the image."""
|
|
29
|
+
return self.image.dim()
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dtype(self) -> torch.dtype:
|
|
33
|
+
"""The image dtype."""
|
|
34
|
+
return self.image.dtype
|
|
35
|
+
|
|
36
|
+
def single_ts_to_chw_tensor(self) -> torch.Tensor:
|
|
37
|
+
"""Single timestep models expect single timestep inputs.
|
|
38
|
+
|
|
39
|
+
This function (1) checks this raster image only has 1 timestep and
|
|
40
|
+
(2) returns the tensor for that (single) timestep (going from CTHW to CHW).
|
|
41
|
+
"""
|
|
42
|
+
if self.image.shape[1] != 1:
|
|
43
|
+
raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
|
|
44
|
+
return self.image[:, 0]
|
|
45
|
+
|
|
46
|
+
|
|
13
47
|
@dataclass
|
|
14
48
|
class SampleMetadata:
|
|
15
49
|
"""Metadata pertaining to an example."""
|
|
@@ -32,7 +66,7 @@ class ModelContext:
|
|
|
32
66
|
"""Context to pass to all model components."""
|
|
33
67
|
|
|
34
68
|
# One input dict per example in the batch.
|
|
35
|
-
inputs: list[dict[str, torch.Tensor]]
|
|
69
|
+
inputs: list[dict[str, torch.Tensor | RasterImage]]
|
|
36
70
|
# One SampleMetadata per example in the batch.
|
|
37
71
|
metadatas: list[SampleMetadata]
|
|
38
72
|
# Arbitrary dict that components can add to.
|
rslearn/train/scheduler.py
CHANGED
|
@@ -8,6 +8,7 @@ from torch.optim.lr_scheduler import (
|
|
|
8
8
|
CosineAnnealingLR,
|
|
9
9
|
CosineAnnealingWarmRestarts,
|
|
10
10
|
LRScheduler,
|
|
11
|
+
MultiStepLR,
|
|
11
12
|
ReduceLROnPlateau,
|
|
12
13
|
)
|
|
13
14
|
|
|
@@ -50,6 +51,20 @@ class PlateauScheduler(SchedulerFactory):
|
|
|
50
51
|
return ReduceLROnPlateau(optimizer, **self.get_kwargs())
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
@dataclass
|
|
55
|
+
class MultiStepScheduler(SchedulerFactory):
|
|
56
|
+
"""Step learning rate scheduler."""
|
|
57
|
+
|
|
58
|
+
milestones: list[int]
|
|
59
|
+
gamma: float | None = None
|
|
60
|
+
last_epoch: int | None = None
|
|
61
|
+
|
|
62
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
63
|
+
"""Build the ReduceLROnPlateau scheduler."""
|
|
64
|
+
super().build(optimizer)
|
|
65
|
+
return MultiStepLR(optimizer, **self.get_kwargs())
|
|
66
|
+
|
|
67
|
+
|
|
53
68
|
@dataclass
|
|
54
69
|
class CosineAnnealingScheduler(SchedulerFactory):
|
|
55
70
|
"""Cosine annealing learning rate scheduler."""
|
|
@@ -16,7 +16,12 @@ from torchmetrics.classification import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from rslearn.models.component import FeatureVector, Predictor
|
|
19
|
-
from rslearn.train.model_context import
|
|
19
|
+
from rslearn.train.model_context import (
|
|
20
|
+
ModelContext,
|
|
21
|
+
ModelOutput,
|
|
22
|
+
RasterImage,
|
|
23
|
+
SampleMetadata,
|
|
24
|
+
)
|
|
20
25
|
from rslearn.utils import Feature, STGeometry
|
|
21
26
|
|
|
22
27
|
from .task import BasicTask
|
|
@@ -99,7 +104,7 @@ class ClassificationTask(BasicTask):
|
|
|
99
104
|
|
|
100
105
|
def process_inputs(
|
|
101
106
|
self,
|
|
102
|
-
raw_inputs: dict[str,
|
|
107
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
103
108
|
metadata: SampleMetadata,
|
|
104
109
|
load_targets: bool = True,
|
|
105
110
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -118,6 +123,7 @@ class ClassificationTask(BasicTask):
|
|
|
118
123
|
return {}, {}
|
|
119
124
|
|
|
120
125
|
data = raw_inputs["targets"]
|
|
126
|
+
assert isinstance(data, list)
|
|
121
127
|
for feat in data:
|
|
122
128
|
if feat.properties is None:
|
|
123
129
|
continue
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -12,7 +12,7 @@ import torchmetrics.classification
|
|
|
12
12
|
import torchvision
|
|
13
13
|
from torchmetrics import Metric, MetricCollection
|
|
14
14
|
|
|
15
|
-
from rslearn.train.model_context import SampleMetadata
|
|
15
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
16
16
|
from rslearn.utils import Feature, STGeometry
|
|
17
17
|
|
|
18
18
|
from .task import BasicTask
|
|
@@ -127,7 +127,7 @@ class DetectionTask(BasicTask):
|
|
|
127
127
|
|
|
128
128
|
def process_inputs(
|
|
129
129
|
self,
|
|
130
|
-
raw_inputs: dict[str,
|
|
130
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
131
131
|
metadata: SampleMetadata,
|
|
132
132
|
load_targets: bool = True,
|
|
133
133
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -152,6 +152,7 @@ class DetectionTask(BasicTask):
|
|
|
152
152
|
valid = 1
|
|
153
153
|
|
|
154
154
|
data = raw_inputs["targets"]
|
|
155
|
+
assert isinstance(data, list)
|
|
155
156
|
for feat in data:
|
|
156
157
|
if feat.properties is None:
|
|
157
158
|
continue
|
|
@@ -3,10 +3,9 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
import torch
|
|
7
6
|
from torchmetrics import Metric, MetricCollection
|
|
8
7
|
|
|
9
|
-
from rslearn.train.model_context import SampleMetadata
|
|
8
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
10
9
|
from rslearn.utils import Feature
|
|
11
10
|
|
|
12
11
|
from .task import Task
|
|
@@ -30,7 +29,7 @@ class MultiTask(Task):
|
|
|
30
29
|
|
|
31
30
|
def process_inputs(
|
|
32
31
|
self,
|
|
33
|
-
raw_inputs: dict[str,
|
|
32
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
34
33
|
metadata: SampleMetadata,
|
|
35
34
|
load_targets: bool = True,
|
|
36
35
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -9,7 +9,12 @@ import torchmetrics
|
|
|
9
9
|
from torchmetrics import Metric, MetricCollection
|
|
10
10
|
|
|
11
11
|
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
-
from rslearn.train.model_context import
|
|
12
|
+
from rslearn.train.model_context import (
|
|
13
|
+
ModelContext,
|
|
14
|
+
ModelOutput,
|
|
15
|
+
RasterImage,
|
|
16
|
+
SampleMetadata,
|
|
17
|
+
)
|
|
13
18
|
from rslearn.utils.feature import Feature
|
|
14
19
|
|
|
15
20
|
from .task import BasicTask
|
|
@@ -42,7 +47,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
42
47
|
|
|
43
48
|
def process_inputs(
|
|
44
49
|
self,
|
|
45
|
-
raw_inputs: dict[str,
|
|
50
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
46
51
|
metadata: SampleMetadata,
|
|
47
52
|
load_targets: bool = True,
|
|
48
53
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -60,11 +65,15 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
60
65
|
if not load_targets:
|
|
61
66
|
return {}, {}
|
|
62
67
|
|
|
63
|
-
assert raw_inputs["targets"]
|
|
64
|
-
|
|
68
|
+
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
69
|
+
assert raw_inputs["targets"].image.shape[0] == 1
|
|
70
|
+
assert raw_inputs["targets"].image.shape[1] == 1
|
|
71
|
+
labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
|
|
65
72
|
|
|
66
73
|
if self.nodata_value is not None:
|
|
67
|
-
valid = (
|
|
74
|
+
valid = (
|
|
75
|
+
raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
|
|
76
|
+
).float()
|
|
68
77
|
else:
|
|
69
78
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
70
79
|
|
|
@@ -11,7 +11,12 @@ from PIL import Image, ImageDraw
|
|
|
11
11
|
from torchmetrics import Metric, MetricCollection
|
|
12
12
|
|
|
13
13
|
from rslearn.models.component import FeatureVector, Predictor
|
|
14
|
-
from rslearn.train.model_context import
|
|
14
|
+
from rslearn.train.model_context import (
|
|
15
|
+
ModelContext,
|
|
16
|
+
ModelOutput,
|
|
17
|
+
RasterImage,
|
|
18
|
+
SampleMetadata,
|
|
19
|
+
)
|
|
15
20
|
from rslearn.utils.feature import Feature
|
|
16
21
|
from rslearn.utils.geometry import STGeometry
|
|
17
22
|
|
|
@@ -63,7 +68,7 @@ class RegressionTask(BasicTask):
|
|
|
63
68
|
|
|
64
69
|
def process_inputs(
|
|
65
70
|
self,
|
|
66
|
-
raw_inputs: dict[str,
|
|
71
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
67
72
|
metadata: SampleMetadata,
|
|
68
73
|
load_targets: bool = True,
|
|
69
74
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -82,6 +87,7 @@ class RegressionTask(BasicTask):
|
|
|
82
87
|
return {}, {}
|
|
83
88
|
|
|
84
89
|
data = raw_inputs["targets"]
|
|
90
|
+
assert isinstance(data, list)
|
|
85
91
|
for feat in data:
|
|
86
92
|
if feat.properties is None or self.filters is None:
|
|
87
93
|
continue
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Segmentation task."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Mapping
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -9,7 +10,13 @@ import torchmetrics.classification
|
|
|
9
10
|
from torchmetrics import Metric, MetricCollection
|
|
10
11
|
|
|
11
12
|
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
-
from rslearn.train.model_context import
|
|
13
|
+
from rslearn.train.model_context import (
|
|
14
|
+
ModelContext,
|
|
15
|
+
ModelOutput,
|
|
16
|
+
RasterImage,
|
|
17
|
+
SampleMetadata,
|
|
18
|
+
)
|
|
19
|
+
from rslearn.utils import Feature
|
|
13
20
|
|
|
14
21
|
from .task import BasicTask
|
|
15
22
|
|
|
@@ -108,7 +115,7 @@ class SegmentationTask(BasicTask):
|
|
|
108
115
|
|
|
109
116
|
def process_inputs(
|
|
110
117
|
self,
|
|
111
|
-
raw_inputs:
|
|
118
|
+
raw_inputs: Mapping[str, RasterImage | list[Feature]],
|
|
112
119
|
metadata: SampleMetadata,
|
|
113
120
|
load_targets: bool = True,
|
|
114
121
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -126,8 +133,10 @@ class SegmentationTask(BasicTask):
|
|
|
126
133
|
if not load_targets:
|
|
127
134
|
return {}, {}
|
|
128
135
|
|
|
129
|
-
assert raw_inputs["targets"]
|
|
130
|
-
|
|
136
|
+
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
137
|
+
assert raw_inputs["targets"].image.shape[0] == 1
|
|
138
|
+
assert raw_inputs["targets"].image.shape[1] == 1
|
|
139
|
+
labels = raw_inputs["targets"].image[0, 0, :, :].long()
|
|
131
140
|
|
|
132
141
|
if self.class_id_mapping is not None:
|
|
133
142
|
new_labels = labels.clone()
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -7,7 +7,7 @@ import numpy.typing as npt
|
|
|
7
7
|
import torch
|
|
8
8
|
from torchmetrics import MetricCollection
|
|
9
9
|
|
|
10
|
-
from rslearn.train.model_context import SampleMetadata
|
|
10
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
11
11
|
from rslearn.utils import Feature
|
|
12
12
|
|
|
13
13
|
|
|
@@ -21,7 +21,7 @@ class Task:
|
|
|
21
21
|
|
|
22
22
|
def process_inputs(
|
|
23
23
|
self,
|
|
24
|
-
raw_inputs: dict[str,
|
|
24
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
25
25
|
metadata: SampleMetadata,
|
|
26
26
|
load_targets: bool = True,
|
|
27
27
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -1,12 +1,23 @@
|
|
|
1
1
|
"""Concatenate bands across multiple image inputs."""
|
|
2
2
|
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
3
5
|
from typing import Any
|
|
4
6
|
|
|
5
7
|
import torch
|
|
6
8
|
|
|
9
|
+
from rslearn.train.model_context import RasterImage
|
|
10
|
+
|
|
7
11
|
from .transform import Transform, read_selector, write_selector
|
|
8
12
|
|
|
9
13
|
|
|
14
|
+
class ConcatenateDim(Enum):
|
|
15
|
+
"""Enum for concatenation dimensions."""
|
|
16
|
+
|
|
17
|
+
CHANNEL = 0
|
|
18
|
+
TIME = 1
|
|
19
|
+
|
|
20
|
+
|
|
10
21
|
class Concatenate(Transform):
|
|
11
22
|
"""Concatenate bands across multiple image inputs."""
|
|
12
23
|
|
|
@@ -14,6 +25,7 @@ class Concatenate(Transform):
|
|
|
14
25
|
self,
|
|
15
26
|
selections: dict[str, list[int]],
|
|
16
27
|
output_selector: str,
|
|
28
|
+
concatenate_dim: ConcatenateDim | int = ConcatenateDim.TIME,
|
|
17
29
|
):
|
|
18
30
|
"""Initialize a new Concatenate.
|
|
19
31
|
|
|
@@ -21,10 +33,16 @@ class Concatenate(Transform):
|
|
|
21
33
|
selections: map from selector to list of band indices in that input to
|
|
22
34
|
retain, or empty list to use all bands.
|
|
23
35
|
output_selector: the output selector under which to save the concatenate image.
|
|
36
|
+
concatenate_dim: the dimension against which to concatenate the inputs
|
|
24
37
|
"""
|
|
25
38
|
super().__init__()
|
|
26
39
|
self.selections = selections
|
|
27
40
|
self.output_selector = output_selector
|
|
41
|
+
self.concatenate_dim = (
|
|
42
|
+
concatenate_dim.value
|
|
43
|
+
if isinstance(concatenate_dim, ConcatenateDim)
|
|
44
|
+
else concatenate_dim
|
|
45
|
+
)
|
|
28
46
|
|
|
29
47
|
def forward(
|
|
30
48
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -36,14 +54,36 @@ class Concatenate(Transform):
|
|
|
36
54
|
target_dict: the target
|
|
37
55
|
|
|
38
56
|
Returns:
|
|
39
|
-
|
|
57
|
+
concatenated (input_dicts, target_dicts) tuple. If one of the
|
|
58
|
+
specified inputs is a RasterImage, a RasterImage will be returned.
|
|
59
|
+
Otherwise it will be a torch.Tensor.
|
|
40
60
|
"""
|
|
41
61
|
images = []
|
|
62
|
+
return_raster_image: bool = False
|
|
63
|
+
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
42
64
|
for selector, wanted_bands in self.selections.items():
|
|
43
65
|
image = read_selector(input_dict, target_dict, selector)
|
|
44
|
-
if
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
66
|
+
if isinstance(image, torch.Tensor):
|
|
67
|
+
if wanted_bands:
|
|
68
|
+
image = image[wanted_bands, :, :]
|
|
69
|
+
images.append(image)
|
|
70
|
+
elif isinstance(image, RasterImage):
|
|
71
|
+
return_raster_image = True
|
|
72
|
+
if wanted_bands:
|
|
73
|
+
images.append(image.image[wanted_bands, :, :])
|
|
74
|
+
else:
|
|
75
|
+
images.append(image.image)
|
|
76
|
+
if timestamps is None:
|
|
77
|
+
if image.timestamps is not None:
|
|
78
|
+
# assume all concatenated modalities have the same
|
|
79
|
+
# number of timestamps
|
|
80
|
+
timestamps = image.timestamps
|
|
81
|
+
if return_raster_image:
|
|
82
|
+
result = RasterImage(
|
|
83
|
+
torch.concatenate(images, dim=self.concatenate_dim),
|
|
84
|
+
timestamps=timestamps,
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
result = torch.concatenate(images, dim=self.concatenate_dim)
|
|
48
88
|
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
49
89
|
return input_dict, target_dict
|