rslearn 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@
3
3
  import os
4
4
  import tempfile
5
5
  import xml.etree.ElementTree as ET
6
- from datetime import timedelta
6
+ from datetime import datetime, timedelta
7
7
  from typing import Any
8
8
 
9
9
  import affine
@@ -12,6 +12,7 @@ import planetary_computer
12
12
  import rasterio
13
13
  import requests
14
14
  from rasterio.enums import Resampling
15
+ from typing_extensions import override
15
16
  from upath import UPath
16
17
 
17
18
  from rslearn.config import LayerConfig
@@ -24,11 +25,104 @@ from rslearn.tile_stores import TileStore, TileStoreWithLayer
24
25
  from rslearn.utils.fsspec import join_upath
25
26
  from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
26
27
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
28
+ from rslearn.utils.stac import StacClient, StacItem
27
29
 
28
30
  from .copernicus import get_harmonize_callback
29
31
 
30
32
  logger = get_logger(__name__)
31
33
 
34
+ # Max limit accepted by Planetary Computer API.
35
+ PLANETARY_COMPUTER_LIMIT = 1000
36
+
37
+
38
+ class PlanetaryComputerStacClient(StacClient):
39
+ """A StacClient subclass that handles Planetary Computer's pagination limits.
40
+
41
+ Planetary Computer STAC API does not support standard pagination and has a max
42
+ limit of 1000. If the initial query returns 1000 items, this client paginates
43
+ by sorting by ID and using gt (greater than) queries to fetch subsequent pages.
44
+ """
45
+
46
+ @override
47
+ def search(
48
+ self,
49
+ collections: list[str] | None = None,
50
+ bbox: tuple[float, float, float, float] | None = None,
51
+ intersects: dict[str, Any] | None = None,
52
+ date_time: datetime | tuple[datetime, datetime] | None = None,
53
+ ids: list[str] | None = None,
54
+ limit: int | None = None,
55
+ query: dict[str, Any] | None = None,
56
+ sortby: list[dict[str, str]] | None = None,
57
+ ) -> list[StacItem]:
58
+ # We will use sortby for pagination, so the caller must not set it.
59
+ if sortby is not None:
60
+ raise ValueError("sortby must not be set for PlanetaryComputerStacClient")
61
+
62
+ # First, try a simple query with the PC limit to detect if pagination is needed.
63
+ # We always use PLANETARY_COMPUTER_LIMIT for the request because PC doesn't
64
+ # support standard pagination, and we need to detect when we hit the limit
65
+ # to switch to ID-based pagination.
66
+ # We could just start sorting by ID here and do pagination, but we treate it as
67
+ # a special case to avoid sorting since that seems to speed up the query.
68
+ stac_items = super().search(
69
+ collections=collections,
70
+ bbox=bbox,
71
+ intersects=intersects,
72
+ date_time=date_time,
73
+ ids=ids,
74
+ limit=PLANETARY_COMPUTER_LIMIT,
75
+ query=query,
76
+ )
77
+
78
+ # If we got fewer than the PC limit, we have all the results.
79
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
80
+ return stac_items
81
+
82
+ # We hit the limit, so we need to paginate by ID.
83
+ # Re-fetch with sorting by ID to ensure consistent ordering for pagination.
84
+ logger.debug(
85
+ "Initial request returned %d items (at limit), switching to ID pagination",
86
+ len(stac_items),
87
+ )
88
+
89
+ all_items: list[StacItem] = []
90
+ last_id: str | None = None
91
+
92
+ while True:
93
+ # Build query with id > last_id if we're paginating.
94
+ combined_query: dict[str, Any] = dict(query) if query else {}
95
+ if last_id is not None:
96
+ combined_query["id"] = {"gt": last_id}
97
+
98
+ stac_items = super().search(
99
+ collections=collections,
100
+ bbox=bbox,
101
+ intersects=intersects,
102
+ date_time=date_time,
103
+ ids=ids,
104
+ limit=PLANETARY_COMPUTER_LIMIT,
105
+ query=combined_query if combined_query else None,
106
+ sortby=[{"field": "id", "direction": "asc"}],
107
+ )
108
+
109
+ all_items.extend(stac_items)
110
+
111
+ # If we got fewer than the limit, we've fetched everything.
112
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
113
+ break
114
+
115
+ # Otherwise, paginate using the last item's ID.
116
+ last_id = stac_items[-1].id
117
+ logger.debug(
118
+ "Got %d items, paginating with id > %s",
119
+ len(stac_items),
120
+ last_id,
121
+ )
122
+
123
+ logger.debug("Total items fetched: %d", len(all_items))
124
+ return all_items
125
+
32
126
 
33
127
  class PlanetaryComputer(StacDataSource, TileStore):
34
128
  """Modality-agnostic data source for data on Microsoft Planetary Computer.
@@ -100,6 +194,10 @@ class PlanetaryComputer(StacDataSource, TileStore):
100
194
  required_assets=required_assets,
101
195
  cache_dir=cache_upath,
102
196
  )
197
+
198
+ # Replace the client with PlanetaryComputerStacClient to handle PC's pagination limits.
199
+ self.client = PlanetaryComputerStacClient(self.STAC_ENDPOINT)
200
+
103
201
  self.asset_bands = asset_bands
104
202
  self.timeout = timeout
105
203
  self.skip_items_missing_assets = skip_items_missing_assets
@@ -12,6 +12,7 @@ from rslearn.const import WGS84_PROJECTION
12
12
  from rslearn.data_sources.data_source import Item, ItemLookupDataSource
13
13
  from rslearn.data_sources.utils import match_candidate_items_to_window
14
14
  from rslearn.log_utils import get_logger
15
+ from rslearn.utils.fsspec import open_atomic
15
16
  from rslearn.utils.geometry import STGeometry
16
17
  from rslearn.utils.stac import StacClient, StacItem
17
18
 
@@ -187,7 +188,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
187
188
 
188
189
  # Finally we cache it if cache_dir is set.
189
190
  if cache_fname is not None:
190
- with cache_fname.open("w") as f:
191
+ with open_atomic(cache_fname, "w") as f:
191
192
  json.dump(item.serialize(), f)
192
193
 
193
194
  return item
@@ -259,7 +260,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
259
260
  cache_fname = self.cache_dir / f"{item.name}.json"
260
261
  if cache_fname.exists():
261
262
  continue
262
- with cache_fname.open("w") as f:
263
+ with open_atomic(cache_fname, "w") as f:
263
264
  json.dump(item.serialize(), f)
264
265
 
265
266
  cur_groups = match_candidate_items_to_window(
@@ -180,7 +180,7 @@ class SimpleTimeSeries(FeatureExtractor):
180
180
  # want to pass 2 timesteps to the model.
181
181
  # TODO is probably to make this behaviour clearer but lets leave it like
182
182
  # this for now to not break things.
183
- num_timesteps = images.shape[1] // image_channels
183
+ num_timesteps = image_channels // images.shape[1]
184
184
  batched_timesteps = images.shape[2] // num_timesteps
185
185
  images = rearrange(
186
186
  images,
rslearn/train/dataset.py CHANGED
@@ -445,6 +445,7 @@ class SplitConfig:
445
445
  overlap_ratio: float | None = None,
446
446
  load_all_patches: bool | None = None,
447
447
  skip_targets: bool | None = None,
448
+ output_layer_name_skip_inference_if_exists: str | None = None,
448
449
  ) -> None:
449
450
  """Initialize a new SplitConfig.
450
451
 
@@ -467,6 +468,10 @@ class SplitConfig:
467
468
  for each window, read all patches as separate sequential items in the
468
469
  dataset.
469
470
  skip_targets: whether to skip targets when loading inputs
471
+ output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
472
+ If set, windows that already
473
+ have this layer completed will be skipped (useful for resuming
474
+ partial inference runs).
470
475
  """
471
476
  self.groups = groups
472
477
  self.names = names
@@ -477,6 +482,9 @@ class SplitConfig:
477
482
  self.sampler = sampler
478
483
  self.patch_size = patch_size
479
484
  self.skip_targets = skip_targets
485
+ self.output_layer_name_skip_inference_if_exists = (
486
+ output_layer_name_skip_inference_if_exists
487
+ )
480
488
 
481
489
  # Note that load_all_patches are handled by the RslearnDataModule rather than
482
490
  # the ModelDataset.
@@ -504,6 +512,7 @@ class SplitConfig:
504
512
  overlap_ratio=self.overlap_ratio,
505
513
  load_all_patches=self.load_all_patches,
506
514
  skip_targets=self.skip_targets,
515
+ output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
507
516
  )
508
517
  if other.groups:
509
518
  result.groups = other.groups
@@ -527,6 +536,10 @@ class SplitConfig:
527
536
  result.load_all_patches = other.load_all_patches
528
537
  if other.skip_targets is not None:
529
538
  result.skip_targets = other.skip_targets
539
+ if other.output_layer_name_skip_inference_if_exists is not None:
540
+ result.output_layer_name_skip_inference_if_exists = (
541
+ other.output_layer_name_skip_inference_if_exists
542
+ )
530
543
  return result
531
544
 
532
545
  def get_patch_size(self) -> tuple[int, int] | None:
@@ -549,16 +562,26 @@ class SplitConfig:
549
562
  """Returns whether skip_targets is enabled (default False)."""
550
563
  return True if self.skip_targets is True else False
551
564
 
565
+ def get_output_layer_name_skip_inference_if_exists(self) -> str | None:
566
+ """Returns output layer to use for resume checks (default None)."""
567
+ return self.output_layer_name_skip_inference_if_exists
568
+
552
569
 
553
- def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
570
+ def check_window(
571
+ inputs: dict[str, DataInput],
572
+ window: Window,
573
+ output_layer_name_skip_inference_if_exists: str | None = None,
574
+ ) -> Window | None:
554
575
  """Verify that the window has the required layers based on the specified inputs.
555
576
 
556
577
  Args:
557
578
  inputs: the inputs to the dataset.
558
579
  window: the window to check.
580
+ output_layer_name_skip_inference_if_exists: optional name of the output layer to check for existence.
559
581
 
560
582
  Returns:
561
- the window if it has all the required inputs or None otherwise
583
+ the window if it has all the required inputs and does not need to be skipped
584
+ due to an existing output layer; or None otherwise
562
585
  """
563
586
 
564
587
  # Make sure window has all the needed layers.
@@ -588,6 +611,16 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
588
611
  )
589
612
  return None
590
613
 
614
+ # Optionally skip windows that already have the specified output layer completed.
615
+ if output_layer_name_skip_inference_if_exists is not None:
616
+ if window.is_layer_completed(output_layer_name_skip_inference_if_exists):
617
+ logger.debug(
618
+ "Skipping window %s since output layer '%s' already exists",
619
+ window.name,
620
+ output_layer_name_skip_inference_if_exists,
621
+ )
622
+ return None
623
+
591
624
  return window
592
625
 
593
626
 
@@ -648,7 +681,14 @@ class ModelDataset(torch.utils.data.Dataset):
648
681
  new_windows = []
649
682
  if workers == 0:
650
683
  for window in windows:
651
- if check_window(self.inputs, window) is None:
684
+ if (
685
+ check_window(
686
+ self.inputs,
687
+ window,
688
+ output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
689
+ )
690
+ is None
691
+ ):
652
692
  continue
653
693
  new_windows.append(window)
654
694
  else:
@@ -660,6 +700,7 @@ class ModelDataset(torch.utils.data.Dataset):
660
700
  dict(
661
701
  inputs=self.inputs,
662
702
  window=window,
703
+ output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
663
704
  )
664
705
  for window in windows
665
706
  ],
@@ -14,27 +14,10 @@ from torchmetrics import Metric, MetricCollection
14
14
 
15
15
  from rslearn.train.model_context import RasterImage, SampleMetadata
16
16
  from rslearn.utils import Feature, STGeometry
17
+ from rslearn.utils.colors import DEFAULT_COLORS
17
18
 
18
19
  from .task import BasicTask
19
20
 
20
- DEFAULT_COLORS = [
21
- (255, 0, 0),
22
- (0, 255, 0),
23
- (0, 0, 255),
24
- (255, 255, 0),
25
- (0, 255, 255),
26
- (255, 0, 255),
27
- (0, 128, 0),
28
- (255, 160, 122),
29
- (139, 69, 19),
30
- (128, 128, 128),
31
- (255, 255, 255),
32
- (143, 188, 143),
33
- (95, 158, 160),
34
- (255, 200, 0),
35
- (128, 0, 0),
36
- ]
37
-
38
21
 
39
22
  class DetectionTask(BasicTask):
40
23
  """A point or bounding box detection task."""
@@ -17,28 +17,10 @@ from rslearn.train.model_context import (
17
17
  SampleMetadata,
18
18
  )
19
19
  from rslearn.utils import Feature
20
+ from rslearn.utils.colors import DEFAULT_COLORS
20
21
 
21
22
  from .task import BasicTask
22
23
 
23
- # TODO: This is duplicated code fix it
24
- DEFAULT_COLORS = [
25
- (255, 0, 0),
26
- (0, 255, 0),
27
- (0, 0, 255),
28
- (255, 255, 0),
29
- (0, 255, 255),
30
- (255, 0, 255),
31
- (0, 128, 0),
32
- (255, 160, 122),
33
- (139, 69, 19),
34
- (128, 128, 128),
35
- (255, 255, 255),
36
- (143, 188, 143),
37
- (95, 158, 160),
38
- (255, 200, 0),
39
- (128, 0, 0),
40
- ]
41
-
42
24
 
43
25
  class SegmentationTask(BasicTask):
44
26
  """A segmentation (per-pixel classification) task."""
@@ -59,6 +41,8 @@ class SegmentationTask(BasicTask):
59
41
  miou_metric_kwargs: dict[str, Any] = {},
60
42
  prob_scales: list[float] | None = None,
61
43
  other_metrics: dict[str, Metric] = {},
44
+ output_probs: bool = False,
45
+ output_class_idx: int | None = None,
62
46
  **kwargs: Any,
63
47
  ) -> None:
64
48
  """Initialize a new SegmentationTask.
@@ -92,6 +76,10 @@ class SegmentationTask(BasicTask):
92
76
  this is only applied during prediction, not when computing val or test
93
77
  metrics.
94
78
  other_metrics: additional metrics to configure on this task.
79
+ output_probs: if True, output raw softmax probabilities instead of class IDs
80
+ during prediction.
81
+ output_class_idx: if set along with output_probs, only output the probability
82
+ for this specific class index (single-channel output).
95
83
  kwargs: additional arguments to pass to BasicTask
96
84
  """
97
85
  super().__init__(**kwargs)
@@ -116,6 +104,8 @@ class SegmentationTask(BasicTask):
116
104
  self.miou_metric_kwargs = miou_metric_kwargs
117
105
  self.prob_scales = prob_scales
118
106
  self.other_metrics = other_metrics
107
+ self.output_probs = output_probs
108
+ self.output_class_idx = output_class_idx
119
109
 
120
110
  def process_inputs(
121
111
  self,
@@ -171,7 +161,9 @@ class SegmentationTask(BasicTask):
171
161
  metadata: metadata about the patch being read
172
162
 
173
163
  Returns:
174
- CHW numpy array with one channel, containing the predicted class IDs.
164
+ CHW numpy array. If output_probs is False, returns one channel with
165
+ predicted class IDs. If output_probs is True, returns softmax probabilities
166
+ (num_classes channels, or 1 channel if output_class_idx is set).
175
167
  """
176
168
  if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
177
169
  raise ValueError("the output for SegmentationTask must be a CHW tensor")
@@ -183,6 +175,15 @@ class SegmentationTask(BasicTask):
183
175
  self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
184
176
  )[:, None, None]
185
177
  )
178
+
179
+ if self.output_probs:
180
+ # Return raw softmax probabilities
181
+ probs = raw_output.cpu().numpy()
182
+ if self.output_class_idx is not None:
183
+ # Return only the specified class probability
184
+ return probs[self.output_class_idx : self.output_class_idx + 1, :, :]
185
+ return probs
186
+
186
187
  classes = raw_output.argmax(dim=0).cpu().numpy()
187
188
  return classes[None, :, :]
188
189
 
@@ -0,0 +1,20 @@
1
+ """Default color palette for visualizations."""
2
+
3
+ DEFAULT_COLORS = [
4
+ (0, 0, 0),
5
+ (255, 0, 0),
6
+ (0, 255, 0),
7
+ (0, 0, 255),
8
+ (255, 255, 0),
9
+ (0, 255, 255),
10
+ (255, 0, 255),
11
+ (0, 128, 0),
12
+ (255, 160, 122),
13
+ (139, 69, 19),
14
+ (128, 128, 128),
15
+ (255, 255, 255),
16
+ (143, 188, 143),
17
+ (95, 158, 160),
18
+ (255, 200, 0),
19
+ (128, 0, 0),
20
+ ]
@@ -476,6 +476,7 @@ class GeotiffRasterFormat(RasterFormat):
476
476
  bounds: PixelBounds,
477
477
  array: npt.NDArray[Any],
478
478
  fname: str | None = None,
479
+ nodata_val: int | float | None = None,
479
480
  ) -> None:
480
481
  """Encodes raster data.
481
482
 
@@ -485,6 +486,7 @@ class GeotiffRasterFormat(RasterFormat):
485
486
  bounds: the bounds of the raster data in the projection
486
487
  array: the raster data
487
488
  fname: override the filename to save as
489
+ nodata_val: set the nodata value when writing the raster.
488
490
  """
489
491
  if fname is None:
490
492
  fname = self.fname
@@ -520,6 +522,9 @@ class GeotiffRasterFormat(RasterFormat):
520
522
  profile["tiled"] = True
521
523
  profile["blockxsize"] = self.block_size
522
524
  profile["blockysize"] = self.block_size
525
+ # Set nodata_val if provided.
526
+ if nodata_val is not None:
527
+ profile["nodata"] = nodata_val
523
528
 
524
529
  profile.update(self.geotiff_options)
525
530
 
@@ -535,6 +540,7 @@ class GeotiffRasterFormat(RasterFormat):
535
540
  bounds: PixelBounds,
536
541
  resampling: Resampling = Resampling.bilinear,
537
542
  fname: str | None = None,
543
+ nodata_val: int | float | None = None,
538
544
  ) -> npt.NDArray[Any]:
539
545
  """Decodes raster data.
540
546
 
@@ -544,6 +550,16 @@ class GeotiffRasterFormat(RasterFormat):
544
550
  bounds: the bounds to read in the given projection.
545
551
  resampling: resampling method to use in case resampling is needed.
546
552
  fname: override the filename to read from
553
+ nodata_val: override the nodata value in the raster when reading. Pixels in
554
+ bounds that are not present in the source raster will be initialized to
555
+ this value. Note that, if the raster specifies a nodata value, and
556
+ some source pixels have that value, they will still be read under their
557
+ original value; overriding the nodata value is primarily useful if the
558
+ user wants out of bounds pixels to have a different value from the
559
+ source pixels, e.g. if the source data has background and foreground
560
+ classes (with background being nodata) but we want to read it in a
561
+ different projection and have out of bounds pixels be a third "invalid"
562
+ value.
547
563
 
548
564
  Returns:
549
565
  the raster data
@@ -561,6 +577,7 @@ class GeotiffRasterFormat(RasterFormat):
561
577
  width=bounds[2] - bounds[0],
562
578
  height=bounds[3] - bounds[1],
563
579
  resampling=resampling,
580
+ src_nodata=nodata_val,
564
581
  ) as vrt:
565
582
  return vrt.read()
566
583
 
rslearn/utils/stac.py CHANGED
@@ -101,6 +101,7 @@ class StacClient:
101
101
  ids: list[str] | None = None,
102
102
  limit: int | None = None,
103
103
  query: dict[str, Any] | None = None,
104
+ sortby: list[dict[str, str]] | None = None,
104
105
  ) -> list[StacItem]:
105
106
  """Execute a STAC item search.
106
107
 
@@ -117,6 +118,7 @@ class StacClient:
117
118
  limit: number of items per page. We will read all the pages.
118
119
  query: query dict, if STAC query extension is supported by this API. See
119
120
  https://github.com/stac-api-extensions/query.
121
+ sortby: list of sort specifications, e.g. [{"field": "id", "direction": "asc"}].
120
122
 
121
123
  Returns:
122
124
  list of matching STAC items.
@@ -142,6 +144,8 @@ class StacClient:
142
144
  request_data["limit"] = limit
143
145
  if query is not None:
144
146
  request_data["query"] = query
147
+ if sortby is not None:
148
+ request_data["sortby"] = sortby
145
149
 
146
150
  # Handle pagination.
147
151
  cur_url = self.endpoint + "/search"
@@ -0,0 +1 @@
1
+ """Visualization module for rslearn datasets."""
@@ -0,0 +1,127 @@
1
+ """Normalization functions for raster data visualization."""
2
+
3
+ from collections.abc import Callable
4
+ from enum import StrEnum
5
+
6
+ import numpy as np
7
+
8
+ from rslearn.log_utils import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class NormalizationMethod(StrEnum):
14
+ """Normalization methods for raster data visualization."""
15
+
16
+ SENTINEL2_RGB = "sentinel2_rgb"
17
+ """Divide by 10 and clip (for Sentinel-2 B04/B03/B02 bands)."""
18
+
19
+ PERCENTILE = "percentile"
20
+ """Use 2-98 percentile clipping."""
21
+
22
+ MINMAX = "minmax"
23
+ """Use min-max stretch."""
24
+
25
+
26
+ def _normalize_sentinel2_rgb(band: np.ndarray) -> np.ndarray:
27
+ """Normalize band using Sentinel-2 RGB method (divide by 10 and clip).
28
+
29
+ Args:
30
+ band: Input band data
31
+
32
+ Returns:
33
+ Normalized band as uint8 array
34
+ """
35
+ band = band / 10.0
36
+ band = np.clip(band, 0, 255).astype(np.uint8)
37
+ return band
38
+
39
+
40
+ def _normalize_percentile(band: np.ndarray) -> np.ndarray:
41
+ """Normalize band using 2-98 percentile clipping.
42
+
43
+ Args:
44
+ band: Input band data
45
+
46
+ Returns:
47
+ Normalized band as uint8 array
48
+ """
49
+ valid_pixels = band[~np.isnan(band)]
50
+ if len(valid_pixels) == 0:
51
+ return np.zeros_like(band, dtype=np.uint8)
52
+ vmin, vmax = np.nanpercentile(valid_pixels, (2, 98))
53
+ if vmax == vmin:
54
+ return np.zeros_like(band, dtype=np.uint8)
55
+ band = np.clip(band, vmin, vmax)
56
+ band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
57
+ return band
58
+
59
+
60
+ def _normalize_minmax(band: np.ndarray) -> np.ndarray:
61
+ """Normalize band using min-max stretch.
62
+
63
+ Args:
64
+ band: Input band data
65
+
66
+ Returns:
67
+ Normalized band as uint8 array
68
+ """
69
+ vmin, vmax = np.nanmin(band), np.nanmax(band)
70
+ if vmax == vmin:
71
+ return np.zeros_like(band, dtype=np.uint8)
72
+ band = np.clip(band, vmin, vmax)
73
+ band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
74
+ return band
75
+
76
+
77
+ _NORMALIZATION_FUNCTIONS: dict[
78
+ NormalizationMethod, Callable[[np.ndarray], np.ndarray]
79
+ ] = {
80
+ NormalizationMethod.SENTINEL2_RGB: _normalize_sentinel2_rgb,
81
+ NormalizationMethod.PERCENTILE: _normalize_percentile,
82
+ NormalizationMethod.MINMAX: _normalize_minmax,
83
+ }
84
+
85
+
86
+ def normalize_band(
87
+ band: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
88
+ ) -> np.ndarray:
89
+ """Normalize band to 0-255 range.
90
+
91
+ Args:
92
+ band: Input band data
93
+ method: Normalization method (string or NormalizationMethod enum)
94
+ - 'sentinel2_rgb': Divide by 10 and clip (for B04/B03/B02)
95
+ - 'percentile': Use 2-98 percentile clipping
96
+ - 'minmax': Use min-max stretch
97
+
98
+ Returns:
99
+ Normalized band as uint8 array
100
+ """
101
+ method_enum = NormalizationMethod(method) if isinstance(method, str) else method
102
+ normalize_func = _NORMALIZATION_FUNCTIONS.get(method_enum)
103
+ if normalize_func is None:
104
+ raise ValueError(f"Unknown normalization method: {method_enum}")
105
+ return normalize_func(band)
106
+
107
+
108
+ def normalize_array(
109
+ array: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
110
+ ) -> np.ndarray:
111
+ """Normalize a multi-band array to 0-255 range.
112
+
113
+ Args:
114
+ array: Input array with shape (channels, height, width) from RasterFormat.decode_raster
115
+ method: Normalization method (applied per-band, string or NormalizationMethod enum)
116
+
117
+ Returns:
118
+ Normalized array as uint8 with shape (height, width, channels)
119
+ """
120
+ if array.ndim == 3:
121
+ array = np.moveaxis(array, 0, -1)
122
+
123
+ normalized = np.zeros_like(array, dtype=np.uint8)
124
+ for i in range(array.shape[-1]):
125
+ normalized[..., i] = normalize_band(array[..., i], method)
126
+
127
+ return normalized