rslearn 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/local_files.py +2 -2
- rslearn/data_sources/utils.py +204 -64
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/data_module.py +5 -0
- rslearn/train/dataset.py +151 -55
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/model_context.py +16 -0
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/METADATA +1 -1
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/RECORD +33 -32
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
|
@@ -66,20 +66,18 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
66
66
|
return {}, {}
|
|
67
67
|
|
|
68
68
|
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
69
|
-
|
|
70
|
-
assert raw_inputs["targets"].image.shape[1] == 1
|
|
71
|
-
labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
|
|
69
|
+
labels = raw_inputs["targets"].get_hw_tensor().float() * self.scale_factor
|
|
72
70
|
|
|
73
71
|
if self.nodata_value is not None:
|
|
74
|
-
valid = (
|
|
75
|
-
raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
|
|
76
|
-
).float()
|
|
72
|
+
valid = (raw_inputs["targets"].get_hw_tensor() != self.nodata_value).float()
|
|
77
73
|
else:
|
|
78
74
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
79
75
|
|
|
76
|
+
# Wrap in RasterImage with CTHW format (C=1, T=1) so values and valid can be
|
|
77
|
+
# used in image transforms.
|
|
80
78
|
return {}, {
|
|
81
|
-
"values": labels,
|
|
82
|
-
"valid": valid,
|
|
79
|
+
"values": RasterImage(labels[None, None, :, :], timestamps=None),
|
|
80
|
+
"valid": RasterImage(valid[None, None, :, :], timestamps=None),
|
|
83
81
|
}
|
|
84
82
|
|
|
85
83
|
def process_output(
|
|
@@ -121,7 +119,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
121
119
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
122
120
|
if target_dict is None:
|
|
123
121
|
raise ValueError("target_dict is required for visualization")
|
|
124
|
-
gt_values = target_dict["
|
|
122
|
+
gt_values = target_dict["values"].get_hw_tensor().cpu().numpy()
|
|
125
123
|
pred_values = output.cpu().numpy()[0, :, :]
|
|
126
124
|
gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
|
|
127
125
|
pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
|
|
@@ -210,8 +208,10 @@ class PerPixelRegressionHead(Predictor):
|
|
|
210
208
|
|
|
211
209
|
losses = {}
|
|
212
210
|
if targets:
|
|
213
|
-
labels = torch.stack(
|
|
214
|
-
|
|
211
|
+
labels = torch.stack(
|
|
212
|
+
[target["values"].get_hw_tensor() for target in targets]
|
|
213
|
+
)
|
|
214
|
+
mask = torch.stack([target["valid"].get_hw_tensor() for target in targets])
|
|
215
215
|
|
|
216
216
|
if self.loss_mode == "mse":
|
|
217
217
|
scores = torch.square(outputs - labels)
|
|
@@ -262,14 +262,14 @@ class PerPixelRegressionMetricWrapper(Metric):
|
|
|
262
262
|
"""
|
|
263
263
|
if not isinstance(preds, torch.Tensor):
|
|
264
264
|
preds = torch.stack(preds)
|
|
265
|
-
labels = torch.stack([target["values"] for target in targets])
|
|
265
|
+
labels = torch.stack([target["values"].get_hw_tensor() for target in targets])
|
|
266
266
|
|
|
267
267
|
# Sub-select the valid labels.
|
|
268
268
|
# We flatten the prediction and label images at valid pixels.
|
|
269
269
|
if len(preds.shape) == 4:
|
|
270
270
|
assert preds.shape[1] == 1
|
|
271
271
|
preds = preds[:, 0, :, :]
|
|
272
|
-
mask = torch.stack([target["valid"] > 0 for target in targets])
|
|
272
|
+
mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
|
|
273
273
|
preds = preds[mask]
|
|
274
274
|
labels = labels[mask]
|
|
275
275
|
if len(preds) == 0:
|
|
@@ -128,9 +128,7 @@ class SegmentationTask(BasicTask):
|
|
|
128
128
|
return {}, {}
|
|
129
129
|
|
|
130
130
|
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
131
|
-
|
|
132
|
-
assert raw_inputs["targets"].image.shape[1] == 1
|
|
133
|
-
labels = raw_inputs["targets"].image[0, 0, :, :].long()
|
|
131
|
+
labels = raw_inputs["targets"].get_hw_tensor().long()
|
|
134
132
|
|
|
135
133
|
if self.class_id_mapping is not None:
|
|
136
134
|
new_labels = labels.clone()
|
|
@@ -146,9 +144,11 @@ class SegmentationTask(BasicTask):
|
|
|
146
144
|
else:
|
|
147
145
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
148
146
|
|
|
147
|
+
# Wrap in RasterImage with CTHW format (C=1, T=1) so classes and valid can be
|
|
148
|
+
# used in image transforms.
|
|
149
149
|
return {}, {
|
|
150
|
-
"classes": labels,
|
|
151
|
-
"valid": valid,
|
|
150
|
+
"classes": RasterImage(labels[None, None, :, :], timestamps=None),
|
|
151
|
+
"valid": RasterImage(valid[None, None, :, :], timestamps=None),
|
|
152
152
|
}
|
|
153
153
|
|
|
154
154
|
def process_output(
|
|
@@ -206,7 +206,7 @@ class SegmentationTask(BasicTask):
|
|
|
206
206
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
207
207
|
if target_dict is None:
|
|
208
208
|
raise ValueError("target_dict is required for visualization")
|
|
209
|
-
gt_classes = target_dict["classes"].cpu().numpy()
|
|
209
|
+
gt_classes = target_dict["classes"].get_hw_tensor().cpu().numpy()
|
|
210
210
|
pred_classes = output.cpu().numpy().argmax(axis=0)
|
|
211
211
|
gt_vis = np.zeros((gt_classes.shape[0], gt_classes.shape[1], 3), dtype=np.uint8)
|
|
212
212
|
pred_vis = np.zeros(
|
|
@@ -291,12 +291,19 @@ class SegmentationTask(BasicTask):
|
|
|
291
291
|
class SegmentationHead(Predictor):
|
|
292
292
|
"""Head for segmentation task."""
|
|
293
293
|
|
|
294
|
-
def __init__(
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
weights: list[float] | None = None,
|
|
297
|
+
dice_loss: bool = False,
|
|
298
|
+
temperature: float = 1.0,
|
|
299
|
+
):
|
|
295
300
|
"""Initialize a new SegmentationTask.
|
|
296
301
|
|
|
297
302
|
Args:
|
|
298
303
|
weights: weights for cross entropy loss (Tensor of size C)
|
|
299
304
|
dice_loss: weather to add dice loss to cross entropy
|
|
305
|
+
temperature: temperature scaling for softmax, does not affect the loss,
|
|
306
|
+
only the predictor outputs
|
|
300
307
|
"""
|
|
301
308
|
super().__init__()
|
|
302
309
|
if weights is not None:
|
|
@@ -304,6 +311,7 @@ class SegmentationHead(Predictor):
|
|
|
304
311
|
else:
|
|
305
312
|
self.weights = None
|
|
306
313
|
self.dice_loss = dice_loss
|
|
314
|
+
self.temperature = temperature
|
|
307
315
|
|
|
308
316
|
def forward(
|
|
309
317
|
self,
|
|
@@ -332,12 +340,16 @@ class SegmentationHead(Predictor):
|
|
|
332
340
|
)
|
|
333
341
|
|
|
334
342
|
logits = intermediates.feature_maps[0]
|
|
335
|
-
outputs = torch.nn.functional.softmax(logits, dim=1)
|
|
343
|
+
outputs = torch.nn.functional.softmax(logits / self.temperature, dim=1)
|
|
336
344
|
|
|
337
345
|
losses = {}
|
|
338
346
|
if targets:
|
|
339
|
-
labels = torch.stack(
|
|
340
|
-
|
|
347
|
+
labels = torch.stack(
|
|
348
|
+
[target["classes"].get_hw_tensor() for target in targets], dim=0
|
|
349
|
+
)
|
|
350
|
+
mask = torch.stack(
|
|
351
|
+
[target["valid"].get_hw_tensor() for target in targets], dim=0
|
|
352
|
+
)
|
|
341
353
|
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
342
354
|
logits, labels, weight=self.weights, reduction="none"
|
|
343
355
|
)
|
|
@@ -350,7 +362,8 @@ class SegmentationHead(Predictor):
|
|
|
350
362
|
# the summed mask loss be zero.
|
|
351
363
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
352
364
|
if self.dice_loss:
|
|
353
|
-
|
|
365
|
+
softmax_woT = torch.nn.functional.softmax(logits, dim=1)
|
|
366
|
+
dice_loss = DiceLoss()(softmax_woT, labels, mask)
|
|
354
367
|
losses["dice"] = dice_loss
|
|
355
368
|
|
|
356
369
|
return ModelOutput(
|
|
@@ -401,12 +414,12 @@ class SegmentationMetric(Metric):
|
|
|
401
414
|
"""
|
|
402
415
|
if not isinstance(preds, torch.Tensor):
|
|
403
416
|
preds = torch.stack(preds)
|
|
404
|
-
labels = torch.stack([target["classes"] for target in targets])
|
|
417
|
+
labels = torch.stack([target["classes"].get_hw_tensor() for target in targets])
|
|
405
418
|
|
|
406
419
|
# Sub-select the valid labels.
|
|
407
420
|
# We flatten the prediction and label images at valid pixels.
|
|
408
421
|
# Prediction is changed from BCHW to BHWC so we can select the valid BHW mask.
|
|
409
|
-
mask = torch.stack([target["valid"] > 0 for target in targets])
|
|
422
|
+
mask = torch.stack([target["valid"].get_hw_tensor() > 0 for target in targets])
|
|
410
423
|
preds = preds.permute(0, 2, 3, 1)[mask]
|
|
411
424
|
labels = labels[mask]
|
|
412
425
|
if len(preds) == 0:
|
|
@@ -54,36 +54,26 @@ class Concatenate(Transform):
|
|
|
54
54
|
target_dict: the target
|
|
55
55
|
|
|
56
56
|
Returns:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
Otherwise it will be a torch.Tensor.
|
|
57
|
+
(input_dicts, target_dicts) where the entry corresponding to
|
|
58
|
+
output_selector contains the concatenated RasterImage.
|
|
60
59
|
"""
|
|
61
|
-
|
|
62
|
-
return_raster_image: bool = False
|
|
60
|
+
tensors: list[torch.Tensor] = []
|
|
63
61
|
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
62
|
+
|
|
64
63
|
for selector, wanted_bands in self.selections.items():
|
|
65
64
|
image = read_selector(input_dict, target_dict, selector)
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
# number of timestamps
|
|
80
|
-
timestamps = image.timestamps
|
|
81
|
-
if return_raster_image:
|
|
82
|
-
result = RasterImage(
|
|
83
|
-
torch.concatenate(images, dim=self.concatenate_dim),
|
|
84
|
-
timestamps=timestamps,
|
|
85
|
-
)
|
|
86
|
-
else:
|
|
87
|
-
result = torch.concatenate(images, dim=self.concatenate_dim)
|
|
65
|
+
if wanted_bands:
|
|
66
|
+
tensors.append(image.image[wanted_bands, :, :])
|
|
67
|
+
else:
|
|
68
|
+
tensors.append(image.image)
|
|
69
|
+
if timestamps is None and image.timestamps is not None:
|
|
70
|
+
# assume all concatenated modalities have the same
|
|
71
|
+
# number of timestamps
|
|
72
|
+
timestamps = image.timestamps
|
|
73
|
+
|
|
74
|
+
result = RasterImage(
|
|
75
|
+
torch.concatenate(tensors, dim=self.concatenate_dim),
|
|
76
|
+
timestamps=timestamps,
|
|
77
|
+
)
|
|
88
78
|
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
89
79
|
return input_dict, target_dict
|
rslearn/train/transforms/crop.py
CHANGED
|
@@ -71,9 +71,7 @@ class Crop(Transform):
|
|
|
71
71
|
"remove_from_top": remove_from_top,
|
|
72
72
|
}
|
|
73
73
|
|
|
74
|
-
def apply_image(
|
|
75
|
-
self, image: RasterImage | torch.Tensor, state: dict[str, Any]
|
|
76
|
-
) -> RasterImage | torch.Tensor:
|
|
74
|
+
def apply_image(self, image: RasterImage, state: dict[str, Any]) -> RasterImage:
|
|
77
75
|
"""Apply the sampled state on the specified image.
|
|
78
76
|
|
|
79
77
|
Args:
|
|
@@ -84,22 +82,13 @@ class Crop(Transform):
|
|
|
84
82
|
crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
|
|
85
83
|
remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
|
|
86
84
|
remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
|
|
87
|
-
|
|
88
|
-
image.image
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
95
|
-
else:
|
|
96
|
-
image = torchvision.transforms.functional.crop(
|
|
97
|
-
image,
|
|
98
|
-
top=remove_from_top,
|
|
99
|
-
left=remove_from_left,
|
|
100
|
-
height=crop_size,
|
|
101
|
-
width=crop_size,
|
|
102
|
-
)
|
|
85
|
+
image.image = torchvision.transforms.functional.crop(
|
|
86
|
+
image.image,
|
|
87
|
+
top=remove_from_top,
|
|
88
|
+
left=remove_from_left,
|
|
89
|
+
height=crop_size,
|
|
90
|
+
width=crop_size,
|
|
91
|
+
)
|
|
103
92
|
return image
|
|
104
93
|
|
|
105
94
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
rslearn/train/transforms/flip.py
CHANGED
|
@@ -57,16 +57,10 @@ class Flip(Transform):
|
|
|
57
57
|
image: the image to transform.
|
|
58
58
|
state: the sampled state.
|
|
59
59
|
"""
|
|
60
|
-
if
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
image.image = torch.flip(image.image, dims=[-2])
|
|
65
|
-
elif isinstance(image, torch.Tensor):
|
|
66
|
-
if state["horizontal"]:
|
|
67
|
-
image = torch.flip(image, dims=[-1])
|
|
68
|
-
if state["vertical"]:
|
|
69
|
-
image = torch.flip(image, dims=[-2])
|
|
60
|
+
if state["horizontal"]:
|
|
61
|
+
image.image = torch.flip(image.image, dims=[-1])
|
|
62
|
+
if state["vertical"]:
|
|
63
|
+
image.image = torch.flip(image.image, dims=[-2])
|
|
70
64
|
return image
|
|
71
65
|
|
|
72
66
|
def apply_boxes(
|
rslearn/train/transforms/mask.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Mask transform."""
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
3
|
from rslearn.train.model_context import RasterImage
|
|
6
4
|
from rslearn.train.transforms.transform import Transform, read_selector
|
|
7
5
|
|
|
@@ -32,9 +30,7 @@ class Mask(Transform):
|
|
|
32
30
|
self.mask_selector = mask_selector
|
|
33
31
|
self.mask_value = mask_value
|
|
34
32
|
|
|
35
|
-
def apply_image(
|
|
36
|
-
self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
|
|
37
|
-
) -> torch.Tensor | RasterImage:
|
|
33
|
+
def apply_image(self, image: RasterImage, mask: RasterImage) -> RasterImage:
|
|
38
34
|
"""Apply the mask on the image.
|
|
39
35
|
|
|
40
36
|
Args:
|
|
@@ -44,21 +40,19 @@ class Mask(Transform):
|
|
|
44
40
|
Returns:
|
|
45
41
|
masked image
|
|
46
42
|
"""
|
|
47
|
-
#
|
|
48
|
-
|
|
49
|
-
mask = mask.image
|
|
43
|
+
# Extract the mask tensor (CTHW format)
|
|
44
|
+
mask_tensor = mask.image
|
|
50
45
|
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
# Tile the mask to have same number of bands (C dimension) as the image.
|
|
47
|
+
if image.shape[0] != mask_tensor.shape[0]:
|
|
48
|
+
if mask_tensor.shape[0] != 1:
|
|
53
49
|
raise ValueError(
|
|
54
50
|
"expected mask to either have same bands as image, or one band"
|
|
55
51
|
)
|
|
56
|
-
|
|
52
|
+
# Repeat along C dimension, keep T, H, W the same
|
|
53
|
+
mask_tensor = mask_tensor.repeat(image.shape[0], 1, 1, 1)
|
|
57
54
|
|
|
58
|
-
|
|
59
|
-
image[mask == 0] = self.mask_value
|
|
60
|
-
else:
|
|
61
|
-
image.image[mask == 0] = self.mask_value
|
|
55
|
+
image.image[mask_tensor == 0] = self.mask_value
|
|
62
56
|
return image
|
|
63
57
|
|
|
64
58
|
def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Normalization transforms."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -35,14 +36,17 @@ class Normalize(Transform):
|
|
|
35
36
|
bands: optionally restrict the normalization to these band indices. If set,
|
|
36
37
|
mean and std must either be one value, or have length equal to the
|
|
37
38
|
number of band indices passed here.
|
|
38
|
-
num_bands:
|
|
39
|
-
in a time series. If set, then the bands list is repeated for each
|
|
40
|
-
image, e.g. if bands=[2] then we apply normalization on images[2],
|
|
41
|
-
images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
|
|
42
|
-
is not set, then we apply the mean and std on each image in the time
|
|
43
|
-
series.
|
|
39
|
+
num_bands: deprecated, no longer used. Will be removed after 2026-04-01.
|
|
44
40
|
"""
|
|
45
41
|
super().__init__()
|
|
42
|
+
|
|
43
|
+
if num_bands is not None:
|
|
44
|
+
warnings.warn(
|
|
45
|
+
"num_bands is deprecated and no longer used. "
|
|
46
|
+
"It will be removed after 2026-04-01.",
|
|
47
|
+
FutureWarning,
|
|
48
|
+
)
|
|
49
|
+
|
|
46
50
|
self.mean = torch.tensor(mean)
|
|
47
51
|
self.std = torch.tensor(std)
|
|
48
52
|
|
|
@@ -55,92 +59,37 @@ class Normalize(Transform):
|
|
|
55
59
|
|
|
56
60
|
self.selectors = selectors
|
|
57
61
|
self.bands = torch.tensor(bands) if bands is not None else None
|
|
58
|
-
self.num_bands = num_bands
|
|
59
62
|
|
|
60
|
-
def apply_image(
|
|
61
|
-
self, image: torch.Tensor | RasterImage
|
|
62
|
-
) -> torch.Tensor | RasterImage:
|
|
63
|
+
def apply_image(self, image: RasterImage) -> RasterImage:
|
|
63
64
|
"""Normalize the specified image.
|
|
64
65
|
|
|
65
66
|
Args:
|
|
66
67
|
image: the image to transform.
|
|
67
68
|
"""
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
#
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if len(self.mean.shape) == 0:
|
|
77
|
-
return self.mean, self.std
|
|
78
|
-
if num_bands is None:
|
|
79
|
-
return self.mean, self.std
|
|
80
|
-
num_images = image_channels // num_bands
|
|
81
|
-
if is_raster_image:
|
|
82
|
-
# add an extra T dimension, CTHW
|
|
83
|
-
return self.mean.repeat(num_images)[
|
|
84
|
-
:, None, None, None
|
|
85
|
-
], self.std.repeat(num_images)[:, None, None, None]
|
|
86
|
-
else:
|
|
87
|
-
# add an extra T dimension, CTHW
|
|
88
|
-
return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
|
|
89
|
-
num_images
|
|
90
|
-
)[:, None, None]
|
|
69
|
+
# Get mean/std with singleton dims for broadcasting over CTHW.
|
|
70
|
+
if len(self.mean.shape) == 0:
|
|
71
|
+
# Scalar - broadcasts naturally.
|
|
72
|
+
mean, std = self.mean, self.std
|
|
73
|
+
else:
|
|
74
|
+
# Vector of length C - add singleton dims for T, H, W.
|
|
75
|
+
mean = self.mean[:, None, None, None]
|
|
76
|
+
std = self.std[:, None, None, None]
|
|
91
77
|
|
|
92
78
|
if self.bands is not None:
|
|
93
|
-
#
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
[
|
|
101
|
-
band_indices + image_idx * self.num_bands
|
|
102
|
-
for image_idx in range(num_images)
|
|
103
|
-
],
|
|
104
|
-
dim=0,
|
|
79
|
+
# Normalize only specific band indices.
|
|
80
|
+
image.image[self.bands] = (image.image[self.bands] - mean) / std
|
|
81
|
+
if self.valid_min is not None:
|
|
82
|
+
image.image[self.bands] = torch.clamp(
|
|
83
|
+
image.image[self.bands],
|
|
84
|
+
min=self.valid_min,
|
|
85
|
+
max=self.valid_max,
|
|
105
86
|
)
|
|
106
|
-
|
|
107
|
-
# We use len(self.bands) here because that is how many bands per timestep
|
|
108
|
-
# we are actually processing with the mean/std.
|
|
109
|
-
mean, std = _repeat_mean_and_std(
|
|
110
|
-
image_channels=len(band_indices),
|
|
111
|
-
num_bands=len(self.bands),
|
|
112
|
-
is_raster_image=isinstance(image, RasterImage),
|
|
113
|
-
)
|
|
114
|
-
if isinstance(image, torch.Tensor):
|
|
115
|
-
image[band_indices] = (image[band_indices] - mean) / std
|
|
116
|
-
if self.valid_min is not None:
|
|
117
|
-
image[band_indices] = torch.clamp(
|
|
118
|
-
image[band_indices], min=self.valid_min, max=self.valid_max
|
|
119
|
-
)
|
|
120
|
-
else:
|
|
121
|
-
image.image[band_indices] = (image.image[band_indices] - mean) / std
|
|
122
|
-
if self.valid_min is not None:
|
|
123
|
-
image.image[band_indices] = torch.clamp(
|
|
124
|
-
image.image[band_indices],
|
|
125
|
-
min=self.valid_min,
|
|
126
|
-
max=self.valid_max,
|
|
127
|
-
)
|
|
128
87
|
else:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if isinstance(image, torch.Tensor):
|
|
135
|
-
image = (image - mean) / std
|
|
136
|
-
if self.valid_min is not None:
|
|
137
|
-
image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
|
|
138
|
-
else:
|
|
139
|
-
image.image = (image.image - mean) / std
|
|
140
|
-
if self.valid_min is not None:
|
|
141
|
-
image.image = torch.clamp(
|
|
142
|
-
image.image, min=self.valid_min, max=self.valid_max
|
|
143
|
-
)
|
|
88
|
+
image.image = (image.image - mean) / std
|
|
89
|
+
if self.valid_min is not None:
|
|
90
|
+
image.image = torch.clamp(
|
|
91
|
+
image.image, min=self.valid_min, max=self.valid_max
|
|
92
|
+
)
|
|
144
93
|
return image
|
|
145
94
|
|
|
146
95
|
def forward(
|
rslearn/train/transforms/pad.py
CHANGED
|
@@ -50,9 +50,7 @@ class Pad(Transform):
|
|
|
50
50
|
"""
|
|
51
51
|
return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
|
|
52
52
|
|
|
53
|
-
def apply_image(
|
|
54
|
-
self, image: RasterImage | torch.Tensor, state: dict[str, bool]
|
|
55
|
-
) -> RasterImage | torch.Tensor:
|
|
53
|
+
def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
|
|
56
54
|
"""Apply the sampled state on the specified image.
|
|
57
55
|
|
|
58
56
|
Args:
|
|
@@ -105,16 +103,12 @@ class Pad(Transform):
|
|
|
105
103
|
horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
|
|
106
104
|
vertical_pad = (vertical_half, vertical_extra - vertical_half)
|
|
107
105
|
|
|
108
|
-
|
|
109
|
-
image.image
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
image.image
|
|
113
|
-
|
|
114
|
-
)
|
|
115
|
-
else:
|
|
116
|
-
image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
|
|
117
|
-
image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
|
|
106
|
+
image.image = apply_padding(
|
|
107
|
+
image.image, True, horizontal_pad[0], horizontal_pad[1]
|
|
108
|
+
)
|
|
109
|
+
image.image = apply_padding(
|
|
110
|
+
image.image, False, vertical_pad[0], vertical_pad[1]
|
|
111
|
+
)
|
|
118
112
|
return image
|
|
119
113
|
|
|
120
114
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import torch
|
|
6
5
|
import torchvision
|
|
7
6
|
from torchvision.transforms import InterpolationMode
|
|
8
7
|
|
|
@@ -40,32 +39,16 @@ class Resize(Transform):
|
|
|
40
39
|
self.selectors = selectors
|
|
41
40
|
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
42
41
|
|
|
43
|
-
def apply_resize(
|
|
44
|
-
self, image: torch.Tensor | RasterImage
|
|
45
|
-
) -> torch.Tensor | RasterImage:
|
|
42
|
+
def apply_resize(self, image: RasterImage) -> RasterImage:
|
|
46
43
|
"""Apply resizing on the specified image.
|
|
47
44
|
|
|
48
|
-
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
49
|
-
back after resizing.
|
|
50
|
-
|
|
51
45
|
Args:
|
|
52
46
|
image: the image to transform.
|
|
53
47
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
image, self.target_size, self.interpolation
|
|
59
|
-
)
|
|
60
|
-
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
61
|
-
return torchvision.transforms.functional.resize(
|
|
62
|
-
image, self.target_size, self.interpolation
|
|
63
|
-
)
|
|
64
|
-
else:
|
|
65
|
-
image.image = torchvision.transforms.functional.resize(
|
|
66
|
-
image.image, self.target_size, self.interpolation
|
|
67
|
-
)
|
|
68
|
-
return image
|
|
48
|
+
image.image = torchvision.transforms.functional.resize(
|
|
49
|
+
image.image, self.target_size, self.interpolation
|
|
50
|
+
)
|
|
51
|
+
return image
|
|
69
52
|
|
|
70
53
|
def forward(
|
|
71
54
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""The SelectBands transform."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
|
-
from rslearn.train.model_context import RasterImage
|
|
6
|
-
|
|
7
6
|
from .transform import Transform, read_selector, write_selector
|
|
8
7
|
|
|
9
8
|
|
|
@@ -17,60 +16,41 @@ class SelectBands(Transform):
|
|
|
17
16
|
output_selector: str = "image",
|
|
18
17
|
num_bands_per_timestep: int | None = None,
|
|
19
18
|
):
|
|
20
|
-
"""Initialize a new
|
|
19
|
+
"""Initialize a new SelectBands.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
band_indices: the bands to select.
|
|
22
|
+
band_indices: the bands to select from the channel dimension.
|
|
24
23
|
input_selector: the selector to read the input image.
|
|
25
24
|
output_selector: the output selector under which to save the output image.
|
|
26
|
-
num_bands_per_timestep:
|
|
27
|
-
|
|
28
|
-
band_indices are selected for each image in the time series.
|
|
25
|
+
num_bands_per_timestep: deprecated, no longer used. Will be removed after
|
|
26
|
+
2026-04-01.
|
|
29
27
|
"""
|
|
30
28
|
super().__init__()
|
|
29
|
+
|
|
30
|
+
if num_bands_per_timestep is not None:
|
|
31
|
+
warnings.warn(
|
|
32
|
+
"num_bands_per_timestep is deprecated and no longer used. "
|
|
33
|
+
"It will be removed after 2026-04-01.",
|
|
34
|
+
FutureWarning,
|
|
35
|
+
)
|
|
36
|
+
|
|
31
37
|
self.input_selector = input_selector
|
|
32
38
|
self.output_selector = output_selector
|
|
33
39
|
self.band_indices = band_indices
|
|
34
|
-
self.num_bands_per_timestep = num_bands_per_timestep
|
|
35
40
|
|
|
36
41
|
def forward(
|
|
37
42
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
38
43
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
39
|
-
"""Apply
|
|
44
|
+
"""Apply band selection over the inputs and targets.
|
|
40
45
|
|
|
41
46
|
Args:
|
|
42
47
|
input_dict: the input
|
|
43
48
|
target_dict: the target
|
|
44
49
|
|
|
45
50
|
Returns:
|
|
46
|
-
|
|
51
|
+
(input_dicts, target_dicts) tuple with selected bands
|
|
47
52
|
"""
|
|
48
53
|
image = read_selector(input_dict, target_dict, self.input_selector)
|
|
49
|
-
|
|
50
|
-
self.num_bands_per_timestep
|
|
51
|
-
if self.num_bands_per_timestep is not None
|
|
52
|
-
else image.shape[0]
|
|
53
|
-
)
|
|
54
|
-
if isinstance(image, RasterImage):
|
|
55
|
-
assert num_bands_per_timestep == image.shape[0], (
|
|
56
|
-
"Expect a seperate dimension for timesteps in RasterImages."
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
if image.shape[0] % num_bands_per_timestep != 0:
|
|
60
|
-
raise ValueError(
|
|
61
|
-
f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
# Copy the band indices for each timestep in the input.
|
|
65
|
-
wanted_bands: list[int] = []
|
|
66
|
-
for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
|
|
67
|
-
wanted_bands.extend(
|
|
68
|
-
[(start_channel_idx + band_idx) for band_idx in self.band_indices]
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
if isinstance(image, RasterImage):
|
|
72
|
-
image.image = image.image[wanted_bands]
|
|
73
|
-
else:
|
|
74
|
-
image = image[wanted_bands]
|
|
54
|
+
image.image = image.image[self.band_indices]
|
|
75
55
|
write_selector(input_dict, target_dict, self.output_selector, image)
|
|
76
56
|
return input_dict, target_dict
|