rslearn 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/local_files.py +2 -2
  3. rslearn/data_sources/utils.py +204 -64
  4. rslearn/dataset/materialize.py +5 -1
  5. rslearn/models/clay/clay.py +3 -3
  6. rslearn/models/detr/detr.py +4 -1
  7. rslearn/models/dinov3.py +0 -1
  8. rslearn/models/olmoearth_pretrain/model.py +3 -1
  9. rslearn/models/pooling_decoder.py +1 -1
  10. rslearn/models/prithvi.py +0 -1
  11. rslearn/models/simple_time_series.py +97 -35
  12. rslearn/train/data_module.py +5 -0
  13. rslearn/train/dataset.py +186 -49
  14. rslearn/train/dataset_index.py +156 -0
  15. rslearn/train/model_context.py +16 -0
  16. rslearn/train/tasks/detection.py +1 -18
  17. rslearn/train/tasks/per_pixel_regression.py +13 -13
  18. rslearn/train/tasks/segmentation.py +27 -32
  19. rslearn/train/transforms/concatenate.py +17 -27
  20. rslearn/train/transforms/crop.py +8 -19
  21. rslearn/train/transforms/flip.py +4 -10
  22. rslearn/train/transforms/mask.py +9 -15
  23. rslearn/train/transforms/normalize.py +31 -82
  24. rslearn/train/transforms/pad.py +7 -13
  25. rslearn/train/transforms/resize.py +5 -22
  26. rslearn/train/transforms/select_bands.py +16 -36
  27. rslearn/train/transforms/sentinel1.py +4 -16
  28. rslearn/utils/colors.py +20 -0
  29. rslearn/vis/__init__.py +1 -0
  30. rslearn/vis/normalization.py +127 -0
  31. rslearn/vis/render_raster_label.py +96 -0
  32. rslearn/vis/render_sensor_image.py +27 -0
  33. rslearn/vis/render_vector_label.py +439 -0
  34. rslearn/vis/utils.py +99 -0
  35. rslearn/vis/vis_server.py +574 -0
  36. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/METADATA +14 -1
  37. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/RECORD +42 -33
  38. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/WHEEL +1 -1
  39. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/entry_points.txt +0 -0
  40. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/LICENSE +0 -0
  41. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/licenses/NOTICE +0 -0
  42. {rslearn-0.0.24.dist-info → rslearn-0.0.26.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """Normalization transforms."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
6
  import torch
@@ -35,14 +36,17 @@ class Normalize(Transform):
35
36
  bands: optionally restrict the normalization to these band indices. If set,
36
37
  mean and std must either be one value, or have length equal to the
37
38
  number of band indices passed here.
38
- num_bands: the number of bands per image, to distinguish different images
39
- in a time series. If set, then the bands list is repeated for each
40
- image, e.g. if bands=[2] then we apply normalization on images[2],
41
- images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
42
- is not set, then we apply the mean and std on each image in the time
43
- series.
39
+ num_bands: deprecated, no longer used. Will be removed after 2026-04-01.
44
40
  """
45
41
  super().__init__()
42
+
43
+ if num_bands is not None:
44
+ warnings.warn(
45
+ "num_bands is deprecated and no longer used. "
46
+ "It will be removed after 2026-04-01.",
47
+ FutureWarning,
48
+ )
49
+
46
50
  self.mean = torch.tensor(mean)
47
51
  self.std = torch.tensor(std)
48
52
 
@@ -55,92 +59,37 @@ class Normalize(Transform):
55
59
 
56
60
  self.selectors = selectors
57
61
  self.bands = torch.tensor(bands) if bands is not None else None
58
- self.num_bands = num_bands
59
62
 
60
- def apply_image(
61
- self, image: torch.Tensor | RasterImage
62
- ) -> torch.Tensor | RasterImage:
63
+ def apply_image(self, image: RasterImage) -> RasterImage:
63
64
  """Normalize the specified image.
64
65
 
65
66
  Args:
66
67
  image: the image to transform.
67
68
  """
68
-
69
- def _repeat_mean_and_std(
70
- image_channels: int, num_bands: int | None, is_raster_image: bool
71
- ) -> tuple[torch.Tensor, torch.Tensor]:
72
- """Get mean and std tensor that are suitable for applying on the image."""
73
- # We only need to repeat the tensor if both of these are true:
74
- # - The mean/std are not just one scalar.
75
- # - self.num_bands is set, otherwise we treat the input as a single image.
76
- if len(self.mean.shape) == 0:
77
- return self.mean, self.std
78
- if num_bands is None:
79
- return self.mean, self.std
80
- num_images = image_channels // num_bands
81
- if is_raster_image:
82
- # add an extra T dimension, CTHW
83
- return self.mean.repeat(num_images)[
84
- :, None, None, None
85
- ], self.std.repeat(num_images)[:, None, None, None]
86
- else:
87
- # add an extra T dimension, CTHW
88
- return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
89
- num_images
90
- )[:, None, None]
69
+ # Get mean/std with singleton dims for broadcasting over CTHW.
70
+ if len(self.mean.shape) == 0:
71
+ # Scalar - broadcasts naturally.
72
+ mean, std = self.mean, self.std
73
+ else:
74
+ # Vector of length C - add singleton dims for T, H, W.
75
+ mean = self.mean[:, None, None, None]
76
+ std = self.std[:, None, None, None]
91
77
 
92
78
  if self.bands is not None:
93
- # User has provided band indices to normalize.
94
- # If num_bands is set, then we repeat these for each image in the input
95
- # image time series.
96
- band_indices = self.bands
97
- if self.num_bands:
98
- num_images = image.shape[0] // self.num_bands
99
- band_indices = torch.cat(
100
- [
101
- band_indices + image_idx * self.num_bands
102
- for image_idx in range(num_images)
103
- ],
104
- dim=0,
79
+ # Normalize only specific band indices.
80
+ image.image[self.bands] = (image.image[self.bands] - mean) / std
81
+ if self.valid_min is not None:
82
+ image.image[self.bands] = torch.clamp(
83
+ image.image[self.bands],
84
+ min=self.valid_min,
85
+ max=self.valid_max,
105
86
  )
106
-
107
- # We use len(self.bands) here because that is how many bands per timestep
108
- # we are actually processing with the mean/std.
109
- mean, std = _repeat_mean_and_std(
110
- image_channels=len(band_indices),
111
- num_bands=len(self.bands),
112
- is_raster_image=isinstance(image, RasterImage),
113
- )
114
- if isinstance(image, torch.Tensor):
115
- image[band_indices] = (image[band_indices] - mean) / std
116
- if self.valid_min is not None:
117
- image[band_indices] = torch.clamp(
118
- image[band_indices], min=self.valid_min, max=self.valid_max
119
- )
120
- else:
121
- image.image[band_indices] = (image.image[band_indices] - mean) / std
122
- if self.valid_min is not None:
123
- image.image[band_indices] = torch.clamp(
124
- image.image[band_indices],
125
- min=self.valid_min,
126
- max=self.valid_max,
127
- )
128
87
  else:
129
- mean, std = _repeat_mean_and_std(
130
- image_channels=image.shape[0],
131
- num_bands=self.num_bands,
132
- is_raster_image=isinstance(image, RasterImage),
133
- )
134
- if isinstance(image, torch.Tensor):
135
- image = (image - mean) / std
136
- if self.valid_min is not None:
137
- image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
138
- else:
139
- image.image = (image.image - mean) / std
140
- if self.valid_min is not None:
141
- image.image = torch.clamp(
142
- image.image, min=self.valid_min, max=self.valid_max
143
- )
88
+ image.image = (image.image - mean) / std
89
+ if self.valid_min is not None:
90
+ image.image = torch.clamp(
91
+ image.image, min=self.valid_min, max=self.valid_max
92
+ )
144
93
  return image
145
94
 
146
95
  def forward(
@@ -50,9 +50,7 @@ class Pad(Transform):
50
50
  """
51
51
  return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
52
52
 
53
- def apply_image(
54
- self, image: RasterImage | torch.Tensor, state: dict[str, bool]
55
- ) -> RasterImage | torch.Tensor:
53
+ def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
56
54
  """Apply the sampled state on the specified image.
57
55
 
58
56
  Args:
@@ -105,16 +103,12 @@ class Pad(Transform):
105
103
  horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
106
104
  vertical_pad = (vertical_half, vertical_extra - vertical_half)
107
105
 
108
- if isinstance(image, RasterImage):
109
- image.image = apply_padding(
110
- image.image, True, horizontal_pad[0], horizontal_pad[1]
111
- )
112
- image.image = apply_padding(
113
- image.image, False, vertical_pad[0], vertical_pad[1]
114
- )
115
- else:
116
- image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
117
- image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
106
+ image.image = apply_padding(
107
+ image.image, True, horizontal_pad[0], horizontal_pad[1]
108
+ )
109
+ image.image = apply_padding(
110
+ image.image, False, vertical_pad[0], vertical_pad[1]
111
+ )
118
112
  return image
119
113
 
120
114
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -2,7 +2,6 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
6
5
  import torchvision
7
6
  from torchvision.transforms import InterpolationMode
8
7
 
@@ -40,32 +39,16 @@ class Resize(Transform):
40
39
  self.selectors = selectors
41
40
  self.interpolation = INTERPOLATION_MODES[interpolation]
42
41
 
43
- def apply_resize(
44
- self, image: torch.Tensor | RasterImage
45
- ) -> torch.Tensor | RasterImage:
42
+ def apply_resize(self, image: RasterImage) -> RasterImage:
46
43
  """Apply resizing on the specified image.
47
44
 
48
- If the image is 2D, it is unsqueezed to 3D and then squeezed
49
- back after resizing.
50
-
51
45
  Args:
52
46
  image: the image to transform.
53
47
  """
54
- if isinstance(image, torch.Tensor):
55
- if image.dim() == 2:
56
- image = image.unsqueeze(0) # (H, W) -> (1, H, W)
57
- result = torchvision.transforms.functional.resize(
58
- image, self.target_size, self.interpolation
59
- )
60
- return result.squeeze(0) # (1, H, W) -> (H, W)
61
- return torchvision.transforms.functional.resize(
62
- image, self.target_size, self.interpolation
63
- )
64
- else:
65
- image.image = torchvision.transforms.functional.resize(
66
- image.image, self.target_size, self.interpolation
67
- )
68
- return image
48
+ image.image = torchvision.transforms.functional.resize(
49
+ image.image, self.target_size, self.interpolation
50
+ )
51
+ return image
69
52
 
70
53
  def forward(
71
54
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -1,9 +1,8 @@
1
1
  """The SelectBands transform."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
- from rslearn.train.model_context import RasterImage
6
-
7
6
  from .transform import Transform, read_selector, write_selector
8
7
 
9
8
 
@@ -17,60 +16,41 @@ class SelectBands(Transform):
17
16
  output_selector: str = "image",
18
17
  num_bands_per_timestep: int | None = None,
19
18
  ):
20
- """Initialize a new Concatenate.
19
+ """Initialize a new SelectBands.
21
20
 
22
21
  Args:
23
- band_indices: the bands to select.
22
+ band_indices: the bands to select from the channel dimension.
24
23
  input_selector: the selector to read the input image.
25
24
  output_selector: the output selector under which to save the output image.
26
- num_bands_per_timestep: the number of bands per image, to distinguish
27
- between stacked images in an image time series. If set, then the
28
- band_indices are selected for each image in the time series.
25
+ num_bands_per_timestep: deprecated, no longer used. Will be removed after
26
+ 2026-04-01.
29
27
  """
30
28
  super().__init__()
29
+
30
+ if num_bands_per_timestep is not None:
31
+ warnings.warn(
32
+ "num_bands_per_timestep is deprecated and no longer used. "
33
+ "It will be removed after 2026-04-01.",
34
+ FutureWarning,
35
+ )
36
+
31
37
  self.input_selector = input_selector
32
38
  self.output_selector = output_selector
33
39
  self.band_indices = band_indices
34
- self.num_bands_per_timestep = num_bands_per_timestep
35
40
 
36
41
  def forward(
37
42
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
38
43
  ) -> tuple[dict[str, Any], dict[str, Any]]:
39
- """Apply concatenation over the inputs and targets.
44
+ """Apply band selection over the inputs and targets.
40
45
 
41
46
  Args:
42
47
  input_dict: the input
43
48
  target_dict: the target
44
49
 
45
50
  Returns:
46
- normalized (input_dicts, target_dicts) tuple
51
+ (input_dicts, target_dicts) tuple with selected bands
47
52
  """
48
53
  image = read_selector(input_dict, target_dict, self.input_selector)
49
- num_bands_per_timestep = (
50
- self.num_bands_per_timestep
51
- if self.num_bands_per_timestep is not None
52
- else image.shape[0]
53
- )
54
- if isinstance(image, RasterImage):
55
- assert num_bands_per_timestep == image.shape[0], (
56
- "Expect a seperate dimension for timesteps in RasterImages."
57
- )
58
-
59
- if image.shape[0] % num_bands_per_timestep != 0:
60
- raise ValueError(
61
- f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
62
- )
63
-
64
- # Copy the band indices for each timestep in the input.
65
- wanted_bands: list[int] = []
66
- for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
67
- wanted_bands.extend(
68
- [(start_channel_idx + band_idx) for band_idx in self.band_indices]
69
- )
70
-
71
- if isinstance(image, RasterImage):
72
- image.image = image.image[wanted_bands]
73
- else:
74
- image = image[wanted_bands]
54
+ image.image = image.image[self.band_indices]
75
55
  write_selector(input_dict, target_dict, self.output_selector, image)
76
56
  return input_dict, target_dict
@@ -33,31 +33,19 @@ class Sentinel1ToDecibels(Transform):
33
33
  self.from_decibels = from_decibels
34
34
  self.epsilon = epsilon
35
35
 
36
- def apply_image(
37
- self, image: torch.Tensor | RasterImage
38
- ) -> torch.Tensor | RasterImage:
36
+ def apply_image(self, image: RasterImage) -> RasterImage:
39
37
  """Normalize the specified image.
40
38
 
41
39
  Args:
42
40
  image: the image to transform.
43
41
  """
44
- if isinstance(image, torch.Tensor):
45
- image_to_process = image
46
- else:
47
- image_to_process = image.image
48
42
  if self.from_decibels:
49
43
  # Decibels to linear scale.
50
- image_to_process = torch.pow(10.0, image_to_process / 10.0)
44
+ image.image = torch.pow(10.0, image.image / 10.0)
51
45
  else:
52
46
  # Linear scale to decibels.
53
- image_to_process = 10 * torch.log10(
54
- torch.clamp(image_to_process, min=self.epsilon)
55
- )
56
- if isinstance(image, torch.Tensor):
57
- return image_to_process
58
- else:
59
- image.image = image_to_process
60
- return image
47
+ image.image = 10 * torch.log10(torch.clamp(image.image, min=self.epsilon))
48
+ return image
61
49
 
62
50
  def forward(
63
51
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -0,0 +1,20 @@
1
+ """Default color palette for visualizations."""
2
+
3
+ DEFAULT_COLORS = [
4
+ (0, 0, 0),
5
+ (255, 0, 0),
6
+ (0, 255, 0),
7
+ (0, 0, 255),
8
+ (255, 255, 0),
9
+ (0, 255, 255),
10
+ (255, 0, 255),
11
+ (0, 128, 0),
12
+ (255, 160, 122),
13
+ (139, 69, 19),
14
+ (128, 128, 128),
15
+ (255, 255, 255),
16
+ (143, 188, 143),
17
+ (95, 158, 160),
18
+ (255, 200, 0),
19
+ (128, 0, 0),
20
+ ]
@@ -0,0 +1 @@
1
+ """Visualization module for rslearn datasets."""
@@ -0,0 +1,127 @@
1
+ """Normalization functions for raster data visualization."""
2
+
3
+ from collections.abc import Callable
4
+ from enum import StrEnum
5
+
6
+ import numpy as np
7
+
8
+ from rslearn.log_utils import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class NormalizationMethod(StrEnum):
14
+ """Normalization methods for raster data visualization."""
15
+
16
+ SENTINEL2_RGB = "sentinel2_rgb"
17
+ """Divide by 10 and clip (for Sentinel-2 B04/B03/B02 bands)."""
18
+
19
+ PERCENTILE = "percentile"
20
+ """Use 2-98 percentile clipping."""
21
+
22
+ MINMAX = "minmax"
23
+ """Use min-max stretch."""
24
+
25
+
26
+ def _normalize_sentinel2_rgb(band: np.ndarray) -> np.ndarray:
27
+ """Normalize band using Sentinel-2 RGB method (divide by 10 and clip).
28
+
29
+ Args:
30
+ band: Input band data
31
+
32
+ Returns:
33
+ Normalized band as uint8 array
34
+ """
35
+ band = band / 10.0
36
+ band = np.clip(band, 0, 255).astype(np.uint8)
37
+ return band
38
+
39
+
40
+ def _normalize_percentile(band: np.ndarray) -> np.ndarray:
41
+ """Normalize band using 2-98 percentile clipping.
42
+
43
+ Args:
44
+ band: Input band data
45
+
46
+ Returns:
47
+ Normalized band as uint8 array
48
+ """
49
+ valid_pixels = band[~np.isnan(band)]
50
+ if len(valid_pixels) == 0:
51
+ return np.zeros_like(band, dtype=np.uint8)
52
+ vmin, vmax = np.nanpercentile(valid_pixels, (2, 98))
53
+ if vmax == vmin:
54
+ return np.zeros_like(band, dtype=np.uint8)
55
+ band = np.clip(band, vmin, vmax)
56
+ band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
57
+ return band
58
+
59
+
60
+ def _normalize_minmax(band: np.ndarray) -> np.ndarray:
61
+ """Normalize band using min-max stretch.
62
+
63
+ Args:
64
+ band: Input band data
65
+
66
+ Returns:
67
+ Normalized band as uint8 array
68
+ """
69
+ vmin, vmax = np.nanmin(band), np.nanmax(band)
70
+ if vmax == vmin:
71
+ return np.zeros_like(band, dtype=np.uint8)
72
+ band = np.clip(band, vmin, vmax)
73
+ band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
74
+ return band
75
+
76
+
77
+ _NORMALIZATION_FUNCTIONS: dict[
78
+ NormalizationMethod, Callable[[np.ndarray], np.ndarray]
79
+ ] = {
80
+ NormalizationMethod.SENTINEL2_RGB: _normalize_sentinel2_rgb,
81
+ NormalizationMethod.PERCENTILE: _normalize_percentile,
82
+ NormalizationMethod.MINMAX: _normalize_minmax,
83
+ }
84
+
85
+
86
+ def normalize_band(
87
+ band: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
88
+ ) -> np.ndarray:
89
+ """Normalize band to 0-255 range.
90
+
91
+ Args:
92
+ band: Input band data
93
+ method: Normalization method (string or NormalizationMethod enum)
94
+ - 'sentinel2_rgb': Divide by 10 and clip (for B04/B03/B02)
95
+ - 'percentile': Use 2-98 percentile clipping
96
+ - 'minmax': Use min-max stretch
97
+
98
+ Returns:
99
+ Normalized band as uint8 array
100
+ """
101
+ method_enum = NormalizationMethod(method) if isinstance(method, str) else method
102
+ normalize_func = _NORMALIZATION_FUNCTIONS.get(method_enum)
103
+ if normalize_func is None:
104
+ raise ValueError(f"Unknown normalization method: {method_enum}")
105
+ return normalize_func(band)
106
+
107
+
108
+ def normalize_array(
109
+ array: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
110
+ ) -> np.ndarray:
111
+ """Normalize a multi-band array to 0-255 range.
112
+
113
+ Args:
114
+ array: Input array with shape (channels, height, width) from RasterFormat.decode_raster
115
+ method: Normalization method (applied per-band, string or NormalizationMethod enum)
116
+
117
+ Returns:
118
+ Normalized array as uint8 with shape (height, width, channels)
119
+ """
120
+ if array.ndim == 3:
121
+ array = np.moveaxis(array, 0, -1)
122
+
123
+ normalized = np.zeros_like(array, dtype=np.uint8)
124
+ for i in range(array.shape[-1]):
125
+ normalized[..., i] = normalize_band(array[..., i], method)
126
+
127
+ return normalized
@@ -0,0 +1,96 @@
1
+ """Functions for rendering raster label masks (e.g., segmentation masks)."""
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from rasterio.warp import Resampling
6
+
7
+ from rslearn.config import DType, LayerConfig
8
+ from rslearn.dataset import Window
9
+ from rslearn.log_utils import get_logger
10
+ from rslearn.train.dataset import DataInput, read_raster_layer_for_data_input
11
+ from rslearn.utils.geometry import PixelBounds, ResolutionFactor
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def render_raster_label(
17
+ label_array: np.ndarray,
18
+ label_colors: dict[str, tuple[int, int, int]],
19
+ layer_config: LayerConfig,
20
+ ) -> np.ndarray:
21
+ """Render a raster label array as a colored mask numpy array.
22
+
23
+ Args:
24
+ label_array: Raster label array with shape (bands, height, width) - typically single band
25
+ label_colors: Dictionary mapping label class names to RGB color tuples
26
+ layer_config: LayerConfig object (to access class_names if available)
27
+
28
+ Returns:
29
+ Array with shape (height, width, 3) as uint8
30
+ """
31
+ if label_array.ndim == 3:
32
+ label_values = label_array[0, :, :]
33
+ else:
34
+ label_values = label_array
35
+
36
+ height, width = label_values.shape
37
+ mask_img = np.zeros((height, width, 3), dtype=np.uint8)
38
+ valid_mask = ~np.isnan(label_values)
39
+
40
+ if not layer_config.class_names:
41
+ raise ValueError(
42
+ "class_names must be specified in config for raster label layer"
43
+ )
44
+
45
+ label_int = label_values.astype(np.int32)
46
+ for idx in range(len(layer_config.class_names)):
47
+ class_name = layer_config.class_names[idx]
48
+ color = label_colors.get(str(class_name), (0, 0, 0))
49
+ mask = (label_int == idx) & valid_mask
50
+ mask_img[mask] = color
51
+
52
+ img = Image.fromarray(mask_img, mode="RGB")
53
+ return np.array(img)
54
+
55
+
56
+ def read_raster_layer(
57
+ window: Window,
58
+ layer_name: str,
59
+ layer_config: LayerConfig,
60
+ band_names: list[str],
61
+ group_idx: int = 0,
62
+ bounds: PixelBounds | None = None,
63
+ ) -> np.ndarray:
64
+ """Read a raster layer for visualization.
65
+
66
+ This reads bands from potentially multiple band sets to get the requested bands.
67
+ Uses read_raster_layer_for_data_input from rslearn.train.dataset.
68
+
69
+ Args:
70
+ window: The window to read from
71
+ layer_name: The layer name
72
+ layer_config: The layer configuration
73
+ band_names: List of band names to read (e.g., ["B04", "B03", "B02"])
74
+ group_idx: The item group index (default 0)
75
+ bounds: Optional bounds to read. If None, uses window.bounds
76
+
77
+ Returns:
78
+ Array with shape (bands, height, width) as float32
79
+ """
80
+ if bounds is None:
81
+ bounds = window.bounds
82
+
83
+ data_input = DataInput(
84
+ data_type="raster",
85
+ layers=[layer_name],
86
+ bands=band_names,
87
+ dtype=DType.FLOAT32,
88
+ resolution_factor=ResolutionFactor(), # Default 1/1, no scaling
89
+ resampling=Resampling.nearest,
90
+ )
91
+
92
+ image_tensor = read_raster_layer_for_data_input(
93
+ window, bounds, layer_name, group_idx, layer_config, data_input
94
+ )
95
+
96
+ return image_tensor.numpy().astype(np.float32)
@@ -0,0 +1,27 @@
1
+ """Functions for rendering raster sensor images (e.g., Sentinel-2, Landsat)."""
2
+
3
+ import numpy as np
4
+
5
+ from .normalization import normalize_array
6
+
7
+
8
+ def render_sensor_image(
9
+ array: np.ndarray,
10
+ normalization_method: str,
11
+ ) -> np.ndarray:
12
+ """Render a raster sensor image array as a numpy array.
13
+
14
+ Args:
15
+ array: Array with shape (channels, height, width) from RasterFormat.decode_raster
16
+ normalization_method: Normalization method to apply
17
+
18
+ Returns:
19
+ Array with shape (height, width, channels) as uint8
20
+ """
21
+ normalized = normalize_array(array, normalization_method)
22
+
23
+ # If more than 3 channels, take only the first 3 for RGB
24
+ if normalized.shape[-1] > 3:
25
+ normalized = normalized[:, :, :3]
26
+
27
+ return normalized