rslearn 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,439 @@
1
+ """Functions for rendering vector label masks (detection and segmentation)."""
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import shapely
7
+ from PIL import Image, ImageDraw
8
+
9
+ from rslearn.config import LayerConfig, LayerType
10
+ from rslearn.dataset import Dataset, Window
11
+ from rslearn.log_utils import get_logger
12
+ from rslearn.utils.feature import Feature
13
+ from rslearn.utils.geometry import PixelBounds, Projection, flatten_shape
14
+ from rslearn.utils.vector_format import VectorFormat
15
+
16
+ from .normalization import normalize_array
17
+ from .render_raster_label import read_raster_layer
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ def point_to_pixel_coords(
23
+ point: shapely.Point,
24
+ bounds: PixelBounds,
25
+ image_width: int,
26
+ image_height: int,
27
+ actual_width: int,
28
+ actual_height: int,
29
+ ) -> tuple[int, int]:
30
+ """Convert a point's coordinates to pixel coordinates in the image.
31
+
32
+ Args:
33
+ point: Shapely Point object
34
+ bounds: Pixel bounds of the window
35
+ image_width: Width of the image in pixels
36
+ image_height: Height of the image in pixels
37
+ actual_width: Actual width of the data (bounds[2] - bounds[0])
38
+ actual_height: Actual height of the data (bounds[3] - bounds[1])
39
+
40
+ Returns:
41
+ Tuple of (pixel_x, pixel_y) coordinates
42
+ """
43
+ x, y = point.x, point.y
44
+ px = int((x - bounds[0]) * image_width / actual_width)
45
+ py = int((y - bounds[1]) * image_height / actual_height)
46
+ return px, py
47
+
48
+
49
+ def draw_bounding_box_around_point(
50
+ draw: ImageDraw.ImageDraw,
51
+ px: int,
52
+ py: int,
53
+ width: int,
54
+ height: int,
55
+ color: tuple[int, int, int],
56
+ box_size: int = 20,
57
+ ) -> bool:
58
+ """Draw a bounding box around a point on an image.
59
+
60
+ Args:
61
+ draw: PIL ImageDraw object
62
+ px: Pixel x coordinate
63
+ py: Pixel y coordinate
64
+ width: Image width
65
+ height: Image height
66
+ color: RGB color tuple for the bounding box
67
+ box_size: Size of the bounding box in pixels (default 20)
68
+
69
+ Returns:
70
+ True if the point was drawn (within bounds), False otherwise
71
+ """
72
+ if 0 <= px < width and 0 <= py < height:
73
+ x1 = max(0, px - box_size // 2)
74
+ y1 = max(0, py - box_size // 2)
75
+ x2 = min(width, px + box_size // 2)
76
+ y2 = min(height, py + box_size // 2)
77
+ draw.rectangle(
78
+ [x1, y1, x2, y2],
79
+ outline=color,
80
+ width=2,
81
+ )
82
+ return True
83
+ return False
84
+
85
+
86
+ def overlay_points_on_image(
87
+ draw: ImageDraw.ImageDraw,
88
+ features: list[Feature],
89
+ bounds: PixelBounds,
90
+ image_width: int,
91
+ image_height: int,
92
+ actual_width: int,
93
+ actual_height: int,
94
+ label_colors: dict[str, tuple[int, int, int]],
95
+ class_property_name: str | None = None,
96
+ ) -> tuple[int, int]:
97
+ """Overlay point features on an image.
98
+
99
+ Args:
100
+ draw: PIL ImageDraw object
101
+ features: List of Feature objects (points)
102
+ bounds: Pixel bounds of the window
103
+ image_width: Width of the image in pixels
104
+ image_height: Height of the image in pixels
105
+ actual_width: Actual width of the data (bounds[2] - bounds[0])
106
+ actual_height: Actual height of the data (bounds[3] - bounds[1])
107
+ label_colors: Dictionary mapping label class names to RGB colors
108
+ class_property_name: Property name to use for label extraction (from config)
109
+
110
+ Returns:
111
+ Tuple of (points_drawn, points_out_of_bounds)
112
+ """
113
+ points_drawn = 0
114
+ points_out_of_bounds = 0
115
+
116
+ if not class_property_name:
117
+ raise ValueError(
118
+ "class_property_name must be specified in config for vector label layers"
119
+ )
120
+
121
+ for feature in features:
122
+ label = feature.properties.get(class_property_name)
123
+ label = str(label)
124
+ color = label_colors.get(label)
125
+ if color is None:
126
+ raise ValueError(
127
+ f"Label '{label}' not found in label_colors. Available labels: {list(label_colors.keys())}"
128
+ )
129
+
130
+ shp = feature.geometry.shp
131
+ flat_shapes = flatten_shape(shp)
132
+ for point in flat_shapes:
133
+ assert isinstance(point, shapely.Point), (
134
+ f"Expected Point, got {type(point)}"
135
+ )
136
+ px, py = point_to_pixel_coords(
137
+ point, bounds, image_width, image_height, actual_width, actual_height
138
+ )
139
+ logger.info(
140
+ f"Point at ({point.x:.2f}, {point.y:.2f}) -> pixel ({px}, {py}), "
141
+ f"bounds: {bounds}, image size: {image_width}x{image_height}, "
142
+ f"actual_size: {actual_width}x{actual_height}"
143
+ )
144
+ if draw_bounding_box_around_point(
145
+ draw, px, py, image_width, image_height, color
146
+ ):
147
+ points_drawn += 1
148
+ else:
149
+ points_out_of_bounds += 1
150
+
151
+ return points_drawn, points_out_of_bounds
152
+
153
+
154
+ def render_vector_label_detection(
155
+ features: list[Feature],
156
+ bounds: PixelBounds,
157
+ label_colors: dict[str, tuple[int, int, int]],
158
+ class_property_name: str | None = None,
159
+ reference_array: np.ndarray | None = None,
160
+ normalization_method: str | None = None,
161
+ ) -> np.ndarray:
162
+ """Render vector labels for detection tasks (overlay points on reference image or blank background).
163
+
164
+ Args:
165
+ features: List of Feature objects (points)
166
+ bounds: Pixel bounds of the window
167
+ label_colors: Dictionary mapping label class names to RGB colors
168
+ class_property_name: Property name to use for label extraction (from config)
169
+ reference_array: Optional reference raster array to overlay points on
170
+ normalization_method: Optional normalization method for the reference array
171
+
172
+ Returns:
173
+ Array with shape (height, width, 3) as uint8
174
+ """
175
+ actual_width = bounds[2] - bounds[0]
176
+ actual_height = bounds[3] - bounds[1]
177
+
178
+ if reference_array is not None and normalization_method is not None:
179
+ normalized = normalize_array(reference_array, normalization_method)
180
+ if normalized.shape[-1] >= 3:
181
+ img = Image.fromarray(normalized[:, :, :3], mode="RGB")
182
+ else:
183
+ img = Image.fromarray(normalized[:, :, 0], mode="L").convert("RGB")
184
+ else:
185
+ img = Image.new("RGB", (actual_width, actual_height), color=(0, 0, 0))
186
+
187
+ draw = ImageDraw.Draw(img)
188
+
189
+ overlay_points_on_image(
190
+ draw,
191
+ features,
192
+ bounds,
193
+ img.size[0],
194
+ img.size[1],
195
+ actual_width,
196
+ actual_height,
197
+ label_colors,
198
+ class_property_name,
199
+ )
200
+ return np.array(img)
201
+
202
+
203
+ def render_vector_label_segmentation(
204
+ features: list[Feature],
205
+ bounds: PixelBounds,
206
+ projection: Projection,
207
+ label_colors: dict[str, tuple[int, int, int]],
208
+ class_property_name: str | None = None,
209
+ ) -> np.ndarray:
210
+ """Render vector labels for segmentation tasks (draw polygons on mask).
211
+
212
+ Args:
213
+ features: List of Feature objects (polygons)
214
+ bounds: Pixel bounds of the window
215
+ projection: Projection of the window
216
+ label_colors: Dictionary mapping label class names to RGB colors
217
+ class_property_name: Property name to use for label extraction (from config)
218
+
219
+ Returns:
220
+ Array with shape (height, width, 3) as uint8
221
+ """
222
+ width = bounds[2] - bounds[0]
223
+ height = bounds[3] - bounds[1]
224
+
225
+ mask_img = Image.new("RGB", (width, height), color=(0, 0, 0))
226
+ draw = ImageDraw.Draw(mask_img)
227
+
228
+ if not class_property_name:
229
+ raise ValueError(
230
+ "class_property_name must be specified in config for vector label layers"
231
+ )
232
+
233
+ def get_label(feat: Any) -> str:
234
+ if not feat.properties:
235
+ return ""
236
+ label = feat.properties.get(class_property_name)
237
+ return str(label) if label else ""
238
+
239
+ sorted_features = sorted(features, key=get_label, reverse=True)
240
+
241
+ for feature in sorted_features:
242
+ label = feature.properties.get(class_property_name)
243
+ label = str(label)
244
+ color = label_colors.get(label, (255, 255, 255))
245
+ geom_pixel = feature.geometry.to_projection(projection)
246
+ shp = geom_pixel.shp
247
+ if shp.geom_type == "Polygon":
248
+ coords = list(shp.exterior.coords)
249
+ pixel_coords = [(int(x - bounds[0]), int(y - bounds[1])) for x, y in coords]
250
+ if len(pixel_coords) >= 3:
251
+ draw.polygon(pixel_coords, fill=color, outline=color)
252
+ for interior in shp.interiors:
253
+ hole_coords = [
254
+ (int(x - bounds[0]), int(y - bounds[1]))
255
+ for x, y in interior.coords
256
+ ]
257
+ if len(hole_coords) >= 3:
258
+ draw.polygon(hole_coords, fill=(0, 0, 0), outline=color)
259
+ elif shp.geom_type == "MultiPolygon":
260
+ for poly in shp.geoms:
261
+ coords = list(poly.exterior.coords)
262
+ pixel_coords = [
263
+ (int(x - bounds[0]), int(y - bounds[1])) for x, y in coords
264
+ ]
265
+ if len(pixel_coords) >= 3:
266
+ draw.polygon(pixel_coords, fill=color, outline=color)
267
+
268
+ return np.array(mask_img)
269
+
270
+
271
+ def read_vector_layer(
272
+ window: Window,
273
+ layer_name: str,
274
+ layer_config: LayerConfig,
275
+ group_idx: int = 0,
276
+ ) -> list[Any]:
277
+ """Read a vector layer for visualization.
278
+
279
+ Args:
280
+ window: The window to read from
281
+ layer_name: The layer name
282
+ layer_config: The layer configuration
283
+ group_idx: The item group index (default 0)
284
+
285
+ Returns:
286
+ List of Feature objects
287
+ """
288
+ if layer_config.type != LayerType.VECTOR:
289
+ raise ValueError(f"Layer {layer_name} is not a vector layer")
290
+
291
+ vector_format: VectorFormat = layer_config.instantiate_vector_format()
292
+ layer_dir = window.get_layer_dir(layer_name, group_idx=group_idx)
293
+ logger.info(
294
+ f"Reading vector layer {layer_name} from {layer_dir}, bounds: {window.bounds}, projection: {window.projection}"
295
+ )
296
+
297
+ features = vector_format.decode_vector(layer_dir, window.projection, window.bounds)
298
+ logger.info(f"Decoded {len(features)} features from vector layer {layer_name}")
299
+ return features
300
+
301
+
302
+ def get_vector_label_by_property(
303
+ window: Window,
304
+ layer_config: LayerConfig,
305
+ layer_name: str,
306
+ group_idx: int = 0,
307
+ ) -> str | None:
308
+ """Get a label value from a vector layer's first feature property.
309
+
310
+ Extracts the label value from the first feature's properties using the property
311
+ name specified in layer_config.class_property_name. This works for both
312
+ classification and segmentation tasks that use vector labels.
313
+
314
+ Args:
315
+ window: The window to read from
316
+ layer_config: The label layer configuration (must be vector type)
317
+ layer_name: The name of the label layer
318
+ group_idx: The item group index (default 0)
319
+
320
+ Returns:
321
+ The label string, or None if not found
322
+ """
323
+ features = read_vector_layer(window, layer_name, layer_config, group_idx=group_idx)
324
+ if not features:
325
+ logger.warning(
326
+ f"No features in vector label layer {layer_name} for {window.name}"
327
+ )
328
+ return None
329
+
330
+ first_feature = features[0]
331
+ if not first_feature.properties:
332
+ return None
333
+
334
+ if not layer_config.class_property_name:
335
+ raise ValueError(
336
+ f"class_property_name must be specified in the config for vector label layer '{layer_name}'. "
337
+ )
338
+
339
+ label = first_feature.properties.get(layer_config.class_property_name)
340
+ logger.info(f"Label for {window.name}: {label}")
341
+ return str(label)
342
+
343
+
344
+ def _get_reference_raster_for_detection(
345
+ window: Window,
346
+ dataset: Dataset,
347
+ bands: dict[str, list[str]],
348
+ normalization: dict[str, str],
349
+ label_layers: list[str],
350
+ group_idx: int,
351
+ ) -> tuple[np.ndarray, str] | None:
352
+ """Get reference raster array for detection tasks.
353
+
354
+ Args:
355
+ window: Window object
356
+ dataset: Dataset object
357
+ bands: Dictionary mapping layer_name -> list of band names
358
+ normalization: Dictionary mapping layer_name -> normalization method
359
+ label_layers: List of label layer names
360
+ group_idx: Item group index
361
+
362
+ Returns:
363
+ Tuple of (reference_array, normalization_method) or None if no raster layers available
364
+ """
365
+ raster_layers = [name for name in bands.keys() if name not in label_layers]
366
+
367
+ ref_layer_name = raster_layers[0]
368
+ ref_layer_config = dataset.layers[ref_layer_name]
369
+ reference_array = read_raster_layer(
370
+ window,
371
+ ref_layer_name,
372
+ ref_layer_config,
373
+ bands[ref_layer_name],
374
+ group_idx=group_idx,
375
+ )
376
+ ref_normalization_method = normalization[ref_layer_name]
377
+ return reference_array, ref_normalization_method
378
+
379
+
380
+ def render_vector_label_image(
381
+ window: Window,
382
+ layer_name: str,
383
+ layer_config: LayerConfig,
384
+ task_type: str,
385
+ label_colors: dict[str, tuple[int, int, int]],
386
+ dataset: Dataset,
387
+ label_layers: list[str],
388
+ group_idx: int,
389
+ bands: dict[str, list[str]] | None = None,
390
+ normalization: dict[str, str] | None = None,
391
+ ) -> np.ndarray:
392
+ """Render a vector label image (detection or segmentation).
393
+
394
+ Args:
395
+ window: Window object
396
+ layer_name: Layer name
397
+ layer_config: LayerConfig object
398
+ task_type: Task type ("detection" or "segmentation")
399
+ label_colors: Dictionary mapping label class names to RGB colors
400
+ dataset: Dataset object
401
+ label_layers: List of label layer names
402
+ group_idx: Item group index
403
+ bands: Optional dictionary mapping layer_name -> list of band names (for detection reference raster)
404
+ normalization: Optional dictionary mapping layer_name -> normalization method (for detection reference raster)
405
+
406
+ Returns:
407
+ Array with shape (height, width, 3) as uint8
408
+ """
409
+ if task_type == "classification":
410
+ raise ValueError("Classification labels are text, not images")
411
+
412
+ features = read_vector_layer(window, layer_name, layer_config, group_idx=group_idx)
413
+
414
+ if task_type == "detection":
415
+ bands = bands or {}
416
+ normalization = normalization or {}
417
+ ref_data = _get_reference_raster_for_detection(
418
+ window, dataset, bands, normalization, label_layers, group_idx
419
+ )
420
+ reference_array = None
421
+ ref_normalization_method = None
422
+ if ref_data is not None:
423
+ reference_array, ref_normalization_method = ref_data
424
+ return render_vector_label_detection(
425
+ features,
426
+ window.bounds,
427
+ label_colors,
428
+ layer_config.class_property_name,
429
+ reference_array,
430
+ ref_normalization_method,
431
+ )
432
+ else:
433
+ return render_vector_label_segmentation(
434
+ features,
435
+ window.bounds,
436
+ window.projection,
437
+ label_colors,
438
+ layer_config.class_property_name,
439
+ )
rslearn/vis/utils.py ADDED
@@ -0,0 +1,99 @@
1
+ """Utility functions and constants for visualization."""
2
+
3
+ from datetime import datetime
4
+ from io import BytesIO
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from rslearn.dataset import Window
10
+ from rslearn.log_utils import get_logger
11
+ from rslearn.utils.colors import DEFAULT_COLORS
12
+ from rslearn.utils.geometry import WGS84_PROJECTION
13
+
14
+ logger = get_logger(__name__)
15
+
16
+ # Fixed size for all visualized images (width, height in pixels)
17
+ VISUALIZATION_IMAGE_SIZE = (512, 512)
18
+
19
+
20
+ def generate_label_colors(label_classes: set[str]) -> dict[str, tuple[int, int, int]]:
21
+ """Generate distinct colors for label classes.
22
+
23
+ Args:
24
+ label_classes: Set or list of label class names
25
+
26
+ Returns:
27
+ Dictionary mapping label class names to RGB color tuples
28
+ """
29
+ label_colors = {}
30
+
31
+ sorted_labels = sorted(label_classes)
32
+ for color_idx, label in enumerate(sorted_labels):
33
+ label_colors[label] = DEFAULT_COLORS[color_idx % len(DEFAULT_COLORS)]
34
+
35
+ return label_colors
36
+
37
+
38
+ def format_window_info(
39
+ window: Window,
40
+ ) -> tuple[tuple[datetime, datetime] | None, float | None, float | None]:
41
+ """Extract window metadata for display.
42
+
43
+ Args:
44
+ window: Window object
45
+
46
+ Returns:
47
+ Tuple of (time_range, lat, lon) where time_range is a tuple of (start, end) datetime objects
48
+ """
49
+ lat = None
50
+ lon = None
51
+
52
+ geom_wgs84 = window.get_geometry().to_projection(WGS84_PROJECTION)
53
+ centroid = geom_wgs84.shp.centroid
54
+ lon = float(centroid.x)
55
+ lat = float(centroid.y)
56
+
57
+ return window.time_range, lat, lon
58
+
59
+
60
+ def array_to_bytes(
61
+ array: np.ndarray, resampling: Image.Resampling = Image.Resampling.LANCZOS
62
+ ) -> bytes:
63
+ """Convert a numpy array to PNG bytes.
64
+
65
+ Args:
66
+ array: Array with shape (height, width, channels) or (height, width) as uint8
67
+ resampling: PIL Image resampling method (default LANCZOS, use NEAREST for labels)
68
+
69
+ Returns:
70
+ PNG image bytes
71
+ """
72
+ if array.ndim == 2:
73
+ img = Image.fromarray(array, mode="L")
74
+ elif array.ndim == 3:
75
+ if array.shape[-1] == 1:
76
+ img = Image.fromarray(array[:, :, 0], mode="L")
77
+ elif array.shape[-1] == 3:
78
+ img = Image.fromarray(array, mode="RGB")
79
+ else:
80
+ img = Image.fromarray(array[:, :, :3], mode="RGB")
81
+ else:
82
+ raise ValueError(f"Unsupported array shape: {array.shape}")
83
+
84
+ img = img.resize(VISUALIZATION_IMAGE_SIZE, resampling)
85
+
86
+ buf = BytesIO()
87
+ img.save(buf, format="PNG")
88
+ return buf.getvalue()
89
+
90
+
91
+ def _escape_html(text: str) -> str:
92
+ """Escape HTML special characters."""
93
+ return (
94
+ text.replace("&", "&amp;")
95
+ .replace("<", "&lt;")
96
+ .replace(">", "&gt;")
97
+ .replace('"', "&quot;")
98
+ .replace("'", "&#x27;")
99
+ )