geoai-py 0.17.0__tar.gz → 0.18.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {geoai_py-0.17.0 → geoai_py-0.18.1}/.gitignore +2 -0
  2. {geoai_py-0.17.0 → geoai_py-0.18.1}/PKG-INFO +6 -3
  3. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/__init__.py +16 -1
  4. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/agents/geo_agents.py +11 -3
  5. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/change_detection.py +16 -1
  6. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/timm_segment.py +4 -1
  7. geoai_py-0.18.1/geoai/tools/__init__.py +65 -0
  8. geoai_py-0.18.1/geoai/tools/cloudmask.py +431 -0
  9. geoai_py-0.18.1/geoai/tools/multiclean.py +357 -0
  10. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/train.py +123 -6
  11. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/utils.py +59 -45
  12. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai_py.egg-info/PKG-INFO +6 -3
  13. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai_py.egg-info/SOURCES.txt +3 -0
  14. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai_py.egg-info/requires.txt +5 -2
  15. {geoai_py-0.17.0 → geoai_py-0.18.1}/mkdocs.yml +3 -0
  16. {geoai_py-0.17.0 → geoai_py-0.18.1}/pyproject.toml +3 -3
  17. {geoai_py-0.17.0 → geoai_py-0.18.1}/requirements.txt +3 -2
  18. {geoai_py-0.17.0 → geoai_py-0.18.1}/.dockerignore +0 -0
  19. {geoai_py-0.17.0 → geoai_py-0.18.1}/.editorconfig +0 -0
  20. {geoai_py-0.17.0 → geoai_py-0.18.1}/.pre-commit-config.yaml +0 -0
  21. {geoai_py-0.17.0 → geoai_py-0.18.1}/CITATION.cff +0 -0
  22. {geoai_py-0.17.0 → geoai_py-0.18.1}/Dockerfile +0 -0
  23. {geoai_py-0.17.0 → geoai_py-0.18.1}/LICENSE +0 -0
  24. {geoai_py-0.17.0 → geoai_py-0.18.1}/MANIFEST.in +0 -0
  25. {geoai_py-0.17.0 → geoai_py-0.18.1}/README.md +0 -0
  26. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/agents/__init__.py +0 -0
  27. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/agents/catalog_models.py +0 -0
  28. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/agents/catalog_tools.py +0 -0
  29. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/agents/map_tools.py +0 -0
  30. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/agents/stac_models.py +0 -0
  31. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/agents/stac_tools.py +0 -0
  32. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/classify.py +0 -0
  33. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/detectron2.py +0 -0
  34. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/dinov3.py +0 -0
  35. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/download.py +0 -0
  36. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/extract.py +0 -0
  37. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/geoai.py +0 -0
  38. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/hf.py +0 -0
  39. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/map_widgets.py +0 -0
  40. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/sam.py +0 -0
  41. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/segment.py +0 -0
  42. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/segmentation.py +0 -0
  43. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai/timm_train.py +0 -0
  44. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai_py.egg-info/dependency_links.txt +0 -0
  45. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai_py.egg-info/entry_points.txt +0 -0
  46. {geoai_py-0.17.0 → geoai_py-0.18.1}/geoai_py.egg-info/top_level.txt +0 -0
  47. {geoai_py-0.17.0 → geoai_py-0.18.1}/pytest.ini +0 -0
  48. {geoai_py-0.17.0 → geoai_py-0.18.1}/requirements_docs.txt +0 -0
  49. {geoai_py-0.17.0 → geoai_py-0.18.1}/setup.cfg +0 -0
  50. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/__init__.py +0 -0
  51. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/create_test_data.py +0 -0
  52. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/test_classify.py +0 -0
  53. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/test_download.py +0 -0
  54. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/test_extract.py +0 -0
  55. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/test_fixtures.py +0 -0
  56. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/test_geoai.py +0 -0
  57. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/test_segment.py +0 -0
  58. {geoai_py-0.17.0 → geoai_py-0.18.1}/tests/test_utils.py +0 -0
@@ -18,6 +18,8 @@ docs/examples/timm_buildings/
18
18
  **/*.zip
19
19
  **/*.las
20
20
  *.geojson
21
+ *.gpkg
22
+ *.csv
21
23
  docs/examples/*.md
22
24
  *.xml
23
25
  docs/examples/output/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: geoai-py
3
- Version: 0.17.0
3
+ Version: 0.18.1
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -24,7 +24,7 @@ Requires-Dist: ever-beta
24
24
  Requires-Dist: geopandas
25
25
  Requires-Dist: huggingface_hub
26
26
  Requires-Dist: jupyter-server-proxy
27
- Requires-Dist: leafmap
27
+ Requires-Dist: leafmap>=0.57.1
28
28
  Requires-Dist: localtileserver
29
29
  Requires-Dist: mapclassify
30
30
  Requires-Dist: maplibre
@@ -39,15 +39,18 @@ Requires-Dist: rioxarray
39
39
  Requires-Dist: scikit-image
40
40
  Requires-Dist: scikit-learn
41
41
  Requires-Dist: timm
42
+ Requires-Dist: tokenizers>=0.22.1
42
43
  Requires-Dist: torch
43
44
  Requires-Dist: torchgeo
44
45
  Requires-Dist: torchinfo
45
46
  Requires-Dist: tqdm
46
- Requires-Dist: transformers
47
+ Requires-Dist: transformers>=4.57.1
47
48
  Provides-Extra: extra
48
49
  Requires-Dist: overturemaps; extra == "extra"
49
50
  Requires-Dist: torchange; extra == "extra"
50
51
  Requires-Dist: lightly-train; extra == "extra"
52
+ Requires-Dist: multiclean; extra == "extra"
53
+ Requires-Dist: omnicloudmask; extra == "extra"
51
54
  Provides-Extra: agents
52
55
  Requires-Dist: strands-agents; extra == "agents"
53
56
  Requires-Dist: strands-agents-tools; extra == "agents"
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.17.0"
5
+ __version__ = "0.18.1"
6
6
 
7
7
 
8
8
  import os
@@ -121,3 +121,18 @@ from .timm_segment import (
121
121
  timm_semantic_segmentation,
122
122
  push_timm_model_to_hub,
123
123
  )
124
+
125
+ # Import tools subpackage
126
+ from . import tools
127
+
128
+ # Expose commonly used tools at package level for convenience
129
+ try:
130
+ from .tools import (
131
+ clean_segmentation_mask,
132
+ clean_raster,
133
+ clean_raster_batch,
134
+ compare_masks,
135
+ )
136
+ except ImportError:
137
+ # MultiClean not available (missing dependency)
138
+ pass
@@ -55,7 +55,9 @@ class UICallbackHandler:
55
55
  # Make tool names more user-friendly
56
56
  friendly_name = tool_name.replace("_", " ").title()
57
57
  self.status_widget.value = (
58
- f"<span style='color:#0a7'>⚙️ {friendly_name}...</span>"
58
+ f"<span style='color:#0a7'>"
59
+ f"<i class='fas fa-spinner fa-spin' style='font-size:1.2em'></i> "
60
+ f"{friendly_name}...</span>"
59
61
  )
60
62
 
61
63
 
@@ -396,7 +398,10 @@ class GeoAgent(Agent):
396
398
  btn_clear = widgets.Button(
397
399
  description="Clear", icon="trash", layout=widgets.Layout(width="120px")
398
400
  )
399
- status = widgets.HTML("<span style='color:#666'>Ready.</span>")
401
+ status = widgets.HTML(
402
+ "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css'>"
403
+ "<span style='color:#666'>Ready.</span>"
404
+ )
400
405
 
401
406
  examples = widgets.Dropdown(
402
407
  options=[
@@ -1017,7 +1022,10 @@ CRITICAL: Return ONLY JSON. NO explanatory text, NO made-up data."""
1017
1022
  btn_clear = widgets.Button(
1018
1023
  description="Clear", icon="trash", layout=widgets.Layout(width="120px")
1019
1024
  )
1020
- status = widgets.HTML("<span style='color:#666'>Ready to search.</span>")
1025
+ status = widgets.HTML(
1026
+ "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css'>"
1027
+ "<span style='color:#666'>Ready to search.</span>"
1028
+ )
1021
1029
 
1022
1030
  examples = widgets.Dropdown(
1023
1031
  options=[
@@ -13,7 +13,8 @@ from skimage.transform import resize
13
13
  try:
14
14
  from torchange.models.segment_any_change import AnyChange, show_change_masks
15
15
  except ImportError:
16
- print("torchange requires Python 3.11 or higher")
16
+ AnyChange = None
17
+ show_change_masks = None
17
18
 
18
19
  from .utils import download_file
19
20
 
@@ -36,6 +37,13 @@ class ChangeDetection:
36
37
 
37
38
  def _init_model(self):
38
39
  """Initialize the AnyChange model."""
40
+ if AnyChange is None:
41
+ raise ImportError(
42
+ "The 'torchange' package is required for change detection. "
43
+ "Please install it using: pip install torchange\n"
44
+ "Note: torchange requires Python 3.11 or higher."
45
+ )
46
+
39
47
  if self.sam_checkpoint is None:
40
48
  self.sam_checkpoint = download_checkpoint(self.sam_model_type)
41
49
 
@@ -551,6 +559,13 @@ class ChangeDetection:
551
559
  Returns:
552
560
  matplotlib.figure.Figure: The figure object
553
561
  """
562
+ if show_change_masks is None:
563
+ raise ImportError(
564
+ "The 'torchange' package is required for change detection visualization. "
565
+ "Please install it using: pip install torchange\n"
566
+ "Note: torchange requires Python 3.11 or higher."
567
+ )
568
+
554
569
  change_masks, img1, img2 = self.detect_changes(
555
570
  image1_path, image2_path, return_results=True
556
571
  )
@@ -241,7 +241,10 @@ class TimmSegmentationModel(pl.LightningModule):
241
241
  )
242
242
 
243
243
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
244
- optimizer, mode="min", factor=0.5, patience=5, verbose=True
244
+ optimizer,
245
+ mode="min",
246
+ factor=0.5,
247
+ patience=5,
245
248
  )
246
249
 
247
250
  return {
@@ -0,0 +1,65 @@
1
+ """
2
+ GeoAI Tools - Utility functions and integrations for geospatial AI workflows.
3
+
4
+ This subpackage contains various tools and integrations for enhancing
5
+ geospatial AI workflows, including post-processing utilities and
6
+ third-party library integrations.
7
+ """
8
+
9
+ __all__ = []
10
+
11
+ # MultiClean integration (optional dependency)
12
+ try:
13
+ from .multiclean import (
14
+ clean_segmentation_mask,
15
+ clean_raster,
16
+ clean_raster_batch,
17
+ compare_masks,
18
+ check_multiclean_available,
19
+ )
20
+
21
+ __all__.extend(
22
+ [
23
+ "clean_segmentation_mask",
24
+ "clean_raster",
25
+ "clean_raster_batch",
26
+ "compare_masks",
27
+ "check_multiclean_available",
28
+ ]
29
+ )
30
+ except ImportError:
31
+ # MultiClean not installed - functions will not be available
32
+ pass
33
+
34
+ # OmniCloudMask integration (optional dependency)
35
+ try:
36
+ from .cloudmask import (
37
+ predict_cloud_mask,
38
+ predict_cloud_mask_from_raster,
39
+ predict_cloud_mask_batch,
40
+ calculate_cloud_statistics,
41
+ create_cloud_free_mask,
42
+ check_omnicloudmask_available,
43
+ CLEAR,
44
+ THICK_CLOUD,
45
+ THIN_CLOUD,
46
+ CLOUD_SHADOW,
47
+ )
48
+
49
+ __all__.extend(
50
+ [
51
+ "predict_cloud_mask",
52
+ "predict_cloud_mask_from_raster",
53
+ "predict_cloud_mask_batch",
54
+ "calculate_cloud_statistics",
55
+ "create_cloud_free_mask",
56
+ "check_omnicloudmask_available",
57
+ "CLEAR",
58
+ "THICK_CLOUD",
59
+ "THIN_CLOUD",
60
+ "CLOUD_SHADOW",
61
+ ]
62
+ )
63
+ except ImportError:
64
+ # OmniCloudMask not installed - functions will not be available
65
+ pass
@@ -0,0 +1,431 @@
1
+ """
2
+ OmniCloudMask integration for cloud and cloud shadow detection in satellite imagery.
3
+
4
+ This module provides functions to use OmniCloudMask (https://github.com/DPIRD-DMA/OmniCloudMask)
5
+ for detecting clouds and cloud shadows in satellite imagery. OmniCloudMask performs semantic
6
+ segmentation to classify pixels into: Clear (0), Thick Cloud (1), Thin Cloud (2), Cloud Shadow (3).
7
+
8
+ Supports Sentinel-2, Landsat 8, PlanetScope, and Maxar imagery at 10-50m resolution.
9
+ """
10
+
11
+ import os
12
+ from typing import Optional, List, Tuple, Dict, Any
13
+ import numpy as np
14
+
15
+ try:
16
+ from omnicloudmask import predict_from_array
17
+
18
+ OMNICLOUDMASK_AVAILABLE = True
19
+ except ImportError:
20
+ OMNICLOUDMASK_AVAILABLE = False
21
+
22
+ try:
23
+ import rasterio
24
+
25
+ RASTERIO_AVAILABLE = True
26
+ except ImportError:
27
+ RASTERIO_AVAILABLE = False
28
+
29
+
30
+ # Cloud mask class values
31
+ CLEAR = 0
32
+ THICK_CLOUD = 1
33
+ THIN_CLOUD = 2
34
+ CLOUD_SHADOW = 3
35
+
36
+
37
+ def check_omnicloudmask_available():
38
+ """
39
+ Check if omnicloudmask is installed.
40
+
41
+ Raises:
42
+ ImportError: If omnicloudmask is not installed.
43
+ """
44
+ if not OMNICLOUDMASK_AVAILABLE:
45
+ raise ImportError(
46
+ "omnicloudmask is not installed. "
47
+ "Please install it with: pip install omnicloudmask "
48
+ "or: pip install geoai-py[extra]"
49
+ )
50
+
51
+
52
+ def predict_cloud_mask(
53
+ image: np.ndarray,
54
+ batch_size: int = 1,
55
+ inference_device: str = "cpu",
56
+ inference_dtype: str = "fp32",
57
+ patch_size: int = 1000,
58
+ export_confidence: bool = False,
59
+ model_version: int = 3,
60
+ ) -> np.ndarray:
61
+ """
62
+ Predict cloud mask from a numpy array using OmniCloudMask.
63
+
64
+ This function classifies each pixel into one of four categories:
65
+ - 0: Clear
66
+ - 1: Thick Cloud
67
+ - 2: Thin Cloud
68
+ - 3: Cloud Shadow
69
+
70
+ Args:
71
+ image (np.ndarray): Input image array with shape (3, height, width) or (height, width, 3).
72
+ Should contain Red, Green, and NIR bands. Values should be in reflectance (0-1)
73
+ or digital numbers (0-10000 typical for Sentinel-2/Landsat).
74
+ batch_size (int): Number of patches to process per inference batch. Defaults to 1.
75
+ inference_device (str): Device for inference ('cpu', 'cuda', or 'mps'). Defaults to 'cpu'.
76
+ inference_dtype (str): Data type for inference ('fp32', 'fp16', or 'bf16').
77
+ 'bf16' recommended for speed on compatible hardware. Defaults to 'fp32'.
78
+ patch_size (int): Size of patches for processing large images. Defaults to 1000.
79
+ export_confidence (bool): If True, also returns confidence map. Defaults to False.
80
+ model_version (int): Model version to use (1, 2, or 3). Defaults to 3.
81
+
82
+ Returns:
83
+ np.ndarray: Cloud mask array with shape (height, width) containing class predictions.
84
+ If export_confidence=True, returns tuple of (mask, confidence).
85
+
86
+ Raises:
87
+ ImportError: If omnicloudmask is not installed.
88
+ ValueError: If image has wrong shape or number of channels.
89
+
90
+ Example:
91
+ >>> import numpy as np
92
+ >>> from geoai.tools.cloudmask import predict_cloud_mask
93
+ >>> # Create synthetic image (3 bands: R, G, NIR)
94
+ >>> image = np.random.rand(3, 512, 512) * 10000
95
+ >>> mask = predict_cloud_mask(image)
96
+ >>> print(f"Clear pixels: {(mask == 0).sum()}")
97
+ """
98
+ check_omnicloudmask_available()
99
+
100
+ # Ensure image has correct shape (3, H, W)
101
+ if image.ndim != 3:
102
+ raise ValueError(f"Image must be 3D, got shape {image.shape}")
103
+
104
+ # Convert (H, W, 3) to (3, H, W) if needed
105
+ if image.shape[2] == 3 and image.shape[0] != 3:
106
+ image = np.transpose(image, (2, 0, 1))
107
+
108
+ if image.shape[0] != 3:
109
+ raise ValueError(
110
+ f"Image must have 3 channels (R, G, NIR), got {image.shape[0]} channels"
111
+ )
112
+
113
+ # Call OmniCloudMask
114
+ result = predict_from_array(
115
+ image,
116
+ batch_size=batch_size,
117
+ inference_device=inference_device,
118
+ inference_dtype=inference_dtype,
119
+ patch_size=patch_size,
120
+ export_confidence=export_confidence,
121
+ model_version=model_version,
122
+ )
123
+
124
+ # Handle output shape - omnicloudmask returns (1, H, W) or ((1, H, W), (1, H, W))
125
+ if export_confidence:
126
+ mask, confidence = result
127
+ # Squeeze batch dimension
128
+ mask = mask.squeeze(0) if mask.ndim == 3 else mask
129
+ confidence = confidence.squeeze(0) if confidence.ndim == 3 else confidence
130
+ return mask, confidence
131
+ else:
132
+ # Squeeze batch dimension
133
+ return result.squeeze(0) if result.ndim == 3 else result
134
+
135
+
136
+ def predict_cloud_mask_from_raster(
137
+ input_path: str,
138
+ output_path: str,
139
+ red_band: int = 1,
140
+ green_band: int = 2,
141
+ nir_band: int = 3,
142
+ batch_size: int = 1,
143
+ inference_device: str = "cpu",
144
+ inference_dtype: str = "fp32",
145
+ patch_size: int = 1000,
146
+ export_confidence: bool = False,
147
+ model_version: int = 3,
148
+ ) -> None:
149
+ """
150
+ Predict cloud mask from a GeoTIFF file and save the result.
151
+
152
+ Reads a multi-band raster, extracts RGB+NIR bands, applies OmniCloudMask,
153
+ and saves the result while preserving geospatial metadata.
154
+
155
+ Args:
156
+ input_path (str): Path to input GeoTIFF file.
157
+ output_path (str): Path to save cloud mask GeoTIFF.
158
+ red_band (int): Band index for Red (1-indexed). Defaults to 1.
159
+ green_band (int): Band index for Green (1-indexed). Defaults to 2.
160
+ nir_band (int): Band index for NIR (1-indexed). Defaults to 3.
161
+ batch_size (int): Patches per inference batch. Defaults to 1.
162
+ inference_device (str): Device ('cpu', 'cuda', 'mps'). Defaults to 'cpu'.
163
+ inference_dtype (str): Dtype ('fp32', 'fp16', 'bf16'). Defaults to 'fp32'.
164
+ patch_size (int): Patch size for large images. Defaults to 1000.
165
+ export_confidence (bool): Export confidence map. Defaults to False.
166
+ model_version (str): Model version ('1.0', '2.0', '3.0'). Defaults to '3.0'.
167
+
168
+ Returns:
169
+ None: Writes cloud mask to output_path.
170
+
171
+ Raises:
172
+ ImportError: If omnicloudmask or rasterio not installed.
173
+ FileNotFoundError: If input_path doesn't exist.
174
+
175
+ Example:
176
+ >>> from geoai.tools.cloudmask import predict_cloud_mask_from_raster
177
+ >>> predict_cloud_mask_from_raster(
178
+ ... "sentinel2_image.tif",
179
+ ... "cloud_mask.tif",
180
+ ... red_band=4, # Sentinel-2 band order
181
+ ... green_band=3,
182
+ ... nir_band=8
183
+ ... )
184
+ """
185
+ check_omnicloudmask_available()
186
+
187
+ if not RASTERIO_AVAILABLE:
188
+ raise ImportError(
189
+ "rasterio is required for raster operations. "
190
+ "Please install it with: pip install rasterio"
191
+ )
192
+
193
+ if not os.path.exists(input_path):
194
+ raise FileNotFoundError(f"Input file not found: {input_path}")
195
+
196
+ # Read input raster
197
+ with rasterio.open(input_path) as src:
198
+ # Read required bands
199
+ red = src.read(red_band).astype(np.float32)
200
+ green = src.read(green_band).astype(np.float32)
201
+ nir = src.read(nir_band).astype(np.float32)
202
+
203
+ # Stack into (3, H, W)
204
+ image = np.stack([red, green, nir], axis=0)
205
+
206
+ # Get metadata
207
+ profile = src.profile.copy()
208
+
209
+ # Predict cloud mask
210
+ result = predict_cloud_mask(
211
+ image,
212
+ batch_size=batch_size,
213
+ inference_device=inference_device,
214
+ inference_dtype=inference_dtype,
215
+ patch_size=patch_size,
216
+ export_confidence=export_confidence,
217
+ model_version=model_version,
218
+ )
219
+
220
+ # Handle confidence output
221
+ if export_confidence:
222
+ mask, confidence = result
223
+ else:
224
+ mask = result
225
+
226
+ # Update profile for output
227
+ profile.update(
228
+ dtype=np.uint8,
229
+ count=1,
230
+ compress="lzw",
231
+ nodata=None,
232
+ )
233
+
234
+ # Write cloud mask
235
+ output_dir = os.path.dirname(os.path.abspath(output_path))
236
+ if output_dir and output_dir != os.path.abspath(os.sep):
237
+ os.makedirs(output_dir, exist_ok=True)
238
+
239
+ with rasterio.open(output_path, "w", **profile) as dst:
240
+ dst.write(mask.astype(np.uint8), 1)
241
+
242
+ # Optionally write confidence map
243
+ if export_confidence:
244
+ confidence_path = output_path.replace(".tif", "_confidence.tif")
245
+ profile.update(dtype=np.float32)
246
+ with rasterio.open(confidence_path, "w", **profile) as dst:
247
+ dst.write(confidence, 1)
248
+
249
+
250
+ def predict_cloud_mask_batch(
251
+ input_paths: List[str],
252
+ output_dir: str,
253
+ red_band: int = 1,
254
+ green_band: int = 2,
255
+ nir_band: int = 3,
256
+ batch_size: int = 1,
257
+ inference_device: str = "cpu",
258
+ inference_dtype: str = "fp32",
259
+ patch_size: int = 1000,
260
+ export_confidence: bool = False,
261
+ model_version: int = 3,
262
+ suffix: str = "_cloudmask",
263
+ verbose: bool = True,
264
+ ) -> List[str]:
265
+ """
266
+ Predict cloud masks for multiple rasters in batch.
267
+
268
+ Processes multiple GeoTIFF files with the same cloud detection parameters
269
+ and saves results to an output directory.
270
+
271
+ Args:
272
+ input_paths (list of str): Paths to input GeoTIFF files.
273
+ output_dir (str): Directory to save cloud masks.
274
+ red_band (int): Red band index. Defaults to 1.
275
+ green_band (int): Green band index. Defaults to 2.
276
+ nir_band (int): NIR band index. Defaults to 3.
277
+ batch_size (int): Patches per batch. Defaults to 1.
278
+ inference_device (str): Device. Defaults to 'cpu'.
279
+ inference_dtype (str): Dtype. Defaults to 'fp32'.
280
+ patch_size (int): Patch size. Defaults to 1000.
281
+ export_confidence (bool): Export confidence. Defaults to False.
282
+ model_version (str): Model version. Defaults to '3.0'.
283
+ suffix (str): Suffix for output filenames. Defaults to '_cloudmask'.
284
+ verbose (bool): Print progress. Defaults to True.
285
+
286
+ Returns:
287
+ list of str: Paths to output cloud mask files.
288
+
289
+ Raises:
290
+ ImportError: If omnicloudmask or rasterio not installed.
291
+
292
+ Example:
293
+ >>> from geoai.tools.cloudmask import predict_cloud_mask_batch
294
+ >>> files = ["scene1.tif", "scene2.tif", "scene3.tif"]
295
+ >>> outputs = predict_cloud_mask_batch(
296
+ ... files,
297
+ ... output_dir="cloud_masks",
298
+ ... inference_device="cuda"
299
+ ... )
300
+ """
301
+ check_omnicloudmask_available()
302
+
303
+ # Create output directory
304
+ os.makedirs(output_dir, exist_ok=True)
305
+
306
+ output_paths = []
307
+
308
+ for i, input_path in enumerate(input_paths):
309
+ if verbose:
310
+ print(f"Processing {i+1}/{len(input_paths)}: {input_path}")
311
+
312
+ # Generate output filename
313
+ basename = os.path.basename(input_path)
314
+ name, ext = os.path.splitext(basename)
315
+ output_filename = f"{name}{suffix}{ext}"
316
+ output_path = os.path.join(output_dir, output_filename)
317
+
318
+ try:
319
+ # Predict cloud mask
320
+ predict_cloud_mask_from_raster(
321
+ input_path,
322
+ output_path,
323
+ red_band=red_band,
324
+ green_band=green_band,
325
+ nir_band=nir_band,
326
+ batch_size=batch_size,
327
+ inference_device=inference_device,
328
+ inference_dtype=inference_dtype,
329
+ patch_size=patch_size,
330
+ export_confidence=export_confidence,
331
+ model_version=model_version,
332
+ )
333
+
334
+ output_paths.append(output_path)
335
+
336
+ if verbose:
337
+ print(f" ✓ Saved to: {output_path}")
338
+
339
+ except Exception as e:
340
+ if verbose:
341
+ print(f" ✗ Failed: {e}")
342
+ continue
343
+
344
+ return output_paths
345
+
346
+
347
+ def calculate_cloud_statistics(
348
+ mask: np.ndarray,
349
+ ) -> Dict[str, Any]:
350
+ """
351
+ Calculate statistics from a cloud mask.
352
+
353
+ Args:
354
+ mask (np.ndarray): Cloud mask array with values 0-3.
355
+
356
+ Returns:
357
+ dict: Statistics including:
358
+ - total_pixels: Total number of pixels
359
+ - clear_pixels: Number of clear pixels
360
+ - thick_cloud_pixels: Number of thick cloud pixels
361
+ - thin_cloud_pixels: Number of thin cloud pixels
362
+ - shadow_pixels: Number of cloud shadow pixels
363
+ - clear_percent: Percentage of clear pixels
364
+ - cloud_percent: Percentage of cloudy pixels (thick + thin)
365
+ - shadow_percent: Percentage of shadow pixels
366
+
367
+ Example:
368
+ >>> from geoai.tools.cloudmask import calculate_cloud_statistics
369
+ >>> import numpy as np
370
+ >>> mask = np.random.randint(0, 4, (512, 512))
371
+ >>> stats = calculate_cloud_statistics(mask)
372
+ >>> print(f"Clear: {stats['clear_percent']:.1f}%")
373
+ """
374
+ total_pixels = mask.size
375
+
376
+ clear_pixels = (mask == CLEAR).sum()
377
+ thick_cloud_pixels = (mask == THICK_CLOUD).sum()
378
+ thin_cloud_pixels = (mask == THIN_CLOUD).sum()
379
+ shadow_pixels = (mask == CLOUD_SHADOW).sum()
380
+
381
+ cloud_pixels = thick_cloud_pixels + thin_cloud_pixels
382
+
383
+ return {
384
+ "total_pixels": int(total_pixels),
385
+ "clear_pixels": int(clear_pixels),
386
+ "thick_cloud_pixels": int(thick_cloud_pixels),
387
+ "thin_cloud_pixels": int(thin_cloud_pixels),
388
+ "shadow_pixels": int(shadow_pixels),
389
+ "clear_percent": float(clear_pixels / total_pixels * 100),
390
+ "cloud_percent": float(cloud_pixels / total_pixels * 100),
391
+ "shadow_percent": float(shadow_pixels / total_pixels * 100),
392
+ }
393
+
394
+
395
+ def create_cloud_free_mask(
396
+ mask: np.ndarray,
397
+ include_thin_clouds: bool = False,
398
+ include_shadows: bool = False,
399
+ ) -> np.ndarray:
400
+ """
401
+ Create a binary mask of cloud-free pixels.
402
+
403
+ Args:
404
+ mask (np.ndarray): Cloud mask with values 0-3.
405
+ include_thin_clouds (bool): If True, treats thin clouds as acceptable.
406
+ Defaults to False.
407
+ include_shadows (bool): If True, treats shadows as acceptable.
408
+ Defaults to False.
409
+
410
+ Returns:
411
+ np.ndarray: Binary mask where 1 = usable, 0 = not usable.
412
+
413
+ Example:
414
+ >>> from geoai.tools.cloudmask import create_cloud_free_mask
415
+ >>> import numpy as np
416
+ >>> mask = np.random.randint(0, 4, (512, 512))
417
+ >>> cloud_free = create_cloud_free_mask(mask)
418
+ >>> print(f"Usable pixels: {cloud_free.sum()}")
419
+ """
420
+ # Start with clear pixels
421
+ usable = mask == CLEAR
422
+
423
+ # Optionally include thin clouds
424
+ if include_thin_clouds:
425
+ usable = usable | (mask == THIN_CLOUD)
426
+
427
+ # Optionally include shadows
428
+ if include_shadows:
429
+ usable = usable | (mask == CLOUD_SHADOW)
430
+
431
+ return usable.astype(np.uint8)