geoai-py 0.17.0__py2.py3-none-any.whl → 0.18.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,357 @@
1
+ """
2
+ MultiClean integration utilities for cleaning segmentation results.
3
+
4
+ This module provides functions to use MultiClean (https://github.com/DPIRD-DMA/MultiClean)
5
+ for post-processing segmentation masks and classification rasters. MultiClean performs
6
+ morphological operations to smooth edges, remove noise islands, and fill gaps.
7
+ """
8
+
9
+ import os
10
+ from typing import Optional, List, Union, Tuple
11
+ import numpy as np
12
+
13
+ try:
14
+ from multiclean import clean_array
15
+
16
+ MULTICLEAN_AVAILABLE = True
17
+ except ImportError:
18
+ MULTICLEAN_AVAILABLE = False
19
+
20
+ try:
21
+ import rasterio
22
+
23
+ RASTERIO_AVAILABLE = True
24
+ except ImportError:
25
+ RASTERIO_AVAILABLE = False
26
+
27
+
28
+ def check_multiclean_available():
29
+ """
30
+ Check if multiclean is installed.
31
+
32
+ Raises:
33
+ ImportError: If multiclean is not installed.
34
+ """
35
+ if not MULTICLEAN_AVAILABLE:
36
+ raise ImportError(
37
+ "multiclean is not installed. "
38
+ "Please install it with: pip install multiclean "
39
+ "or: pip install geoai-py[extra]"
40
+ )
41
+
42
+
43
+ def clean_segmentation_mask(
44
+ mask: np.ndarray,
45
+ class_values: Optional[Union[int, List[int]]] = None,
46
+ smooth_edge_size: int = 2,
47
+ min_island_size: int = 100,
48
+ connectivity: int = 8,
49
+ max_workers: Optional[int] = None,
50
+ fill_nan: bool = False,
51
+ ) -> np.ndarray:
52
+ """
53
+ Clean a segmentation mask using MultiClean morphological operations.
54
+
55
+ This function applies three cleaning operations:
56
+ 1. Edge smoothing - Uses morphological opening to reduce jagged boundaries
57
+ 2. Island removal - Eliminates small connected components (noise)
58
+ 3. Gap filling - Replaces invalid pixels with nearest valid class
59
+
60
+ Args:
61
+ mask (np.ndarray): 2D numpy array containing segmentation classes.
62
+ Can be int or float. NaN values are treated as nodata.
63
+ class_values (int, list of int, or None): Target class values to process.
64
+ If None, auto-detects unique values from the mask. Defaults to None.
65
+ smooth_edge_size (int): Kernel width in pixels for edge smoothing.
66
+ Set to 0 to disable smoothing. Defaults to 2.
67
+ min_island_size (int): Minimum area (in pixels) for connected components.
68
+ Components with area strictly less than this are removed. Defaults to 100.
69
+ connectivity (int): Connectivity for component detection. Use 4 or 8.
70
+ 8-connectivity considers diagonal neighbors. Defaults to 8.
71
+ max_workers (int, optional): Thread pool size for parallel processing.
72
+ If None, uses default threading. Defaults to None.
73
+ fill_nan (bool): Whether to fill NaN pixels with nearest valid class.
74
+ Defaults to False.
75
+
76
+ Returns:
77
+ np.ndarray: Cleaned 2D segmentation mask with same shape as input.
78
+
79
+ Raises:
80
+ ImportError: If multiclean is not installed.
81
+ ValueError: If mask is not 2D or if connectivity is not 4 or 8.
82
+
83
+ Example:
84
+ >>> import numpy as np
85
+ >>> from geoai.tools.multiclean import clean_segmentation_mask
86
+ >>> mask = np.random.randint(0, 3, (512, 512))
87
+ >>> cleaned = clean_segmentation_mask(
88
+ ... mask,
89
+ ... class_values=[0, 1, 2],
90
+ ... smooth_edge_size=2,
91
+ ... min_island_size=50
92
+ ... )
93
+ """
94
+ check_multiclean_available()
95
+
96
+ if mask.ndim != 2:
97
+ raise ValueError(f"Mask must be 2D, got shape {mask.shape}")
98
+
99
+ if connectivity not in [4, 8]:
100
+ raise ValueError(f"Connectivity must be 4 or 8, got {connectivity}")
101
+
102
+ # Apply MultiClean
103
+ cleaned = clean_array(
104
+ mask,
105
+ class_values=class_values,
106
+ smooth_edge_size=smooth_edge_size,
107
+ min_island_size=min_island_size,
108
+ connectivity=connectivity,
109
+ max_workers=max_workers,
110
+ fill_nan=fill_nan,
111
+ )
112
+
113
+ return cleaned
114
+
115
+
116
+ def clean_raster(
117
+ input_path: str,
118
+ output_path: str,
119
+ class_values: Optional[Union[int, List[int]]] = None,
120
+ smooth_edge_size: int = 2,
121
+ min_island_size: int = 100,
122
+ connectivity: int = 8,
123
+ max_workers: Optional[int] = None,
124
+ fill_nan: bool = False,
125
+ band: int = 1,
126
+ nodata: Optional[float] = None,
127
+ ) -> None:
128
+ """
129
+ Clean a classification raster (GeoTIFF) and save the result.
130
+
131
+ Reads a GeoTIFF file, applies MultiClean morphological operations,
132
+ and saves the cleaned result while preserving geospatial metadata
133
+ (CRS, transform, nodata value).
134
+
135
+ Args:
136
+ input_path (str): Path to input GeoTIFF file.
137
+ output_path (str): Path to save cleaned GeoTIFF file.
138
+ class_values (int, list of int, or None): Target class values to process.
139
+ If None, auto-detects unique values. Defaults to None.
140
+ smooth_edge_size (int): Kernel width in pixels for edge smoothing.
141
+ Defaults to 2.
142
+ min_island_size (int): Minimum area (in pixels) for components.
143
+ Defaults to 100.
144
+ connectivity (int): Connectivity for component detection (4 or 8).
145
+ Defaults to 8.
146
+ max_workers (int, optional): Thread pool size. Defaults to None.
147
+ fill_nan (bool): Whether to fill NaN/nodata pixels. Defaults to False.
148
+ band (int): Band index to read (1-indexed). Defaults to 1.
149
+ nodata (float, optional): Nodata value to use. If None, uses value
150
+ from input file. Defaults to None.
151
+
152
+ Returns:
153
+ None: Writes cleaned raster to output_path.
154
+
155
+ Raises:
156
+ ImportError: If multiclean or rasterio is not installed.
157
+ FileNotFoundError: If input_path does not exist.
158
+
159
+ Example:
160
+ >>> from geoai.tools.multiclean import clean_raster
161
+ >>> clean_raster(
162
+ ... "segmentation_raw.tif",
163
+ ... "segmentation_cleaned.tif",
164
+ ... class_values=[0, 1, 2],
165
+ ... smooth_edge_size=3,
166
+ ... min_island_size=50
167
+ ... )
168
+ """
169
+ check_multiclean_available()
170
+
171
+ if not RASTERIO_AVAILABLE:
172
+ raise ImportError(
173
+ "rasterio is required for raster operations. "
174
+ "Please install it with: pip install rasterio"
175
+ )
176
+
177
+ if not os.path.exists(input_path):
178
+ raise FileNotFoundError(f"Input file not found: {input_path}")
179
+
180
+ # Read input raster
181
+ with rasterio.open(input_path) as src:
182
+ # Read the specified band
183
+ mask = src.read(band)
184
+
185
+ # Get metadata
186
+ profile = src.profile.copy()
187
+
188
+ # Handle nodata
189
+ if nodata is None:
190
+ nodata = src.nodata
191
+
192
+ # Convert nodata to NaN if specified
193
+ if nodata is not None:
194
+ mask = mask.astype(np.float32)
195
+ mask[mask == nodata] = np.nan
196
+
197
+ # Clean the mask
198
+ cleaned = clean_segmentation_mask(
199
+ mask,
200
+ class_values=class_values,
201
+ smooth_edge_size=smooth_edge_size,
202
+ min_island_size=min_island_size,
203
+ connectivity=connectivity,
204
+ max_workers=max_workers,
205
+ fill_nan=fill_nan,
206
+ )
207
+
208
+ # Convert NaN back to nodata if needed
209
+ if nodata is not None:
210
+ # Convert any remaining NaN values back to nodata value
211
+ if np.isnan(cleaned).any():
212
+ cleaned = np.nan_to_num(cleaned, nan=nodata)
213
+
214
+ # Update profile for output
215
+ profile.update(
216
+ dtype=cleaned.dtype,
217
+ count=1,
218
+ compress="lzw",
219
+ nodata=nodata,
220
+ )
221
+
222
+ # Write cleaned raster
223
+ output_dir = os.path.dirname(os.path.abspath(output_path))
224
+ if output_dir and output_dir != os.path.abspath(os.sep):
225
+ os.makedirs(output_dir, exist_ok=True)
226
+ with rasterio.open(output_path, "w", **profile) as dst:
227
+ dst.write(cleaned, 1)
228
+
229
+
230
+ def clean_raster_batch(
231
+ input_paths: List[str],
232
+ output_dir: str,
233
+ class_values: Optional[Union[int, List[int]]] = None,
234
+ smooth_edge_size: int = 2,
235
+ min_island_size: int = 100,
236
+ connectivity: int = 8,
237
+ max_workers: Optional[int] = None,
238
+ fill_nan: bool = False,
239
+ band: int = 1,
240
+ suffix: str = "_cleaned",
241
+ verbose: bool = True,
242
+ ) -> List[str]:
243
+ """
244
+ Clean multiple classification rasters in batch.
245
+
246
+ Processes multiple GeoTIFF files with the same cleaning parameters
247
+ and saves results to an output directory.
248
+
249
+ Args:
250
+ input_paths (list of str): List of paths to input GeoTIFF files.
251
+ output_dir (str): Directory to save cleaned files.
252
+ class_values (int, list of int, or None): Target class values.
253
+ Defaults to None (auto-detect).
254
+ smooth_edge_size (int): Kernel width for edge smoothing. Defaults to 2.
255
+ min_island_size (int): Minimum component area. Defaults to 100.
256
+ connectivity (int): Connectivity (4 or 8). Defaults to 8.
257
+ max_workers (int, optional): Thread pool size. Defaults to None.
258
+ fill_nan (bool): Whether to fill NaN pixels. Defaults to False.
259
+ band (int): Band index to read (1-indexed). Defaults to 1.
260
+ suffix (str): Suffix to add to output filenames. Defaults to "_cleaned".
261
+ verbose (bool): Whether to print progress. Defaults to True.
262
+
263
+ Returns:
264
+ list of str: Paths to cleaned output files.
265
+
266
+ Raises:
267
+ ImportError: If multiclean or rasterio is not installed.
268
+
269
+ Example:
270
+ >>> from geoai.tools.multiclean import clean_raster_batch
271
+ >>> input_files = ["mask1.tif", "mask2.tif", "mask3.tif"]
272
+ >>> outputs = clean_raster_batch(
273
+ ... input_files,
274
+ ... output_dir="cleaned_masks",
275
+ ... min_island_size=50
276
+ ... )
277
+ """
278
+ check_multiclean_available()
279
+
280
+ # Create output directory
281
+ os.makedirs(output_dir, exist_ok=True)
282
+
283
+ output_paths = []
284
+
285
+ for i, input_path in enumerate(input_paths):
286
+ if verbose:
287
+ print(f"Processing {i+1}/{len(input_paths)}: {input_path}")
288
+
289
+ # Generate output filename
290
+ basename = os.path.basename(input_path)
291
+ name, ext = os.path.splitext(basename)
292
+ output_filename = f"{name}{suffix}{ext}"
293
+ output_path = os.path.join(output_dir, output_filename)
294
+
295
+ try:
296
+ # Clean the raster
297
+ clean_raster(
298
+ input_path,
299
+ output_path,
300
+ class_values=class_values,
301
+ smooth_edge_size=smooth_edge_size,
302
+ min_island_size=min_island_size,
303
+ connectivity=connectivity,
304
+ max_workers=max_workers,
305
+ fill_nan=fill_nan,
306
+ band=band,
307
+ )
308
+
309
+ output_paths.append(output_path)
310
+
311
+ if verbose:
312
+ print(f" ✓ Saved to: {output_path}")
313
+
314
+ except Exception as e:
315
+ if verbose:
316
+ print(f" ✗ Failed: {e}")
317
+ continue
318
+
319
+ return output_paths
320
+
321
+
322
+ def compare_masks(
323
+ original: np.ndarray,
324
+ cleaned: np.ndarray,
325
+ ) -> Tuple[int, int, float]:
326
+ """
327
+ Compare original and cleaned masks to quantify changes.
328
+
329
+ Args:
330
+ original (np.ndarray): Original segmentation mask.
331
+ cleaned (np.ndarray): Cleaned segmentation mask.
332
+
333
+ Returns:
334
+ tuple: (pixels_changed, total_pixels, change_percentage)
335
+ - pixels_changed: Number of pixels that changed value
336
+ - total_pixels: Total number of valid pixels
337
+ - change_percentage: Percentage of pixels changed
338
+
339
+ Example:
340
+ >>> import numpy as np
341
+ >>> from geoai.tools.multiclean import compare_masks
342
+ >>> original = np.random.randint(0, 3, (512, 512))
343
+ >>> cleaned = original.copy()
344
+ >>> changed, total, pct = compare_masks(original, cleaned)
345
+ >>> print(f"Changed: {pct:.2f}%")
346
+ """
347
+ # Handle NaN values
348
+ valid_mask = ~(np.isnan(original) | np.isnan(cleaned))
349
+
350
+ # Count changed pixels
351
+ pixels_changed = np.sum((original != cleaned) & valid_mask)
352
+ total_pixels = np.sum(valid_mask)
353
+
354
+ # Calculate percentage
355
+ change_percentage = (pixels_changed / total_pixels * 100) if total_pixels > 0 else 0
356
+
357
+ return pixels_changed, total_pixels, change_percentage
geoai/train.py CHANGED
@@ -1436,8 +1436,12 @@ def instance_segmentation_inference_on_geotiff(
1436
1436
  # Apply Non-Maximum Suppression to handle overlapping detections
1437
1437
  if len(all_detections) > 0:
1438
1438
  # Convert to tensors for NMS
1439
- boxes = torch.tensor([det["box"] for det in all_detections])
1440
- scores = torch.tensor([det["score"] for det in all_detections])
1439
+ boxes = torch.tensor(
1440
+ [det["box"] for det in all_detections], dtype=torch.float32
1441
+ )
1442
+ scores = torch.tensor(
1443
+ [det["score"] for det in all_detections], dtype=torch.float32
1444
+ )
1441
1445
 
1442
1446
  # Apply NMS with IoU threshold
1443
1447
  nms_threshold = 0.3 # IoU threshold for NMS
@@ -1917,6 +1921,96 @@ class SemanticRandomHorizontalFlip:
1917
1921
  return image, mask
1918
1922
 
1919
1923
 
1924
+ class SemanticRandomVerticalFlip:
1925
+ """Random vertical flip transform for semantic segmentation."""
1926
+
1927
+ def __init__(self, prob: float = 0.5) -> None:
1928
+ self.prob = prob
1929
+
1930
+ def __call__(
1931
+ self, image: torch.Tensor, mask: torch.Tensor
1932
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1933
+ if random.random() < self.prob:
1934
+ # Flip image and mask along height dimension
1935
+ image = torch.flip(image, dims=[1])
1936
+ mask = torch.flip(mask, dims=[0])
1937
+ return image, mask
1938
+
1939
+
1940
+ class SemanticRandomRotation90:
1941
+ """Random 90-degree rotation transform for semantic segmentation."""
1942
+
1943
+ def __init__(self, prob: float = 0.5) -> None:
1944
+ self.prob = prob
1945
+
1946
+ def __call__(
1947
+ self, image: torch.Tensor, mask: torch.Tensor
1948
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1949
+ if random.random() < self.prob:
1950
+ # Randomly rotate by 90, 180, or 270 degrees
1951
+ k = random.randint(1, 3)
1952
+ image = torch.rot90(image, k, dims=[1, 2])
1953
+ mask = torch.rot90(mask, k, dims=[0, 1])
1954
+ return image, mask
1955
+
1956
+
1957
+ class SemanticBrightnessAdjustment:
1958
+ """Random brightness adjustment transform for semantic segmentation."""
1959
+
1960
+ def __init__(
1961
+ self, brightness_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
1962
+ ) -> None:
1963
+ """
1964
+ Initialize brightness adjustment transform.
1965
+
1966
+ Args:
1967
+ brightness_range: Tuple of (min, max) brightness factors.
1968
+ prob: Probability of applying the transform.
1969
+ """
1970
+ self.brightness_range = brightness_range
1971
+ self.prob = prob
1972
+
1973
+ def __call__(
1974
+ self, image: torch.Tensor, mask: torch.Tensor
1975
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1976
+ if random.random() < self.prob:
1977
+ # Apply random brightness adjustment
1978
+ factor = self.brightness_range[0] + random.random() * (
1979
+ self.brightness_range[1] - self.brightness_range[0]
1980
+ )
1981
+ image = torch.clamp(image * factor, 0, 1)
1982
+ return image, mask
1983
+
1984
+
1985
+ class SemanticContrastAdjustment:
1986
+ """Random contrast adjustment transform for semantic segmentation."""
1987
+
1988
+ def __init__(
1989
+ self, contrast_range: Tuple[float, float] = (0.8, 1.2), prob: float = 0.5
1990
+ ) -> None:
1991
+ """
1992
+ Initialize contrast adjustment transform.
1993
+
1994
+ Args:
1995
+ contrast_range: Tuple of (min, max) contrast factors.
1996
+ prob: Probability of applying the transform.
1997
+ """
1998
+ self.contrast_range = contrast_range
1999
+ self.prob = prob
2000
+
2001
+ def __call__(
2002
+ self, image: torch.Tensor, mask: torch.Tensor
2003
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
2004
+ if random.random() < self.prob:
2005
+ # Apply random contrast adjustment
2006
+ factor = self.contrast_range[0] + random.random() * (
2007
+ self.contrast_range[1] - self.contrast_range[0]
2008
+ )
2009
+ mean = image.mean(dim=(1, 2), keepdim=True)
2010
+ image = torch.clamp((image - mean) * factor + mean, 0, 1)
2011
+ return image, mask
2012
+
2013
+
1920
2014
  def get_semantic_transform(train: bool) -> Any:
1921
2015
  """
1922
2016
  Get transforms for semantic segmentation data augmentation.
@@ -2388,6 +2482,8 @@ def train_segmentation_model(
2388
2482
  resize_mode: str = "resize",
2389
2483
  num_workers: Optional[int] = None,
2390
2484
  early_stopping_patience: Optional[int] = None,
2485
+ train_transforms: Optional[Callable] = None,
2486
+ val_transforms: Optional[Callable] = None,
2391
2487
  **kwargs: Any,
2392
2488
  ) -> torch.nn.Module:
2393
2489
  """
@@ -2440,8 +2536,17 @@ def train_segmentation_model(
2440
2536
  'resize' - Resize images to target_size (may change aspect ratio)
2441
2537
  'pad' - Pad images to target_size (preserves aspect ratio). Defaults to 'resize'.
2442
2538
  num_workers (int): Number of workers for data loading. If None, uses 0 on macOS and Windows, 8 otherwise.
2443
- early_stopping_patience (int, optional): Number of epochs with no improvement after which
2444
- training will be stopped. If None, early stopping is disabled. Defaults to None.
2539
+ Both image and mask should be torch.Tensor objects. The image tensor is expected to be in
2540
+ CHW format (channels, height, width), and the mask tensor in HW format (height, width).
2541
+ If None, uses default transforms (horizontal flip with 0.5 probability). Defaults to None.
2542
+ val_transforms (callable, optional): Custom transforms for validation data.
2543
+ Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
2544
+ The image tensor is expected to be in CHW format (channels, height, width), and the mask tensor in HW format (height, width).
2545
+ Both image and mask should be torch.Tensor objects. If None, uses default transforms
2546
+ (horizontal flip with 0.5 probability). Defaults to None.
2547
+ val_transforms (callable, optional): Custom transforms for validation data.
2548
+ Should be a callable that accepts (image, mask) tensors and returns transformed (image, mask).
2549
+ If None, uses default transforms (no augmentation). Defaults to None.
2445
2550
  **kwargs: Additional arguments passed to smp.create_model().
2446
2551
  Returns:
2447
2552
  None: Model weights are saved to output_dir.
@@ -2584,10 +2689,22 @@ def train_segmentation_model(
2584
2689
  print("No resizing needed.")
2585
2690
 
2586
2691
  # Create datasets
2692
+ # Use custom transforms if provided, otherwise use default transforms
2693
+ train_transform = (
2694
+ train_transforms
2695
+ if train_transforms is not None
2696
+ else get_semantic_transform(train=True)
2697
+ )
2698
+ val_transform = (
2699
+ val_transforms
2700
+ if val_transforms is not None
2701
+ else get_semantic_transform(train=False)
2702
+ )
2703
+
2587
2704
  train_dataset = SemanticSegmentationDataset(
2588
2705
  train_imgs,
2589
2706
  train_labels,
2590
- transforms=get_semantic_transform(train=True),
2707
+ transforms=train_transform,
2591
2708
  num_channels=num_channels,
2592
2709
  target_size=target_size,
2593
2710
  resize_mode=resize_mode,
@@ -2596,7 +2713,7 @@ def train_segmentation_model(
2596
2713
  val_dataset = SemanticSegmentationDataset(
2597
2714
  val_imgs,
2598
2715
  val_labels,
2599
- transforms=get_semantic_transform(train=False),
2716
+ transforms=val_transform,
2600
2717
  num_channels=num_channels,
2601
2718
  target_size=target_size,
2602
2719
  resize_mode=resize_mode,