rslearn 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/local_files.py +2 -2
- rslearn/data_sources/utils.py +204 -64
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/data_module.py +5 -0
- rslearn/train/dataset.py +186 -49
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/model_context.py +16 -0
- rslearn/train/tasks/detection.py +1 -18
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/segmentation.py +27 -32
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/colors.py +20 -0
- rslearn/vis/__init__.py +1 -0
- rslearn/vis/normalization.py +127 -0
- rslearn/vis/render_raster_label.py +96 -0
- rslearn/vis/render_sensor_image.py +27 -0
- rslearn/vis/render_vector_label.py +439 -0
- rslearn/vis/utils.py +99 -0
- rslearn/vis/vis_server.py +574 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/METADATA +14 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/RECORD +42 -33
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/WHEEL +1 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,574 @@
|
|
|
1
|
+
"""Visualization server for rslearn datasets.
|
|
2
|
+
|
|
3
|
+
This module provides a web server to visualize rslearn datasets using the Dataset/Window APIs.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import random
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from flask import Flask, Response
|
|
13
|
+
from flask import render_template as flask_render_template
|
|
14
|
+
from PIL import Image
|
|
15
|
+
from upath import UPath
|
|
16
|
+
|
|
17
|
+
from rslearn.config import LayerType
|
|
18
|
+
from rslearn.dataset import Dataset, Window
|
|
19
|
+
from rslearn.log_utils import get_logger
|
|
20
|
+
|
|
21
|
+
from .render_raster_label import read_raster_layer, render_raster_label
|
|
22
|
+
from .render_sensor_image import render_sensor_image
|
|
23
|
+
from .render_vector_label import get_vector_label_by_property, render_vector_label_image
|
|
24
|
+
from .utils import array_to_bytes, format_window_info, generate_label_colors
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def generate_image_as_bytes(
|
|
30
|
+
window: Window,
|
|
31
|
+
layer_name: str,
|
|
32
|
+
dataset: Dataset,
|
|
33
|
+
bands: dict[str, list[str]],
|
|
34
|
+
normalization: dict[str, str],
|
|
35
|
+
task_type: str,
|
|
36
|
+
label_colors: dict[str, tuple[int, int, int]] | None,
|
|
37
|
+
label_colors_dict: dict[str, dict[str, tuple[int, int, int]]] | None = None,
|
|
38
|
+
group_idx: int = 0,
|
|
39
|
+
label_layers: list[str] | None = None,
|
|
40
|
+
) -> bytes:
|
|
41
|
+
"""Generate an image for a window/layer combination as PNG bytes.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
window: Window object
|
|
45
|
+
layer_name: Layer name to visualize
|
|
46
|
+
dataset: Dataset object
|
|
47
|
+
bands: Dictionary mapping layer_name -> list of band names
|
|
48
|
+
normalization: Dictionary mapping layer_name -> normalization method
|
|
49
|
+
task_type: Task type
|
|
50
|
+
label_colors: Dictionary mapping label class names to RGB colors
|
|
51
|
+
label_colors_dict: Dictionary mapping layer_name -> label_colors
|
|
52
|
+
group_idx: Item group index
|
|
53
|
+
label_layers: List of layer names that are labels
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
PNG image bytes
|
|
57
|
+
"""
|
|
58
|
+
label_layers = label_layers or []
|
|
59
|
+
layer_config = dataset.layers[layer_name]
|
|
60
|
+
|
|
61
|
+
# Render raster sensor image
|
|
62
|
+
if layer_config.type == LayerType.RASTER and layer_name not in label_layers:
|
|
63
|
+
if layer_name in bands and layer_name in normalization:
|
|
64
|
+
array = read_raster_layer(
|
|
65
|
+
window, layer_name, layer_config, bands[layer_name], group_idx=group_idx
|
|
66
|
+
)
|
|
67
|
+
image_array = render_sensor_image(array, normalization[layer_name])
|
|
68
|
+
return array_to_bytes(image_array)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Bands or normalization not specified for layer {layer_name}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Render labels
|
|
75
|
+
elif layer_name in label_layers:
|
|
76
|
+
layer_label_colors = (
|
|
77
|
+
label_colors_dict.get(layer_name) if label_colors_dict else label_colors
|
|
78
|
+
)
|
|
79
|
+
if not layer_label_colors:
|
|
80
|
+
raise ValueError(f"No label colors available for layer {layer_name}")
|
|
81
|
+
|
|
82
|
+
# Render raster label
|
|
83
|
+
if layer_config.type == LayerType.RASTER:
|
|
84
|
+
band_set = layer_config.band_sets[0]
|
|
85
|
+
label_array = read_raster_layer(
|
|
86
|
+
window,
|
|
87
|
+
layer_name,
|
|
88
|
+
layer_config,
|
|
89
|
+
[band_set.bands[0]],
|
|
90
|
+
group_idx=group_idx,
|
|
91
|
+
)
|
|
92
|
+
image_array = render_raster_label(
|
|
93
|
+
label_array, layer_label_colors, layer_config
|
|
94
|
+
)
|
|
95
|
+
return array_to_bytes(image_array, resampling=Image.Resampling.NEAREST)
|
|
96
|
+
|
|
97
|
+
# Render vector label
|
|
98
|
+
elif layer_config.type == LayerType.VECTOR:
|
|
99
|
+
image_array = render_vector_label_image(
|
|
100
|
+
window,
|
|
101
|
+
layer_name,
|
|
102
|
+
layer_config,
|
|
103
|
+
task_type,
|
|
104
|
+
layer_label_colors,
|
|
105
|
+
dataset,
|
|
106
|
+
label_layers,
|
|
107
|
+
group_idx,
|
|
108
|
+
bands,
|
|
109
|
+
normalization,
|
|
110
|
+
)
|
|
111
|
+
# Use NEAREST for segmentation labels, LANCZOS for detection (which overlays on reference image)
|
|
112
|
+
resampling = (
|
|
113
|
+
Image.Resampling.NEAREST
|
|
114
|
+
if task_type == "segmentation"
|
|
115
|
+
else Image.Resampling.LANCZOS
|
|
116
|
+
)
|
|
117
|
+
return array_to_bytes(image_array, resampling=resampling)
|
|
118
|
+
|
|
119
|
+
raise ValueError(f"Layer {layer_name} is not a raster sensor image or label layer")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# Global state (set during initialization)
|
|
123
|
+
_app_state: dict[str, Any] = {}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def prepare_visualization_data(
|
|
127
|
+
sampled_windows: list[Window],
|
|
128
|
+
dataset: Dataset,
|
|
129
|
+
layers: list[str],
|
|
130
|
+
label_layers: list[str],
|
|
131
|
+
task_type: str,
|
|
132
|
+
label_colors_dict: dict[str, dict[str, tuple[int, int, int]]],
|
|
133
|
+
group_idx: int,
|
|
134
|
+
) -> dict[str, Any]:
|
|
135
|
+
"""Prepare data for visualization template.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
sampled_windows: List of windows to display
|
|
139
|
+
dataset: Dataset object
|
|
140
|
+
layers: List of all layers
|
|
141
|
+
label_layers: List of label layer names
|
|
142
|
+
task_type: Task type
|
|
143
|
+
label_colors_dict: Dictionary mapping layer_name -> label_colors
|
|
144
|
+
group_idx: Item group index
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Dictionary with template context data
|
|
148
|
+
"""
|
|
149
|
+
window_data: list[dict[str, Any]] = []
|
|
150
|
+
for idx, window in enumerate(sampled_windows):
|
|
151
|
+
time_range, lat, lon = format_window_info(window)
|
|
152
|
+
maps_link = (
|
|
153
|
+
f"https://www.google.com/maps?q={lat},{lon}"
|
|
154
|
+
if lat is not None and lon is not None
|
|
155
|
+
else None
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
available_layers = set()
|
|
159
|
+
mask_layers = []
|
|
160
|
+
label_texts = {}
|
|
161
|
+
|
|
162
|
+
for layer_name in layers:
|
|
163
|
+
if layer_name not in dataset.layers:
|
|
164
|
+
continue
|
|
165
|
+
layer_config = dataset.layers[layer_name]
|
|
166
|
+
try:
|
|
167
|
+
if (
|
|
168
|
+
layer_config.type == LayerType.RASTER
|
|
169
|
+
and layer_name not in label_layers
|
|
170
|
+
):
|
|
171
|
+
if window.is_layer_completed(layer_name, group_idx=group_idx):
|
|
172
|
+
available_layers.add(layer_name)
|
|
173
|
+
elif layer_name in label_layers:
|
|
174
|
+
if layer_config.type == LayerType.VECTOR:
|
|
175
|
+
if task_type == "classification":
|
|
176
|
+
try:
|
|
177
|
+
if not window.is_layer_completed(
|
|
178
|
+
layer_name, group_idx=group_idx
|
|
179
|
+
):
|
|
180
|
+
logger.debug(
|
|
181
|
+
f"Layer {layer_name} not marked as completed for window {window.name}, attempting to read anyway"
|
|
182
|
+
)
|
|
183
|
+
label_text = get_vector_label_by_property(
|
|
184
|
+
window,
|
|
185
|
+
layer_config,
|
|
186
|
+
layer_name,
|
|
187
|
+
group_idx=group_idx,
|
|
188
|
+
)
|
|
189
|
+
if label_text is not None:
|
|
190
|
+
label_texts[layer_name] = label_text
|
|
191
|
+
except Exception as e:
|
|
192
|
+
logger.debug(
|
|
193
|
+
f"Failed to get label text for {layer_name} in window {window.name}: {e}"
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
if window.is_layer_completed(
|
|
197
|
+
layer_name, group_idx=group_idx
|
|
198
|
+
):
|
|
199
|
+
mask_layers.append(layer_name)
|
|
200
|
+
elif layer_config.type == LayerType.RASTER:
|
|
201
|
+
if window.is_layer_completed(layer_name, group_idx=group_idx):
|
|
202
|
+
mask_layers.append(layer_name)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.debug(
|
|
205
|
+
f"Error processing layer {layer_name} for window {window.name}: {e}"
|
|
206
|
+
)
|
|
207
|
+
continue
|
|
208
|
+
|
|
209
|
+
# Format time range for template
|
|
210
|
+
time_range_formatted = None
|
|
211
|
+
if time_range:
|
|
212
|
+
time_range_formatted = (
|
|
213
|
+
time_range[0].isoformat()[:10],
|
|
214
|
+
time_range[1].isoformat()[:10],
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
window_data.append(
|
|
218
|
+
{
|
|
219
|
+
"idx": idx,
|
|
220
|
+
"name": window.name,
|
|
221
|
+
"time_range": time_range,
|
|
222
|
+
"time_range_formatted": time_range_formatted,
|
|
223
|
+
"lat": lat,
|
|
224
|
+
"lon": lon,
|
|
225
|
+
"maps_link": maps_link,
|
|
226
|
+
"available_layers": available_layers,
|
|
227
|
+
"mask_layers": mask_layers,
|
|
228
|
+
"label_texts": label_texts,
|
|
229
|
+
}
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
label_colors = None
|
|
233
|
+
sorted_label_keys = None
|
|
234
|
+
if label_colors_dict:
|
|
235
|
+
first_label_layer = list(label_colors_dict.keys())[0]
|
|
236
|
+
label_colors = label_colors_dict[first_label_layer]
|
|
237
|
+
sorted_label_keys = sorted(label_colors.keys())
|
|
238
|
+
|
|
239
|
+
return {
|
|
240
|
+
"windows": window_data,
|
|
241
|
+
"layers": layers,
|
|
242
|
+
"label_colors": label_colors,
|
|
243
|
+
"sorted_label_keys": sorted_label_keys,
|
|
244
|
+
"task_type": task_type,
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def create_app(
|
|
249
|
+
dataset: Dataset,
|
|
250
|
+
windows: list[Window],
|
|
251
|
+
layers: list[str],
|
|
252
|
+
bands: dict[str, list[str]],
|
|
253
|
+
normalization: dict[str, str],
|
|
254
|
+
task_type: str,
|
|
255
|
+
label_colors_dict: dict[str, dict[str, tuple[int, int, int]]],
|
|
256
|
+
group_idx: int,
|
|
257
|
+
label_layers: list[str],
|
|
258
|
+
max_samples: int,
|
|
259
|
+
) -> Flask:
|
|
260
|
+
"""Create and configure Flask app.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
dataset: Dataset object
|
|
264
|
+
windows: List of all windows
|
|
265
|
+
layers: List of all layers
|
|
266
|
+
bands: Dictionary mapping layer_name -> list of band names
|
|
267
|
+
normalization: Dictionary mapping layer_name -> normalization method
|
|
268
|
+
task_type: Task type
|
|
269
|
+
label_colors_dict: Dictionary mapping layer_name -> label_colors
|
|
270
|
+
group_idx: Item group index
|
|
271
|
+
label_layers: List of label layer names
|
|
272
|
+
max_samples: Maximum number of windows to sample
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Configured Flask app
|
|
276
|
+
"""
|
|
277
|
+
# Set template folder explicitly to ensure Flask can find templates
|
|
278
|
+
template_folder = Path(__file__).parent / "templates"
|
|
279
|
+
app = Flask(__name__, template_folder=str(template_folder))
|
|
280
|
+
|
|
281
|
+
@app.route("/")
|
|
282
|
+
def index() -> str:
|
|
283
|
+
"""Render the main visualization page."""
|
|
284
|
+
if len(windows) > max_samples:
|
|
285
|
+
sampled_windows = random.sample(windows, max_samples)
|
|
286
|
+
else:
|
|
287
|
+
sampled_windows = windows
|
|
288
|
+
|
|
289
|
+
template_data = prepare_visualization_data(
|
|
290
|
+
sampled_windows,
|
|
291
|
+
dataset,
|
|
292
|
+
layers,
|
|
293
|
+
label_layers,
|
|
294
|
+
task_type,
|
|
295
|
+
label_colors_dict,
|
|
296
|
+
group_idx,
|
|
297
|
+
)
|
|
298
|
+
return flask_render_template("visualization.html", **template_data)
|
|
299
|
+
|
|
300
|
+
@app.route("/images/<int:window_idx>/<layer_name>")
|
|
301
|
+
def get_image(window_idx: int, layer_name: str) -> Response:
|
|
302
|
+
"""Generate and serve an image for a specific window/layer.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
window_idx: Index of the window in the windows list
|
|
306
|
+
layer_name: Name of the layer to visualize
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
PNG image response or error response
|
|
310
|
+
"""
|
|
311
|
+
if window_idx < 0 or window_idx >= len(windows):
|
|
312
|
+
return Response(
|
|
313
|
+
"Window index out of range", status=404, mimetype="text/plain"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
window = windows[window_idx]
|
|
317
|
+
|
|
318
|
+
layer_label_colors = None
|
|
319
|
+
if label_colors_dict and layer_name in label_colors_dict:
|
|
320
|
+
layer_label_colors = label_colors_dict[layer_name]
|
|
321
|
+
elif label_colors_dict:
|
|
322
|
+
first_label_layer = list(label_colors_dict.keys())[0]
|
|
323
|
+
layer_label_colors = label_colors_dict[first_label_layer]
|
|
324
|
+
|
|
325
|
+
image_bytes = generate_image_as_bytes(
|
|
326
|
+
window,
|
|
327
|
+
layer_name,
|
|
328
|
+
dataset,
|
|
329
|
+
bands,
|
|
330
|
+
normalization,
|
|
331
|
+
task_type,
|
|
332
|
+
layer_label_colors,
|
|
333
|
+
label_colors_dict,
|
|
334
|
+
group_idx,
|
|
335
|
+
label_layers,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
return Response(
|
|
339
|
+
image_bytes,
|
|
340
|
+
mimetype="image/png",
|
|
341
|
+
headers={"Content-Length": str(len(image_bytes))},
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
return app
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def run(
|
|
348
|
+
dataset_path: str | Path | UPath,
|
|
349
|
+
layers: list[str] | None = None,
|
|
350
|
+
bands: dict[str, list[str]] | None = None,
|
|
351
|
+
normalization: dict[str, str] | None = None,
|
|
352
|
+
task_type: str | None = None,
|
|
353
|
+
max_samples: int = 100,
|
|
354
|
+
port: int = 8000,
|
|
355
|
+
host: str = "0.0.0.0",
|
|
356
|
+
group_idx: int = 0,
|
|
357
|
+
label_layers: list[str] | None = None,
|
|
358
|
+
) -> None:
|
|
359
|
+
"""Run the visualization server.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
dataset_path: Path to dataset directory (containing config.json)
|
|
363
|
+
layers: List of layer names to visualize
|
|
364
|
+
bands: Dictionary mapping layer_name -> list of band names
|
|
365
|
+
normalization: Dictionary mapping layer_name -> normalization method
|
|
366
|
+
task_type: Task type - "classification", "regression", "detection", or "segmentation"
|
|
367
|
+
max_samples: Maximum number of windows to sample
|
|
368
|
+
port: Port to serve on
|
|
369
|
+
host: Host to bind to
|
|
370
|
+
group_idx: Item group index (default 0)
|
|
371
|
+
label_layers: List of layer names that are labels
|
|
372
|
+
"""
|
|
373
|
+
dataset_path = UPath(dataset_path)
|
|
374
|
+
dataset = Dataset(dataset_path)
|
|
375
|
+
|
|
376
|
+
label_layers = label_layers or []
|
|
377
|
+
|
|
378
|
+
if layers is None:
|
|
379
|
+
raise ValueError("--layers is required")
|
|
380
|
+
all_layers = list(set(layers + label_layers))
|
|
381
|
+
raster_image_layers = [name for name in all_layers if name not in label_layers]
|
|
382
|
+
label_layers_in_list = [name for name in all_layers if name in label_layers]
|
|
383
|
+
|
|
384
|
+
bands = bands or {}
|
|
385
|
+
normalization = normalization or {}
|
|
386
|
+
for layer_name in raster_image_layers:
|
|
387
|
+
if layer_name not in bands:
|
|
388
|
+
raise ValueError(
|
|
389
|
+
f"Bands not specified for layer {layer_name}. Please provide --bands {layer_name}:band1,band2,band3"
|
|
390
|
+
)
|
|
391
|
+
if layer_name not in normalization:
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"Normalization not specified for layer {layer_name}. Please provide --normalization {layer_name}:method"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
label_colors_dict = {}
|
|
397
|
+
if task_type in ("segmentation", "detection"):
|
|
398
|
+
for label_layer_name in label_layers_in_list:
|
|
399
|
+
label_config = dataset.layers[label_layer_name]
|
|
400
|
+
if not label_config.class_names:
|
|
401
|
+
raise ValueError(
|
|
402
|
+
f"class_names must be specified in the config for label layer '{label_layer_name}'. "
|
|
403
|
+
"Auto-detection of class names is not supported."
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
label_classes = set(label_config.class_names)
|
|
407
|
+
label_colors = generate_label_colors(label_classes)
|
|
408
|
+
label_colors_dict[label_layer_name] = label_colors
|
|
409
|
+
logger.info(
|
|
410
|
+
f"Found {len(label_classes)} label classes for {label_layer_name}: {sorted(label_classes)}"
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
logger.info(f"Loading all windows from dataset {dataset_path}")
|
|
414
|
+
windows = dataset.load_windows()
|
|
415
|
+
logger.info(f"Loaded {len(windows)} windows from dataset")
|
|
416
|
+
logger.info(f"Layers: {all_layers}")
|
|
417
|
+
logger.info(f"Bands: {bands}")
|
|
418
|
+
logger.info(f"Normalization: {normalization}")
|
|
419
|
+
logger.info(f"Task type: {task_type}")
|
|
420
|
+
|
|
421
|
+
if task_type is None:
|
|
422
|
+
raise ValueError("--task_type is required")
|
|
423
|
+
|
|
424
|
+
app = create_app(
|
|
425
|
+
dataset,
|
|
426
|
+
windows,
|
|
427
|
+
all_layers,
|
|
428
|
+
bands,
|
|
429
|
+
normalization,
|
|
430
|
+
task_type,
|
|
431
|
+
label_colors_dict,
|
|
432
|
+
group_idx,
|
|
433
|
+
label_layers,
|
|
434
|
+
max_samples,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
logger.info(f"Serving on http://{host}:{port}")
|
|
438
|
+
logger.info(f"Open http://localhost:{port} in your browser")
|
|
439
|
+
logger.info(
|
|
440
|
+
f"Loaded {len(windows)} windows - refreshing the page will show a different random sample"
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
app.run(host=host, port=port, debug=False)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def parse_bands_arg(bands_str: str | None) -> dict[str, list[str]]:
|
|
447
|
+
"""Parse --bands argument as JSON.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
bands_str: JSON string mapping layer_name -> list of band names, e.g. '{"sentinel2": ["B04", "B03", "B02"]}'
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
Dictionary mapping layer_name -> list of band names
|
|
454
|
+
"""
|
|
455
|
+
if not bands_str:
|
|
456
|
+
return {}
|
|
457
|
+
try:
|
|
458
|
+
bands_dict = json.loads(bands_str)
|
|
459
|
+
if not isinstance(bands_dict, dict):
|
|
460
|
+
raise ValueError("Bands must be a JSON object/dictionary")
|
|
461
|
+
for layer_name, band_list in bands_dict.items():
|
|
462
|
+
if not isinstance(band_list, list):
|
|
463
|
+
raise ValueError(f"Bands for layer '{layer_name}' must be a list")
|
|
464
|
+
if not all(isinstance(band, str) for band in band_list):
|
|
465
|
+
raise ValueError(f"All bands for layer '{layer_name}' must be strings")
|
|
466
|
+
return bands_dict
|
|
467
|
+
except json.JSONDecodeError as e:
|
|
468
|
+
raise ValueError(f"Invalid JSON for bands: {e}") from e
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def parse_normalization_arg(norm_str: str | None) -> dict[str, str]:
|
|
472
|
+
"""Parse --normalization argument as JSON.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
norm_str: JSON string mapping layer_name -> normalization method, e.g. '{"sentinel2": "sentinel2_rgb"}'
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
Dictionary mapping layer_name -> normalization method
|
|
479
|
+
"""
|
|
480
|
+
if not norm_str:
|
|
481
|
+
return {}
|
|
482
|
+
try:
|
|
483
|
+
norm_dict = json.loads(norm_str)
|
|
484
|
+
if not isinstance(norm_dict, dict):
|
|
485
|
+
raise ValueError("Normalization must be a JSON object/dictionary")
|
|
486
|
+
for layer_name, method in norm_dict.items():
|
|
487
|
+
if not isinstance(method, str):
|
|
488
|
+
raise ValueError(
|
|
489
|
+
f"Normalization method for layer '{layer_name}' must be a string"
|
|
490
|
+
)
|
|
491
|
+
return norm_dict
|
|
492
|
+
except json.JSONDecodeError as e:
|
|
493
|
+
raise ValueError(f"Invalid JSON for normalization: {e}") from e
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def main() -> None:
|
|
497
|
+
"""Main entry point for the visualization server CLI."""
|
|
498
|
+
parser = argparse.ArgumentParser(
|
|
499
|
+
description="Visualize rslearn dataset in a web browser"
|
|
500
|
+
)
|
|
501
|
+
parser.add_argument(
|
|
502
|
+
"dataset_path",
|
|
503
|
+
type=str,
|
|
504
|
+
help="Path to dataset directory (containing config.json)",
|
|
505
|
+
)
|
|
506
|
+
parser.add_argument(
|
|
507
|
+
"--layers",
|
|
508
|
+
type=str,
|
|
509
|
+
required=True,
|
|
510
|
+
nargs="+",
|
|
511
|
+
help="List of layer names to visualize (default: all raster layers)",
|
|
512
|
+
)
|
|
513
|
+
parser.add_argument(
|
|
514
|
+
"--label_layers",
|
|
515
|
+
type=str,
|
|
516
|
+
nargs="+",
|
|
517
|
+
help="List of layer names that are labels (same format as --layers)",
|
|
518
|
+
)
|
|
519
|
+
parser.add_argument(
|
|
520
|
+
"--bands",
|
|
521
|
+
type=str,
|
|
522
|
+
required=True,
|
|
523
|
+
help='Bands to visualize per layer as JSON. Example: \'{"sentinel2": ["B04", "B03", "B02"]}\'',
|
|
524
|
+
)
|
|
525
|
+
parser.add_argument(
|
|
526
|
+
"--normalization",
|
|
527
|
+
type=str,
|
|
528
|
+
required=True,
|
|
529
|
+
help='Normalization method per layer as JSON. Example: \'{"sentinel2": "sentinel2_rgb"}\'',
|
|
530
|
+
)
|
|
531
|
+
parser.add_argument(
|
|
532
|
+
"--task_type",
|
|
533
|
+
type=str,
|
|
534
|
+
required=True,
|
|
535
|
+
choices=["classification", "regression", "detection", "segmentation"],
|
|
536
|
+
help="Task type (default: auto-detect, but user should specify)",
|
|
537
|
+
)
|
|
538
|
+
parser.add_argument(
|
|
539
|
+
"--max_samples",
|
|
540
|
+
type=int,
|
|
541
|
+
default=100,
|
|
542
|
+
help="Maximum number of windows to sample",
|
|
543
|
+
)
|
|
544
|
+
parser.add_argument(
|
|
545
|
+
"--port", type=int, default=8000, help="Port to serve on (default: 8000)"
|
|
546
|
+
)
|
|
547
|
+
parser.add_argument(
|
|
548
|
+
"--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)"
|
|
549
|
+
)
|
|
550
|
+
parser.add_argument(
|
|
551
|
+
"--group_idx", type=int, default=0, help="Item group index (default: 0)"
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
args = parser.parse_args()
|
|
555
|
+
|
|
556
|
+
bands_dict = parse_bands_arg(args.bands)
|
|
557
|
+
normalization_dict = parse_normalization_arg(args.normalization)
|
|
558
|
+
|
|
559
|
+
run(
|
|
560
|
+
dataset_path=args.dataset_path,
|
|
561
|
+
layers=args.layers,
|
|
562
|
+
bands=bands_dict if bands_dict else None,
|
|
563
|
+
normalization=normalization_dict if normalization_dict else None,
|
|
564
|
+
task_type=args.task_type,
|
|
565
|
+
max_samples=args.max_samples,
|
|
566
|
+
port=args.port,
|
|
567
|
+
host=args.host,
|
|
568
|
+
group_idx=args.group_idx,
|
|
569
|
+
label_layers=args.label_layers,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
if __name__ == "__main__":
|
|
574
|
+
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.26
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -214,6 +214,7 @@ License-File: LICENSE
|
|
|
214
214
|
License-File: NOTICE
|
|
215
215
|
Requires-Dist: boto3>=1.39
|
|
216
216
|
Requires-Dist: fiona>=1.10
|
|
217
|
+
Requires-Dist: flask>=3.0.0
|
|
217
218
|
Requires-Dist: fsspec>=2025.10.0
|
|
218
219
|
Requires-Dist: jsonargparse>=4.35.0
|
|
219
220
|
Requires-Dist: lightning>=2.5.1.post0
|
|
@@ -482,6 +483,18 @@ We can visualize both the GeoTIFFs together in qgis:
|
|
|
482
483
|
qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/*/*/geotiff.tif
|
|
483
484
|
```
|
|
484
485
|
|
|
486
|
+
We can also visualize samples using the visualization module:
|
|
487
|
+
```
|
|
488
|
+
python -m rslearn.vis.vis_server \
|
|
489
|
+
$DATASET_PATH \
|
|
490
|
+
--layers sentinel2 \ # image modality layers
|
|
491
|
+
--label_layers label_raster \ # layer layers
|
|
492
|
+
--bands '{"sentinel2": ["B04", "B03", "B02"]}' \ # specify bands wanted for each image modality
|
|
493
|
+
--normalization '{"sentinel2": "sentinel2_rgb"}' \ # specify normalization wanted for each image modality
|
|
494
|
+
--task_type segmentation \ # segmentation, detection, or classification
|
|
495
|
+
--max_samples 100 \ # number of datapoints to randomly sample and visualize
|
|
496
|
+
--port 8000
|
|
497
|
+
```
|
|
485
498
|
|
|
486
499
|
### Training a Model
|
|
487
500
|
|