geoai-py 0.17.0__py2.py3-none-any.whl → 0.18.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +16 -1
- geoai/agents/geo_agents.py +11 -3
- geoai/change_detection.py +16 -1
- geoai/timm_segment.py +4 -1
- geoai/tools/__init__.py +65 -0
- geoai/tools/cloudmask.py +431 -0
- geoai/tools/multiclean.py +357 -0
- geoai/train.py +123 -6
- {geoai_py-0.17.0.dist-info → geoai_py-0.18.0.dist-info}/METADATA +5 -2
- {geoai_py-0.17.0.dist-info → geoai_py-0.18.0.dist-info}/RECORD +14 -11
- {geoai_py-0.17.0.dist-info → geoai_py-0.18.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.17.0.dist-info → geoai_py-0.18.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.17.0.dist-info → geoai_py-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.17.0.dist-info → geoai_py-0.18.0.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
__author__ = """Qiusheng Wu"""
|
|
4
4
|
__email__ = "giswqs@gmail.com"
|
|
5
|
-
__version__ = "0.
|
|
5
|
+
__version__ = "0.18.0"
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
import os
|
|
@@ -121,3 +121,18 @@ from .timm_segment import (
|
|
|
121
121
|
timm_semantic_segmentation,
|
|
122
122
|
push_timm_model_to_hub,
|
|
123
123
|
)
|
|
124
|
+
|
|
125
|
+
# Import tools subpackage
|
|
126
|
+
from . import tools
|
|
127
|
+
|
|
128
|
+
# Expose commonly used tools at package level for convenience
|
|
129
|
+
try:
|
|
130
|
+
from .tools import (
|
|
131
|
+
clean_segmentation_mask,
|
|
132
|
+
clean_raster,
|
|
133
|
+
clean_raster_batch,
|
|
134
|
+
compare_masks,
|
|
135
|
+
)
|
|
136
|
+
except ImportError:
|
|
137
|
+
# MultiClean not available (missing dependency)
|
|
138
|
+
pass
|
geoai/agents/geo_agents.py
CHANGED
|
@@ -55,7 +55,9 @@ class UICallbackHandler:
|
|
|
55
55
|
# Make tool names more user-friendly
|
|
56
56
|
friendly_name = tool_name.replace("_", " ").title()
|
|
57
57
|
self.status_widget.value = (
|
|
58
|
-
f"<span style='color:#0a7'
|
|
58
|
+
f"<span style='color:#0a7'>"
|
|
59
|
+
f"<i class='fas fa-spinner fa-spin' style='font-size:1.2em'></i> "
|
|
60
|
+
f"{friendly_name}...</span>"
|
|
59
61
|
)
|
|
60
62
|
|
|
61
63
|
|
|
@@ -396,7 +398,10 @@ class GeoAgent(Agent):
|
|
|
396
398
|
btn_clear = widgets.Button(
|
|
397
399
|
description="Clear", icon="trash", layout=widgets.Layout(width="120px")
|
|
398
400
|
)
|
|
399
|
-
status = widgets.HTML(
|
|
401
|
+
status = widgets.HTML(
|
|
402
|
+
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css'>"
|
|
403
|
+
"<span style='color:#666'>Ready.</span>"
|
|
404
|
+
)
|
|
400
405
|
|
|
401
406
|
examples = widgets.Dropdown(
|
|
402
407
|
options=[
|
|
@@ -1017,7 +1022,10 @@ CRITICAL: Return ONLY JSON. NO explanatory text, NO made-up data."""
|
|
|
1017
1022
|
btn_clear = widgets.Button(
|
|
1018
1023
|
description="Clear", icon="trash", layout=widgets.Layout(width="120px")
|
|
1019
1024
|
)
|
|
1020
|
-
status = widgets.HTML(
|
|
1025
|
+
status = widgets.HTML(
|
|
1026
|
+
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css'>"
|
|
1027
|
+
"<span style='color:#666'>Ready to search.</span>"
|
|
1028
|
+
)
|
|
1021
1029
|
|
|
1022
1030
|
examples = widgets.Dropdown(
|
|
1023
1031
|
options=[
|
geoai/change_detection.py
CHANGED
|
@@ -13,7 +13,8 @@ from skimage.transform import resize
|
|
|
13
13
|
try:
|
|
14
14
|
from torchange.models.segment_any_change import AnyChange, show_change_masks
|
|
15
15
|
except ImportError:
|
|
16
|
-
|
|
16
|
+
AnyChange = None
|
|
17
|
+
show_change_masks = None
|
|
17
18
|
|
|
18
19
|
from .utils import download_file
|
|
19
20
|
|
|
@@ -36,6 +37,13 @@ class ChangeDetection:
|
|
|
36
37
|
|
|
37
38
|
def _init_model(self):
|
|
38
39
|
"""Initialize the AnyChange model."""
|
|
40
|
+
if AnyChange is None:
|
|
41
|
+
raise ImportError(
|
|
42
|
+
"The 'torchange' package is required for change detection. "
|
|
43
|
+
"Please install it using: pip install torchange\n"
|
|
44
|
+
"Note: torchange requires Python 3.11 or higher."
|
|
45
|
+
)
|
|
46
|
+
|
|
39
47
|
if self.sam_checkpoint is None:
|
|
40
48
|
self.sam_checkpoint = download_checkpoint(self.sam_model_type)
|
|
41
49
|
|
|
@@ -551,6 +559,13 @@ class ChangeDetection:
|
|
|
551
559
|
Returns:
|
|
552
560
|
matplotlib.figure.Figure: The figure object
|
|
553
561
|
"""
|
|
562
|
+
if show_change_masks is None:
|
|
563
|
+
raise ImportError(
|
|
564
|
+
"The 'torchange' package is required for change detection visualization. "
|
|
565
|
+
"Please install it using: pip install torchange\n"
|
|
566
|
+
"Note: torchange requires Python 3.11 or higher."
|
|
567
|
+
)
|
|
568
|
+
|
|
554
569
|
change_masks, img1, img2 = self.detect_changes(
|
|
555
570
|
image1_path, image2_path, return_results=True
|
|
556
571
|
)
|
geoai/timm_segment.py
CHANGED
|
@@ -241,7 +241,10 @@ class TimmSegmentationModel(pl.LightningModule):
|
|
|
241
241
|
)
|
|
242
242
|
|
|
243
243
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
244
|
-
optimizer,
|
|
244
|
+
optimizer,
|
|
245
|
+
mode="min",
|
|
246
|
+
factor=0.5,
|
|
247
|
+
patience=5,
|
|
245
248
|
)
|
|
246
249
|
|
|
247
250
|
return {
|
geoai/tools/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GeoAI Tools - Utility functions and integrations for geospatial AI workflows.
|
|
3
|
+
|
|
4
|
+
This subpackage contains various tools and integrations for enhancing
|
|
5
|
+
geospatial AI workflows, including post-processing utilities and
|
|
6
|
+
third-party library integrations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
__all__ = []
|
|
10
|
+
|
|
11
|
+
# MultiClean integration (optional dependency)
|
|
12
|
+
try:
|
|
13
|
+
from .multiclean import (
|
|
14
|
+
clean_segmentation_mask,
|
|
15
|
+
clean_raster,
|
|
16
|
+
clean_raster_batch,
|
|
17
|
+
compare_masks,
|
|
18
|
+
check_multiclean_available,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__.extend(
|
|
22
|
+
[
|
|
23
|
+
"clean_segmentation_mask",
|
|
24
|
+
"clean_raster",
|
|
25
|
+
"clean_raster_batch",
|
|
26
|
+
"compare_masks",
|
|
27
|
+
"check_multiclean_available",
|
|
28
|
+
]
|
|
29
|
+
)
|
|
30
|
+
except ImportError:
|
|
31
|
+
# MultiClean not installed - functions will not be available
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
# OmniCloudMask integration (optional dependency)
|
|
35
|
+
try:
|
|
36
|
+
from .cloudmask import (
|
|
37
|
+
predict_cloud_mask,
|
|
38
|
+
predict_cloud_mask_from_raster,
|
|
39
|
+
predict_cloud_mask_batch,
|
|
40
|
+
calculate_cloud_statistics,
|
|
41
|
+
create_cloud_free_mask,
|
|
42
|
+
check_omnicloudmask_available,
|
|
43
|
+
CLEAR,
|
|
44
|
+
THICK_CLOUD,
|
|
45
|
+
THIN_CLOUD,
|
|
46
|
+
CLOUD_SHADOW,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
__all__.extend(
|
|
50
|
+
[
|
|
51
|
+
"predict_cloud_mask",
|
|
52
|
+
"predict_cloud_mask_from_raster",
|
|
53
|
+
"predict_cloud_mask_batch",
|
|
54
|
+
"calculate_cloud_statistics",
|
|
55
|
+
"create_cloud_free_mask",
|
|
56
|
+
"check_omnicloudmask_available",
|
|
57
|
+
"CLEAR",
|
|
58
|
+
"THICK_CLOUD",
|
|
59
|
+
"THIN_CLOUD",
|
|
60
|
+
"CLOUD_SHADOW",
|
|
61
|
+
]
|
|
62
|
+
)
|
|
63
|
+
except ImportError:
|
|
64
|
+
# OmniCloudMask not installed - functions will not be available
|
|
65
|
+
pass
|
geoai/tools/cloudmask.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OmniCloudMask integration for cloud and cloud shadow detection in satellite imagery.
|
|
3
|
+
|
|
4
|
+
This module provides functions to use OmniCloudMask (https://github.com/DPIRD-DMA/OmniCloudMask)
|
|
5
|
+
for detecting clouds and cloud shadows in satellite imagery. OmniCloudMask performs semantic
|
|
6
|
+
segmentation to classify pixels into: Clear (0), Thick Cloud (1), Thin Cloud (2), Cloud Shadow (3).
|
|
7
|
+
|
|
8
|
+
Supports Sentinel-2, Landsat 8, PlanetScope, and Maxar imagery at 10-50m resolution.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
from typing import Optional, List, Tuple, Dict, Any
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from omnicloudmask import predict_from_array
|
|
17
|
+
|
|
18
|
+
OMNICLOUDMASK_AVAILABLE = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
OMNICLOUDMASK_AVAILABLE = False
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import rasterio
|
|
24
|
+
|
|
25
|
+
RASTERIO_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
RASTERIO_AVAILABLE = False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Cloud mask class values
|
|
31
|
+
CLEAR = 0
|
|
32
|
+
THICK_CLOUD = 1
|
|
33
|
+
THIN_CLOUD = 2
|
|
34
|
+
CLOUD_SHADOW = 3
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def check_omnicloudmask_available():
|
|
38
|
+
"""
|
|
39
|
+
Check if omnicloudmask is installed.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ImportError: If omnicloudmask is not installed.
|
|
43
|
+
"""
|
|
44
|
+
if not OMNICLOUDMASK_AVAILABLE:
|
|
45
|
+
raise ImportError(
|
|
46
|
+
"omnicloudmask is not installed. "
|
|
47
|
+
"Please install it with: pip install omnicloudmask "
|
|
48
|
+
"or: pip install geoai-py[extra]"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def predict_cloud_mask(
|
|
53
|
+
image: np.ndarray,
|
|
54
|
+
batch_size: int = 1,
|
|
55
|
+
inference_device: str = "cpu",
|
|
56
|
+
inference_dtype: str = "fp32",
|
|
57
|
+
patch_size: int = 1000,
|
|
58
|
+
export_confidence: bool = False,
|
|
59
|
+
model_version: int = 3,
|
|
60
|
+
) -> np.ndarray:
|
|
61
|
+
"""
|
|
62
|
+
Predict cloud mask from a numpy array using OmniCloudMask.
|
|
63
|
+
|
|
64
|
+
This function classifies each pixel into one of four categories:
|
|
65
|
+
- 0: Clear
|
|
66
|
+
- 1: Thick Cloud
|
|
67
|
+
- 2: Thin Cloud
|
|
68
|
+
- 3: Cloud Shadow
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
image (np.ndarray): Input image array with shape (3, height, width) or (height, width, 3).
|
|
72
|
+
Should contain Red, Green, and NIR bands. Values should be in reflectance (0-1)
|
|
73
|
+
or digital numbers (0-10000 typical for Sentinel-2/Landsat).
|
|
74
|
+
batch_size (int): Number of patches to process per inference batch. Defaults to 1.
|
|
75
|
+
inference_device (str): Device for inference ('cpu', 'cuda', or 'mps'). Defaults to 'cpu'.
|
|
76
|
+
inference_dtype (str): Data type for inference ('fp32', 'fp16', or 'bf16').
|
|
77
|
+
'bf16' recommended for speed on compatible hardware. Defaults to 'fp32'.
|
|
78
|
+
patch_size (int): Size of patches for processing large images. Defaults to 1000.
|
|
79
|
+
export_confidence (bool): If True, also returns confidence map. Defaults to False.
|
|
80
|
+
model_version (int): Model version to use (1, 2, or 3). Defaults to 3.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
np.ndarray: Cloud mask array with shape (height, width) containing class predictions.
|
|
84
|
+
If export_confidence=True, returns tuple of (mask, confidence).
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ImportError: If omnicloudmask is not installed.
|
|
88
|
+
ValueError: If image has wrong shape or number of channels.
|
|
89
|
+
|
|
90
|
+
Example:
|
|
91
|
+
>>> import numpy as np
|
|
92
|
+
>>> from geoai.tools.cloudmask import predict_cloud_mask
|
|
93
|
+
>>> # Create synthetic image (3 bands: R, G, NIR)
|
|
94
|
+
>>> image = np.random.rand(3, 512, 512) * 10000
|
|
95
|
+
>>> mask = predict_cloud_mask(image)
|
|
96
|
+
>>> print(f"Clear pixels: {(mask == 0).sum()}")
|
|
97
|
+
"""
|
|
98
|
+
check_omnicloudmask_available()
|
|
99
|
+
|
|
100
|
+
# Ensure image has correct shape (3, H, W)
|
|
101
|
+
if image.ndim != 3:
|
|
102
|
+
raise ValueError(f"Image must be 3D, got shape {image.shape}")
|
|
103
|
+
|
|
104
|
+
# Convert (H, W, 3) to (3, H, W) if needed
|
|
105
|
+
if image.shape[2] == 3 and image.shape[0] != 3:
|
|
106
|
+
image = np.transpose(image, (2, 0, 1))
|
|
107
|
+
|
|
108
|
+
if image.shape[0] != 3:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Image must have 3 channels (R, G, NIR), got {image.shape[0]} channels"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Call OmniCloudMask
|
|
114
|
+
result = predict_from_array(
|
|
115
|
+
image,
|
|
116
|
+
batch_size=batch_size,
|
|
117
|
+
inference_device=inference_device,
|
|
118
|
+
inference_dtype=inference_dtype,
|
|
119
|
+
patch_size=patch_size,
|
|
120
|
+
export_confidence=export_confidence,
|
|
121
|
+
model_version=model_version,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Handle output shape - omnicloudmask returns (1, H, W) or ((1, H, W), (1, H, W))
|
|
125
|
+
if export_confidence:
|
|
126
|
+
mask, confidence = result
|
|
127
|
+
# Squeeze batch dimension
|
|
128
|
+
mask = mask.squeeze(0) if mask.ndim == 3 else mask
|
|
129
|
+
confidence = confidence.squeeze(0) if confidence.ndim == 3 else confidence
|
|
130
|
+
return mask, confidence
|
|
131
|
+
else:
|
|
132
|
+
# Squeeze batch dimension
|
|
133
|
+
return result.squeeze(0) if result.ndim == 3 else result
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def predict_cloud_mask_from_raster(
|
|
137
|
+
input_path: str,
|
|
138
|
+
output_path: str,
|
|
139
|
+
red_band: int = 1,
|
|
140
|
+
green_band: int = 2,
|
|
141
|
+
nir_band: int = 3,
|
|
142
|
+
batch_size: int = 1,
|
|
143
|
+
inference_device: str = "cpu",
|
|
144
|
+
inference_dtype: str = "fp32",
|
|
145
|
+
patch_size: int = 1000,
|
|
146
|
+
export_confidence: bool = False,
|
|
147
|
+
model_version: int = 3,
|
|
148
|
+
) -> None:
|
|
149
|
+
"""
|
|
150
|
+
Predict cloud mask from a GeoTIFF file and save the result.
|
|
151
|
+
|
|
152
|
+
Reads a multi-band raster, extracts RGB+NIR bands, applies OmniCloudMask,
|
|
153
|
+
and saves the result while preserving geospatial metadata.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
input_path (str): Path to input GeoTIFF file.
|
|
157
|
+
output_path (str): Path to save cloud mask GeoTIFF.
|
|
158
|
+
red_band (int): Band index for Red (1-indexed). Defaults to 1.
|
|
159
|
+
green_band (int): Band index for Green (1-indexed). Defaults to 2.
|
|
160
|
+
nir_band (int): Band index for NIR (1-indexed). Defaults to 3.
|
|
161
|
+
batch_size (int): Patches per inference batch. Defaults to 1.
|
|
162
|
+
inference_device (str): Device ('cpu', 'cuda', 'mps'). Defaults to 'cpu'.
|
|
163
|
+
inference_dtype (str): Dtype ('fp32', 'fp16', 'bf16'). Defaults to 'fp32'.
|
|
164
|
+
patch_size (int): Patch size for large images. Defaults to 1000.
|
|
165
|
+
export_confidence (bool): Export confidence map. Defaults to False.
|
|
166
|
+
model_version (str): Model version ('1.0', '2.0', '3.0'). Defaults to '3.0'.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
None: Writes cloud mask to output_path.
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
ImportError: If omnicloudmask or rasterio not installed.
|
|
173
|
+
FileNotFoundError: If input_path doesn't exist.
|
|
174
|
+
|
|
175
|
+
Example:
|
|
176
|
+
>>> from geoai.tools.cloudmask import predict_cloud_mask_from_raster
|
|
177
|
+
>>> predict_cloud_mask_from_raster(
|
|
178
|
+
... "sentinel2_image.tif",
|
|
179
|
+
... "cloud_mask.tif",
|
|
180
|
+
... red_band=4, # Sentinel-2 band order
|
|
181
|
+
... green_band=3,
|
|
182
|
+
... nir_band=8
|
|
183
|
+
... )
|
|
184
|
+
"""
|
|
185
|
+
check_omnicloudmask_available()
|
|
186
|
+
|
|
187
|
+
if not RASTERIO_AVAILABLE:
|
|
188
|
+
raise ImportError(
|
|
189
|
+
"rasterio is required for raster operations. "
|
|
190
|
+
"Please install it with: pip install rasterio"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if not os.path.exists(input_path):
|
|
194
|
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
|
195
|
+
|
|
196
|
+
# Read input raster
|
|
197
|
+
with rasterio.open(input_path) as src:
|
|
198
|
+
# Read required bands
|
|
199
|
+
red = src.read(red_band).astype(np.float32)
|
|
200
|
+
green = src.read(green_band).astype(np.float32)
|
|
201
|
+
nir = src.read(nir_band).astype(np.float32)
|
|
202
|
+
|
|
203
|
+
# Stack into (3, H, W)
|
|
204
|
+
image = np.stack([red, green, nir], axis=0)
|
|
205
|
+
|
|
206
|
+
# Get metadata
|
|
207
|
+
profile = src.profile.copy()
|
|
208
|
+
|
|
209
|
+
# Predict cloud mask
|
|
210
|
+
result = predict_cloud_mask(
|
|
211
|
+
image,
|
|
212
|
+
batch_size=batch_size,
|
|
213
|
+
inference_device=inference_device,
|
|
214
|
+
inference_dtype=inference_dtype,
|
|
215
|
+
patch_size=patch_size,
|
|
216
|
+
export_confidence=export_confidence,
|
|
217
|
+
model_version=model_version,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Handle confidence output
|
|
221
|
+
if export_confidence:
|
|
222
|
+
mask, confidence = result
|
|
223
|
+
else:
|
|
224
|
+
mask = result
|
|
225
|
+
|
|
226
|
+
# Update profile for output
|
|
227
|
+
profile.update(
|
|
228
|
+
dtype=np.uint8,
|
|
229
|
+
count=1,
|
|
230
|
+
compress="lzw",
|
|
231
|
+
nodata=None,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Write cloud mask
|
|
235
|
+
output_dir = os.path.dirname(os.path.abspath(output_path))
|
|
236
|
+
if output_dir and output_dir != os.path.abspath(os.sep):
|
|
237
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
238
|
+
|
|
239
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
|
240
|
+
dst.write(mask.astype(np.uint8), 1)
|
|
241
|
+
|
|
242
|
+
# Optionally write confidence map
|
|
243
|
+
if export_confidence:
|
|
244
|
+
confidence_path = output_path.replace(".tif", "_confidence.tif")
|
|
245
|
+
profile.update(dtype=np.float32)
|
|
246
|
+
with rasterio.open(confidence_path, "w", **profile) as dst:
|
|
247
|
+
dst.write(confidence, 1)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def predict_cloud_mask_batch(
|
|
251
|
+
input_paths: List[str],
|
|
252
|
+
output_dir: str,
|
|
253
|
+
red_band: int = 1,
|
|
254
|
+
green_band: int = 2,
|
|
255
|
+
nir_band: int = 3,
|
|
256
|
+
batch_size: int = 1,
|
|
257
|
+
inference_device: str = "cpu",
|
|
258
|
+
inference_dtype: str = "fp32",
|
|
259
|
+
patch_size: int = 1000,
|
|
260
|
+
export_confidence: bool = False,
|
|
261
|
+
model_version: int = 3,
|
|
262
|
+
suffix: str = "_cloudmask",
|
|
263
|
+
verbose: bool = True,
|
|
264
|
+
) -> List[str]:
|
|
265
|
+
"""
|
|
266
|
+
Predict cloud masks for multiple rasters in batch.
|
|
267
|
+
|
|
268
|
+
Processes multiple GeoTIFF files with the same cloud detection parameters
|
|
269
|
+
and saves results to an output directory.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
input_paths (list of str): Paths to input GeoTIFF files.
|
|
273
|
+
output_dir (str): Directory to save cloud masks.
|
|
274
|
+
red_band (int): Red band index. Defaults to 1.
|
|
275
|
+
green_band (int): Green band index. Defaults to 2.
|
|
276
|
+
nir_band (int): NIR band index. Defaults to 3.
|
|
277
|
+
batch_size (int): Patches per batch. Defaults to 1.
|
|
278
|
+
inference_device (str): Device. Defaults to 'cpu'.
|
|
279
|
+
inference_dtype (str): Dtype. Defaults to 'fp32'.
|
|
280
|
+
patch_size (int): Patch size. Defaults to 1000.
|
|
281
|
+
export_confidence (bool): Export confidence. Defaults to False.
|
|
282
|
+
model_version (str): Model version. Defaults to '3.0'.
|
|
283
|
+
suffix (str): Suffix for output filenames. Defaults to '_cloudmask'.
|
|
284
|
+
verbose (bool): Print progress. Defaults to True.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
list of str: Paths to output cloud mask files.
|
|
288
|
+
|
|
289
|
+
Raises:
|
|
290
|
+
ImportError: If omnicloudmask or rasterio not installed.
|
|
291
|
+
|
|
292
|
+
Example:
|
|
293
|
+
>>> from geoai.tools.cloudmask import predict_cloud_mask_batch
|
|
294
|
+
>>> files = ["scene1.tif", "scene2.tif", "scene3.tif"]
|
|
295
|
+
>>> outputs = predict_cloud_mask_batch(
|
|
296
|
+
... files,
|
|
297
|
+
... output_dir="cloud_masks",
|
|
298
|
+
... inference_device="cuda"
|
|
299
|
+
... )
|
|
300
|
+
"""
|
|
301
|
+
check_omnicloudmask_available()
|
|
302
|
+
|
|
303
|
+
# Create output directory
|
|
304
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
305
|
+
|
|
306
|
+
output_paths = []
|
|
307
|
+
|
|
308
|
+
for i, input_path in enumerate(input_paths):
|
|
309
|
+
if verbose:
|
|
310
|
+
print(f"Processing {i+1}/{len(input_paths)}: {input_path}")
|
|
311
|
+
|
|
312
|
+
# Generate output filename
|
|
313
|
+
basename = os.path.basename(input_path)
|
|
314
|
+
name, ext = os.path.splitext(basename)
|
|
315
|
+
output_filename = f"{name}{suffix}{ext}"
|
|
316
|
+
output_path = os.path.join(output_dir, output_filename)
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
# Predict cloud mask
|
|
320
|
+
predict_cloud_mask_from_raster(
|
|
321
|
+
input_path,
|
|
322
|
+
output_path,
|
|
323
|
+
red_band=red_band,
|
|
324
|
+
green_band=green_band,
|
|
325
|
+
nir_band=nir_band,
|
|
326
|
+
batch_size=batch_size,
|
|
327
|
+
inference_device=inference_device,
|
|
328
|
+
inference_dtype=inference_dtype,
|
|
329
|
+
patch_size=patch_size,
|
|
330
|
+
export_confidence=export_confidence,
|
|
331
|
+
model_version=model_version,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
output_paths.append(output_path)
|
|
335
|
+
|
|
336
|
+
if verbose:
|
|
337
|
+
print(f" ✓ Saved to: {output_path}")
|
|
338
|
+
|
|
339
|
+
except Exception as e:
|
|
340
|
+
if verbose:
|
|
341
|
+
print(f" ✗ Failed: {e}")
|
|
342
|
+
continue
|
|
343
|
+
|
|
344
|
+
return output_paths
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def calculate_cloud_statistics(
|
|
348
|
+
mask: np.ndarray,
|
|
349
|
+
) -> Dict[str, Any]:
|
|
350
|
+
"""
|
|
351
|
+
Calculate statistics from a cloud mask.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
mask (np.ndarray): Cloud mask array with values 0-3.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
dict: Statistics including:
|
|
358
|
+
- total_pixels: Total number of pixels
|
|
359
|
+
- clear_pixels: Number of clear pixels
|
|
360
|
+
- thick_cloud_pixels: Number of thick cloud pixels
|
|
361
|
+
- thin_cloud_pixels: Number of thin cloud pixels
|
|
362
|
+
- shadow_pixels: Number of cloud shadow pixels
|
|
363
|
+
- clear_percent: Percentage of clear pixels
|
|
364
|
+
- cloud_percent: Percentage of cloudy pixels (thick + thin)
|
|
365
|
+
- shadow_percent: Percentage of shadow pixels
|
|
366
|
+
|
|
367
|
+
Example:
|
|
368
|
+
>>> from geoai.tools.cloudmask import calculate_cloud_statistics
|
|
369
|
+
>>> import numpy as np
|
|
370
|
+
>>> mask = np.random.randint(0, 4, (512, 512))
|
|
371
|
+
>>> stats = calculate_cloud_statistics(mask)
|
|
372
|
+
>>> print(f"Clear: {stats['clear_percent']:.1f}%")
|
|
373
|
+
"""
|
|
374
|
+
total_pixels = mask.size
|
|
375
|
+
|
|
376
|
+
clear_pixels = (mask == CLEAR).sum()
|
|
377
|
+
thick_cloud_pixels = (mask == THICK_CLOUD).sum()
|
|
378
|
+
thin_cloud_pixels = (mask == THIN_CLOUD).sum()
|
|
379
|
+
shadow_pixels = (mask == CLOUD_SHADOW).sum()
|
|
380
|
+
|
|
381
|
+
cloud_pixels = thick_cloud_pixels + thin_cloud_pixels
|
|
382
|
+
|
|
383
|
+
return {
|
|
384
|
+
"total_pixels": int(total_pixels),
|
|
385
|
+
"clear_pixels": int(clear_pixels),
|
|
386
|
+
"thick_cloud_pixels": int(thick_cloud_pixels),
|
|
387
|
+
"thin_cloud_pixels": int(thin_cloud_pixels),
|
|
388
|
+
"shadow_pixels": int(shadow_pixels),
|
|
389
|
+
"clear_percent": float(clear_pixels / total_pixels * 100),
|
|
390
|
+
"cloud_percent": float(cloud_pixels / total_pixels * 100),
|
|
391
|
+
"shadow_percent": float(shadow_pixels / total_pixels * 100),
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def create_cloud_free_mask(
|
|
396
|
+
mask: np.ndarray,
|
|
397
|
+
include_thin_clouds: bool = False,
|
|
398
|
+
include_shadows: bool = False,
|
|
399
|
+
) -> np.ndarray:
|
|
400
|
+
"""
|
|
401
|
+
Create a binary mask of cloud-free pixels.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
mask (np.ndarray): Cloud mask with values 0-3.
|
|
405
|
+
include_thin_clouds (bool): If True, treats thin clouds as acceptable.
|
|
406
|
+
Defaults to False.
|
|
407
|
+
include_shadows (bool): If True, treats shadows as acceptable.
|
|
408
|
+
Defaults to False.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
np.ndarray: Binary mask where 1 = usable, 0 = not usable.
|
|
412
|
+
|
|
413
|
+
Example:
|
|
414
|
+
>>> from geoai.tools.cloudmask import create_cloud_free_mask
|
|
415
|
+
>>> import numpy as np
|
|
416
|
+
>>> mask = np.random.randint(0, 4, (512, 512))
|
|
417
|
+
>>> cloud_free = create_cloud_free_mask(mask)
|
|
418
|
+
>>> print(f"Usable pixels: {cloud_free.sum()}")
|
|
419
|
+
"""
|
|
420
|
+
# Start with clear pixels
|
|
421
|
+
usable = mask == CLEAR
|
|
422
|
+
|
|
423
|
+
# Optionally include thin clouds
|
|
424
|
+
if include_thin_clouds:
|
|
425
|
+
usable = usable | (mask == THIN_CLOUD)
|
|
426
|
+
|
|
427
|
+
# Optionally include shadows
|
|
428
|
+
if include_shadows:
|
|
429
|
+
usable = usable | (mask == CLOUD_SHADOW)
|
|
430
|
+
|
|
431
|
+
return usable.astype(np.uint8)
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MultiClean integration utilities for cleaning segmentation results.
|
|
3
|
+
|
|
4
|
+
This module provides functions to use MultiClean (https://github.com/DPIRD-DMA/MultiClean)
|
|
5
|
+
for post-processing segmentation masks and classification rasters. MultiClean performs
|
|
6
|
+
morphological operations to smooth edges, remove noise islands, and fill gaps.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from typing import Optional, List, Union, Tuple
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from multiclean import clean_array
|
|
15
|
+
|
|
16
|
+
MULTICLEAN_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
MULTICLEAN_AVAILABLE = False
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import rasterio
|
|
22
|
+
|
|
23
|
+
RASTERIO_AVAILABLE = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
RASTERIO_AVAILABLE = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def check_multiclean_available():
|
|
29
|
+
"""
|
|
30
|
+
Check if multiclean is installed.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ImportError: If multiclean is not installed.
|
|
34
|
+
"""
|
|
35
|
+
if not MULTICLEAN_AVAILABLE:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"multiclean is not installed. "
|
|
38
|
+
"Please install it with: pip install multiclean "
|
|
39
|
+
"or: pip install geoai-py[extra]"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def clean_segmentation_mask(
|
|
44
|
+
mask: np.ndarray,
|
|
45
|
+
class_values: Optional[Union[int, List[int]]] = None,
|
|
46
|
+
smooth_edge_size: int = 2,
|
|
47
|
+
min_island_size: int = 100,
|
|
48
|
+
connectivity: int = 8,
|
|
49
|
+
max_workers: Optional[int] = None,
|
|
50
|
+
fill_nan: bool = False,
|
|
51
|
+
) -> np.ndarray:
|
|
52
|
+
"""
|
|
53
|
+
Clean a segmentation mask using MultiClean morphological operations.
|
|
54
|
+
|
|
55
|
+
This function applies three cleaning operations:
|
|
56
|
+
1. Edge smoothing - Uses morphological opening to reduce jagged boundaries
|
|
57
|
+
2. Island removal - Eliminates small connected components (noise)
|
|
58
|
+
3. Gap filling - Replaces invalid pixels with nearest valid class
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
mask (np.ndarray): 2D numpy array containing segmentation classes.
|
|
62
|
+
Can be int or float. NaN values are treated as nodata.
|
|
63
|
+
class_values (int, list of int, or None): Target class values to process.
|
|
64
|
+
If None, auto-detects unique values from the mask. Defaults to None.
|
|
65
|
+
smooth_edge_size (int): Kernel width in pixels for edge smoothing.
|
|
66
|
+
Set to 0 to disable smoothing. Defaults to 2.
|
|
67
|
+
min_island_size (int): Minimum area (in pixels) for connected components.
|
|
68
|
+
Components with area strictly less than this are removed. Defaults to 100.
|
|
69
|
+
connectivity (int): Connectivity for component detection. Use 4 or 8.
|
|
70
|
+
8-connectivity considers diagonal neighbors. Defaults to 8.
|
|
71
|
+
max_workers (int, optional): Thread pool size for parallel processing.
|
|
72
|
+
If None, uses default threading. Defaults to None.
|
|
73
|
+
fill_nan (bool): Whether to fill NaN pixels with nearest valid class.
|
|
74
|
+
Defaults to False.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
np.ndarray: Cleaned 2D segmentation mask with same shape as input.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ImportError: If multiclean is not installed.
|
|
81
|
+
ValueError: If mask is not 2D or if connectivity is not 4 or 8.
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
>>> import numpy as np
|
|
85
|
+
>>> from geoai.tools.multiclean import clean_segmentation_mask
|
|
86
|
+
>>> mask = np.random.randint(0, 3, (512, 512))
|
|
87
|
+
>>> cleaned = clean_segmentation_mask(
|
|
88
|
+
... mask,
|
|
89
|
+
... class_values=[0, 1, 2],
|
|
90
|
+
... smooth_edge_size=2,
|
|
91
|
+
... min_island_size=50
|
|
92
|
+
... )
|
|
93
|
+
"""
|
|
94
|
+
check_multiclean_available()
|
|
95
|
+
|
|
96
|
+
if mask.ndim != 2:
|
|
97
|
+
raise ValueError(f"Mask must be 2D, got shape {mask.shape}")
|
|
98
|
+
|
|
99
|
+
if connectivity not in [4, 8]:
|
|
100
|
+
raise ValueError(f"Connectivity must be 4 or 8, got {connectivity}")
|
|
101
|
+
|
|
102
|
+
# Apply MultiClean
|
|
103
|
+
cleaned = clean_array(
|
|
104
|
+
mask,
|
|
105
|
+
class_values=class_values,
|
|
106
|
+
smooth_edge_size=smooth_edge_size,
|
|
107
|
+
min_island_size=min_island_size,
|
|
108
|
+
connectivity=connectivity,
|
|
109
|
+
max_workers=max_workers,
|
|
110
|
+
fill_nan=fill_nan,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return cleaned
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def clean_raster(
|
|
117
|
+
input_path: str,
|
|
118
|
+
output_path: str,
|
|
119
|
+
class_values: Optional[Union[int, List[int]]] = None,
|
|
120
|
+
smooth_edge_size: int = 2,
|
|
121
|
+
min_island_size: int = 100,
|
|
122
|
+
connectivity: int = 8,
|
|
123
|
+
max_workers: Optional[int] = None,
|
|
124
|
+
fill_nan: bool = False,
|
|
125
|
+
band: int = 1,
|
|
126
|
+
nodata: Optional[float] = None,
|
|
127
|
+
) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Clean a classification raster (GeoTIFF) and save the result.
|
|
130
|
+
|
|
131
|
+
Reads a GeoTIFF file, applies MultiClean morphological operations,
|
|
132
|
+
and saves the cleaned result while preserving geospatial metadata
|
|
133
|
+
(CRS, transform, nodata value).
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
input_path (str): Path to input GeoTIFF file.
|
|
137
|
+
output_path (str): Path to save cleaned GeoTIFF file.
|
|
138
|
+
class_values (int, list of int, or None): Target class values to process.
|
|
139
|
+
If None, auto-detects unique values. Defaults to None.
|
|
140
|
+
smooth_edge_size (int): Kernel width in pixels for edge smoothing.
|
|
141
|
+
Defaults to 2.
|
|
142
|
+
min_island_size (int): Minimum area (in pixels) for components.
|
|
143
|
+
Defaults to 100.
|
|
144
|
+
connectivity (int): Connectivity for component detection (4 or 8).
|
|
145
|
+
Defaults to 8.
|
|
146
|
+
max_workers (int, optional): Thread pool size. Defaults to None.
|
|
147
|
+
fill_nan (bool): Whether to fill NaN/nodata pixels. Defaults to False.
|
|
148
|
+
band (int): Band index to read (1-indexed). Defaults to 1.
|
|
149
|
+
nodata (float, optional): Nodata value to use. If None, uses value
|
|
150
|
+
from input file. Defaults to None.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
None: Writes cleaned raster to output_path.
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
ImportError: If multiclean or rasterio is not installed.
|
|
157
|
+
FileNotFoundError: If input_path does not exist.
|
|
158
|
+
|
|
159
|
+
Example:
|
|
160
|
+
>>> from geoai.tools.multiclean import clean_raster
|
|
161
|
+
>>> clean_raster(
|
|
162
|
+
... "segmentation_raw.tif",
|
|
163
|
+
... "segmentation_cleaned.tif",
|
|
164
|
+
... class_values=[0, 1, 2],
|
|
165
|
+
... smooth_edge_size=3,
|
|
166
|
+
... min_island_size=50
|
|
167
|
+
... )
|
|
168
|
+
"""
|
|
169
|
+
check_multiclean_available()
|
|
170
|
+
|
|
171
|
+
if not RASTERIO_AVAILABLE:
|
|
172
|
+
raise ImportError(
|
|
173
|
+
"rasterio is required for raster operations. "
|
|
174
|
+
"Please install it with: pip install rasterio"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if not os.path.exists(input_path):
|
|
178
|
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
|
179
|
+
|
|
180
|
+
# Read input raster
|
|
181
|
+
with rasterio.open(input_path) as src:
|
|
182
|
+
# Read the specified band
|
|
183
|
+
mask = src.read(band)
|
|
184
|
+
|
|
185
|
+
# Get metadata
|
|
186
|
+
profile = src.profile.copy()
|
|
187
|
+
|
|
188
|
+
# Handle nodata
|
|
189
|
+
if nodata is None:
|
|
190
|
+
nodata = src.nodata
|
|
191
|
+
|
|
192
|
+
# Convert nodata to NaN if specified
|
|
193
|
+
if nodata is not None:
|
|
194
|
+
mask = mask.astype(np.float32)
|
|
195
|
+
mask[mask == nodata] = np.nan
|
|
196
|
+
|
|
197
|
+
# Clean the mask
|
|
198
|
+
cleaned = clean_segmentation_mask(
|
|
199
|
+
mask,
|
|
200
|
+
class_values=class_values,
|
|
201
|
+
smooth_edge_size=smooth_edge_size,
|
|
202
|
+
min_island_size=min_island_size,
|
|
203
|
+
connectivity=connectivity,
|
|
204
|
+
max_workers=max_workers,
|
|
205
|
+
fill_nan=fill_nan,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Convert NaN back to nodata if needed
|
|
209
|
+
if nodata is not None:
|
|
210
|
+
# Convert any remaining NaN values back to nodata value
|
|
211
|
+
if np.isnan(cleaned).any():
|
|
212
|
+
cleaned = np.nan_to_num(cleaned, nan=nodata)
|
|
213
|
+
|
|
214
|
+
# Update profile for output
|
|
215
|
+
profile.update(
|
|
216
|
+
dtype=cleaned.dtype,
|
|
217
|
+
count=1,
|
|
218
|
+
compress="lzw",
|
|
219
|
+
nodata=nodata,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Write cleaned raster
|
|
223
|
+
output_dir = os.path.dirname(os.path.abspath(output_path))
|
|
224
|
+
if output_dir and output_dir != os.path.abspath(os.sep):
|
|
225
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
226
|
+
with rasterio.open(output_path, "w", **profile) as dst:
|
|
227
|
+
dst.write(cleaned, 1)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def clean_raster_batch(
|
|
231
|
+
input_paths: List[str],
|
|
232
|
+
output_dir: str,
|
|
233
|
+
class_values: Optional[Union[int, List[int]]] = None,
|
|
234
|
+
smooth_edge_size: int = 2,
|
|
235
|
+
min_island_size: int = 100,
|
|
236
|
+
connectivity: int = 8,
|
|
237
|
+
max_workers: Optional[int] = None,
|
|
238
|
+
fill_nan: bool = False,
|
|
239
|
+
band: int = 1,
|
|
240
|
+
suffix: str = "_cleaned",
|
|
241
|
+
verbose: bool = True,
|
|
242
|
+
) -> List[str]:
|
|
243
|
+
"""
|
|
244
|
+
Clean multiple classification rasters in batch.
|
|
245
|
+
|
|
246
|
+
Processes multiple GeoTIFF files with the same cleaning parameters
|
|
247
|
+
and saves results to an output directory.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
input_paths (list of str): List of paths to input GeoTIFF files.
|
|
251
|
+
output_dir (str): Directory to save cleaned files.
|
|
252
|
+
class_values (int, list of int, or None): Target class values.
|
|
253
|
+
Defaults to None (auto-detect).
|
|
254
|
+
smooth_edge_size (int): Kernel width for edge smoothing. Defaults to 2.
|
|
255
|
+
min_island_size (int): Minimum component area. Defaults to 100.
|
|
256
|
+
connectivity (int): Connectivity (4 or 8). Defaults to 8.
|
|
257
|
+
max_workers (int, optional): Thread pool size. Defaults to None.
|
|
258
|
+
fill_nan (bool): Whether to fill NaN pixels. Defaults to False.
|
|
259
|
+
band (int): Band index to read (1-indexed). Defaults to 1.
|
|
260
|
+
suffix (str): Suffix to add to output filenames. Defaults to "_cleaned".
|
|
261
|
+
verbose (bool): Whether to print progress. Defaults to True.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
list of str: Paths to cleaned output files.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
ImportError: If multiclean or rasterio is not installed.
|
|
268
|
+
|
|
269
|
+
Example:
|
|
270
|
+
>>> from geoai.tools.multiclean import clean_raster_batch
|
|
271
|
+
>>> input_files = ["mask1.tif", "mask2.tif", "mask3.tif"]
|
|
272
|
+
>>> outputs = clean_raster_batch(
|
|
273
|
+
... input_files,
|
|
274
|
+
... output_dir="cleaned_masks",
|
|
275
|
+
... min_island_size=50
|
|
276
|
+
... )
|
|
277
|
+
"""
|
|
278
|
+
check_multiclean_available()
|
|
279
|
+
|
|
280
|
+
# Create output directory
|
|
281
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
282
|
+
|
|
283
|
+
output_paths = []
|
|
284
|
+
|
|
285
|
+
for i, input_path in enumerate(input_paths):
|
|
286
|
+
if verbose:
|
|
287
|
+
print(f"Processing {i+1}/{len(input_paths)}: {input_path}")
|
|
288
|
+
|
|
289
|
+
# Generate output filename
|
|
290
|
+
basename = os.path.basename(input_path)
|
|
291
|
+
name, ext = os.path.splitext(basename)
|
|
292
|
+
output_filename = f"{name}{suffix}{ext}"
|
|
293
|
+
output_path = os.path.join(output_dir, output_filename)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
# Clean the raster
|
|
297
|
+
clean_raster(
|
|
298
|
+
input_path,
|
|
299
|
+
output_path,
|
|
300
|
+
class_values=class_values,
|
|
301
|
+
smooth_edge_size=smooth_edge_size,
|
|
302
|
+
min_island_size=min_island_size,
|
|
303
|
+
connectivity=connectivity,
|
|
304
|
+
max_workers=max_workers,
|
|
305
|
+
fill_nan=fill_nan,
|
|
306
|
+
band=band,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
output_paths.append(output_path)
|
|
310
|
+
|
|
311
|
+
if verbose:
|
|
312
|
+
print(f" ✓ Saved to: {output_path}")
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
if verbose:
|
|
316
|
+
print(f" ✗ Failed: {e}")
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
return output_paths
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def compare_masks(
|
|
323
|
+
original: np.ndarray,
|
|
324
|
+
cleaned: np.ndarray,
|
|
325
|
+
) -> Tuple[int, int, float]:
|
|
326
|
+
"""
|
|
327
|
+
Compare original and cleaned masks to quantify changes.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
original (np.ndarray): Original segmentation mask.
|
|
331
|
+
cleaned (np.ndarray): Cleaned segmentation mask.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
tuple: (pixels_changed, total_pixels, change_percentage)
|
|
335
|
+
- pixels_changed: Number of pixels that changed value
|
|
336
|
+
- total_pixels: Total number of valid pixels
|
|
337
|
+
- change_percentage: Percentage of pixels changed
|
|
338
|
+
|
|
339
|
+
Example:
|
|
340
|
+
>>> import numpy as np
|
|
341
|
+
>>> from geoai.tools.multiclean import compare_masks
|
|
342
|
+
>>> original = np.random.randint(0, 3, (512, 512))
|
|
343
|
+
>>> cleaned = original.copy()
|
|
344
|
+
>>> changed, total, pct = compare_masks(original, cleaned)
|
|
345
|
+
>>> print(f"Changed: {pct:.2f}%")
|
|
346
|
+
"""
|
|
347
|
+
# Handle NaN values
|
|
348
|
+
valid_mask = ~(np.isnan(original) | np.isnan(cleaned))
|
|
349
|
+
|
|
350
|
+
# Count changed pixels
|
|
351
|
+
pixels_changed = np.sum((original != cleaned) & valid_mask)
|
|
352
|
+
total_pixels = np.sum(valid_mask)
|
|
353
|
+
|
|
354
|
+
# Calculate percentage
|
|
355
|
+
change_percentage = (pixels_changed / total_pixels * 100) if total_pixels > 0 else 0
|
|
356
|
+
|
|
357
|
+
return pixels_changed, total_pixels, change_percentage
|
geoai/train.py
CHANGED
|
@@ -1436,8 +1436,12 @@ def instance_segmentation_inference_on_geotiff(
|
|
|
1436
1436
|
# Apply Non-Maximum Suppression to handle overlapping detections
|
|
1437
1437
|
if len(all_detections) > 0:
|
|
1438
1438
|
# Convert to tensors for NMS
|
|
1439
|
-
boxes = torch.tensor(
|
|
1440
|
-
|
|
1439
|
+
boxes = torch.tensor(
|
|
1440
|
+
[det["box"] for det in all_detections], dtype=torch.float32
|
|
1441
|
+
)
|
|
1442
|
+
scores = torch.tensor(
|
|
1443
|
+
[det["score"] for det in all_detections], dtype=torch.float32
|
|
1444
|
+
)
|
|
1441
1445
|
|
|
1442
1446
|
# Apply NMS with IoU threshold
|
|
1443
1447
|
nms_threshold = 0.3 # IoU threshold for NMS
|
|
@@ -1917,6 +1921,96 @@ class SemanticRandomHorizontalFlip:
|
|
|
1917
1921
|
return image, mask
|
|
1918
1922
|
|
|
1919
1923
|
|
|
1924
|
+
class SemanticRandomVerticalFlip:
|
|
1925
|
+
"""Random vertical flip transform for semantic segmentation."""
|
|
1926
|
+
|
|
1927
|
+
def __init__(self, prob: float = 0.5) -> None:
|
|
1928
|
+
self.prob = prob
|
|
1929
|
+
|
|
1930
|
+
def __call__(
|
|
1931
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
1932
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1933
|
+
if random.random() < self.prob:
|
|
1934
|
+
# Flip image and mask along height dimension
|
|
1935
|
+
image = torch.flip(image, dims=[1])
|
|
1936
|
+
mask = torch.flip(mask, dims=[0])
|
|
1937
|
+
return image, mask
|
|
1938
|
+
|
|
1939
|
+
|
|
1940
|
+
class SemanticRandomRotation90:
|
|
1941
|
+
"""Random 90-degree rotation transform for semantic segmentation."""
|
|
1942
|
+
|
|
1943
|
+
def __init__(self, prob: float = 0.5) -> None:
|
|
1944
|
+
self.prob = prob
|
|
1945
|
+
|
|
1946
|
+
def __call__(
|
|
1947
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
1948
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1949
|
+
if random.random() < self.prob:
|
|
1950
|
+
# Randomly rotate by 90, 180, or 270 degrees
|
|
1951
|
+
k = random.randint(1, 3)
|
|
1952
|
+
image = torch.rot90(image, k, dims=[1, 2])
|
|
1953
|
+
mask = torch.rot90(mask, k, dims=[0, 1])
|
|
1954
|
+
return image, mask
|
|
1955
|
+
|
|
1956
|
+
|
|
1957
|
+
class SemanticBrightnessAdjustment:
|
|
1958
|
+
"""Random brightness adjustment transform for semantic segmentation."""
|
|
1959
|
+
|
|
1960
|
+
def __init__(
|
|
1961
|
+
self, brightness_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
|
|
1962
|
+
) -> None:
|
|
1963
|
+
"""
|
|
1964
|
+
Initialize brightness adjustment transform.
|
|
1965
|
+
|
|
1966
|
+
Args:
|
|
1967
|
+
brightness_range: Tuple of (min, max) brightness factors.
|
|
1968
|
+
prob: Probability of applying the transform.
|
|
1969
|
+
"""
|
|
1970
|
+
self.brightness_range = brightness_range
|
|
1971
|
+
self.prob = prob
|
|
1972
|
+
|
|
1973
|
+
def __call__(
|
|
1974
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
1975
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1976
|
+
if random.random() < self.prob:
|
|
1977
|
+
# Apply random brightness adjustment
|
|
1978
|
+
factor = self.brightness_range[0] + random.random() * (
|
|
1979
|
+
self.brightness_range[1] - self.brightness_range[0]
|
|
1980
|
+
)
|
|
1981
|
+
image = torch.clamp(image * factor, 0, 1)
|
|
1982
|
+
return image, mask
|
|
1983
|
+
|
|
1984
|
+
|
|
1985
|
+
class SemanticContrastAdjustment:
|
|
1986
|
+
"""Random contrast adjustment transform for semantic segmentation."""
|
|
1987
|
+
|
|
1988
|
+
def __init__(
|
|
1989
|
+
self, contrast_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
|
|
1990
|
+
) -> None:
|
|
1991
|
+
"""
|
|
1992
|
+
Initialize contrast adjustment transform.
|
|
1993
|
+
|
|
1994
|
+
Args:
|
|
1995
|
+
contrast_range: Tuple of (min, max) contrast factors.
|
|
1996
|
+
prob: Probability of applying the transform.
|
|
1997
|
+
"""
|
|
1998
|
+
self.contrast_range = contrast_range
|
|
1999
|
+
self.prob = prob
|
|
2000
|
+
|
|
2001
|
+
def __call__(
|
|
2002
|
+
self, image: torch.Tensor, mask: torch.Tensor
|
|
2003
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
2004
|
+
if random.random() < self.prob:
|
|
2005
|
+
# Apply random contrast adjustment
|
|
2006
|
+
factor = self.contrast_range[0] + random.random() * (
|
|
2007
|
+
self.contrast_range[1] - self.contrast_range[0]
|
|
2008
|
+
)
|
|
2009
|
+
mean = image.mean(dim=(1, 2), keepdim=True)
|
|
2010
|
+
image = torch.clamp((image - mean) * factor + mean, 0, 1)
|
|
2011
|
+
return image, mask
|
|
2012
|
+
|
|
2013
|
+
|
|
1920
2014
|
def get_semantic_transform(train: bool) -> Any:
|
|
1921
2015
|
"""
|
|
1922
2016
|
Get transforms for semantic segmentation data augmentation.
|
|
@@ -2388,6 +2482,8 @@ def train_segmentation_model(
|
|
|
2388
2482
|
resize_mode: str = "resize",
|
|
2389
2483
|
num_workers: Optional[int] = None,
|
|
2390
2484
|
early_stopping_patience: Optional[int] = None,
|
|
2485
|
+
train_transforms: Optional[Callable] = None,
|
|
2486
|
+
val_transforms: Optional[Callable] = None,
|
|
2391
2487
|
**kwargs: Any,
|
|
2392
2488
|
) -> torch.nn.Module:
|
|
2393
2489
|
"""
|
|
@@ -2440,8 +2536,17 @@ def train_segmentation_model(
|
|
|
2440
2536
|
'resize' - Resize images to target_size (may change aspect ratio)
|
|
2441
2537
|
'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
|
|
2442
2538
|
num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
|
|
2443
|
-
|
|
2444
|
-
|
|
2539
|
+
Both image and mask should be torch.Tensor objects. The image tensor is expected to be in
|
|
2540
|
+
CHW format (channels, height, width), and the mask tensor in HW format (height, width).
|
|
2541
|
+
If None, uses default transforms (horizontal flip with 0.5 probability). Defaults to None.
|
|
2542
|
+
val_transforms (callable, optional): Custom transforms for validation data.
|
|
2543
|
+
Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
|
|
2544
|
+
The image tensor is expected to be in CHW format (channels, height, width), and the mask tensor in HW format (height, width).
|
|
2545
|
+
Both image and mask should be torch.Tensor objects. If None, uses default transforms
|
|
2546
|
+
(horizontal flip with 0.5 probability). Defaults to None.
|
|
2547
|
+
val_transforms (callable, optional): Custom transforms for validation data.
|
|
2548
|
+
Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
|
|
2549
|
+
If None, uses default transforms (no augmentation). Defaults to None.
|
|
2445
2550
|
**kwargs: Additional arguments passed to smp.create_model().
|
|
2446
2551
|
Returns:
|
|
2447
2552
|
None: Model weights are saved to output_dir.
|
|
@@ -2584,10 +2689,22 @@ def train_segmentation_model(
|
|
|
2584
2689
|
print("No resizing needed.")
|
|
2585
2690
|
|
|
2586
2691
|
# Create datasets
|
|
2692
|
+
# Use custom transforms if provided, otherwise use default transforms
|
|
2693
|
+
train_transform = (
|
|
2694
|
+
train_transforms
|
|
2695
|
+
if train_transforms is not None
|
|
2696
|
+
else get_semantic_transform(train=True)
|
|
2697
|
+
)
|
|
2698
|
+
val_transform = (
|
|
2699
|
+
val_transforms
|
|
2700
|
+
if val_transforms is not None
|
|
2701
|
+
else get_semantic_transform(train=False)
|
|
2702
|
+
)
|
|
2703
|
+
|
|
2587
2704
|
train_dataset = SemanticSegmentationDataset(
|
|
2588
2705
|
train_imgs,
|
|
2589
2706
|
train_labels,
|
|
2590
|
-
transforms=
|
|
2707
|
+
transforms=train_transform,
|
|
2591
2708
|
num_channels=num_channels,
|
|
2592
2709
|
target_size=target_size,
|
|
2593
2710
|
resize_mode=resize_mode,
|
|
@@ -2596,7 +2713,7 @@ def train_segmentation_model(
|
|
|
2596
2713
|
val_dataset = SemanticSegmentationDataset(
|
|
2597
2714
|
val_imgs,
|
|
2598
2715
|
val_labels,
|
|
2599
|
-
transforms=
|
|
2716
|
+
transforms=val_transform,
|
|
2600
2717
|
num_channels=num_channels,
|
|
2601
2718
|
target_size=target_size,
|
|
2602
2719
|
resize_mode=resize_mode,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: geoai-py
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.18.0
|
|
4
4
|
Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
|
|
5
5
|
Author-email: Qiusheng Wu <giswqs@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -39,15 +39,18 @@ Requires-Dist: rioxarray
|
|
|
39
39
|
Requires-Dist: scikit-image
|
|
40
40
|
Requires-Dist: scikit-learn
|
|
41
41
|
Requires-Dist: timm
|
|
42
|
+
Requires-Dist: tokenizers>=0.22.1
|
|
42
43
|
Requires-Dist: torch
|
|
43
44
|
Requires-Dist: torchgeo
|
|
44
45
|
Requires-Dist: torchinfo
|
|
45
46
|
Requires-Dist: tqdm
|
|
46
|
-
Requires-Dist: transformers
|
|
47
|
+
Requires-Dist: transformers>=4.57.1
|
|
47
48
|
Provides-Extra: extra
|
|
48
49
|
Requires-Dist: overturemaps; extra == "extra"
|
|
49
50
|
Requires-Dist: torchange; extra == "extra"
|
|
50
51
|
Requires-Dist: lightly-train; extra == "extra"
|
|
52
|
+
Requires-Dist: multiclean; extra == "extra"
|
|
53
|
+
Requires-Dist: omnicloudmask; extra == "extra"
|
|
51
54
|
Provides-Extra: agents
|
|
52
55
|
Requires-Dist: strands-agents; extra == "agents"
|
|
53
56
|
Requires-Dist: strands-agents-tools; extra == "agents"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
geoai/__init__.py,sha256=
|
|
2
|
-
geoai/change_detection.py,sha256=
|
|
1
|
+
geoai/__init__.py,sha256=9wW42cjIMMY_ec5s-DOO04XgQHXVre8CisL72wvJCIU,4620
|
|
2
|
+
geoai/change_detection.py,sha256=pdQofnPRiwoES8vMln2vHghRnpeTdsmqLir74dnqZYU,60389
|
|
3
3
|
geoai/classify.py,sha256=0DcComVR6vKU4qWtH2oHVeXc7ZTcV0mFvdXRtlNmolo,35637
|
|
4
4
|
geoai/detectron2.py,sha256=dOOFM9M9-6PV8q2A4-mnIPrz7yTo-MpEvDiAW34nl0w,14610
|
|
5
5
|
geoai/dinov3.py,sha256=u4Lulihhvs4wTgi84RjRw8jWQpB8omQSl-dVVryNVus,40377
|
|
@@ -11,20 +11,23 @@ geoai/map_widgets.py,sha256=QLmkILsztNaRXRULHKOd7Glb7S0pEWXSK9-P8S5AuzQ,5856
|
|
|
11
11
|
geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
|
|
12
12
|
geoai/segment.py,sha256=yBGTxA-ti8lBpk7WVaBOp6yP23HkaulKJQk88acrmZ0,43788
|
|
13
13
|
geoai/segmentation.py,sha256=7yEzBSKCyHW1dNssoK0rdvhxi2IXsIQIFSga817KdI4,11535
|
|
14
|
-
geoai/timm_segment.py,sha256=
|
|
14
|
+
geoai/timm_segment.py,sha256=GfvWmxT6t1S99-iZOf8PlsCkwodIUyrt0AwO_j6dCjE,38470
|
|
15
15
|
geoai/timm_train.py,sha256=y_Sm9Fwe7bTsHEKdtPee5rGY7s01CbkAZKP1TwUDXlU,20551
|
|
16
|
-
geoai/train.py,sha256=
|
|
16
|
+
geoai/train.py,sha256=Ef-lCCQvaMWl3wvhi-IYYi9sdR4YBqMt9QkfiRAUlkQ,174762
|
|
17
17
|
geoai/utils.py,sha256=AUdVj1tt864UFxJtsatpUmXRV9-Lw4f4tbdyjqj0c3c,360240
|
|
18
18
|
geoai/agents/__init__.py,sha256=5xtb_dGpI26nPFcAm8Dj7O4bLskqr1xTw2BRQqbgH4w,285
|
|
19
19
|
geoai/agents/catalog_models.py,sha256=19E-PiE7FvpGEiOi4gDMKPf257FOhLseuVGWJbOjrDs,2089
|
|
20
20
|
geoai/agents/catalog_tools.py,sha256=psVw7-di65hhnJUFqWXFoOkbGaG2_sHrQhA5vdXp3x4,33597
|
|
21
|
-
geoai/agents/geo_agents.py,sha256=
|
|
21
|
+
geoai/agents/geo_agents.py,sha256=2O25JTNN2qr9YXhjHJfXFMByA7B3yyhGGzWyzroIDYg,59302
|
|
22
22
|
geoai/agents/map_tools.py,sha256=OK5uB0VUHjjUnc-DYRy2CQ__kyUIARSCPBucGabO0Xw,60669
|
|
23
23
|
geoai/agents/stac_models.py,sha256=N2kv7HHdAKT8lgYheEd98QK4l0UcbgpNTLOWW24ayBs,2573
|
|
24
24
|
geoai/agents/stac_tools.py,sha256=ILUg2xFRXVZ9WHOfPeJBvPSFT7lRsPLnGMZhnpDZ1co,16107
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
geoai_py-0.
|
|
29
|
-
geoai_py-0.
|
|
30
|
-
geoai_py-0.
|
|
25
|
+
geoai/tools/__init__.py,sha256=McC49tQjxrTha1TS69IeM3rRvqVQP3H1NdAZPZPpKEI,1683
|
|
26
|
+
geoai/tools/cloudmask.py,sha256=qzvqVa8FAEgd8mePXBaV5Ptx4fHhwfS1BsYL0JAZBjM,14500
|
|
27
|
+
geoai/tools/multiclean.py,sha256=TVwmWgeQyGIyUuCe10b6pGCtgIl8TkZmcgVXPimn9uM,11949
|
|
28
|
+
geoai_py-0.18.0.dist-info/licenses/LICENSE,sha256=TlBm8mRusRVB9yF2NTg-STcb71v69-XZaKaPdshqP2I,1074
|
|
29
|
+
geoai_py-0.18.0.dist-info/METADATA,sha256=iqGhGU99OgqXBXfMF5r4GfyrByGQFFzx3PHBOAxxYG0,11255
|
|
30
|
+
geoai_py-0.18.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
|
31
|
+
geoai_py-0.18.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
|
32
|
+
geoai_py-0.18.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
|
33
|
+
geoai_py-0.18.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|