rslearn 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/planetary_computer.py +99 -1
- rslearn/data_sources/stac.py +3 -2
- rslearn/models/simple_time_series.py +1 -1
- rslearn/train/dataset.py +44 -3
- rslearn/train/tasks/detection.py +1 -18
- rslearn/train/tasks/segmentation.py +21 -20
- rslearn/utils/colors.py +20 -0
- rslearn/utils/raster_format.py +17 -0
- rslearn/utils/stac.py +4 -0
- rslearn/vis/__init__.py +1 -0
- rslearn/vis/normalization.py +127 -0
- rslearn/vis/render_raster_label.py +96 -0
- rslearn/vis/render_sensor_image.py +27 -0
- rslearn/vis/render_vector_label.py +439 -0
- rslearn/vis/utils.py +99 -0
- rslearn/vis/vis_server.py +574 -0
- {rslearn-0.0.23.dist-info → rslearn-0.0.25.dist-info}/METADATA +14 -1
- {rslearn-0.0.23.dist-info → rslearn-0.0.25.dist-info}/RECORD +23 -15
- {rslearn-0.0.23.dist-info → rslearn-0.0.25.dist-info}/WHEEL +1 -1
- {rslearn-0.0.23.dist-info → rslearn-0.0.25.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.23.dist-info → rslearn-0.0.25.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.23.dist-info → rslearn-0.0.25.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.23.dist-info → rslearn-0.0.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Functions for rendering raster label masks (e.g., segmentation masks)."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from rasterio.warp import Resampling
|
|
6
|
+
|
|
7
|
+
from rslearn.config import DType, LayerConfig
|
|
8
|
+
from rslearn.dataset import Window
|
|
9
|
+
from rslearn.log_utils import get_logger
|
|
10
|
+
from rslearn.train.dataset import DataInput, read_raster_layer_for_data_input
|
|
11
|
+
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def render_raster_label(
|
|
17
|
+
label_array: np.ndarray,
|
|
18
|
+
label_colors: dict[str, tuple[int, int, int]],
|
|
19
|
+
layer_config: LayerConfig,
|
|
20
|
+
) -> np.ndarray:
|
|
21
|
+
"""Render a raster label array as a colored mask numpy array.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
label_array: Raster label array with shape (bands, height, width) - typically single band
|
|
25
|
+
label_colors: Dictionary mapping label class names to RGB color tuples
|
|
26
|
+
layer_config: LayerConfig object (to access class_names if available)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Array with shape (height, width, 3) as uint8
|
|
30
|
+
"""
|
|
31
|
+
if label_array.ndim == 3:
|
|
32
|
+
label_values = label_array[0, :, :]
|
|
33
|
+
else:
|
|
34
|
+
label_values = label_array
|
|
35
|
+
|
|
36
|
+
height, width = label_values.shape
|
|
37
|
+
mask_img = np.zeros((height, width, 3), dtype=np.uint8)
|
|
38
|
+
valid_mask = ~np.isnan(label_values)
|
|
39
|
+
|
|
40
|
+
if not layer_config.class_names:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"class_names must be specified in config for raster label layer"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
label_int = label_values.astype(np.int32)
|
|
46
|
+
for idx in range(len(layer_config.class_names)):
|
|
47
|
+
class_name = layer_config.class_names[idx]
|
|
48
|
+
color = label_colors.get(str(class_name), (0, 0, 0))
|
|
49
|
+
mask = (label_int == idx) & valid_mask
|
|
50
|
+
mask_img[mask] = color
|
|
51
|
+
|
|
52
|
+
img = Image.fromarray(mask_img, mode="RGB")
|
|
53
|
+
return np.array(img)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def read_raster_layer(
|
|
57
|
+
window: Window,
|
|
58
|
+
layer_name: str,
|
|
59
|
+
layer_config: LayerConfig,
|
|
60
|
+
band_names: list[str],
|
|
61
|
+
group_idx: int = 0,
|
|
62
|
+
bounds: PixelBounds | None = None,
|
|
63
|
+
) -> np.ndarray:
|
|
64
|
+
"""Read a raster layer for visualization.
|
|
65
|
+
|
|
66
|
+
This reads bands from potentially multiple band sets to get the requested bands.
|
|
67
|
+
Uses read_raster_layer_for_data_input from rslearn.train.dataset.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
window: The window to read from
|
|
71
|
+
layer_name: The layer name
|
|
72
|
+
layer_config: The layer configuration
|
|
73
|
+
band_names: List of band names to read (e.g., ["B04", "B03", "B02"])
|
|
74
|
+
group_idx: The item group index (default 0)
|
|
75
|
+
bounds: Optional bounds to read. If None, uses window.bounds
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Array with shape (bands, height, width) as float32
|
|
79
|
+
"""
|
|
80
|
+
if bounds is None:
|
|
81
|
+
bounds = window.bounds
|
|
82
|
+
|
|
83
|
+
data_input = DataInput(
|
|
84
|
+
data_type="raster",
|
|
85
|
+
layers=[layer_name],
|
|
86
|
+
bands=band_names,
|
|
87
|
+
dtype=DType.FLOAT32,
|
|
88
|
+
resolution_factor=ResolutionFactor(), # Default 1/1, no scaling
|
|
89
|
+
resampling=Resampling.nearest,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
image_tensor = read_raster_layer_for_data_input(
|
|
93
|
+
window, bounds, layer_name, group_idx, layer_config, data_input
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return image_tensor.numpy().astype(np.float32)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Functions for rendering raster sensor images (e.g., Sentinel-2, Landsat)."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from .normalization import normalize_array
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def render_sensor_image(
|
|
9
|
+
array: np.ndarray,
|
|
10
|
+
normalization_method: str,
|
|
11
|
+
) -> np.ndarray:
|
|
12
|
+
"""Render a raster sensor image array as a numpy array.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
array: Array with shape (channels, height, width) from RasterFormat.decode_raster
|
|
16
|
+
normalization_method: Normalization method to apply
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Array with shape (height, width, channels) as uint8
|
|
20
|
+
"""
|
|
21
|
+
normalized = normalize_array(array, normalization_method)
|
|
22
|
+
|
|
23
|
+
# If more than 3 channels, take only the first 3 for RGB
|
|
24
|
+
if normalized.shape[-1] > 3:
|
|
25
|
+
normalized = normalized[:, :, :3]
|
|
26
|
+
|
|
27
|
+
return normalized
|
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
"""Functions for rendering vector label masks (detection and segmentation)."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import shapely
|
|
7
|
+
from PIL import Image, ImageDraw
|
|
8
|
+
|
|
9
|
+
from rslearn.config import LayerConfig, LayerType
|
|
10
|
+
from rslearn.dataset import Dataset, Window
|
|
11
|
+
from rslearn.log_utils import get_logger
|
|
12
|
+
from rslearn.utils.feature import Feature
|
|
13
|
+
from rslearn.utils.geometry import PixelBounds, Projection, flatten_shape
|
|
14
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
15
|
+
|
|
16
|
+
from .normalization import normalize_array
|
|
17
|
+
from .render_raster_label import read_raster_layer
|
|
18
|
+
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def point_to_pixel_coords(
|
|
23
|
+
point: shapely.Point,
|
|
24
|
+
bounds: PixelBounds,
|
|
25
|
+
image_width: int,
|
|
26
|
+
image_height: int,
|
|
27
|
+
actual_width: int,
|
|
28
|
+
actual_height: int,
|
|
29
|
+
) -> tuple[int, int]:
|
|
30
|
+
"""Convert a point's coordinates to pixel coordinates in the image.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
point: Shapely Point object
|
|
34
|
+
bounds: Pixel bounds of the window
|
|
35
|
+
image_width: Width of the image in pixels
|
|
36
|
+
image_height: Height of the image in pixels
|
|
37
|
+
actual_width: Actual width of the data (bounds[2] - bounds[0])
|
|
38
|
+
actual_height: Actual height of the data (bounds[3] - bounds[1])
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Tuple of (pixel_x, pixel_y) coordinates
|
|
42
|
+
"""
|
|
43
|
+
x, y = point.x, point.y
|
|
44
|
+
px = int((x - bounds[0]) * image_width / actual_width)
|
|
45
|
+
py = int((y - bounds[1]) * image_height / actual_height)
|
|
46
|
+
return px, py
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def draw_bounding_box_around_point(
|
|
50
|
+
draw: ImageDraw.ImageDraw,
|
|
51
|
+
px: int,
|
|
52
|
+
py: int,
|
|
53
|
+
width: int,
|
|
54
|
+
height: int,
|
|
55
|
+
color: tuple[int, int, int],
|
|
56
|
+
box_size: int = 20,
|
|
57
|
+
) -> bool:
|
|
58
|
+
"""Draw a bounding box around a point on an image.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
draw: PIL ImageDraw object
|
|
62
|
+
px: Pixel x coordinate
|
|
63
|
+
py: Pixel y coordinate
|
|
64
|
+
width: Image width
|
|
65
|
+
height: Image height
|
|
66
|
+
color: RGB color tuple for the bounding box
|
|
67
|
+
box_size: Size of the bounding box in pixels (default 20)
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
True if the point was drawn (within bounds), False otherwise
|
|
71
|
+
"""
|
|
72
|
+
if 0 <= px < width and 0 <= py < height:
|
|
73
|
+
x1 = max(0, px - box_size // 2)
|
|
74
|
+
y1 = max(0, py - box_size // 2)
|
|
75
|
+
x2 = min(width, px + box_size // 2)
|
|
76
|
+
y2 = min(height, py + box_size // 2)
|
|
77
|
+
draw.rectangle(
|
|
78
|
+
[x1, y1, x2, y2],
|
|
79
|
+
outline=color,
|
|
80
|
+
width=2,
|
|
81
|
+
)
|
|
82
|
+
return True
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def overlay_points_on_image(
|
|
87
|
+
draw: ImageDraw.ImageDraw,
|
|
88
|
+
features: list[Feature],
|
|
89
|
+
bounds: PixelBounds,
|
|
90
|
+
image_width: int,
|
|
91
|
+
image_height: int,
|
|
92
|
+
actual_width: int,
|
|
93
|
+
actual_height: int,
|
|
94
|
+
label_colors: dict[str, tuple[int, int, int]],
|
|
95
|
+
class_property_name: str | None = None,
|
|
96
|
+
) -> tuple[int, int]:
|
|
97
|
+
"""Overlay point features on an image.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
draw: PIL ImageDraw object
|
|
101
|
+
features: List of Feature objects (points)
|
|
102
|
+
bounds: Pixel bounds of the window
|
|
103
|
+
image_width: Width of the image in pixels
|
|
104
|
+
image_height: Height of the image in pixels
|
|
105
|
+
actual_width: Actual width of the data (bounds[2] - bounds[0])
|
|
106
|
+
actual_height: Actual height of the data (bounds[3] - bounds[1])
|
|
107
|
+
label_colors: Dictionary mapping label class names to RGB colors
|
|
108
|
+
class_property_name: Property name to use for label extraction (from config)
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Tuple of (points_drawn, points_out_of_bounds)
|
|
112
|
+
"""
|
|
113
|
+
points_drawn = 0
|
|
114
|
+
points_out_of_bounds = 0
|
|
115
|
+
|
|
116
|
+
if not class_property_name:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"class_property_name must be specified in config for vector label layers"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for feature in features:
|
|
122
|
+
label = feature.properties.get(class_property_name)
|
|
123
|
+
label = str(label)
|
|
124
|
+
color = label_colors.get(label)
|
|
125
|
+
if color is None:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Label '{label}' not found in label_colors. Available labels: {list(label_colors.keys())}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
shp = feature.geometry.shp
|
|
131
|
+
flat_shapes = flatten_shape(shp)
|
|
132
|
+
for point in flat_shapes:
|
|
133
|
+
assert isinstance(point, shapely.Point), (
|
|
134
|
+
f"Expected Point, got {type(point)}"
|
|
135
|
+
)
|
|
136
|
+
px, py = point_to_pixel_coords(
|
|
137
|
+
point, bounds, image_width, image_height, actual_width, actual_height
|
|
138
|
+
)
|
|
139
|
+
logger.info(
|
|
140
|
+
f"Point at ({point.x:.2f}, {point.y:.2f}) -> pixel ({px}, {py}), "
|
|
141
|
+
f"bounds: {bounds}, image size: {image_width}x{image_height}, "
|
|
142
|
+
f"actual_size: {actual_width}x{actual_height}"
|
|
143
|
+
)
|
|
144
|
+
if draw_bounding_box_around_point(
|
|
145
|
+
draw, px, py, image_width, image_height, color
|
|
146
|
+
):
|
|
147
|
+
points_drawn += 1
|
|
148
|
+
else:
|
|
149
|
+
points_out_of_bounds += 1
|
|
150
|
+
|
|
151
|
+
return points_drawn, points_out_of_bounds
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def render_vector_label_detection(
|
|
155
|
+
features: list[Feature],
|
|
156
|
+
bounds: PixelBounds,
|
|
157
|
+
label_colors: dict[str, tuple[int, int, int]],
|
|
158
|
+
class_property_name: str | None = None,
|
|
159
|
+
reference_array: np.ndarray | None = None,
|
|
160
|
+
normalization_method: str | None = None,
|
|
161
|
+
) -> np.ndarray:
|
|
162
|
+
"""Render vector labels for detection tasks (overlay points on reference image or blank background).
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
features: List of Feature objects (points)
|
|
166
|
+
bounds: Pixel bounds of the window
|
|
167
|
+
label_colors: Dictionary mapping label class names to RGB colors
|
|
168
|
+
class_property_name: Property name to use for label extraction (from config)
|
|
169
|
+
reference_array: Optional reference raster array to overlay points on
|
|
170
|
+
normalization_method: Optional normalization method for the reference array
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Array with shape (height, width, 3) as uint8
|
|
174
|
+
"""
|
|
175
|
+
actual_width = bounds[2] - bounds[0]
|
|
176
|
+
actual_height = bounds[3] - bounds[1]
|
|
177
|
+
|
|
178
|
+
if reference_array is not None and normalization_method is not None:
|
|
179
|
+
normalized = normalize_array(reference_array, normalization_method)
|
|
180
|
+
if normalized.shape[-1] >= 3:
|
|
181
|
+
img = Image.fromarray(normalized[:, :, :3], mode="RGB")
|
|
182
|
+
else:
|
|
183
|
+
img = Image.fromarray(normalized[:, :, 0], mode="L").convert("RGB")
|
|
184
|
+
else:
|
|
185
|
+
img = Image.new("RGB", (actual_width, actual_height), color=(0, 0, 0))
|
|
186
|
+
|
|
187
|
+
draw = ImageDraw.Draw(img)
|
|
188
|
+
|
|
189
|
+
overlay_points_on_image(
|
|
190
|
+
draw,
|
|
191
|
+
features,
|
|
192
|
+
bounds,
|
|
193
|
+
img.size[0],
|
|
194
|
+
img.size[1],
|
|
195
|
+
actual_width,
|
|
196
|
+
actual_height,
|
|
197
|
+
label_colors,
|
|
198
|
+
class_property_name,
|
|
199
|
+
)
|
|
200
|
+
return np.array(img)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def render_vector_label_segmentation(
|
|
204
|
+
features: list[Feature],
|
|
205
|
+
bounds: PixelBounds,
|
|
206
|
+
projection: Projection,
|
|
207
|
+
label_colors: dict[str, tuple[int, int, int]],
|
|
208
|
+
class_property_name: str | None = None,
|
|
209
|
+
) -> np.ndarray:
|
|
210
|
+
"""Render vector labels for segmentation tasks (draw polygons on mask).
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
features: List of Feature objects (polygons)
|
|
214
|
+
bounds: Pixel bounds of the window
|
|
215
|
+
projection: Projection of the window
|
|
216
|
+
label_colors: Dictionary mapping label class names to RGB colors
|
|
217
|
+
class_property_name: Property name to use for label extraction (from config)
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Array with shape (height, width, 3) as uint8
|
|
221
|
+
"""
|
|
222
|
+
width = bounds[2] - bounds[0]
|
|
223
|
+
height = bounds[3] - bounds[1]
|
|
224
|
+
|
|
225
|
+
mask_img = Image.new("RGB", (width, height), color=(0, 0, 0))
|
|
226
|
+
draw = ImageDraw.Draw(mask_img)
|
|
227
|
+
|
|
228
|
+
if not class_property_name:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
"class_property_name must be specified in config for vector label layers"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def get_label(feat: Any) -> str:
|
|
234
|
+
if not feat.properties:
|
|
235
|
+
return ""
|
|
236
|
+
label = feat.properties.get(class_property_name)
|
|
237
|
+
return str(label) if label else ""
|
|
238
|
+
|
|
239
|
+
sorted_features = sorted(features, key=get_label, reverse=True)
|
|
240
|
+
|
|
241
|
+
for feature in sorted_features:
|
|
242
|
+
label = feature.properties.get(class_property_name)
|
|
243
|
+
label = str(label)
|
|
244
|
+
color = label_colors.get(label, (255, 255, 255))
|
|
245
|
+
geom_pixel = feature.geometry.to_projection(projection)
|
|
246
|
+
shp = geom_pixel.shp
|
|
247
|
+
if shp.geom_type == "Polygon":
|
|
248
|
+
coords = list(shp.exterior.coords)
|
|
249
|
+
pixel_coords = [(int(x - bounds[0]), int(y - bounds[1])) for x, y in coords]
|
|
250
|
+
if len(pixel_coords) >= 3:
|
|
251
|
+
draw.polygon(pixel_coords, fill=color, outline=color)
|
|
252
|
+
for interior in shp.interiors:
|
|
253
|
+
hole_coords = [
|
|
254
|
+
(int(x - bounds[0]), int(y - bounds[1]))
|
|
255
|
+
for x, y in interior.coords
|
|
256
|
+
]
|
|
257
|
+
if len(hole_coords) >= 3:
|
|
258
|
+
draw.polygon(hole_coords, fill=(0, 0, 0), outline=color)
|
|
259
|
+
elif shp.geom_type == "MultiPolygon":
|
|
260
|
+
for poly in shp.geoms:
|
|
261
|
+
coords = list(poly.exterior.coords)
|
|
262
|
+
pixel_coords = [
|
|
263
|
+
(int(x - bounds[0]), int(y - bounds[1])) for x, y in coords
|
|
264
|
+
]
|
|
265
|
+
if len(pixel_coords) >= 3:
|
|
266
|
+
draw.polygon(pixel_coords, fill=color, outline=color)
|
|
267
|
+
|
|
268
|
+
return np.array(mask_img)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def read_vector_layer(
|
|
272
|
+
window: Window,
|
|
273
|
+
layer_name: str,
|
|
274
|
+
layer_config: LayerConfig,
|
|
275
|
+
group_idx: int = 0,
|
|
276
|
+
) -> list[Any]:
|
|
277
|
+
"""Read a vector layer for visualization.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
window: The window to read from
|
|
281
|
+
layer_name: The layer name
|
|
282
|
+
layer_config: The layer configuration
|
|
283
|
+
group_idx: The item group index (default 0)
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
List of Feature objects
|
|
287
|
+
"""
|
|
288
|
+
if layer_config.type != LayerType.VECTOR:
|
|
289
|
+
raise ValueError(f"Layer {layer_name} is not a vector layer")
|
|
290
|
+
|
|
291
|
+
vector_format: VectorFormat = layer_config.instantiate_vector_format()
|
|
292
|
+
layer_dir = window.get_layer_dir(layer_name, group_idx=group_idx)
|
|
293
|
+
logger.info(
|
|
294
|
+
f"Reading vector layer {layer_name} from {layer_dir}, bounds: {window.bounds}, projection: {window.projection}"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
features = vector_format.decode_vector(layer_dir, window.projection, window.bounds)
|
|
298
|
+
logger.info(f"Decoded {len(features)} features from vector layer {layer_name}")
|
|
299
|
+
return features
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def get_vector_label_by_property(
|
|
303
|
+
window: Window,
|
|
304
|
+
layer_config: LayerConfig,
|
|
305
|
+
layer_name: str,
|
|
306
|
+
group_idx: int = 0,
|
|
307
|
+
) -> str | None:
|
|
308
|
+
"""Get a label value from a vector layer's first feature property.
|
|
309
|
+
|
|
310
|
+
Extracts the label value from the first feature's properties using the property
|
|
311
|
+
name specified in layer_config.class_property_name. This works for both
|
|
312
|
+
classification and segmentation tasks that use vector labels.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
window: The window to read from
|
|
316
|
+
layer_config: The label layer configuration (must be vector type)
|
|
317
|
+
layer_name: The name of the label layer
|
|
318
|
+
group_idx: The item group index (default 0)
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
The label string, or None if not found
|
|
322
|
+
"""
|
|
323
|
+
features = read_vector_layer(window, layer_name, layer_config, group_idx=group_idx)
|
|
324
|
+
if not features:
|
|
325
|
+
logger.warning(
|
|
326
|
+
f"No features in vector label layer {layer_name} for {window.name}"
|
|
327
|
+
)
|
|
328
|
+
return None
|
|
329
|
+
|
|
330
|
+
first_feature = features[0]
|
|
331
|
+
if not first_feature.properties:
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
if not layer_config.class_property_name:
|
|
335
|
+
raise ValueError(
|
|
336
|
+
f"class_property_name must be specified in the config for vector label layer '{layer_name}'. "
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
label = first_feature.properties.get(layer_config.class_property_name)
|
|
340
|
+
logger.info(f"Label for {window.name}: {label}")
|
|
341
|
+
return str(label)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _get_reference_raster_for_detection(
|
|
345
|
+
window: Window,
|
|
346
|
+
dataset: Dataset,
|
|
347
|
+
bands: dict[str, list[str]],
|
|
348
|
+
normalization: dict[str, str],
|
|
349
|
+
label_layers: list[str],
|
|
350
|
+
group_idx: int,
|
|
351
|
+
) -> tuple[np.ndarray, str] | None:
|
|
352
|
+
"""Get reference raster array for detection tasks.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
window: Window object
|
|
356
|
+
dataset: Dataset object
|
|
357
|
+
bands: Dictionary mapping layer_name -> list of band names
|
|
358
|
+
normalization: Dictionary mapping layer_name -> normalization method
|
|
359
|
+
label_layers: List of label layer names
|
|
360
|
+
group_idx: Item group index
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Tuple of (reference_array, normalization_method) or None if no raster layers available
|
|
364
|
+
"""
|
|
365
|
+
raster_layers = [name for name in bands.keys() if name not in label_layers]
|
|
366
|
+
|
|
367
|
+
ref_layer_name = raster_layers[0]
|
|
368
|
+
ref_layer_config = dataset.layers[ref_layer_name]
|
|
369
|
+
reference_array = read_raster_layer(
|
|
370
|
+
window,
|
|
371
|
+
ref_layer_name,
|
|
372
|
+
ref_layer_config,
|
|
373
|
+
bands[ref_layer_name],
|
|
374
|
+
group_idx=group_idx,
|
|
375
|
+
)
|
|
376
|
+
ref_normalization_method = normalization[ref_layer_name]
|
|
377
|
+
return reference_array, ref_normalization_method
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def render_vector_label_image(
|
|
381
|
+
window: Window,
|
|
382
|
+
layer_name: str,
|
|
383
|
+
layer_config: LayerConfig,
|
|
384
|
+
task_type: str,
|
|
385
|
+
label_colors: dict[str, tuple[int, int, int]],
|
|
386
|
+
dataset: Dataset,
|
|
387
|
+
label_layers: list[str],
|
|
388
|
+
group_idx: int,
|
|
389
|
+
bands: dict[str, list[str]] | None = None,
|
|
390
|
+
normalization: dict[str, str] | None = None,
|
|
391
|
+
) -> np.ndarray:
|
|
392
|
+
"""Render a vector label image (detection or segmentation).
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
window: Window object
|
|
396
|
+
layer_name: Layer name
|
|
397
|
+
layer_config: LayerConfig object
|
|
398
|
+
task_type: Task type ("detection" or "segmentation")
|
|
399
|
+
label_colors: Dictionary mapping label class names to RGB colors
|
|
400
|
+
dataset: Dataset object
|
|
401
|
+
label_layers: List of label layer names
|
|
402
|
+
group_idx: Item group index
|
|
403
|
+
bands: Optional dictionary mapping layer_name -> list of band names (for detection reference raster)
|
|
404
|
+
normalization: Optional dictionary mapping layer_name -> normalization method (for detection reference raster)
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Array with shape (height, width, 3) as uint8
|
|
408
|
+
"""
|
|
409
|
+
if task_type == "classification":
|
|
410
|
+
raise ValueError("Classification labels are text, not images")
|
|
411
|
+
|
|
412
|
+
features = read_vector_layer(window, layer_name, layer_config, group_idx=group_idx)
|
|
413
|
+
|
|
414
|
+
if task_type == "detection":
|
|
415
|
+
bands = bands or {}
|
|
416
|
+
normalization = normalization or {}
|
|
417
|
+
ref_data = _get_reference_raster_for_detection(
|
|
418
|
+
window, dataset, bands, normalization, label_layers, group_idx
|
|
419
|
+
)
|
|
420
|
+
reference_array = None
|
|
421
|
+
ref_normalization_method = None
|
|
422
|
+
if ref_data is not None:
|
|
423
|
+
reference_array, ref_normalization_method = ref_data
|
|
424
|
+
return render_vector_label_detection(
|
|
425
|
+
features,
|
|
426
|
+
window.bounds,
|
|
427
|
+
label_colors,
|
|
428
|
+
layer_config.class_property_name,
|
|
429
|
+
reference_array,
|
|
430
|
+
ref_normalization_method,
|
|
431
|
+
)
|
|
432
|
+
else:
|
|
433
|
+
return render_vector_label_segmentation(
|
|
434
|
+
features,
|
|
435
|
+
window.bounds,
|
|
436
|
+
window.projection,
|
|
437
|
+
label_colors,
|
|
438
|
+
layer_config.class_property_name,
|
|
439
|
+
)
|
rslearn/vis/utils.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Utility functions and constants for visualization."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from rslearn.dataset import Window
|
|
10
|
+
from rslearn.log_utils import get_logger
|
|
11
|
+
from rslearn.utils.colors import DEFAULT_COLORS
|
|
12
|
+
from rslearn.utils.geometry import WGS84_PROJECTION
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
# Fixed size for all visualized images (width, height in pixels)
|
|
17
|
+
VISUALIZATION_IMAGE_SIZE = (512, 512)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def generate_label_colors(label_classes: set[str]) -> dict[str, tuple[int, int, int]]:
|
|
21
|
+
"""Generate distinct colors for label classes.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
label_classes: Set or list of label class names
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Dictionary mapping label class names to RGB color tuples
|
|
28
|
+
"""
|
|
29
|
+
label_colors = {}
|
|
30
|
+
|
|
31
|
+
sorted_labels = sorted(label_classes)
|
|
32
|
+
for color_idx, label in enumerate(sorted_labels):
|
|
33
|
+
label_colors[label] = DEFAULT_COLORS[color_idx % len(DEFAULT_COLORS)]
|
|
34
|
+
|
|
35
|
+
return label_colors
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def format_window_info(
|
|
39
|
+
window: Window,
|
|
40
|
+
) -> tuple[tuple[datetime, datetime] | None, float | None, float | None]:
|
|
41
|
+
"""Extract window metadata for display.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
window: Window object
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple of (time_range, lat, lon) where time_range is a tuple of (start, end) datetime objects
|
|
48
|
+
"""
|
|
49
|
+
lat = None
|
|
50
|
+
lon = None
|
|
51
|
+
|
|
52
|
+
geom_wgs84 = window.get_geometry().to_projection(WGS84_PROJECTION)
|
|
53
|
+
centroid = geom_wgs84.shp.centroid
|
|
54
|
+
lon = float(centroid.x)
|
|
55
|
+
lat = float(centroid.y)
|
|
56
|
+
|
|
57
|
+
return window.time_range, lat, lon
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def array_to_bytes(
|
|
61
|
+
array: np.ndarray, resampling: Image.Resampling = Image.Resampling.LANCZOS
|
|
62
|
+
) -> bytes:
|
|
63
|
+
"""Convert a numpy array to PNG bytes.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
array: Array with shape (height, width, channels) or (height, width) as uint8
|
|
67
|
+
resampling: PIL Image resampling method (default LANCZOS, use NEAREST for labels)
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
PNG image bytes
|
|
71
|
+
"""
|
|
72
|
+
if array.ndim == 2:
|
|
73
|
+
img = Image.fromarray(array, mode="L")
|
|
74
|
+
elif array.ndim == 3:
|
|
75
|
+
if array.shape[-1] == 1:
|
|
76
|
+
img = Image.fromarray(array[:, :, 0], mode="L")
|
|
77
|
+
elif array.shape[-1] == 3:
|
|
78
|
+
img = Image.fromarray(array, mode="RGB")
|
|
79
|
+
else:
|
|
80
|
+
img = Image.fromarray(array[:, :, :3], mode="RGB")
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Unsupported array shape: {array.shape}")
|
|
83
|
+
|
|
84
|
+
img = img.resize(VISUALIZATION_IMAGE_SIZE, resampling)
|
|
85
|
+
|
|
86
|
+
buf = BytesIO()
|
|
87
|
+
img.save(buf, format="PNG")
|
|
88
|
+
return buf.getvalue()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _escape_html(text: str) -> str:
|
|
92
|
+
"""Escape HTML special characters."""
|
|
93
|
+
return (
|
|
94
|
+
text.replace("&", "&")
|
|
95
|
+
.replace("<", "<")
|
|
96
|
+
.replace(">", ">")
|
|
97
|
+
.replace('"', """)
|
|
98
|
+
.replace("'", "'")
|
|
99
|
+
)
|