rslearn 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,574 @@
1
+ """Visualization server for rslearn datasets.
2
+
3
+ This module provides a web server to visualize rslearn datasets using the Dataset/Window APIs.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import random
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from flask import Flask, Response
13
+ from flask import render_template as flask_render_template
14
+ from PIL import Image
15
+ from upath import UPath
16
+
17
+ from rslearn.config import LayerType
18
+ from rslearn.dataset import Dataset, Window
19
+ from rslearn.log_utils import get_logger
20
+
21
+ from .render_raster_label import read_raster_layer, render_raster_label
22
+ from .render_sensor_image import render_sensor_image
23
+ from .render_vector_label import get_vector_label_by_property, render_vector_label_image
24
+ from .utils import array_to_bytes, format_window_info, generate_label_colors
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ def generate_image_as_bytes(
30
+ window: Window,
31
+ layer_name: str,
32
+ dataset: Dataset,
33
+ bands: dict[str, list[str]],
34
+ normalization: dict[str, str],
35
+ task_type: str,
36
+ label_colors: dict[str, tuple[int, int, int]] | None,
37
+ label_colors_dict: dict[str, dict[str, tuple[int, int, int]]] | None = None,
38
+ group_idx: int = 0,
39
+ label_layers: list[str] | None = None,
40
+ ) -> bytes:
41
+ """Generate an image for a window/layer combination as PNG bytes.
42
+
43
+ Args:
44
+ window: Window object
45
+ layer_name: Layer name to visualize
46
+ dataset: Dataset object
47
+ bands: Dictionary mapping layer_name -> list of band names
48
+ normalization: Dictionary mapping layer_name -> normalization method
49
+ task_type: Task type
50
+ label_colors: Dictionary mapping label class names to RGB colors
51
+ label_colors_dict: Dictionary mapping layer_name -> label_colors
52
+ group_idx: Item group index
53
+ label_layers: List of layer names that are labels
54
+
55
+ Returns:
56
+ PNG image bytes
57
+ """
58
+ label_layers = label_layers or []
59
+ layer_config = dataset.layers[layer_name]
60
+
61
+ # Render raster sensor image
62
+ if layer_config.type == LayerType.RASTER and layer_name not in label_layers:
63
+ if layer_name in bands and layer_name in normalization:
64
+ array = read_raster_layer(
65
+ window, layer_name, layer_config, bands[layer_name], group_idx=group_idx
66
+ )
67
+ image_array = render_sensor_image(array, normalization[layer_name])
68
+ return array_to_bytes(image_array)
69
+ else:
70
+ raise ValueError(
71
+ f"Bands or normalization not specified for layer {layer_name}"
72
+ )
73
+
74
+ # Render labels
75
+ elif layer_name in label_layers:
76
+ layer_label_colors = (
77
+ label_colors_dict.get(layer_name) if label_colors_dict else label_colors
78
+ )
79
+ if not layer_label_colors:
80
+ raise ValueError(f"No label colors available for layer {layer_name}")
81
+
82
+ # Render raster label
83
+ if layer_config.type == LayerType.RASTER:
84
+ band_set = layer_config.band_sets[0]
85
+ label_array = read_raster_layer(
86
+ window,
87
+ layer_name,
88
+ layer_config,
89
+ [band_set.bands[0]],
90
+ group_idx=group_idx,
91
+ )
92
+ image_array = render_raster_label(
93
+ label_array, layer_label_colors, layer_config
94
+ )
95
+ return array_to_bytes(image_array, resampling=Image.Resampling.NEAREST)
96
+
97
+ # Render vector label
98
+ elif layer_config.type == LayerType.VECTOR:
99
+ image_array = render_vector_label_image(
100
+ window,
101
+ layer_name,
102
+ layer_config,
103
+ task_type,
104
+ layer_label_colors,
105
+ dataset,
106
+ label_layers,
107
+ group_idx,
108
+ bands,
109
+ normalization,
110
+ )
111
+ # Use NEAREST for segmentation labels, LANCZOS for detection (which overlays on reference image)
112
+ resampling = (
113
+ Image.Resampling.NEAREST
114
+ if task_type == "segmentation"
115
+ else Image.Resampling.LANCZOS
116
+ )
117
+ return array_to_bytes(image_array, resampling=resampling)
118
+
119
+ raise ValueError(f"Layer {layer_name} is not a raster sensor image or label layer")
120
+
121
+
122
+ # Global state (set during initialization)
123
+ _app_state: dict[str, Any] = {}
124
+
125
+
126
+ def prepare_visualization_data(
127
+ sampled_windows: list[Window],
128
+ dataset: Dataset,
129
+ layers: list[str],
130
+ label_layers: list[str],
131
+ task_type: str,
132
+ label_colors_dict: dict[str, dict[str, tuple[int, int, int]]],
133
+ group_idx: int,
134
+ ) -> dict[str, Any]:
135
+ """Prepare data for visualization template.
136
+
137
+ Args:
138
+ sampled_windows: List of windows to display
139
+ dataset: Dataset object
140
+ layers: List of all layers
141
+ label_layers: List of label layer names
142
+ task_type: Task type
143
+ label_colors_dict: Dictionary mapping layer_name -> label_colors
144
+ group_idx: Item group index
145
+
146
+ Returns:
147
+ Dictionary with template context data
148
+ """
149
+ window_data: list[dict[str, Any]] = []
150
+ for idx, window in enumerate(sampled_windows):
151
+ time_range, lat, lon = format_window_info(window)
152
+ maps_link = (
153
+ f"https://www.google.com/maps?q={lat},{lon}"
154
+ if lat is not None and lon is not None
155
+ else None
156
+ )
157
+
158
+ available_layers = set()
159
+ mask_layers = []
160
+ label_texts = {}
161
+
162
+ for layer_name in layers:
163
+ if layer_name not in dataset.layers:
164
+ continue
165
+ layer_config = dataset.layers[layer_name]
166
+ try:
167
+ if (
168
+ layer_config.type == LayerType.RASTER
169
+ and layer_name not in label_layers
170
+ ):
171
+ if window.is_layer_completed(layer_name, group_idx=group_idx):
172
+ available_layers.add(layer_name)
173
+ elif layer_name in label_layers:
174
+ if layer_config.type == LayerType.VECTOR:
175
+ if task_type == "classification":
176
+ try:
177
+ if not window.is_layer_completed(
178
+ layer_name, group_idx=group_idx
179
+ ):
180
+ logger.debug(
181
+ f"Layer {layer_name} not marked as completed for window {window.name}, attempting to read anyway"
182
+ )
183
+ label_text = get_vector_label_by_property(
184
+ window,
185
+ layer_config,
186
+ layer_name,
187
+ group_idx=group_idx,
188
+ )
189
+ if label_text is not None:
190
+ label_texts[layer_name] = label_text
191
+ except Exception as e:
192
+ logger.debug(
193
+ f"Failed to get label text for {layer_name} in window {window.name}: {e}"
194
+ )
195
+ else:
196
+ if window.is_layer_completed(
197
+ layer_name, group_idx=group_idx
198
+ ):
199
+ mask_layers.append(layer_name)
200
+ elif layer_config.type == LayerType.RASTER:
201
+ if window.is_layer_completed(layer_name, group_idx=group_idx):
202
+ mask_layers.append(layer_name)
203
+ except Exception as e:
204
+ logger.debug(
205
+ f"Error processing layer {layer_name} for window {window.name}: {e}"
206
+ )
207
+ continue
208
+
209
+ # Format time range for template
210
+ time_range_formatted = None
211
+ if time_range:
212
+ time_range_formatted = (
213
+ time_range[0].isoformat()[:10],
214
+ time_range[1].isoformat()[:10],
215
+ )
216
+
217
+ window_data.append(
218
+ {
219
+ "idx": idx,
220
+ "name": window.name,
221
+ "time_range": time_range,
222
+ "time_range_formatted": time_range_formatted,
223
+ "lat": lat,
224
+ "lon": lon,
225
+ "maps_link": maps_link,
226
+ "available_layers": available_layers,
227
+ "mask_layers": mask_layers,
228
+ "label_texts": label_texts,
229
+ }
230
+ )
231
+
232
+ label_colors = None
233
+ sorted_label_keys = None
234
+ if label_colors_dict:
235
+ first_label_layer = list(label_colors_dict.keys())[0]
236
+ label_colors = label_colors_dict[first_label_layer]
237
+ sorted_label_keys = sorted(label_colors.keys())
238
+
239
+ return {
240
+ "windows": window_data,
241
+ "layers": layers,
242
+ "label_colors": label_colors,
243
+ "sorted_label_keys": sorted_label_keys,
244
+ "task_type": task_type,
245
+ }
246
+
247
+
248
+ def create_app(
249
+ dataset: Dataset,
250
+ windows: list[Window],
251
+ layers: list[str],
252
+ bands: dict[str, list[str]],
253
+ normalization: dict[str, str],
254
+ task_type: str,
255
+ label_colors_dict: dict[str, dict[str, tuple[int, int, int]]],
256
+ group_idx: int,
257
+ label_layers: list[str],
258
+ max_samples: int,
259
+ ) -> Flask:
260
+ """Create and configure Flask app.
261
+
262
+ Args:
263
+ dataset: Dataset object
264
+ windows: List of all windows
265
+ layers: List of all layers
266
+ bands: Dictionary mapping layer_name -> list of band names
267
+ normalization: Dictionary mapping layer_name -> normalization method
268
+ task_type: Task type
269
+ label_colors_dict: Dictionary mapping layer_name -> label_colors
270
+ group_idx: Item group index
271
+ label_layers: List of label layer names
272
+ max_samples: Maximum number of windows to sample
273
+
274
+ Returns:
275
+ Configured Flask app
276
+ """
277
+ # Set template folder explicitly to ensure Flask can find templates
278
+ template_folder = Path(__file__).parent / "templates"
279
+ app = Flask(__name__, template_folder=str(template_folder))
280
+
281
+ @app.route("/")
282
+ def index() -> str:
283
+ """Render the main visualization page."""
284
+ if len(windows) > max_samples:
285
+ sampled_windows = random.sample(windows, max_samples)
286
+ else:
287
+ sampled_windows = windows
288
+
289
+ template_data = prepare_visualization_data(
290
+ sampled_windows,
291
+ dataset,
292
+ layers,
293
+ label_layers,
294
+ task_type,
295
+ label_colors_dict,
296
+ group_idx,
297
+ )
298
+ return flask_render_template("visualization.html", **template_data)
299
+
300
+ @app.route("/images/<int:window_idx>/<layer_name>")
301
+ def get_image(window_idx: int, layer_name: str) -> Response:
302
+ """Generate and serve an image for a specific window/layer.
303
+
304
+ Args:
305
+ window_idx: Index of the window in the windows list
306
+ layer_name: Name of the layer to visualize
307
+
308
+ Returns:
309
+ PNG image response or error response
310
+ """
311
+ if window_idx < 0 or window_idx >= len(windows):
312
+ return Response(
313
+ "Window index out of range", status=404, mimetype="text/plain"
314
+ )
315
+
316
+ window = windows[window_idx]
317
+
318
+ layer_label_colors = None
319
+ if label_colors_dict and layer_name in label_colors_dict:
320
+ layer_label_colors = label_colors_dict[layer_name]
321
+ elif label_colors_dict:
322
+ first_label_layer = list(label_colors_dict.keys())[0]
323
+ layer_label_colors = label_colors_dict[first_label_layer]
324
+
325
+ image_bytes = generate_image_as_bytes(
326
+ window,
327
+ layer_name,
328
+ dataset,
329
+ bands,
330
+ normalization,
331
+ task_type,
332
+ layer_label_colors,
333
+ label_colors_dict,
334
+ group_idx,
335
+ label_layers,
336
+ )
337
+
338
+ return Response(
339
+ image_bytes,
340
+ mimetype="image/png",
341
+ headers={"Content-Length": str(len(image_bytes))},
342
+ )
343
+
344
+ return app
345
+
346
+
347
+ def run(
348
+ dataset_path: str | Path | UPath,
349
+ layers: list[str] | None = None,
350
+ bands: dict[str, list[str]] | None = None,
351
+ normalization: dict[str, str] | None = None,
352
+ task_type: str | None = None,
353
+ max_samples: int = 100,
354
+ port: int = 8000,
355
+ host: str = "0.0.0.0",
356
+ group_idx: int = 0,
357
+ label_layers: list[str] | None = None,
358
+ ) -> None:
359
+ """Run the visualization server.
360
+
361
+ Args:
362
+ dataset_path: Path to dataset directory (containing config.json)
363
+ layers: List of layer names to visualize
364
+ bands: Dictionary mapping layer_name -> list of band names
365
+ normalization: Dictionary mapping layer_name -> normalization method
366
+ task_type: Task type - "classification", "regression", "detection", or "segmentation"
367
+ max_samples: Maximum number of windows to sample
368
+ port: Port to serve on
369
+ host: Host to bind to
370
+ group_idx: Item group index (default 0)
371
+ label_layers: List of layer names that are labels
372
+ """
373
+ dataset_path = UPath(dataset_path)
374
+ dataset = Dataset(dataset_path)
375
+
376
+ label_layers = label_layers or []
377
+
378
+ if layers is None:
379
+ raise ValueError("--layers is required")
380
+ all_layers = list(set(layers + label_layers))
381
+ raster_image_layers = [name for name in all_layers if name not in label_layers]
382
+ label_layers_in_list = [name for name in all_layers if name in label_layers]
383
+
384
+ bands = bands or {}
385
+ normalization = normalization or {}
386
+ for layer_name in raster_image_layers:
387
+ if layer_name not in bands:
388
+ raise ValueError(
389
+ f"Bands not specified for layer {layer_name}. Please provide --bands {layer_name}:band1,band2,band3"
390
+ )
391
+ if layer_name not in normalization:
392
+ raise ValueError(
393
+ f"Normalization not specified for layer {layer_name}. Please provide --normalization {layer_name}:method"
394
+ )
395
+
396
+ label_colors_dict = {}
397
+ if task_type in ("segmentation", "detection"):
398
+ for label_layer_name in label_layers_in_list:
399
+ label_config = dataset.layers[label_layer_name]
400
+ if not label_config.class_names:
401
+ raise ValueError(
402
+ f"class_names must be specified in the config for label layer '{label_layer_name}'. "
403
+ "Auto-detection of class names is not supported."
404
+ )
405
+ else:
406
+ label_classes = set(label_config.class_names)
407
+ label_colors = generate_label_colors(label_classes)
408
+ label_colors_dict[label_layer_name] = label_colors
409
+ logger.info(
410
+ f"Found {len(label_classes)} label classes for {label_layer_name}: {sorted(label_classes)}"
411
+ )
412
+
413
+ logger.info(f"Loading all windows from dataset {dataset_path}")
414
+ windows = dataset.load_windows()
415
+ logger.info(f"Loaded {len(windows)} windows from dataset")
416
+ logger.info(f"Layers: {all_layers}")
417
+ logger.info(f"Bands: {bands}")
418
+ logger.info(f"Normalization: {normalization}")
419
+ logger.info(f"Task type: {task_type}")
420
+
421
+ if task_type is None:
422
+ raise ValueError("--task_type is required")
423
+
424
+ app = create_app(
425
+ dataset,
426
+ windows,
427
+ all_layers,
428
+ bands,
429
+ normalization,
430
+ task_type,
431
+ label_colors_dict,
432
+ group_idx,
433
+ label_layers,
434
+ max_samples,
435
+ )
436
+
437
+ logger.info(f"Serving on http://{host}:{port}")
438
+ logger.info(f"Open http://localhost:{port} in your browser")
439
+ logger.info(
440
+ f"Loaded {len(windows)} windows - refreshing the page will show a different random sample"
441
+ )
442
+
443
+ app.run(host=host, port=port, debug=False)
444
+
445
+
446
+ def parse_bands_arg(bands_str: str | None) -> dict[str, list[str]]:
447
+ """Parse --bands argument as JSON.
448
+
449
+ Args:
450
+ bands_str: JSON string mapping layer_name -> list of band names, e.g. '{"sentinel2": ["B04", "B03", "B02"]}'
451
+
452
+ Returns:
453
+ Dictionary mapping layer_name -> list of band names
454
+ """
455
+ if not bands_str:
456
+ return {}
457
+ try:
458
+ bands_dict = json.loads(bands_str)
459
+ if not isinstance(bands_dict, dict):
460
+ raise ValueError("Bands must be a JSON object/dictionary")
461
+ for layer_name, band_list in bands_dict.items():
462
+ if not isinstance(band_list, list):
463
+ raise ValueError(f"Bands for layer '{layer_name}' must be a list")
464
+ if not all(isinstance(band, str) for band in band_list):
465
+ raise ValueError(f"All bands for layer '{layer_name}' must be strings")
466
+ return bands_dict
467
+ except json.JSONDecodeError as e:
468
+ raise ValueError(f"Invalid JSON for bands: {e}") from e
469
+
470
+
471
+ def parse_normalization_arg(norm_str: str | None) -> dict[str, str]:
472
+ """Parse --normalization argument as JSON.
473
+
474
+ Args:
475
+ norm_str: JSON string mapping layer_name -> normalization method, e.g. '{"sentinel2": "sentinel2_rgb"}'
476
+
477
+ Returns:
478
+ Dictionary mapping layer_name -> normalization method
479
+ """
480
+ if not norm_str:
481
+ return {}
482
+ try:
483
+ norm_dict = json.loads(norm_str)
484
+ if not isinstance(norm_dict, dict):
485
+ raise ValueError("Normalization must be a JSON object/dictionary")
486
+ for layer_name, method in norm_dict.items():
487
+ if not isinstance(method, str):
488
+ raise ValueError(
489
+ f"Normalization method for layer '{layer_name}' must be a string"
490
+ )
491
+ return norm_dict
492
+ except json.JSONDecodeError as e:
493
+ raise ValueError(f"Invalid JSON for normalization: {e}") from e
494
+
495
+
496
+ def main() -> None:
497
+ """Main entry point for the visualization server CLI."""
498
+ parser = argparse.ArgumentParser(
499
+ description="Visualize rslearn dataset in a web browser"
500
+ )
501
+ parser.add_argument(
502
+ "dataset_path",
503
+ type=str,
504
+ help="Path to dataset directory (containing config.json)",
505
+ )
506
+ parser.add_argument(
507
+ "--layers",
508
+ type=str,
509
+ required=True,
510
+ nargs="+",
511
+ help="List of layer names to visualize (default: all raster layers)",
512
+ )
513
+ parser.add_argument(
514
+ "--label_layers",
515
+ type=str,
516
+ nargs="+",
517
+ help="List of layer names that are labels (same format as --layers)",
518
+ )
519
+ parser.add_argument(
520
+ "--bands",
521
+ type=str,
522
+ required=True,
523
+ help='Bands to visualize per layer as JSON. Example: \'{"sentinel2": ["B04", "B03", "B02"]}\'',
524
+ )
525
+ parser.add_argument(
526
+ "--normalization",
527
+ type=str,
528
+ required=True,
529
+ help='Normalization method per layer as JSON. Example: \'{"sentinel2": "sentinel2_rgb"}\'',
530
+ )
531
+ parser.add_argument(
532
+ "--task_type",
533
+ type=str,
534
+ required=True,
535
+ choices=["classification", "regression", "detection", "segmentation"],
536
+ help="Task type (default: auto-detect, but user should specify)",
537
+ )
538
+ parser.add_argument(
539
+ "--max_samples",
540
+ type=int,
541
+ default=100,
542
+ help="Maximum number of windows to sample",
543
+ )
544
+ parser.add_argument(
545
+ "--port", type=int, default=8000, help="Port to serve on (default: 8000)"
546
+ )
547
+ parser.add_argument(
548
+ "--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
549
+ )
550
+ parser.add_argument(
551
+ "--group_idx", type=int, default=0, help="Item group index (default: 0)"
552
+ )
553
+
554
+ args = parser.parse_args()
555
+
556
+ bands_dict = parse_bands_arg(args.bands)
557
+ normalization_dict = parse_normalization_arg(args.normalization)
558
+
559
+ run(
560
+ dataset_path=args.dataset_path,
561
+ layers=args.layers,
562
+ bands=bands_dict if bands_dict else None,
563
+ normalization=normalization_dict if normalization_dict else None,
564
+ task_type=args.task_type,
565
+ max_samples=args.max_samples,
566
+ port=args.port,
567
+ host=args.host,
568
+ group_idx=args.group_idx,
569
+ label_layers=args.label_layers,
570
+ )
571
+
572
+
573
+ if __name__ == "__main__":
574
+ main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.23
3
+ Version: 0.0.25
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -214,6 +214,7 @@ License-File: LICENSE
214
214
  License-File: NOTICE
215
215
  Requires-Dist: boto3>=1.39
216
216
  Requires-Dist: fiona>=1.10
217
+ Requires-Dist: flask>=3.0.0
217
218
  Requires-Dist: fsspec>=2025.10.0
218
219
  Requires-Dist: jsonargparse>=4.35.0
219
220
  Requires-Dist: lightning>=2.5.1.post0
@@ -482,6 +483,18 @@ We can visualize both the GeoTIFFs together in qgis:
482
483
  qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/*/*/geotiff.tif
483
484
  ```
484
485
 
486
+ We can also visualize samples using the visualization module:
487
+ ```
488
+ python -m rslearn.vis.vis_server \
489
+ $DATASET_PATH \
490
+ --layers sentinel2 \ # image modality layers
491
+ --label_layers label_raster \ # layer layers
492
+ --bands '{"sentinel2": ["B04", "B03", "B02"]}' \ # specify bands wanted for each image modality
493
+ --normalization '{"sentinel2": "sentinel2_rgb"}' \ # specify normalization wanted for each image modality
494
+ --task_type segmentation \ # segmentation, detection, or classification
495
+ --max_samples 100 \ # number of datapoints to randomly sample and visualize
496
+ --port 8000
497
+ ```
485
498
 
486
499
  ### Training a Model
487
500