rslearn 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/train/dataset.py +44 -3
- rslearn/train/tasks/detection.py +1 -18
- rslearn/train/tasks/segmentation.py +1 -19
- rslearn/utils/colors.py +20 -0
- rslearn/vis/__init__.py +1 -0
- rslearn/vis/normalization.py +127 -0
- rslearn/vis/render_raster_label.py +96 -0
- rslearn/vis/render_sensor_image.py +27 -0
- rslearn/vis/render_vector_label.py +439 -0
- rslearn/vis/utils.py +99 -0
- rslearn/vis/vis_server.py +574 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.25.dist-info}/METADATA +14 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.25.dist-info}/RECORD +18 -10
- {rslearn-0.0.24.dist-info → rslearn-0.0.25.dist-info}/WHEEL +1 -1
- {rslearn-0.0.24.dist-info → rslearn-0.0.25.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.25.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.25.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.24.dist-info → rslearn-0.0.25.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py
CHANGED
|
@@ -445,6 +445,7 @@ class SplitConfig:
|
|
|
445
445
|
overlap_ratio: float | None = None,
|
|
446
446
|
load_all_patches: bool | None = None,
|
|
447
447
|
skip_targets: bool | None = None,
|
|
448
|
+
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
448
449
|
) -> None:
|
|
449
450
|
"""Initialize a new SplitConfig.
|
|
450
451
|
|
|
@@ -467,6 +468,10 @@ class SplitConfig:
|
|
|
467
468
|
for each window, read all patches as separate sequential items in the
|
|
468
469
|
dataset.
|
|
469
470
|
skip_targets: whether to skip targets when loading inputs
|
|
471
|
+
output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
|
|
472
|
+
If set, windows that already
|
|
473
|
+
have this layer completed will be skipped (useful for resuming
|
|
474
|
+
partial inference runs).
|
|
470
475
|
"""
|
|
471
476
|
self.groups = groups
|
|
472
477
|
self.names = names
|
|
@@ -477,6 +482,9 @@ class SplitConfig:
|
|
|
477
482
|
self.sampler = sampler
|
|
478
483
|
self.patch_size = patch_size
|
|
479
484
|
self.skip_targets = skip_targets
|
|
485
|
+
self.output_layer_name_skip_inference_if_exists = (
|
|
486
|
+
output_layer_name_skip_inference_if_exists
|
|
487
|
+
)
|
|
480
488
|
|
|
481
489
|
# Note that load_all_patches are handled by the RslearnDataModule rather than
|
|
482
490
|
# the ModelDataset.
|
|
@@ -504,6 +512,7 @@ class SplitConfig:
|
|
|
504
512
|
overlap_ratio=self.overlap_ratio,
|
|
505
513
|
load_all_patches=self.load_all_patches,
|
|
506
514
|
skip_targets=self.skip_targets,
|
|
515
|
+
output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
|
|
507
516
|
)
|
|
508
517
|
if other.groups:
|
|
509
518
|
result.groups = other.groups
|
|
@@ -527,6 +536,10 @@ class SplitConfig:
|
|
|
527
536
|
result.load_all_patches = other.load_all_patches
|
|
528
537
|
if other.skip_targets is not None:
|
|
529
538
|
result.skip_targets = other.skip_targets
|
|
539
|
+
if other.output_layer_name_skip_inference_if_exists is not None:
|
|
540
|
+
result.output_layer_name_skip_inference_if_exists = (
|
|
541
|
+
other.output_layer_name_skip_inference_if_exists
|
|
542
|
+
)
|
|
530
543
|
return result
|
|
531
544
|
|
|
532
545
|
def get_patch_size(self) -> tuple[int, int] | None:
|
|
@@ -549,16 +562,26 @@ class SplitConfig:
|
|
|
549
562
|
"""Returns whether skip_targets is enabled (default False)."""
|
|
550
563
|
return True if self.skip_targets is True else False
|
|
551
564
|
|
|
565
|
+
def get_output_layer_name_skip_inference_if_exists(self) -> str | None:
|
|
566
|
+
"""Returns output layer to use for resume checks (default None)."""
|
|
567
|
+
return self.output_layer_name_skip_inference_if_exists
|
|
568
|
+
|
|
552
569
|
|
|
553
|
-
def check_window(
|
|
570
|
+
def check_window(
|
|
571
|
+
inputs: dict[str, DataInput],
|
|
572
|
+
window: Window,
|
|
573
|
+
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
574
|
+
) -> Window | None:
|
|
554
575
|
"""Verify that the window has the required layers based on the specified inputs.
|
|
555
576
|
|
|
556
577
|
Args:
|
|
557
578
|
inputs: the inputs to the dataset.
|
|
558
579
|
window: the window to check.
|
|
580
|
+
output_layer_name_skip_inference_if_exists: optional name of the output layer to check for existence.
|
|
559
581
|
|
|
560
582
|
Returns:
|
|
561
|
-
the window if it has all the required inputs
|
|
583
|
+
the window if it has all the required inputs and does not need to be skipped
|
|
584
|
+
due to an existing output layer; or None otherwise
|
|
562
585
|
"""
|
|
563
586
|
|
|
564
587
|
# Make sure window has all the needed layers.
|
|
@@ -588,6 +611,16 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
|
|
|
588
611
|
)
|
|
589
612
|
return None
|
|
590
613
|
|
|
614
|
+
# Optionally skip windows that already have the specified output layer completed.
|
|
615
|
+
if output_layer_name_skip_inference_if_exists is not None:
|
|
616
|
+
if window.is_layer_completed(output_layer_name_skip_inference_if_exists):
|
|
617
|
+
logger.debug(
|
|
618
|
+
"Skipping window %s since output layer '%s' already exists",
|
|
619
|
+
window.name,
|
|
620
|
+
output_layer_name_skip_inference_if_exists,
|
|
621
|
+
)
|
|
622
|
+
return None
|
|
623
|
+
|
|
591
624
|
return window
|
|
592
625
|
|
|
593
626
|
|
|
@@ -648,7 +681,14 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
648
681
|
new_windows = []
|
|
649
682
|
if workers == 0:
|
|
650
683
|
for window in windows:
|
|
651
|
-
if
|
|
684
|
+
if (
|
|
685
|
+
check_window(
|
|
686
|
+
self.inputs,
|
|
687
|
+
window,
|
|
688
|
+
output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
|
|
689
|
+
)
|
|
690
|
+
is None
|
|
691
|
+
):
|
|
652
692
|
continue
|
|
653
693
|
new_windows.append(window)
|
|
654
694
|
else:
|
|
@@ -660,6 +700,7 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
660
700
|
dict(
|
|
661
701
|
inputs=self.inputs,
|
|
662
702
|
window=window,
|
|
703
|
+
output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
|
|
663
704
|
)
|
|
664
705
|
for window in windows
|
|
665
706
|
],
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -14,27 +14,10 @@ from torchmetrics import Metric, MetricCollection
|
|
|
14
14
|
|
|
15
15
|
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
16
16
|
from rslearn.utils import Feature, STGeometry
|
|
17
|
+
from rslearn.utils.colors import DEFAULT_COLORS
|
|
17
18
|
|
|
18
19
|
from .task import BasicTask
|
|
19
20
|
|
|
20
|
-
DEFAULT_COLORS = [
|
|
21
|
-
(255, 0, 0),
|
|
22
|
-
(0, 255, 0),
|
|
23
|
-
(0, 0, 255),
|
|
24
|
-
(255, 255, 0),
|
|
25
|
-
(0, 255, 255),
|
|
26
|
-
(255, 0, 255),
|
|
27
|
-
(0, 128, 0),
|
|
28
|
-
(255, 160, 122),
|
|
29
|
-
(139, 69, 19),
|
|
30
|
-
(128, 128, 128),
|
|
31
|
-
(255, 255, 255),
|
|
32
|
-
(143, 188, 143),
|
|
33
|
-
(95, 158, 160),
|
|
34
|
-
(255, 200, 0),
|
|
35
|
-
(128, 0, 0),
|
|
36
|
-
]
|
|
37
|
-
|
|
38
21
|
|
|
39
22
|
class DetectionTask(BasicTask):
|
|
40
23
|
"""A point or bounding box detection task."""
|
|
@@ -17,28 +17,10 @@ from rslearn.train.model_context import (
|
|
|
17
17
|
SampleMetadata,
|
|
18
18
|
)
|
|
19
19
|
from rslearn.utils import Feature
|
|
20
|
+
from rslearn.utils.colors import DEFAULT_COLORS
|
|
20
21
|
|
|
21
22
|
from .task import BasicTask
|
|
22
23
|
|
|
23
|
-
# TODO: This is duplicated code fix it
|
|
24
|
-
DEFAULT_COLORS = [
|
|
25
|
-
(255, 0, 0),
|
|
26
|
-
(0, 255, 0),
|
|
27
|
-
(0, 0, 255),
|
|
28
|
-
(255, 255, 0),
|
|
29
|
-
(0, 255, 255),
|
|
30
|
-
(255, 0, 255),
|
|
31
|
-
(0, 128, 0),
|
|
32
|
-
(255, 160, 122),
|
|
33
|
-
(139, 69, 19),
|
|
34
|
-
(128, 128, 128),
|
|
35
|
-
(255, 255, 255),
|
|
36
|
-
(143, 188, 143),
|
|
37
|
-
(95, 158, 160),
|
|
38
|
-
(255, 200, 0),
|
|
39
|
-
(128, 0, 0),
|
|
40
|
-
]
|
|
41
|
-
|
|
42
24
|
|
|
43
25
|
class SegmentationTask(BasicTask):
|
|
44
26
|
"""A segmentation (per-pixel classification) task."""
|
rslearn/utils/colors.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Default color palette for visualizations."""
|
|
2
|
+
|
|
3
|
+
DEFAULT_COLORS = [
|
|
4
|
+
(0, 0, 0),
|
|
5
|
+
(255, 0, 0),
|
|
6
|
+
(0, 255, 0),
|
|
7
|
+
(0, 0, 255),
|
|
8
|
+
(255, 255, 0),
|
|
9
|
+
(0, 255, 255),
|
|
10
|
+
(255, 0, 255),
|
|
11
|
+
(0, 128, 0),
|
|
12
|
+
(255, 160, 122),
|
|
13
|
+
(139, 69, 19),
|
|
14
|
+
(128, 128, 128),
|
|
15
|
+
(255, 255, 255),
|
|
16
|
+
(143, 188, 143),
|
|
17
|
+
(95, 158, 160),
|
|
18
|
+
(255, 200, 0),
|
|
19
|
+
(128, 0, 0),
|
|
20
|
+
]
|
rslearn/vis/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Visualization module for rslearn datasets."""
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Normalization functions for raster data visualization."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NormalizationMethod(StrEnum):
|
|
14
|
+
"""Normalization methods for raster data visualization."""
|
|
15
|
+
|
|
16
|
+
SENTINEL2_RGB = "sentinel2_rgb"
|
|
17
|
+
"""Divide by 10 and clip (for Sentinel-2 B04/B03/B02 bands)."""
|
|
18
|
+
|
|
19
|
+
PERCENTILE = "percentile"
|
|
20
|
+
"""Use 2-98 percentile clipping."""
|
|
21
|
+
|
|
22
|
+
MINMAX = "minmax"
|
|
23
|
+
"""Use min-max stretch."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _normalize_sentinel2_rgb(band: np.ndarray) -> np.ndarray:
|
|
27
|
+
"""Normalize band using Sentinel-2 RGB method (divide by 10 and clip).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
band: Input band data
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Normalized band as uint8 array
|
|
34
|
+
"""
|
|
35
|
+
band = band / 10.0
|
|
36
|
+
band = np.clip(band, 0, 255).astype(np.uint8)
|
|
37
|
+
return band
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _normalize_percentile(band: np.ndarray) -> np.ndarray:
|
|
41
|
+
"""Normalize band using 2-98 percentile clipping.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
band: Input band data
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Normalized band as uint8 array
|
|
48
|
+
"""
|
|
49
|
+
valid_pixels = band[~np.isnan(band)]
|
|
50
|
+
if len(valid_pixels) == 0:
|
|
51
|
+
return np.zeros_like(band, dtype=np.uint8)
|
|
52
|
+
vmin, vmax = np.nanpercentile(valid_pixels, (2, 98))
|
|
53
|
+
if vmax == vmin:
|
|
54
|
+
return np.zeros_like(band, dtype=np.uint8)
|
|
55
|
+
band = np.clip(band, vmin, vmax)
|
|
56
|
+
band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
|
|
57
|
+
return band
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _normalize_minmax(band: np.ndarray) -> np.ndarray:
|
|
61
|
+
"""Normalize band using min-max stretch.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
band: Input band data
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Normalized band as uint8 array
|
|
68
|
+
"""
|
|
69
|
+
vmin, vmax = np.nanmin(band), np.nanmax(band)
|
|
70
|
+
if vmax == vmin:
|
|
71
|
+
return np.zeros_like(band, dtype=np.uint8)
|
|
72
|
+
band = np.clip(band, vmin, vmax)
|
|
73
|
+
band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
|
|
74
|
+
return band
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
_NORMALIZATION_FUNCTIONS: dict[
|
|
78
|
+
NormalizationMethod, Callable[[np.ndarray], np.ndarray]
|
|
79
|
+
] = {
|
|
80
|
+
NormalizationMethod.SENTINEL2_RGB: _normalize_sentinel2_rgb,
|
|
81
|
+
NormalizationMethod.PERCENTILE: _normalize_percentile,
|
|
82
|
+
NormalizationMethod.MINMAX: _normalize_minmax,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def normalize_band(
|
|
87
|
+
band: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
|
|
88
|
+
) -> np.ndarray:
|
|
89
|
+
"""Normalize band to 0-255 range.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
band: Input band data
|
|
93
|
+
method: Normalization method (string or NormalizationMethod enum)
|
|
94
|
+
- 'sentinel2_rgb': Divide by 10 and clip (for B04/B03/B02)
|
|
95
|
+
- 'percentile': Use 2-98 percentile clipping
|
|
96
|
+
- 'minmax': Use min-max stretch
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Normalized band as uint8 array
|
|
100
|
+
"""
|
|
101
|
+
method_enum = NormalizationMethod(method) if isinstance(method, str) else method
|
|
102
|
+
normalize_func = _NORMALIZATION_FUNCTIONS.get(method_enum)
|
|
103
|
+
if normalize_func is None:
|
|
104
|
+
raise ValueError(f"Unknown normalization method: {method_enum}")
|
|
105
|
+
return normalize_func(band)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def normalize_array(
|
|
109
|
+
array: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
|
|
110
|
+
) -> np.ndarray:
|
|
111
|
+
"""Normalize a multi-band array to 0-255 range.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
array: Input array with shape (channels, height, width) from RasterFormat.decode_raster
|
|
115
|
+
method: Normalization method (applied per-band, string or NormalizationMethod enum)
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Normalized array as uint8 with shape (height, width, channels)
|
|
119
|
+
"""
|
|
120
|
+
if array.ndim == 3:
|
|
121
|
+
array = np.moveaxis(array, 0, -1)
|
|
122
|
+
|
|
123
|
+
normalized = np.zeros_like(array, dtype=np.uint8)
|
|
124
|
+
for i in range(array.shape[-1]):
|
|
125
|
+
normalized[..., i] = normalize_band(array[..., i], method)
|
|
126
|
+
|
|
127
|
+
return normalized
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Functions for rendering raster label masks (e.g., segmentation masks)."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from rasterio.warp import Resampling
|
|
6
|
+
|
|
7
|
+
from rslearn.config import DType, LayerConfig
|
|
8
|
+
from rslearn.dataset import Window
|
|
9
|
+
from rslearn.log_utils import get_logger
|
|
10
|
+
from rslearn.train.dataset import DataInput, read_raster_layer_for_data_input
|
|
11
|
+
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def render_raster_label(
|
|
17
|
+
label_array: np.ndarray,
|
|
18
|
+
label_colors: dict[str, tuple[int, int, int]],
|
|
19
|
+
layer_config: LayerConfig,
|
|
20
|
+
) -> np.ndarray:
|
|
21
|
+
"""Render a raster label array as a colored mask numpy array.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
label_array: Raster label array with shape (bands, height, width) - typically single band
|
|
25
|
+
label_colors: Dictionary mapping label class names to RGB color tuples
|
|
26
|
+
layer_config: LayerConfig object (to access class_names if available)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Array with shape (height, width, 3) as uint8
|
|
30
|
+
"""
|
|
31
|
+
if label_array.ndim == 3:
|
|
32
|
+
label_values = label_array[0, :, :]
|
|
33
|
+
else:
|
|
34
|
+
label_values = label_array
|
|
35
|
+
|
|
36
|
+
height, width = label_values.shape
|
|
37
|
+
mask_img = np.zeros((height, width, 3), dtype=np.uint8)
|
|
38
|
+
valid_mask = ~np.isnan(label_values)
|
|
39
|
+
|
|
40
|
+
if not layer_config.class_names:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"class_names must be specified in config for raster label layer"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
label_int = label_values.astype(np.int32)
|
|
46
|
+
for idx in range(len(layer_config.class_names)):
|
|
47
|
+
class_name = layer_config.class_names[idx]
|
|
48
|
+
color = label_colors.get(str(class_name), (0, 0, 0))
|
|
49
|
+
mask = (label_int == idx) & valid_mask
|
|
50
|
+
mask_img[mask] = color
|
|
51
|
+
|
|
52
|
+
img = Image.fromarray(mask_img, mode="RGB")
|
|
53
|
+
return np.array(img)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def read_raster_layer(
|
|
57
|
+
window: Window,
|
|
58
|
+
layer_name: str,
|
|
59
|
+
layer_config: LayerConfig,
|
|
60
|
+
band_names: list[str],
|
|
61
|
+
group_idx: int = 0,
|
|
62
|
+
bounds: PixelBounds | None = None,
|
|
63
|
+
) -> np.ndarray:
|
|
64
|
+
"""Read a raster layer for visualization.
|
|
65
|
+
|
|
66
|
+
This reads bands from potentially multiple band sets to get the requested bands.
|
|
67
|
+
Uses read_raster_layer_for_data_input from rslearn.train.dataset.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
window: The window to read from
|
|
71
|
+
layer_name: The layer name
|
|
72
|
+
layer_config: The layer configuration
|
|
73
|
+
band_names: List of band names to read (e.g., ["B04", "B03", "B02"])
|
|
74
|
+
group_idx: The item group index (default 0)
|
|
75
|
+
bounds: Optional bounds to read. If None, uses window.bounds
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Array with shape (bands, height, width) as float32
|
|
79
|
+
"""
|
|
80
|
+
if bounds is None:
|
|
81
|
+
bounds = window.bounds
|
|
82
|
+
|
|
83
|
+
data_input = DataInput(
|
|
84
|
+
data_type="raster",
|
|
85
|
+
layers=[layer_name],
|
|
86
|
+
bands=band_names,
|
|
87
|
+
dtype=DType.FLOAT32,
|
|
88
|
+
resolution_factor=ResolutionFactor(), # Default 1/1, no scaling
|
|
89
|
+
resampling=Resampling.nearest,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
image_tensor = read_raster_layer_for_data_input(
|
|
93
|
+
window, bounds, layer_name, group_idx, layer_config, data_input
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return image_tensor.numpy().astype(np.float32)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Functions for rendering raster sensor images (e.g., Sentinel-2, Landsat)."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from .normalization import normalize_array
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def render_sensor_image(
|
|
9
|
+
array: np.ndarray,
|
|
10
|
+
normalization_method: str,
|
|
11
|
+
) -> np.ndarray:
|
|
12
|
+
"""Render a raster sensor image array as a numpy array.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
array: Array with shape (channels, height, width) from RasterFormat.decode_raster
|
|
16
|
+
normalization_method: Normalization method to apply
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Array with shape (height, width, channels) as uint8
|
|
20
|
+
"""
|
|
21
|
+
normalized = normalize_array(array, normalization_method)
|
|
22
|
+
|
|
23
|
+
# If more than 3 channels, take only the first 3 for RGB
|
|
24
|
+
if normalized.shape[-1] > 3:
|
|
25
|
+
normalized = normalized[:, :, :3]
|
|
26
|
+
|
|
27
|
+
return normalized
|