rslearn 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/local_files.py +2 -2
- rslearn/data_sources/utils.py +204 -64
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/data_module.py +5 -0
- rslearn/train/dataset.py +186 -49
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/model_context.py +16 -0
- rslearn/train/tasks/detection.py +1 -18
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/segmentation.py +27 -32
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/colors.py +20 -0
- rslearn/vis/__init__.py +1 -0
- rslearn/vis/normalization.py +127 -0
- rslearn/vis/render_raster_label.py +96 -0
- rslearn/vis/render_sensor_image.py +27 -0
- rslearn/vis/render_vector_label.py +439 -0
- rslearn/vis/utils.py +99 -0
- rslearn/vis/vis_server.py +574 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/METADATA +14 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/RECORD +42 -33
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/WHEEL +1 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Normalization transforms."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -35,14 +36,17 @@ class Normalize(Transform):
|
|
|
35
36
|
bands: optionally restrict the normalization to these band indices. If set,
|
|
36
37
|
mean and std must either be one value, or have length equal to the
|
|
37
38
|
number of band indices passed here.
|
|
38
|
-
num_bands:
|
|
39
|
-
in a time series. If set, then the bands list is repeated for each
|
|
40
|
-
image, e.g. if bands=[2] then we apply normalization on images[2],
|
|
41
|
-
images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
|
|
42
|
-
is not set, then we apply the mean and std on each image in the time
|
|
43
|
-
series.
|
|
39
|
+
num_bands: deprecated, no longer used. Will be removed after 2026-04-01.
|
|
44
40
|
"""
|
|
45
41
|
super().__init__()
|
|
42
|
+
|
|
43
|
+
if num_bands is not None:
|
|
44
|
+
warnings.warn(
|
|
45
|
+
"num_bands is deprecated and no longer used. "
|
|
46
|
+
"It will be removed after 2026-04-01.",
|
|
47
|
+
FutureWarning,
|
|
48
|
+
)
|
|
49
|
+
|
|
46
50
|
self.mean = torch.tensor(mean)
|
|
47
51
|
self.std = torch.tensor(std)
|
|
48
52
|
|
|
@@ -55,92 +59,37 @@ class Normalize(Transform):
|
|
|
55
59
|
|
|
56
60
|
self.selectors = selectors
|
|
57
61
|
self.bands = torch.tensor(bands) if bands is not None else None
|
|
58
|
-
self.num_bands = num_bands
|
|
59
62
|
|
|
60
|
-
def apply_image(
|
|
61
|
-
self, image: torch.Tensor | RasterImage
|
|
62
|
-
) -> torch.Tensor | RasterImage:
|
|
63
|
+
def apply_image(self, image: RasterImage) -> RasterImage:
|
|
63
64
|
"""Normalize the specified image.
|
|
64
65
|
|
|
65
66
|
Args:
|
|
66
67
|
image: the image to transform.
|
|
67
68
|
"""
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
#
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if len(self.mean.shape) == 0:
|
|
77
|
-
return self.mean, self.std
|
|
78
|
-
if num_bands is None:
|
|
79
|
-
return self.mean, self.std
|
|
80
|
-
num_images = image_channels // num_bands
|
|
81
|
-
if is_raster_image:
|
|
82
|
-
# add an extra T dimension, CTHW
|
|
83
|
-
return self.mean.repeat(num_images)[
|
|
84
|
-
:, None, None, None
|
|
85
|
-
], self.std.repeat(num_images)[:, None, None, None]
|
|
86
|
-
else:
|
|
87
|
-
# add an extra T dimension, CTHW
|
|
88
|
-
return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
|
|
89
|
-
num_images
|
|
90
|
-
)[:, None, None]
|
|
69
|
+
# Get mean/std with singleton dims for broadcasting over CTHW.
|
|
70
|
+
if len(self.mean.shape) == 0:
|
|
71
|
+
# Scalar - broadcasts naturally.
|
|
72
|
+
mean, std = self.mean, self.std
|
|
73
|
+
else:
|
|
74
|
+
# Vector of length C - add singleton dims for T, H, W.
|
|
75
|
+
mean = self.mean[:, None, None, None]
|
|
76
|
+
std = self.std[:, None, None, None]
|
|
91
77
|
|
|
92
78
|
if self.bands is not None:
|
|
93
|
-
#
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
[
|
|
101
|
-
band_indices + image_idx * self.num_bands
|
|
102
|
-
for image_idx in range(num_images)
|
|
103
|
-
],
|
|
104
|
-
dim=0,
|
|
79
|
+
# Normalize only specific band indices.
|
|
80
|
+
image.image[self.bands] = (image.image[self.bands] - mean) / std
|
|
81
|
+
if self.valid_min is not None:
|
|
82
|
+
image.image[self.bands] = torch.clamp(
|
|
83
|
+
image.image[self.bands],
|
|
84
|
+
min=self.valid_min,
|
|
85
|
+
max=self.valid_max,
|
|
105
86
|
)
|
|
106
|
-
|
|
107
|
-
# We use len(self.bands) here because that is how many bands per timestep
|
|
108
|
-
# we are actually processing with the mean/std.
|
|
109
|
-
mean, std = _repeat_mean_and_std(
|
|
110
|
-
image_channels=len(band_indices),
|
|
111
|
-
num_bands=len(self.bands),
|
|
112
|
-
is_raster_image=isinstance(image, RasterImage),
|
|
113
|
-
)
|
|
114
|
-
if isinstance(image, torch.Tensor):
|
|
115
|
-
image[band_indices] = (image[band_indices] - mean) / std
|
|
116
|
-
if self.valid_min is not None:
|
|
117
|
-
image[band_indices] = torch.clamp(
|
|
118
|
-
image[band_indices], min=self.valid_min, max=self.valid_max
|
|
119
|
-
)
|
|
120
|
-
else:
|
|
121
|
-
image.image[band_indices] = (image.image[band_indices] - mean) / std
|
|
122
|
-
if self.valid_min is not None:
|
|
123
|
-
image.image[band_indices] = torch.clamp(
|
|
124
|
-
image.image[band_indices],
|
|
125
|
-
min=self.valid_min,
|
|
126
|
-
max=self.valid_max,
|
|
127
|
-
)
|
|
128
87
|
else:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if isinstance(image, torch.Tensor):
|
|
135
|
-
image = (image - mean) / std
|
|
136
|
-
if self.valid_min is not None:
|
|
137
|
-
image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
|
|
138
|
-
else:
|
|
139
|
-
image.image = (image.image - mean) / std
|
|
140
|
-
if self.valid_min is not None:
|
|
141
|
-
image.image = torch.clamp(
|
|
142
|
-
image.image, min=self.valid_min, max=self.valid_max
|
|
143
|
-
)
|
|
88
|
+
image.image = (image.image - mean) / std
|
|
89
|
+
if self.valid_min is not None:
|
|
90
|
+
image.image = torch.clamp(
|
|
91
|
+
image.image, min=self.valid_min, max=self.valid_max
|
|
92
|
+
)
|
|
144
93
|
return image
|
|
145
94
|
|
|
146
95
|
def forward(
|
rslearn/train/transforms/pad.py
CHANGED
|
@@ -50,9 +50,7 @@ class Pad(Transform):
|
|
|
50
50
|
"""
|
|
51
51
|
return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
|
|
52
52
|
|
|
53
|
-
def apply_image(
|
|
54
|
-
self, image: RasterImage | torch.Tensor, state: dict[str, bool]
|
|
55
|
-
) -> RasterImage | torch.Tensor:
|
|
53
|
+
def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
|
|
56
54
|
"""Apply the sampled state on the specified image.
|
|
57
55
|
|
|
58
56
|
Args:
|
|
@@ -105,16 +103,12 @@ class Pad(Transform):
|
|
|
105
103
|
horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
|
|
106
104
|
vertical_pad = (vertical_half, vertical_extra - vertical_half)
|
|
107
105
|
|
|
108
|
-
|
|
109
|
-
image.image
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
image.image
|
|
113
|
-
|
|
114
|
-
)
|
|
115
|
-
else:
|
|
116
|
-
image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
|
|
117
|
-
image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
|
|
106
|
+
image.image = apply_padding(
|
|
107
|
+
image.image, True, horizontal_pad[0], horizontal_pad[1]
|
|
108
|
+
)
|
|
109
|
+
image.image = apply_padding(
|
|
110
|
+
image.image, False, vertical_pad[0], vertical_pad[1]
|
|
111
|
+
)
|
|
118
112
|
return image
|
|
119
113
|
|
|
120
114
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import torch
|
|
6
5
|
import torchvision
|
|
7
6
|
from torchvision.transforms import InterpolationMode
|
|
8
7
|
|
|
@@ -40,32 +39,16 @@ class Resize(Transform):
|
|
|
40
39
|
self.selectors = selectors
|
|
41
40
|
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
42
41
|
|
|
43
|
-
def apply_resize(
|
|
44
|
-
self, image: torch.Tensor | RasterImage
|
|
45
|
-
) -> torch.Tensor | RasterImage:
|
|
42
|
+
def apply_resize(self, image: RasterImage) -> RasterImage:
|
|
46
43
|
"""Apply resizing on the specified image.
|
|
47
44
|
|
|
48
|
-
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
49
|
-
back after resizing.
|
|
50
|
-
|
|
51
45
|
Args:
|
|
52
46
|
image: the image to transform.
|
|
53
47
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
image, self.target_size, self.interpolation
|
|
59
|
-
)
|
|
60
|
-
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
61
|
-
return torchvision.transforms.functional.resize(
|
|
62
|
-
image, self.target_size, self.interpolation
|
|
63
|
-
)
|
|
64
|
-
else:
|
|
65
|
-
image.image = torchvision.transforms.functional.resize(
|
|
66
|
-
image.image, self.target_size, self.interpolation
|
|
67
|
-
)
|
|
68
|
-
return image
|
|
48
|
+
image.image = torchvision.transforms.functional.resize(
|
|
49
|
+
image.image, self.target_size, self.interpolation
|
|
50
|
+
)
|
|
51
|
+
return image
|
|
69
52
|
|
|
70
53
|
def forward(
|
|
71
54
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""The SelectBands transform."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
|
-
from rslearn.train.model_context import RasterImage
|
|
6
|
-
|
|
7
6
|
from .transform import Transform, read_selector, write_selector
|
|
8
7
|
|
|
9
8
|
|
|
@@ -17,60 +16,41 @@ class SelectBands(Transform):
|
|
|
17
16
|
output_selector: str = "image",
|
|
18
17
|
num_bands_per_timestep: int | None = None,
|
|
19
18
|
):
|
|
20
|
-
"""Initialize a new
|
|
19
|
+
"""Initialize a new SelectBands.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
band_indices: the bands to select.
|
|
22
|
+
band_indices: the bands to select from the channel dimension.
|
|
24
23
|
input_selector: the selector to read the input image.
|
|
25
24
|
output_selector: the output selector under which to save the output image.
|
|
26
|
-
num_bands_per_timestep:
|
|
27
|
-
|
|
28
|
-
band_indices are selected for each image in the time series.
|
|
25
|
+
num_bands_per_timestep: deprecated, no longer used. Will be removed after
|
|
26
|
+
2026-04-01.
|
|
29
27
|
"""
|
|
30
28
|
super().__init__()
|
|
29
|
+
|
|
30
|
+
if num_bands_per_timestep is not None:
|
|
31
|
+
warnings.warn(
|
|
32
|
+
"num_bands_per_timestep is deprecated and no longer used. "
|
|
33
|
+
"It will be removed after 2026-04-01.",
|
|
34
|
+
FutureWarning,
|
|
35
|
+
)
|
|
36
|
+
|
|
31
37
|
self.input_selector = input_selector
|
|
32
38
|
self.output_selector = output_selector
|
|
33
39
|
self.band_indices = band_indices
|
|
34
|
-
self.num_bands_per_timestep = num_bands_per_timestep
|
|
35
40
|
|
|
36
41
|
def forward(
|
|
37
42
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
38
43
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
39
|
-
"""Apply
|
|
44
|
+
"""Apply band selection over the inputs and targets.
|
|
40
45
|
|
|
41
46
|
Args:
|
|
42
47
|
input_dict: the input
|
|
43
48
|
target_dict: the target
|
|
44
49
|
|
|
45
50
|
Returns:
|
|
46
|
-
|
|
51
|
+
(input_dicts, target_dicts) tuple with selected bands
|
|
47
52
|
"""
|
|
48
53
|
image = read_selector(input_dict, target_dict, self.input_selector)
|
|
49
|
-
|
|
50
|
-
self.num_bands_per_timestep
|
|
51
|
-
if self.num_bands_per_timestep is not None
|
|
52
|
-
else image.shape[0]
|
|
53
|
-
)
|
|
54
|
-
if isinstance(image, RasterImage):
|
|
55
|
-
assert num_bands_per_timestep == image.shape[0], (
|
|
56
|
-
"Expect a seperate dimension for timesteps in RasterImages."
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
if image.shape[0] % num_bands_per_timestep != 0:
|
|
60
|
-
raise ValueError(
|
|
61
|
-
f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
# Copy the band indices for each timestep in the input.
|
|
65
|
-
wanted_bands: list[int] = []
|
|
66
|
-
for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
|
|
67
|
-
wanted_bands.extend(
|
|
68
|
-
[(start_channel_idx + band_idx) for band_idx in self.band_indices]
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
if isinstance(image, RasterImage):
|
|
72
|
-
image.image = image.image[wanted_bands]
|
|
73
|
-
else:
|
|
74
|
-
image = image[wanted_bands]
|
|
54
|
+
image.image = image.image[self.band_indices]
|
|
75
55
|
write_selector(input_dict, target_dict, self.output_selector, image)
|
|
76
56
|
return input_dict, target_dict
|
|
@@ -33,31 +33,19 @@ class Sentinel1ToDecibels(Transform):
|
|
|
33
33
|
self.from_decibels = from_decibels
|
|
34
34
|
self.epsilon = epsilon
|
|
35
35
|
|
|
36
|
-
def apply_image(
|
|
37
|
-
self, image: torch.Tensor | RasterImage
|
|
38
|
-
) -> torch.Tensor | RasterImage:
|
|
36
|
+
def apply_image(self, image: RasterImage) -> RasterImage:
|
|
39
37
|
"""Normalize the specified image.
|
|
40
38
|
|
|
41
39
|
Args:
|
|
42
40
|
image: the image to transform.
|
|
43
41
|
"""
|
|
44
|
-
if isinstance(image, torch.Tensor):
|
|
45
|
-
image_to_process = image
|
|
46
|
-
else:
|
|
47
|
-
image_to_process = image.image
|
|
48
42
|
if self.from_decibels:
|
|
49
43
|
# Decibels to linear scale.
|
|
50
|
-
|
|
44
|
+
image.image = torch.pow(10.0, image.image / 10.0)
|
|
51
45
|
else:
|
|
52
46
|
# Linear scale to decibels.
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
)
|
|
56
|
-
if isinstance(image, torch.Tensor):
|
|
57
|
-
return image_to_process
|
|
58
|
-
else:
|
|
59
|
-
image.image = image_to_process
|
|
60
|
-
return image
|
|
47
|
+
image.image = 10 * torch.log10(torch.clamp(image.image, min=self.epsilon))
|
|
48
|
+
return image
|
|
61
49
|
|
|
62
50
|
def forward(
|
|
63
51
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
rslearn/utils/colors.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Default color palette for visualizations."""
|
|
2
|
+
|
|
3
|
+
DEFAULT_COLORS = [
|
|
4
|
+
(0, 0, 0),
|
|
5
|
+
(255, 0, 0),
|
|
6
|
+
(0, 255, 0),
|
|
7
|
+
(0, 0, 255),
|
|
8
|
+
(255, 255, 0),
|
|
9
|
+
(0, 255, 255),
|
|
10
|
+
(255, 0, 255),
|
|
11
|
+
(0, 128, 0),
|
|
12
|
+
(255, 160, 122),
|
|
13
|
+
(139, 69, 19),
|
|
14
|
+
(128, 128, 128),
|
|
15
|
+
(255, 255, 255),
|
|
16
|
+
(143, 188, 143),
|
|
17
|
+
(95, 158, 160),
|
|
18
|
+
(255, 200, 0),
|
|
19
|
+
(128, 0, 0),
|
|
20
|
+
]
|
rslearn/vis/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Visualization module for rslearn datasets."""
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Normalization functions for raster data visualization."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NormalizationMethod(StrEnum):
|
|
14
|
+
"""Normalization methods for raster data visualization."""
|
|
15
|
+
|
|
16
|
+
SENTINEL2_RGB = "sentinel2_rgb"
|
|
17
|
+
"""Divide by 10 and clip (for Sentinel-2 B04/B03/B02 bands)."""
|
|
18
|
+
|
|
19
|
+
PERCENTILE = "percentile"
|
|
20
|
+
"""Use 2-98 percentile clipping."""
|
|
21
|
+
|
|
22
|
+
MINMAX = "minmax"
|
|
23
|
+
"""Use min-max stretch."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _normalize_sentinel2_rgb(band: np.ndarray) -> np.ndarray:
|
|
27
|
+
"""Normalize band using Sentinel-2 RGB method (divide by 10 and clip).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
band: Input band data
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Normalized band as uint8 array
|
|
34
|
+
"""
|
|
35
|
+
band = band / 10.0
|
|
36
|
+
band = np.clip(band, 0, 255).astype(np.uint8)
|
|
37
|
+
return band
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _normalize_percentile(band: np.ndarray) -> np.ndarray:
|
|
41
|
+
"""Normalize band using 2-98 percentile clipping.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
band: Input band data
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Normalized band as uint8 array
|
|
48
|
+
"""
|
|
49
|
+
valid_pixels = band[~np.isnan(band)]
|
|
50
|
+
if len(valid_pixels) == 0:
|
|
51
|
+
return np.zeros_like(band, dtype=np.uint8)
|
|
52
|
+
vmin, vmax = np.nanpercentile(valid_pixels, (2, 98))
|
|
53
|
+
if vmax == vmin:
|
|
54
|
+
return np.zeros_like(band, dtype=np.uint8)
|
|
55
|
+
band = np.clip(band, vmin, vmax)
|
|
56
|
+
band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
|
|
57
|
+
return band
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _normalize_minmax(band: np.ndarray) -> np.ndarray:
|
|
61
|
+
"""Normalize band using min-max stretch.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
band: Input band data
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Normalized band as uint8 array
|
|
68
|
+
"""
|
|
69
|
+
vmin, vmax = np.nanmin(band), np.nanmax(band)
|
|
70
|
+
if vmax == vmin:
|
|
71
|
+
return np.zeros_like(band, dtype=np.uint8)
|
|
72
|
+
band = np.clip(band, vmin, vmax)
|
|
73
|
+
band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
|
|
74
|
+
return band
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
_NORMALIZATION_FUNCTIONS: dict[
|
|
78
|
+
NormalizationMethod, Callable[[np.ndarray], np.ndarray]
|
|
79
|
+
] = {
|
|
80
|
+
NormalizationMethod.SENTINEL2_RGB: _normalize_sentinel2_rgb,
|
|
81
|
+
NormalizationMethod.PERCENTILE: _normalize_percentile,
|
|
82
|
+
NormalizationMethod.MINMAX: _normalize_minmax,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def normalize_band(
|
|
87
|
+
band: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
|
|
88
|
+
) -> np.ndarray:
|
|
89
|
+
"""Normalize band to 0-255 range.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
band: Input band data
|
|
93
|
+
method: Normalization method (string or NormalizationMethod enum)
|
|
94
|
+
- 'sentinel2_rgb': Divide by 10 and clip (for B04/B03/B02)
|
|
95
|
+
- 'percentile': Use 2-98 percentile clipping
|
|
96
|
+
- 'minmax': Use min-max stretch
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Normalized band as uint8 array
|
|
100
|
+
"""
|
|
101
|
+
method_enum = NormalizationMethod(method) if isinstance(method, str) else method
|
|
102
|
+
normalize_func = _NORMALIZATION_FUNCTIONS.get(method_enum)
|
|
103
|
+
if normalize_func is None:
|
|
104
|
+
raise ValueError(f"Unknown normalization method: {method_enum}")
|
|
105
|
+
return normalize_func(band)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def normalize_array(
|
|
109
|
+
array: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
|
|
110
|
+
) -> np.ndarray:
|
|
111
|
+
"""Normalize a multi-band array to 0-255 range.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
array: Input array with shape (channels, height, width) from RasterFormat.decode_raster
|
|
115
|
+
method: Normalization method (applied per-band, string or NormalizationMethod enum)
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Normalized array as uint8 with shape (height, width, channels)
|
|
119
|
+
"""
|
|
120
|
+
if array.ndim == 3:
|
|
121
|
+
array = np.moveaxis(array, 0, -1)
|
|
122
|
+
|
|
123
|
+
normalized = np.zeros_like(array, dtype=np.uint8)
|
|
124
|
+
for i in range(array.shape[-1]):
|
|
125
|
+
normalized[..., i] = normalize_band(array[..., i], method)
|
|
126
|
+
|
|
127
|
+
return normalized
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Functions for rendering raster label masks (e.g., segmentation masks)."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from rasterio.warp import Resampling
|
|
6
|
+
|
|
7
|
+
from rslearn.config import DType, LayerConfig
|
|
8
|
+
from rslearn.dataset import Window
|
|
9
|
+
from rslearn.log_utils import get_logger
|
|
10
|
+
from rslearn.train.dataset import DataInput, read_raster_layer_for_data_input
|
|
11
|
+
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def render_raster_label(
|
|
17
|
+
label_array: np.ndarray,
|
|
18
|
+
label_colors: dict[str, tuple[int, int, int]],
|
|
19
|
+
layer_config: LayerConfig,
|
|
20
|
+
) -> np.ndarray:
|
|
21
|
+
"""Render a raster label array as a colored mask numpy array.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
label_array: Raster label array with shape (bands, height, width) - typically single band
|
|
25
|
+
label_colors: Dictionary mapping label class names to RGB color tuples
|
|
26
|
+
layer_config: LayerConfig object (to access class_names if available)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Array with shape (height, width, 3) as uint8
|
|
30
|
+
"""
|
|
31
|
+
if label_array.ndim == 3:
|
|
32
|
+
label_values = label_array[0, :, :]
|
|
33
|
+
else:
|
|
34
|
+
label_values = label_array
|
|
35
|
+
|
|
36
|
+
height, width = label_values.shape
|
|
37
|
+
mask_img = np.zeros((height, width, 3), dtype=np.uint8)
|
|
38
|
+
valid_mask = ~np.isnan(label_values)
|
|
39
|
+
|
|
40
|
+
if not layer_config.class_names:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"class_names must be specified in config for raster label layer"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
label_int = label_values.astype(np.int32)
|
|
46
|
+
for idx in range(len(layer_config.class_names)):
|
|
47
|
+
class_name = layer_config.class_names[idx]
|
|
48
|
+
color = label_colors.get(str(class_name), (0, 0, 0))
|
|
49
|
+
mask = (label_int == idx) & valid_mask
|
|
50
|
+
mask_img[mask] = color
|
|
51
|
+
|
|
52
|
+
img = Image.fromarray(mask_img, mode="RGB")
|
|
53
|
+
return np.array(img)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def read_raster_layer(
|
|
57
|
+
window: Window,
|
|
58
|
+
layer_name: str,
|
|
59
|
+
layer_config: LayerConfig,
|
|
60
|
+
band_names: list[str],
|
|
61
|
+
group_idx: int = 0,
|
|
62
|
+
bounds: PixelBounds | None = None,
|
|
63
|
+
) -> np.ndarray:
|
|
64
|
+
"""Read a raster layer for visualization.
|
|
65
|
+
|
|
66
|
+
This reads bands from potentially multiple band sets to get the requested bands.
|
|
67
|
+
Uses read_raster_layer_for_data_input from rslearn.train.dataset.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
window: The window to read from
|
|
71
|
+
layer_name: The layer name
|
|
72
|
+
layer_config: The layer configuration
|
|
73
|
+
band_names: List of band names to read (e.g., ["B04", "B03", "B02"])
|
|
74
|
+
group_idx: The item group index (default 0)
|
|
75
|
+
bounds: Optional bounds to read. If None, uses window.bounds
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Array with shape (bands, height, width) as float32
|
|
79
|
+
"""
|
|
80
|
+
if bounds is None:
|
|
81
|
+
bounds = window.bounds
|
|
82
|
+
|
|
83
|
+
data_input = DataInput(
|
|
84
|
+
data_type="raster",
|
|
85
|
+
layers=[layer_name],
|
|
86
|
+
bands=band_names,
|
|
87
|
+
dtype=DType.FLOAT32,
|
|
88
|
+
resolution_factor=ResolutionFactor(), # Default 1/1, no scaling
|
|
89
|
+
resampling=Resampling.nearest,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
image_tensor = read_raster_layer_for_data_input(
|
|
93
|
+
window, bounds, layer_name, group_idx, layer_config, data_input
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return image_tensor.numpy().astype(np.float32)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Functions for rendering raster sensor images (e.g., Sentinel-2, Landsat)."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from .normalization import normalize_array
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def render_sensor_image(
|
|
9
|
+
array: np.ndarray,
|
|
10
|
+
normalization_method: str,
|
|
11
|
+
) -> np.ndarray:
|
|
12
|
+
"""Render a raster sensor image array as a numpy array.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
array: Array with shape (channels, height, width) from RasterFormat.decode_raster
|
|
16
|
+
normalization_method: Normalization method to apply
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Array with shape (height, width, channels) as uint8
|
|
20
|
+
"""
|
|
21
|
+
normalized = normalize_array(array, normalization_method)
|
|
22
|
+
|
|
23
|
+
# If more than 3 channels, take only the first 3 for RGB
|
|
24
|
+
if normalized.shape[-1] > 3:
|
|
25
|
+
normalized = normalized[:, :, :3]
|
|
26
|
+
|
|
27
|
+
return normalized
|