geoai-py 0.18.1__py2.py3-none-any.whl → 0.19.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/tools/__init__.py CHANGED
@@ -63,3 +63,14 @@ try:
63
63
  except ImportError:
64
64
  # OmniCloudMask not installed - functions will not be available
65
65
  pass
66
+
67
+
68
+ # Super Resolution integration (optional dependency)
69
+
70
+ try:
71
+ from .sr import super_resolution
72
+
73
+ __all__.extend(["super_resolution"])
74
+ except ImportError:
75
+ # Super resolution not installed - function will not be available
76
+ pass
geoai/tools/sr.py ADDED
@@ -0,0 +1,194 @@
1
+ """
2
+ Super-resolution utilities using OpenSR latent diffusion models.
3
+
4
+ This module provides functions to perform super-resolution on multispectral
5
+ GeoTIFF images using the latent diffusion models from the ESA OpenSR project:
6
+
7
+ GitHub: https://github.com/ESAOpenSR/opensr-model.git
8
+ """
9
+
10
+ import os
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ import numpy as np
15
+ import rasterio
16
+ import requests
17
+ from rasterio.transform import Affine
18
+ from io import StringIO
19
+ from omegaconf import OmegaConf
20
+
21
+ try:
22
+ import opensr_model
23
+
24
+ OPENSR_MODEL_AVAILABLE = True
25
+ except ImportError:
26
+ OPENSR_MODEL_AVAILABLE = False
27
+
28
+
29
+ def super_resolution(
30
+ input_lr_path: str,
31
+ output_sr_path: str,
32
+ output_uncertainty_path: str,
33
+ rgb_nir_bands: list[int] = [3, 2, 1, 4], # Default example: R=3,G=2,B=1,NIR=4
34
+ sampling_steps: int = 100,
35
+ n_variations: int = 25,
36
+ scale: int = 4, # OpenSR scaling factor, e.g., 10m -> 2.5m
37
+ ) -> Tuple[np.ndarray, np.ndarray]:
38
+ """
39
+ Perform super-resolution on RGB+NIR bands of a multispectral GeoTIFF using OpenSR latent diffusion.
40
+
41
+ Args:
42
+ input_lr_path (str): Path to the input low-resolution GeoTIFF.
43
+ output_sr_path (str): Path to save the super-resolution GeoTIFF.
44
+ output_uncertainty_path (str): Path to save the uncertainty map GeoTIFF.
45
+ rgb_nir_bands (list[int]): List of 4 band indices corresponding to [R, G, B, NIR].
46
+ sampling_steps (int): Number of diffusion sampling steps. Default is 100.
47
+ n_variations (int): Number of samples to compute uncertainty. Default is 25.
48
+ scale (int, optional): Super-resolution scale factor. Default is 4.
49
+ This adjusts the affine transform to ensure georeference matches
50
+ the original image.
51
+
52
+ Returns:
53
+ Tuple[np.ndarray, np.ndarray]: Tuple containing:
54
+ - sr_image: Super-resolution image as a NumPy array (4, H, W)
55
+ - uncertainty: Uncertainty map as a NumPy array (H, W)
56
+ """
57
+ if len(rgb_nir_bands) != 4:
58
+ raise ValueError("rgb_nir_bands must be a list of 4 integers: [R, G, B, NIR]")
59
+ if not all(isinstance(b, int) for b in rgb_nir_bands):
60
+ raise ValueError(
61
+ "All elements of rgb_nir_bands must be integers. Received: {}".format(
62
+ rgb_nir_bands
63
+ )
64
+ )
65
+ if not OPENSR_MODEL_AVAILABLE:
66
+ raise ImportError(
67
+ "The 'opensr-model' package is required for super-resolution. "
68
+ "Please install it using: pip install opensr-model\n"
69
+ "Or install GeoAI with the sr optional dependency: pip install geoai-py[sr]"
70
+ )
71
+
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+
74
+ # Download configuration YAML from GitHub
75
+ config_url = "https://raw.githubusercontent.com/ESAOpenSR/opensr-model/refs/heads/main/opensr_model/configs/config_10m.yaml"
76
+ print("Downloading model configuration from:", config_url)
77
+ try:
78
+ response = requests.get(config_url)
79
+ response.raise_for_status()
80
+ except requests.RequestException as e:
81
+ print(f"Error downloading model configuration: {e}")
82
+ raise
83
+ config = OmegaConf.load(StringIO(response.text))
84
+
85
+ # Initialize latent diffusion model and load pretrained weights
86
+ model = opensr_model.SRLatentDiffusion(config, device=device)
87
+ model.load_pretrained(config.ckpt_version)
88
+
89
+ # Load only the specified RGB+NIR bands
90
+ lr_tensor, profile = load_image_tensor(
91
+ image_path=input_lr_path, device=device, bands=rgb_nir_bands
92
+ )
93
+
94
+ # Generate super-resolution tensor
95
+ sr_tensor = model.forward(lr_tensor, sampling_steps=sampling_steps)
96
+ sr_image = sr_tensor.squeeze(0).cpu().numpy().astype(np.float32)
97
+ save_geotiff(sr_image, profile, output_sr_path, scale)
98
+ print("Saved super-resolution image to:", output_sr_path)
99
+
100
+ # Compute uncertainty map
101
+ unc_tensor = model.uncertainty_map(lr_tensor, n_variations=n_variations)
102
+ uncertainty = unc_tensor.squeeze(0).cpu().numpy().astype(np.float32)
103
+ save_geotiff(uncertainty, profile, output_uncertainty_path, scale)
104
+ print("Saved uncertainty map to:", output_uncertainty_path)
105
+
106
+ return sr_image, uncertainty
107
+
108
+
109
+ def save_geotiff(
110
+ data: np.ndarray, reference_profile: dict, output_path: str, scale: int = 4
111
+ ):
112
+ """
113
+ Save a 2D or 3D NumPy array as a GeoTIFF with super-resolution scaling
114
+ and corrected georeference.
115
+
116
+ Args:
117
+ data (np.ndarray): Image array to save. Can be:
118
+ - 2D array (H, W) for a single-band image
119
+ - 3D array (C, H, W) for multi-band images (e.g., RGB+NIR)
120
+ reference_profile (dict): Rasterio metadata from a reference GeoTIFF.
121
+ Used to preserve CRS, transform, and other metadata.
122
+ output_path (str): Path to save the output GeoTIFF.
123
+ scale (int, optional): Super-resolution scale factor. Default is 4.
124
+ This adjusts the affine transform to ensure georeference matches
125
+ the original image.
126
+
127
+ Returns:
128
+ None
129
+
130
+ Note:
131
+ Writes the image to disk at the specified output path.
132
+ """
133
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
134
+
135
+ if data.ndim == 2:
136
+ data = data[np.newaxis, ...]
137
+
138
+ # Update profile and transform
139
+ profile = reference_profile.copy()
140
+ old_transform = profile["transform"]
141
+ new_transform = Affine(
142
+ old_transform.a / scale,
143
+ old_transform.b,
144
+ old_transform.c,
145
+ old_transform.d,
146
+ old_transform.e / scale,
147
+ old_transform.f,
148
+ )
149
+ profile.update(
150
+ dtype=rasterio.float32,
151
+ count=data.shape[0],
152
+ height=data.shape[1],
153
+ width=data.shape[2],
154
+ compress="lzw",
155
+ transform=new_transform,
156
+ )
157
+
158
+ with rasterio.open(output_path, "w", **profile) as dst:
159
+ dst.write(data.astype(np.float32))
160
+
161
+
162
+ def load_image_tensor(
163
+ image_path: str, device: str, bands: list[int]
164
+ ) -> Tuple[torch.Tensor, dict]:
165
+ """
166
+ Load only specified bands of a multispectral GeoTIFF as a PyTorch tensor.
167
+
168
+ Args:
169
+ image_path (str): Path to input GeoTIFF.
170
+ device (str): Device to move the tensor to ('cpu' or 'cuda').
171
+ bands (list[int]): List of 1-based band indices to read.
172
+
173
+ Returns:
174
+ Tuple[torch.Tensor, dict]: Tensor (1, C, H, W) and rasterio profile.
175
+
176
+ Raises:
177
+ FileNotFoundError: If input image does not exist.
178
+ ValueError: If any band index is out of range.
179
+ """
180
+ if not os.path.exists(image_path):
181
+ raise FileNotFoundError(f"Input image does not exist: {image_path}")
182
+
183
+ with rasterio.open(image_path) as src:
184
+ n_bands = src.count
185
+ if min(bands) < 1 or max(bands) > n_bands:
186
+ raise ValueError(
187
+ f"Input image has {n_bands} bands, requested bands {bands} out of range."
188
+ )
189
+ image = src.read(bands) # shape: (4, H, W)
190
+ profile = src.profile
191
+
192
+ image = image.astype(np.float32)
193
+ tensor = torch.from_numpy(image).unsqueeze(0).to(device)
194
+ return tensor, profile
geoai/train.py CHANGED
@@ -2015,17 +2015,39 @@ def get_semantic_transform(train: bool) -> Any:
2015
2015
  """
2016
2016
  Get transforms for semantic segmentation data augmentation.
2017
2017
 
2018
+ This function returns default data augmentation transforms for training
2019
+ semantic segmentation models. The transforms include geometric transformations
2020
+ (horizontal/vertical flips, rotations) and photometric adjustments (brightness,
2021
+ contrast) that are commonly used in remote sensing tasks.
2022
+
2018
2023
  Args:
2019
2024
  train (bool): Whether to include training-specific transforms.
2025
+ If True, applies augmentations (flips, rotations, brightness/contrast adjustments).
2026
+ If False, only converts to tensor (for validation).
2020
2027
 
2021
2028
  Returns:
2022
2029
  SemanticTransforms: Composed transforms.
2030
+
2031
+ Example:
2032
+ >>> train_transform = get_semantic_transform(train=True)
2033
+ >>> val_transform = get_semantic_transform(train=False)
2023
2034
  """
2024
2035
  transforms = []
2025
2036
  transforms.append(SemanticToTensor())
2026
2037
 
2027
2038
  if train:
2039
+ # Geometric transforms - preserve spatial structure
2028
2040
  transforms.append(SemanticRandomHorizontalFlip(0.5))
2041
+ transforms.append(SemanticRandomVerticalFlip(0.5))
2042
+ transforms.append(SemanticRandomRotation90(0.5))
2043
+
2044
+ # Photometric transforms - improve model robustness
2045
+ transforms.append(
2046
+ SemanticBrightnessAdjustment(brightness_range=(0.8, 1.2), prob=0.5)
2047
+ )
2048
+ transforms.append(
2049
+ SemanticContrastAdjustment(contrast_range=(0.8, 1.2), prob=0.5)
2050
+ )
2029
2051
 
2030
2052
  return SemanticTransforms(transforms)
2031
2053