rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -365,7 +365,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
365
365
|
for image_suffix, image in images.items():
|
|
366
366
|
out_fname = os.path.join(
|
|
367
367
|
self.visualize_dir,
|
|
368
|
-
f"{metadata.window_name}_{metadata.
|
|
368
|
+
f"{metadata.window_name}_{metadata.crop_bounds[0]}_{metadata.crop_bounds[1]}_{image_suffix}.png",
|
|
369
369
|
)
|
|
370
370
|
Image.fromarray(image).save(out_fname)
|
|
371
371
|
|
rslearn/train/model_context.py
CHANGED
|
@@ -43,6 +43,22 @@ class RasterImage:
|
|
|
43
43
|
raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
|
|
44
44
|
return self.image[:, 0]
|
|
45
45
|
|
|
46
|
+
def get_hw_tensor(self) -> torch.Tensor:
|
|
47
|
+
"""Get a 2D HW tensor from a single-channel, single-timestep RasterImage.
|
|
48
|
+
|
|
49
|
+
This function checks that C=1 and T=1, then returns the HW tensor.
|
|
50
|
+
Useful for per-pixel labels like segmentation masks.
|
|
51
|
+
"""
|
|
52
|
+
if self.image.shape[0] != 1:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Expected single channel (C=1), got {self.image.shape[0]}"
|
|
55
|
+
)
|
|
56
|
+
if self.image.shape[1] != 1:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Expected single timestep (T=1), got {self.image.shape[1]}"
|
|
59
|
+
)
|
|
60
|
+
return self.image[0, 0]
|
|
61
|
+
|
|
46
62
|
|
|
47
63
|
@dataclass
|
|
48
64
|
class SampleMetadata:
|
|
@@ -51,9 +67,9 @@ class SampleMetadata:
|
|
|
51
67
|
window_group: str
|
|
52
68
|
window_name: str
|
|
53
69
|
window_bounds: PixelBounds
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
70
|
+
crop_bounds: PixelBounds
|
|
71
|
+
crop_idx: int
|
|
72
|
+
num_crops_in_window: int
|
|
57
73
|
time_range: tuple[datetime, datetime] | None
|
|
58
74
|
projection: Projection
|
|
59
75
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""rslearn PredictionWriter implementation."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import warnings
|
|
4
5
|
from collections.abc import Iterable, Sequence
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from pathlib import Path
|
|
@@ -39,20 +40,20 @@ logger = get_logger(__name__)
|
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
@dataclass
|
|
42
|
-
class
|
|
43
|
-
"""A
|
|
43
|
+
class PendingCropOutput:
|
|
44
|
+
"""A crop output that hasn't been merged yet."""
|
|
44
45
|
|
|
45
46
|
bounds: PixelBounds
|
|
46
47
|
output: Any
|
|
47
48
|
|
|
48
49
|
|
|
49
|
-
class
|
|
50
|
-
"""Base class for merging predictions from multiple
|
|
50
|
+
class CropPredictionMerger:
|
|
51
|
+
"""Base class for merging predictions from multiple crops."""
|
|
51
52
|
|
|
52
53
|
def merge(
|
|
53
54
|
self,
|
|
54
55
|
window: Window,
|
|
55
|
-
outputs: Sequence[
|
|
56
|
+
outputs: Sequence[PendingCropOutput],
|
|
56
57
|
layer_config: LayerConfig,
|
|
57
58
|
) -> Any:
|
|
58
59
|
"""Merge the outputs.
|
|
@@ -68,39 +69,60 @@ class PatchPredictionMerger:
|
|
|
68
69
|
raise NotImplementedError
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
class VectorMerger(
|
|
72
|
+
class VectorMerger(CropPredictionMerger):
|
|
72
73
|
"""Merger for vector data that simply concatenates the features."""
|
|
73
74
|
|
|
74
75
|
def merge(
|
|
75
76
|
self,
|
|
76
77
|
window: Window,
|
|
77
|
-
outputs: Sequence[
|
|
78
|
+
outputs: Sequence[PendingCropOutput],
|
|
78
79
|
layer_config: LayerConfig,
|
|
79
80
|
) -> list[Feature]:
|
|
80
81
|
"""Concatenate the vector features."""
|
|
81
82
|
return [feat for output in outputs for feat in output.output]
|
|
82
83
|
|
|
83
84
|
|
|
84
|
-
class RasterMerger(
|
|
85
|
+
class RasterMerger(CropPredictionMerger):
|
|
85
86
|
"""Merger for raster data that copies the rasters to the output."""
|
|
86
87
|
|
|
87
|
-
def __init__(
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
overlap_pixels: int | None = None,
|
|
91
|
+
downsample_factor: int = 1,
|
|
92
|
+
# Deprecated parameter (for backwards compatibility)
|
|
93
|
+
padding: int | None = None,
|
|
94
|
+
):
|
|
88
95
|
"""Create a new RasterMerger.
|
|
89
96
|
|
|
90
97
|
Args:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
98
|
+
overlap_pixels: the number of pixels shared between adjacent crops during
|
|
99
|
+
sliding window inference. Half of this overlap is removed from each
|
|
100
|
+
crop during merging (except at window boundaries where the full crop
|
|
101
|
+
is retained).
|
|
94
102
|
downsample_factor: the factor by which the rasters output by the task are
|
|
95
103
|
lower in resolution relative to the window resolution.
|
|
104
|
+
padding: deprecated, use overlap_pixels instead. The old padding value
|
|
105
|
+
equals overlap_pixels // 2.
|
|
96
106
|
"""
|
|
97
|
-
|
|
107
|
+
# Handle deprecated padding parameter
|
|
108
|
+
if padding is not None:
|
|
109
|
+
warnings.warn(
|
|
110
|
+
"padding is deprecated, use overlap_pixels instead. "
|
|
111
|
+
"Note: overlap_pixels = padding * 2",
|
|
112
|
+
FutureWarning,
|
|
113
|
+
stacklevel=2,
|
|
114
|
+
)
|
|
115
|
+
if overlap_pixels is not None:
|
|
116
|
+
raise ValueError("Cannot specify both padding and overlap_pixels")
|
|
117
|
+
overlap_pixels = padding * 2
|
|
118
|
+
|
|
119
|
+
self.overlap_pixels = overlap_pixels
|
|
98
120
|
self.downsample_factor = downsample_factor
|
|
99
121
|
|
|
100
122
|
def merge(
|
|
101
123
|
self,
|
|
102
124
|
window: Window,
|
|
103
|
-
outputs: Sequence[
|
|
125
|
+
outputs: Sequence[PendingCropOutput],
|
|
104
126
|
layer_config: LayerConfig,
|
|
105
127
|
) -> npt.NDArray:
|
|
106
128
|
"""Merge the raster outputs."""
|
|
@@ -114,6 +136,12 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
114
136
|
dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
|
|
115
137
|
)
|
|
116
138
|
|
|
139
|
+
# Compute how many pixels to trim from each side.
|
|
140
|
+
# We remove half of the overlap from each side (not at window boundaries).
|
|
141
|
+
trim_pixels = (
|
|
142
|
+
self.overlap_pixels // 2 if self.overlap_pixels is not None else None
|
|
143
|
+
)
|
|
144
|
+
|
|
117
145
|
# Ensure the outputs are sorted by height then width.
|
|
118
146
|
# This way when we merge we can be sure that outputs that are lower or further
|
|
119
147
|
# to the right will overwrite earlier outputs.
|
|
@@ -123,18 +151,18 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
123
151
|
for output in sorted_outputs:
|
|
124
152
|
# So now we just need to compute the src_offset to copy.
|
|
125
153
|
# If the output is not on the left or top boundary, then we should apply
|
|
126
|
-
# the
|
|
154
|
+
# the trim (if set).
|
|
127
155
|
src = output.output
|
|
128
156
|
src_offset = (
|
|
129
157
|
output.bounds[0] // self.downsample_factor,
|
|
130
158
|
output.bounds[1] // self.downsample_factor,
|
|
131
159
|
)
|
|
132
|
-
if
|
|
133
|
-
src = src[:, :,
|
|
134
|
-
src_offset = (src_offset[0] +
|
|
135
|
-
if
|
|
136
|
-
src = src[:,
|
|
137
|
-
src_offset = (src_offset[0], src_offset[1] +
|
|
160
|
+
if trim_pixels is not None and output.bounds[0] != window.bounds[0]:
|
|
161
|
+
src = src[:, :, trim_pixels:]
|
|
162
|
+
src_offset = (src_offset[0] + trim_pixels, src_offset[1])
|
|
163
|
+
if trim_pixels is not None and output.bounds[1] != window.bounds[1]:
|
|
164
|
+
src = src[:, trim_pixels:, :]
|
|
165
|
+
src_offset = (src_offset[0], src_offset[1] + trim_pixels)
|
|
138
166
|
|
|
139
167
|
copy_spatial_array(
|
|
140
168
|
src=src,
|
|
@@ -162,7 +190,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
162
190
|
output_layer: str,
|
|
163
191
|
path_options: dict[str, Any] | None = None,
|
|
164
192
|
selector: list[str] | None = None,
|
|
165
|
-
merger:
|
|
193
|
+
merger: CropPredictionMerger | None = None,
|
|
166
194
|
output_path: str | Path | None = None,
|
|
167
195
|
layer_config: LayerConfig | None = None,
|
|
168
196
|
storage_config: StorageConfig | None = None,
|
|
@@ -175,7 +203,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
175
203
|
path_options: additional options for path to pass to fsspec
|
|
176
204
|
selector: keys to access the desired output in the output dict if needed.
|
|
177
205
|
e.g ["key1", "key2"] gets output["key1"]["key2"]
|
|
178
|
-
merger: merger to use to merge outputs from overlapped
|
|
206
|
+
merger: merger to use to merge outputs from overlapped crops.
|
|
179
207
|
output_path: optional custom path for writing predictions. If provided,
|
|
180
208
|
predictions will be written to this path instead of deriving from dataset path.
|
|
181
209
|
layer_config: optional layer configuration. If provided, this config will be
|
|
@@ -217,9 +245,9 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
217
245
|
self.merger = VectorMerger()
|
|
218
246
|
|
|
219
247
|
# Map from window name to pending data to write.
|
|
220
|
-
# This is used when windows are split up into
|
|
221
|
-
#
|
|
222
|
-
self.pending_outputs: dict[str, list[
|
|
248
|
+
# This is used when windows are split up into crops, so the data from all the
|
|
249
|
+
# crops of each window need to be reconstituted.
|
|
250
|
+
self.pending_outputs: dict[str, list[PendingCropOutput]] = {}
|
|
223
251
|
|
|
224
252
|
def _get_layer_config_and_dataset_storage(
|
|
225
253
|
self,
|
|
@@ -327,7 +355,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
327
355
|
will be processed by the task to obtain a vector (list[Feature]) or
|
|
328
356
|
raster (npt.NDArray) output.
|
|
329
357
|
metadatas: corresponding list of metadatas from the batch describing the
|
|
330
|
-
|
|
358
|
+
crops that were processed.
|
|
331
359
|
"""
|
|
332
360
|
# Process the predictions into outputs that can be written.
|
|
333
361
|
outputs: list = [
|
|
@@ -349,17 +377,17 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
349
377
|
)
|
|
350
378
|
self.process_output(
|
|
351
379
|
window,
|
|
352
|
-
metadata.
|
|
353
|
-
metadata.
|
|
354
|
-
metadata.
|
|
380
|
+
metadata.crop_idx,
|
|
381
|
+
metadata.num_crops_in_window,
|
|
382
|
+
metadata.crop_bounds,
|
|
355
383
|
output,
|
|
356
384
|
)
|
|
357
385
|
|
|
358
386
|
def process_output(
|
|
359
387
|
self,
|
|
360
388
|
window: Window,
|
|
361
|
-
|
|
362
|
-
|
|
389
|
+
crop_idx: int,
|
|
390
|
+
num_crops: int,
|
|
363
391
|
cur_bounds: PixelBounds,
|
|
364
392
|
output: npt.NDArray | list[Feature],
|
|
365
393
|
) -> None:
|
|
@@ -367,28 +395,28 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
367
395
|
|
|
368
396
|
Args:
|
|
369
397
|
window: the window that the output pertains to.
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
cur_bounds: the bounds of the current
|
|
398
|
+
crop_idx: the index of this crop for the window.
|
|
399
|
+
num_crops: the total number of crops to be processed for the window.
|
|
400
|
+
cur_bounds: the bounds of the current crop.
|
|
373
401
|
output: the output data.
|
|
374
402
|
"""
|
|
375
|
-
# Incorporate the output into our list of pending
|
|
403
|
+
# Incorporate the output into our list of pending crop outputs.
|
|
376
404
|
if window.name not in self.pending_outputs:
|
|
377
405
|
self.pending_outputs[window.name] = []
|
|
378
|
-
self.pending_outputs[window.name].append(
|
|
406
|
+
self.pending_outputs[window.name].append(PendingCropOutput(cur_bounds, output))
|
|
379
407
|
logger.debug(
|
|
380
|
-
f"Stored
|
|
408
|
+
f"Stored PendingCropOutput for crop #{crop_idx}/{num_crops} at window {window.name}"
|
|
381
409
|
)
|
|
382
410
|
|
|
383
|
-
if
|
|
411
|
+
if crop_idx < num_crops - 1:
|
|
384
412
|
return
|
|
385
413
|
|
|
386
|
-
# This is the last
|
|
414
|
+
# This is the last crop so it's time to write it.
|
|
387
415
|
# First get the pending output and clear it.
|
|
388
416
|
pending_output = self.pending_outputs[window.name]
|
|
389
417
|
del self.pending_outputs[window.name]
|
|
390
418
|
|
|
391
|
-
# Merge outputs from overlapped
|
|
419
|
+
# Merge outputs from overlapped crops if merger is set.
|
|
392
420
|
logger.debug(f"Merging and writing for window {window.name}")
|
|
393
421
|
merged_output = self.merger.merge(window, pending_output, self.layer_config)
|
|
394
422
|
|
|
@@ -201,7 +201,7 @@ class ClassificationTask(BasicTask):
|
|
|
201
201
|
feature = Feature(
|
|
202
202
|
STGeometry(
|
|
203
203
|
metadata.projection,
|
|
204
|
-
shapely.Point(metadata.
|
|
204
|
+
shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
|
|
205
205
|
None,
|
|
206
206
|
),
|
|
207
207
|
{
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -128,7 +128,7 @@ class DetectionTask(BasicTask):
|
|
|
128
128
|
if not load_targets:
|
|
129
129
|
return {}, {}
|
|
130
130
|
|
|
131
|
-
bounds = metadata.
|
|
131
|
+
bounds = metadata.crop_bounds
|
|
132
132
|
|
|
133
133
|
boxes = []
|
|
134
134
|
class_labels = []
|
|
@@ -244,10 +244,10 @@ class DetectionTask(BasicTask):
|
|
|
244
244
|
features = []
|
|
245
245
|
for box, class_id, score in zip(boxes, class_ids, scores):
|
|
246
246
|
shp = shapely.box(
|
|
247
|
-
metadata.
|
|
248
|
-
metadata.
|
|
249
|
-
metadata.
|
|
250
|
-
metadata.
|
|
247
|
+
metadata.crop_bounds[0] + float(box[0]),
|
|
248
|
+
metadata.crop_bounds[1] + float(box[1]),
|
|
249
|
+
metadata.crop_bounds[0] + float(box[2]),
|
|
250
|
+
metadata.crop_bounds[1] + float(box[3]),
|
|
251
251
|
)
|
|
252
252
|
geom = STGeometry(metadata.projection, shp, None)
|
|
253
253
|
properties: dict[str, Any] = {
|
|
@@ -66,20 +66,18 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
66
66
|
return {}, {}
|
|
67
67
|
|
|
68
68
|
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
69
|
-
|
|
70
|
-
assert raw_inputs["targets"].image.shape[1] == 1
|
|
71
|
-
labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
|
|
69
|
+
labels = raw_inputs["targets"].get_hw_tensor().float() * self.scale_factor
|
|
72
70
|
|
|
73
71
|
if self.nodata_value is not None:
|
|
74
|
-
valid = (
|
|
75
|
-
raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
|
|
76
|
-
).float()
|
|
72
|
+
valid = (raw_inputs["targets"].get_hw_tensor() != self.nodata_value).float()
|
|
77
73
|
else:
|
|
78
74
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
79
75
|
|
|
76
|
+
# Wrap in RasterImage with CTHW format (C=1, T=1) so values and valid can be
|
|
77
|
+
# used in image transforms.
|
|
80
78
|
return {}, {
|
|
81
|
-
"values": labels,
|
|
82
|
-
"valid": valid,
|
|
79
|
+
"values": RasterImage(labels[None, None, :, :], timestamps=None),
|
|
80
|
+
"valid": RasterImage(valid[None, None, :, :], timestamps=None),
|
|
83
81
|
}
|
|
84
82
|
|
|
85
83
|
def process_output(
|
|
@@ -121,7 +119,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
121
119
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
122
120
|
if target_dict is None:
|
|
123
121
|
raise ValueError("target_dict is required for visualization")
|
|
124
|
-
gt_values = target_dict["
|
|
122
|
+
gt_values = target_dict["values"].get_hw_tensor().cpu().numpy()
|
|
125
123
|
pred_values = output.cpu().numpy()[0, :, :]
|
|
126
124
|
gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
|
|
127
125
|
pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
|
|
@@ -210,8 +208,10 @@ class PerPixelRegressionHead(Predictor):
|
|
|
210
208
|
|
|
211
209
|
losses = {}
|
|
212
210
|
if targets:
|
|
213
|
-
labels = torch.stack(
|
|
214
|
-
|
|
211
|
+
labels = torch.stack(
|
|
212
|
+
[target["values"].get_hw_tensor() for target in targets]
|
|
213
|
+
)
|
|
214
|
+
mask = torch.stack([target["valid"].get_hw_tensor() for target in targets])
|
|
215
215
|
|
|
216
216
|
if self.loss_mode == "mse":
|
|
217
217
|
scores = torch.square(outputs - labels)
|
|
@@ -262,14 +262,14 @@ class PerPixelRegressionMetricWrapper(Metric):
|
|
|
262
262
|
"""
|
|
263
263
|
if not isinstance(preds, torch.Tensor):
|
|
264
264
|
preds = torch.stack(preds)
|
|
265
|
-
labels = torch.stack([target["values"] for target in targets])
|
|
265
|
+
labels = torch.stack([target["values"].get_hw_tensor() for target in targets])
|
|
266
266
|
|
|
267
267
|
# Sub-select the valid labels.
|
|
268
268
|
# We flatten the prediction and label images at valid pixels.
|
|
269
269
|
if len(preds.shape) == 4:
|
|
270
270
|
assert preds.shape[1] == 1
|
|
271
271
|
preds = preds[:, 0, :, :]
|
|
272
|
-
mask = torch.stack([target["valid"] > 0 for target in targets])
|
|
272
|
+
mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
|
|
273
273
|
preds = preds[mask]
|
|
274
274
|
labels = labels[mask]
|
|
275
275
|
if len(preds) == 0:
|
|
@@ -130,7 +130,7 @@ class RegressionTask(BasicTask):
|
|
|
130
130
|
feature = Feature(
|
|
131
131
|
STGeometry(
|
|
132
132
|
metadata.projection,
|
|
133
|
-
shapely.Point(metadata.
|
|
133
|
+
shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
|
|
134
134
|
None,
|
|
135
135
|
),
|
|
136
136
|
{
|
|
@@ -128,9 +128,7 @@ class SegmentationTask(BasicTask):
|
|
|
128
128
|
return {}, {}
|
|
129
129
|
|
|
130
130
|
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
131
|
-
|
|
132
|
-
assert raw_inputs["targets"].image.shape[1] == 1
|
|
133
|
-
labels = raw_inputs["targets"].image[0, 0, :, :].long()
|
|
131
|
+
labels = raw_inputs["targets"].get_hw_tensor().long()
|
|
134
132
|
|
|
135
133
|
if self.class_id_mapping is not None:
|
|
136
134
|
new_labels = labels.clone()
|
|
@@ -146,9 +144,11 @@ class SegmentationTask(BasicTask):
|
|
|
146
144
|
else:
|
|
147
145
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
148
146
|
|
|
147
|
+
# Wrap in RasterImage with CTHW format (C=1, T=1) so classes and valid can be
|
|
148
|
+
# used in image transforms.
|
|
149
149
|
return {}, {
|
|
150
|
-
"classes": labels,
|
|
151
|
-
"valid": valid,
|
|
150
|
+
"classes": RasterImage(labels[None, None, :, :], timestamps=None),
|
|
151
|
+
"valid": RasterImage(valid[None, None, :, :], timestamps=None),
|
|
152
152
|
}
|
|
153
153
|
|
|
154
154
|
def process_output(
|
|
@@ -206,7 +206,7 @@ class SegmentationTask(BasicTask):
|
|
|
206
206
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
207
207
|
if target_dict is None:
|
|
208
208
|
raise ValueError("target_dict is required for visualization")
|
|
209
|
-
gt_classes = target_dict["classes"].cpu().numpy()
|
|
209
|
+
gt_classes = target_dict["classes"].get_hw_tensor().cpu().numpy()
|
|
210
210
|
pred_classes = output.cpu().numpy().argmax(axis=0)
|
|
211
211
|
gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
|
|
212
212
|
pred_vis = np.zeros(
|
|
@@ -291,12 +291,19 @@ class SegmentationTask(BasicTask):
|
|
|
291
291
|
class SegmentationHead(Predictor):
|
|
292
292
|
"""Head for segmentation task."""
|
|
293
293
|
|
|
294
|
-
def __init__(
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
weights: list[float] | None = None,
|
|
297
|
+
dice_loss: bool = False,
|
|
298
|
+
temperature: float = 1.0,
|
|
299
|
+
):
|
|
295
300
|
"""Initialize a new SegmentationTask.
|
|
296
301
|
|
|
297
302
|
Args:
|
|
298
303
|
weights: weights for cross entropy loss (Tensor of size C)
|
|
299
304
|
dice_loss: weather to add dice loss to cross entropy
|
|
305
|
+
temperature: temperature scaling for softmax, does not affect the loss,
|
|
306
|
+
only the predictor outputs
|
|
300
307
|
"""
|
|
301
308
|
super().__init__()
|
|
302
309
|
if weights is not None:
|
|
@@ -304,6 +311,7 @@ class SegmentationHead(Predictor):
|
|
|
304
311
|
else:
|
|
305
312
|
self.weights = None
|
|
306
313
|
self.dice_loss = dice_loss
|
|
314
|
+
self.temperature = temperature
|
|
307
315
|
|
|
308
316
|
def forward(
|
|
309
317
|
self,
|
|
@@ -332,12 +340,16 @@ class SegmentationHead(Predictor):
|
|
|
332
340
|
)
|
|
333
341
|
|
|
334
342
|
logits = intermediates.feature_maps[0]
|
|
335
|
-
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
343
|
+
outputs = torch.nn.functional.softmax(logits / self.temperature, dim=1)
|
|
336
344
|
|
|
337
345
|
losses = {}
|
|
338
346
|
if targets:
|
|
339
|
-
labels = torch.stack(
|
|
340
|
-
|
|
347
|
+
labels = torch.stack(
|
|
348
|
+
[target["classes"].get_hw_tensor() for target in targets], dim=0
|
|
349
|
+
)
|
|
350
|
+
mask = torch.stack(
|
|
351
|
+
[target["valid"].get_hw_tensor() for target in targets], dim=0
|
|
352
|
+
)
|
|
341
353
|
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
342
354
|
logits, labels, weight=self.weights, reduction="none"
|
|
343
355
|
)
|
|
@@ -350,7 +362,8 @@ class SegmentationHead(Predictor):
|
|
|
350
362
|
# the summed mask loss be zero.
|
|
351
363
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
352
364
|
if self.dice_loss:
|
|
353
|
-
|
|
365
|
+
softmax_woT = torch.nn.functional.softmax(logits, dim=1)
|
|
366
|
+
dice_loss = DiceLoss()(softmax_woT, labels, mask)
|
|
354
367
|
losses["dice"] = dice_loss
|
|
355
368
|
|
|
356
369
|
return ModelOutput(
|
|
@@ -401,12 +414,12 @@ class SegmentationMetric(Metric):
|
|
|
401
414
|
"""
|
|
402
415
|
if not isinstance(preds, torch.Tensor):
|
|
403
416
|
preds = torch.stack(preds)
|
|
404
|
-
labels = torch.stack([target["classes"] for target in targets])
|
|
417
|
+
labels = torch.stack([target["classes"].get_hw_tensor() for target in targets])
|
|
405
418
|
|
|
406
419
|
# Sub-select the valid labels.
|
|
407
420
|
# We flatten the prediction and label images at valid pixels.
|
|
408
421
|
# Prediction is changed from BCHW to BHWC so we can select the valid BHW mask.
|
|
409
|
-
mask = torch.stack([target["valid"] > 0 for target in targets])
|
|
422
|
+
mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
|
|
410
423
|
preds = preds.permute(0, 2, 3, 1)[mask]
|
|
411
424
|
labels = labels[mask]
|
|
412
425
|
if len(preds) == 0:
|
|
@@ -54,36 +54,26 @@ class Concatenate(Transform):
|
|
|
54
54
|
target_dict: the target
|
|
55
55
|
|
|
56
56
|
Returns:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
Otherwise it will be a torch.Tensor.
|
|
57
|
+
(input_dicts, target_dicts) where the entry corresponding to
|
|
58
|
+
output_selector contains the concatenated RasterImage.
|
|
60
59
|
"""
|
|
61
|
-
|
|
62
|
-
return_raster_image: bool = False
|
|
60
|
+
tensors: list[torch.Tensor] = []
|
|
63
61
|
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
62
|
+
|
|
64
63
|
for selector, wanted_bands in self.selections.items():
|
|
65
64
|
image = read_selector(input_dict, target_dict, selector)
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
# number of timestamps
|
|
80
|
-
timestamps = image.timestamps
|
|
81
|
-
if return_raster_image:
|
|
82
|
-
result = RasterImage(
|
|
83
|
-
torch.concatenate(images, dim=self.concatenate_dim),
|
|
84
|
-
timestamps=timestamps,
|
|
85
|
-
)
|
|
86
|
-
else:
|
|
87
|
-
result = torch.concatenate(images, dim=self.concatenate_dim)
|
|
65
|
+
if wanted_bands:
|
|
66
|
+
tensors.append(image.image[wanted_bands, :, :])
|
|
67
|
+
else:
|
|
68
|
+
tensors.append(image.image)
|
|
69
|
+
if timestamps is None and image.timestamps is not None:
|
|
70
|
+
# assume all concatenated modalities have the same
|
|
71
|
+
# number of timestamps
|
|
72
|
+
timestamps = image.timestamps
|
|
73
|
+
|
|
74
|
+
result = RasterImage(
|
|
75
|
+
torch.concatenate(tensors, dim=self.concatenate_dim),
|
|
76
|
+
timestamps=timestamps,
|
|
77
|
+
)
|
|
88
78
|
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
89
79
|
return input_dict, target_dict
|
rslearn/train/transforms/crop.py
CHANGED
|
@@ -71,9 +71,7 @@ class Crop(Transform):
|
|
|
71
71
|
"remove_from_top": remove_from_top,
|
|
72
72
|
}
|
|
73
73
|
|
|
74
|
-
def apply_image(
|
|
75
|
-
self, image: RasterImage | torch.Tensor, state: dict[str, Any]
|
|
76
|
-
) -> RasterImage | torch.Tensor:
|
|
74
|
+
def apply_image(self, image: RasterImage, state: dict[str, Any]) -> RasterImage:
|
|
77
75
|
"""Apply the sampled state on the specified image.
|
|
78
76
|
|
|
79
77
|
Args:
|
|
@@ -84,22 +82,13 @@ class Crop(Transform):
|
|
|
84
82
|
crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
|
|
85
83
|
remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
|
|
86
84
|
remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
|
|
87
|
-
|
|
88
|
-
image.image
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
95
|
-
else:
|
|
96
|
-
image = torchvision.transforms.functional.crop(
|
|
97
|
-
image,
|
|
98
|
-
top=remove_from_top,
|
|
99
|
-
left=remove_from_left,
|
|
100
|
-
height=crop_size,
|
|
101
|
-
width=crop_size,
|
|
102
|
-
)
|
|
85
|
+
image.image = torchvision.transforms.functional.crop(
|
|
86
|
+
image.image,
|
|
87
|
+
top=remove_from_top,
|
|
88
|
+
left=remove_from_left,
|
|
89
|
+
height=crop_size,
|
|
90
|
+
width=crop_size,
|
|
91
|
+
)
|
|
103
92
|
return image
|
|
104
93
|
|
|
105
94
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
rslearn/train/transforms/flip.py
CHANGED
|
@@ -57,16 +57,10 @@ class Flip(Transform):
|
|
|
57
57
|
image: the image to transform.
|
|
58
58
|
state: the sampled state.
|
|
59
59
|
"""
|
|
60
|
-
if
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
image.image = torch.flip(image.image, dims=[-2])
|
|
65
|
-
elif isinstance(image, torch.Tensor):
|
|
66
|
-
if state["horizontal"]:
|
|
67
|
-
image = torch.flip(image, dims=[-1])
|
|
68
|
-
if state["vertical"]:
|
|
69
|
-
image = torch.flip(image, dims=[-2])
|
|
60
|
+
if state["horizontal"]:
|
|
61
|
+
image.image = torch.flip(image.image, dims=[-1])
|
|
62
|
+
if state["vertical"]:
|
|
63
|
+
image.image = torch.flip(image.image, dims=[-2])
|
|
70
64
|
return image
|
|
71
65
|
|
|
72
66
|
def apply_boxes(
|