geoai-py 0.3.5__py2.py3-none-any.whl → 0.3.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.3.5"
5
+ __version__ = "0.3.6"
6
6
 
7
7
 
8
8
  import os
geoai/extract.py CHANGED
@@ -1,21 +1,29 @@
1
+ """This module provides a dataset class for object extraction from raster data"""
2
+
3
+ # Standard Library
1
4
  import os
5
+
6
+ # Third-Party Libraries
7
+ import cv2
8
+ import geopandas as gpd
9
+ import matplotlib.pyplot as plt
2
10
  import numpy as np
11
+ import rasterio
12
+ import scipy.ndimage as ndimage
3
13
  import torch
4
- import matplotlib.pyplot as plt
14
+ from huggingface_hub import hf_hub_download
15
+ from rasterio.windows import Window
5
16
  from shapely.geometry import Polygon, box
6
- import geopandas as gpd
7
17
  from tqdm import tqdm
18
+ from torchvision.models.detection import (
19
+ maskrcnn_resnet50_fpn,
20
+ fasterrcnn_resnet50_fpn_v2,
21
+ )
8
22
 
9
- import cv2
10
- from torchvision.models.detection import maskrcnn_resnet50_fpn
11
- import torchvision.transforms as T
12
- import rasterio
13
- from rasterio.windows import Window
14
- from rasterio.features import shapes
15
- from huggingface_hub import hf_hub_download
16
- import scipy.ndimage as ndimage
23
+ # Local Imports
17
24
  from .utils import get_raster_stats
18
25
 
26
+
19
27
  try:
20
28
  from torchgeo.datasets import NonGeoDataset
21
29
  except ImportError as e:
@@ -60,6 +68,7 @@ class CustomDataset(NonGeoDataset):
60
68
  chip_size=(512, 512),
61
69
  overlap=0.5,
62
70
  transforms=None,
71
+ band_indexes=None,
63
72
  verbose=False,
64
73
  ):
65
74
  """
@@ -70,6 +79,7 @@ class CustomDataset(NonGeoDataset):
70
79
  chip_size: Size of image chips to extract (height, width). Default is (512, 512).
71
80
  overlap: Amount of overlap between adjacent tiles (0.0-1.0). Default is 0.5 (50%).
72
81
  transforms: Transforms to apply to the image. Default is None.
82
+ band_indexes: List of band indexes to use. Default is None (use all bands).
73
83
  verbose: Whether to print detailed processing information. Default is False.
74
84
 
75
85
  Raises:
@@ -82,6 +92,7 @@ class CustomDataset(NonGeoDataset):
82
92
  self.chip_size = chip_size
83
93
  self.overlap = overlap
84
94
  self.transforms = transforms
95
+ self.band_indexes = band_indexes
85
96
  self.verbose = verbose
86
97
  self.warned_about_bands = False
87
98
 
@@ -191,7 +202,10 @@ class CustomDataset(NonGeoDataset):
191
202
  if not self.warned_about_bands and self.verbose:
192
203
  print(f"Image has {image.shape[0]} bands, using first 3 bands only")
193
204
  self.warned_about_bands = True
194
- image = image[:3]
205
+ if self.band_indexes is not None:
206
+ image = image[self.band_indexes]
207
+ else:
208
+ image = image[:3]
195
209
  elif image.shape[0] < 3:
196
210
  # If image has fewer than 3 bands, duplicate the last band to make 3
197
211
  if not self.warned_about_bands and self.verbose:
@@ -594,7 +608,7 @@ class ObjectDetector:
594
608
 
595
609
  Args:
596
610
  mask_path: Path to the object masks GeoTIFF
597
- output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
611
+ output_path: Path to save the output GeoJSON or Parquet file (default: mask_path with .geojson extension)
598
612
  simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
599
613
  mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
600
614
  min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
@@ -779,7 +793,10 @@ class ObjectDetector:
779
793
 
780
794
  # Save to file
781
795
  if output_path:
782
- gdf.to_file(output_path)
796
+ if output_path.endswith(".parquet"):
797
+ gdf.to_parquet(output_path)
798
+ else:
799
+ gdf.to_file(output_path)
783
800
  print(f"Saved {len(gdf)} objects to {output_path}")
784
801
 
785
802
  return gdf
@@ -792,6 +809,7 @@ class ObjectDetector:
792
809
  batch_size=4,
793
810
  filter_edges=True,
794
811
  edge_buffer=20,
812
+ band_indexes=None,
795
813
  **kwargs,
796
814
  ):
797
815
  """
@@ -799,10 +817,11 @@ class ObjectDetector:
799
817
 
800
818
  Args:
801
819
  raster_path: Path to input raster file
802
- output_path: Path to output GeoJSON file (optional)
820
+ output_path: Path to output GeoJSON or Parquet file (optional)
803
821
  batch_size: Batch size for processing
804
822
  filter_edges: Whether to filter out objects at the edges of the image
805
823
  edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
824
+ band_indexes: List of band indexes to use (if None, use all bands)
806
825
  **kwargs: Additional parameters:
807
826
  confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
808
827
  overlap: Overlap between adjacent tiles (0.0-1.0)
@@ -843,7 +862,10 @@ class ObjectDetector:
843
862
 
844
863
  # Create dataset
845
864
  dataset = CustomDataset(
846
- raster_path=raster_path, chip_size=chip_size, overlap=overlap
865
+ raster_path=raster_path,
866
+ chip_size=chip_size,
867
+ overlap=overlap,
868
+ band_indexes=band_indexes,
847
869
  )
848
870
  self.raster_stats = dataset.raster_stats
849
871
 
@@ -1021,7 +1043,10 @@ class ObjectDetector:
1021
1043
 
1022
1044
  # Save to file if requested
1023
1045
  if output_path:
1024
- gdf.to_file(output_path, driver="GeoJSON")
1046
+ if output_path.endswith(".parquet"):
1047
+ gdf.to_parquet(output_path)
1048
+ else:
1049
+ gdf.to_file(output_path, driver="GeoJSON")
1025
1050
  print(f"Saved {len(gdf)} objects to {output_path}")
1026
1051
 
1027
1052
  return gdf
@@ -1887,21 +1912,30 @@ class ObjectDetector:
1887
1912
  output_path=None,
1888
1913
  confidence_threshold=None,
1889
1914
  mask_threshold=None,
1915
+ min_object_area=10,
1916
+ max_object_area=float("inf"),
1890
1917
  overlap=0.25,
1891
1918
  batch_size=4,
1919
+ band_indexes=None,
1892
1920
  verbose=False,
1893
1921
  **kwargs,
1894
1922
  ):
1895
1923
  """
1896
1924
  Save masks with confidence values as a multi-band GeoTIFF.
1897
1925
 
1926
+ Objects with area smaller than min_object_area or larger than max_object_area
1927
+ will be filtered out.
1928
+
1898
1929
  Args:
1899
1930
  raster_path: Path to input raster
1900
1931
  output_path: Path for output GeoTIFF
1901
1932
  confidence_threshold: Minimum confidence score (0.0-1.0)
1902
1933
  mask_threshold: Threshold for mask binarization (0.0-1.0)
1934
+ min_object_area: Minimum area (in pixels) for an object to be included
1935
+ max_object_area: Maximum area (in pixels) for an object to be included
1903
1936
  overlap: Overlap between tiles (0.0-1.0)
1904
1937
  batch_size: Batch size for processing
1938
+ band_indexes: List of band indexes to use (default: all bands)
1905
1939
  verbose: Whether to print detailed processing information
1906
1940
 
1907
1941
  Returns:
@@ -1926,6 +1960,7 @@ class ObjectDetector:
1926
1960
  raster_path=raster_path,
1927
1961
  chip_size=chip_size,
1928
1962
  overlap=overlap,
1963
+ band_indexes=band_indexes,
1929
1964
  verbose=verbose,
1930
1965
  )
1931
1966
 
@@ -2012,6 +2047,21 @@ class ObjectDetector:
2012
2047
  for mask_idx, mask in enumerate(masks):
2013
2048
  # Convert to binary mask
2014
2049
  binary_mask = (mask[0] > mask_threshold).astype(np.uint8) * 255
2050
+
2051
+ # Check object area - calculate number of pixels in the mask
2052
+ object_area = np.sum(binary_mask > 0)
2053
+
2054
+ # Skip objects that don't meet area criteria
2055
+ if (
2056
+ object_area < min_object_area
2057
+ or object_area > max_object_area
2058
+ ):
2059
+ if verbose:
2060
+ print(
2061
+ f"Filtering out object with area {object_area} pixels"
2062
+ )
2063
+ continue
2064
+
2015
2065
  conf_value = int(scores[mask_idx] * 255) # Scale to 0-255
2016
2066
 
2017
2067
  # Update the mask and confidence arrays
@@ -2164,7 +2214,7 @@ class ObjectDetector:
2164
2214
  # Save to file if requested
2165
2215
  if output_path:
2166
2216
  gdf.to_file(output_path, driver="GeoJSON")
2167
- print(f"Saved {len(gdf)} cars with confidence to {output_path}")
2217
+ print(f"Saved {len(gdf)} objects with confidence to {output_path}")
2168
2218
 
2169
2219
  return gdf
2170
2220
  else:
@@ -2277,3 +2327,32 @@ class ShipDetector(ObjectDetector):
2277
2327
  super().__init__(
2278
2328
  model_path=model_path, repo_id=repo_id, model=model, device=device
2279
2329
  )
2330
+
2331
+
2332
+ class SolarPanelDetector(ObjectDetector):
2333
+ """
2334
+ Solar panel detection using a pre-trained Mask R-CNN model.
2335
+
2336
+ This class extends the
2337
+ `ObjectDetector` class with additional methods for solar panel detection."
2338
+ """
2339
+
2340
+ def __init__(
2341
+ self,
2342
+ model_path="solar_panel_detection.pth",
2343
+ repo_id=None,
2344
+ model=None,
2345
+ device=None,
2346
+ ):
2347
+ """
2348
+ Initialize the object extractor.
2349
+
2350
+ Args:
2351
+ model_path: Path to the .pth model file.
2352
+ repo_id: Repo ID for loading models from the Hub.
2353
+ model: Custom model to use for inference.
2354
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2355
+ """
2356
+ super().__init__(
2357
+ model_path=model_path, repo_id=repo_id, model=model, device=device
2358
+ )
geoai/geoai.py CHANGED
@@ -2,3 +2,4 @@
2
2
 
3
3
  from .utils import *
4
4
  from .extract import *
5
+ from .segment import *
geoai/segment.py ADDED
@@ -0,0 +1,305 @@
1
+ """This module provides functionality for segmenting high-resolution satellite imagery using vision-language models."""
2
+
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ import rasterio
9
+ from rasterio.windows import Window
10
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
11
+
12
+
13
+ class CLIPSegmentation:
14
+ """
15
+ A class for segmenting high-resolution satellite imagery using text prompts with CLIP-based models.
16
+
17
+ This segmenter utilizes the CLIP-Seg model to perform semantic segmentation based on text prompts.
18
+ It can process large GeoTIFF files by tiling them and handles proper georeferencing in the output.
19
+
20
+ Args:
21
+ model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
22
+ device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
23
+ tile_size (int): Size of tiles to process the image in chunks. Defaults to 352.
24
+ overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 16.
25
+
26
+ Attributes:
27
+ processor (CLIPSegProcessor): The processor for the CLIP-Seg model.
28
+ model (CLIPSegForImageSegmentation): The CLIP-Seg model for segmentation.
29
+ device (str): The device being used ('cuda' or 'cpu').
30
+ tile_size (int): Size of tiles for processing.
31
+ overlap (int): Overlap between tiles.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model_name="CIDAS/clipseg-rd64-refined",
37
+ device=None,
38
+ tile_size=512,
39
+ overlap=32,
40
+ ):
41
+ """
42
+ Initialize the ImageSegmenter with the specified model and settings.
43
+
44
+ Args:
45
+ model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
46
+ device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
47
+ tile_size (int): Size of tiles to process the image in chunks. Defaults to 512.
48
+ overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 32.
49
+ """
50
+ self.tile_size = tile_size
51
+ self.overlap = overlap
52
+
53
+ # Set device
54
+ if device is None:
55
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ else:
57
+ self.device = device
58
+
59
+ # Load model and processor
60
+ self.processor = CLIPSegProcessor.from_pretrained(model_name)
61
+ self.model = CLIPSegForImageSegmentation.from_pretrained(model_name).to(
62
+ self.device
63
+ )
64
+
65
+ print(f"Model loaded on {self.device}")
66
+
67
+ def segment_image(
68
+ self, input_path, output_path, text_prompt, threshold=0.5, smoothing_sigma=1.0
69
+ ):
70
+ """
71
+ Segment a GeoTIFF image using the provided text prompt.
72
+
73
+ The function processes the image in tiles and saves the result as a GeoTIFF with two bands:
74
+ - Band 1: Binary segmentation mask (0 or 1)
75
+ - Band 2: Probability scores (0.0 to 1.0)
76
+
77
+ Args:
78
+ input_path (str): Path to the input GeoTIFF file.
79
+ output_path (str): Path where the output GeoTIFF will be saved.
80
+ text_prompt (str): Text description of what to segment (e.g., "water", "buildings").
81
+ threshold (float): Threshold for binary segmentation (0.0 to 1.0). Defaults to 0.5.
82
+ smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
83
+
84
+ Returns:
85
+ str: Path to the saved output file.
86
+ """
87
+ # Open the input GeoTIFF
88
+ with rasterio.open(input_path) as src:
89
+ # Get metadata
90
+ meta = src.meta
91
+ height = src.height
92
+ width = src.width
93
+
94
+ # Create output metadata
95
+ out_meta = meta.copy()
96
+ out_meta.update({"count": 2, "dtype": "float32", "nodata": None})
97
+
98
+ # Create arrays for results
99
+ segmentation = np.zeros((height, width), dtype=np.float32)
100
+ probabilities = np.zeros((height, width), dtype=np.float32)
101
+
102
+ # Calculate effective tile size (accounting for overlap)
103
+ effective_tile_size = self.tile_size - 2 * self.overlap
104
+
105
+ # Calculate number of tiles
106
+ n_tiles_x = max(1, int(np.ceil(width / effective_tile_size)))
107
+ n_tiles_y = max(1, int(np.ceil(height / effective_tile_size)))
108
+ total_tiles = n_tiles_x * n_tiles_y
109
+
110
+ # Process tiles with tqdm progress bar
111
+ with tqdm(total=total_tiles, desc="Processing tiles") as pbar:
112
+ # Iterate through tiles
113
+ for y in range(n_tiles_y):
114
+ for x in range(n_tiles_x):
115
+ # Calculate tile coordinates with overlap
116
+ x_start = max(0, x * effective_tile_size - self.overlap)
117
+ y_start = max(0, y * effective_tile_size - self.overlap)
118
+ x_end = min(width, (x + 1) * effective_tile_size + self.overlap)
119
+ y_end = min(
120
+ height, (y + 1) * effective_tile_size + self.overlap
121
+ )
122
+
123
+ tile_width = x_end - x_start
124
+ tile_height = y_end - y_start
125
+
126
+ # Read the tile
127
+ window = Window(x_start, y_start, tile_width, tile_height)
128
+ tile_data = src.read(window=window)
129
+
130
+ # Process the tile
131
+ try:
132
+ # Convert to RGB if necessary (handling different satellite bands)
133
+ if tile_data.shape[0] > 3:
134
+ # Use first three bands for RGB representation
135
+ rgb_tile = tile_data[:3].transpose(1, 2, 0)
136
+ # Normalize data to 0-255 range if needed
137
+ if rgb_tile.max() > 0:
138
+ rgb_tile = (
139
+ (rgb_tile - rgb_tile.min())
140
+ / (rgb_tile.max() - rgb_tile.min())
141
+ * 255
142
+ ).astype(np.uint8)
143
+ elif tile_data.shape[0] == 1:
144
+ # Create RGB from grayscale
145
+ rgb_tile = np.repeat(
146
+ tile_data[0][:, :, np.newaxis], 3, axis=2
147
+ )
148
+ # Normalize if needed
149
+ if rgb_tile.max() > 0:
150
+ rgb_tile = (
151
+ (rgb_tile - rgb_tile.min())
152
+ / (rgb_tile.max() - rgb_tile.min())
153
+ * 255
154
+ ).astype(np.uint8)
155
+ else:
156
+ # Already 3-channel, assume RGB
157
+ rgb_tile = tile_data.transpose(1, 2, 0)
158
+ # Normalize if needed
159
+ if rgb_tile.max() > 0:
160
+ rgb_tile = (
161
+ (rgb_tile - rgb_tile.min())
162
+ / (rgb_tile.max() - rgb_tile.min())
163
+ * 255
164
+ ).astype(np.uint8)
165
+
166
+ # Convert to PIL Image
167
+ pil_image = Image.fromarray(rgb_tile)
168
+
169
+ # Resize if needed to match model's requirements
170
+ if (
171
+ pil_image.width > self.tile_size
172
+ or pil_image.height > self.tile_size
173
+ ):
174
+ # Keep aspect ratio
175
+ pil_image.thumbnail(
176
+ (self.tile_size, self.tile_size), Image.LANCZOS
177
+ )
178
+
179
+ # Process with CLIP-Seg
180
+ inputs = self.processor(
181
+ text=text_prompt, images=pil_image, return_tensors="pt"
182
+ ).to(self.device)
183
+
184
+ # Forward pass
185
+ with torch.no_grad():
186
+ outputs = self.model(**inputs)
187
+
188
+ # Get logits and resize to original tile size
189
+ logits = outputs.logits[0]
190
+
191
+ # Convert logits to probabilities with sigmoid
192
+ probs = torch.sigmoid(logits).cpu().numpy()
193
+
194
+ # Resize back to original tile size if needed
195
+ if probs.shape != (tile_height, tile_width):
196
+ # Use bicubic interpolation for smoother results
197
+ probs_resized = np.array(
198
+ Image.fromarray(probs).resize(
199
+ (tile_width, tile_height), Image.BICUBIC
200
+ )
201
+ )
202
+ else:
203
+ probs_resized = probs
204
+
205
+ # Apply gaussian blur to reduce blockiness
206
+ try:
207
+ from scipy.ndimage import gaussian_filter
208
+
209
+ probs_resized = gaussian_filter(
210
+ probs_resized, sigma=smoothing_sigma
211
+ )
212
+ except ImportError:
213
+ pass # Continue without smoothing if scipy is not available
214
+
215
+ # Store results in the full arrays
216
+ # Only store the non-overlapping part (except at edges)
217
+ valid_x_start = self.overlap if x > 0 else 0
218
+ valid_y_start = self.overlap if y > 0 else 0
219
+ valid_x_end = (
220
+ tile_width - self.overlap
221
+ if x < n_tiles_x - 1
222
+ else tile_width
223
+ )
224
+ valid_y_end = (
225
+ tile_height - self.overlap
226
+ if y < n_tiles_y - 1
227
+ else tile_height
228
+ )
229
+
230
+ dest_x_start = x_start + valid_x_start
231
+ dest_y_start = y_start + valid_y_start
232
+ dest_x_end = x_start + valid_x_end
233
+ dest_y_end = y_start + valid_y_end
234
+
235
+ # Store probabilities
236
+ probabilities[
237
+ dest_y_start:dest_y_end, dest_x_start:dest_x_end
238
+ ] = probs_resized[
239
+ valid_y_start:valid_y_end, valid_x_start:valid_x_end
240
+ ]
241
+
242
+ except Exception as e:
243
+ print(f"Error processing tile at ({x}, {y}): {str(e)}")
244
+ # Continue with next tile
245
+
246
+ # Update progress bar
247
+ pbar.update(1)
248
+
249
+ # Create binary segmentation from probabilities
250
+ segmentation = (probabilities >= threshold).astype(np.float32)
251
+
252
+ # Write the output GeoTIFF
253
+ with rasterio.open(output_path, "w", **out_meta) as dst:
254
+ dst.write(segmentation, 1)
255
+ dst.write(probabilities, 2)
256
+
257
+ # Add descriptions to bands
258
+ dst.set_band_description(1, "Binary Segmentation")
259
+ dst.set_band_description(2, "Probability Scores")
260
+
261
+ print(f"Segmentation saved to {output_path}")
262
+ return output_path
263
+
264
+ def segment_image_batch(
265
+ self,
266
+ input_paths,
267
+ output_dir,
268
+ text_prompt,
269
+ threshold=0.5,
270
+ smoothing_sigma=1.0,
271
+ suffix="_segmented",
272
+ ):
273
+ """
274
+ Segment multiple GeoTIFF images using the provided text prompt.
275
+
276
+ Args:
277
+ input_paths (list): List of paths to input GeoTIFF files.
278
+ output_dir (str): Directory where output GeoTIFFs will be saved.
279
+ text_prompt (str): Text description of what to segment.
280
+ threshold (float): Threshold for binary segmentation. Defaults to 0.5.
281
+ smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
282
+ suffix (str): Suffix to add to output filenames. Defaults to "_segmented".
283
+
284
+ Returns:
285
+ list: Paths to all saved output files.
286
+ """
287
+ # Create output directory if it doesn't exist
288
+ os.makedirs(output_dir, exist_ok=True)
289
+
290
+ output_paths = []
291
+
292
+ # Process each input file
293
+ for input_path in tqdm(input_paths, desc="Processing files"):
294
+ # Generate output path
295
+ filename = os.path.basename(input_path)
296
+ base_name, ext = os.path.splitext(filename)
297
+ output_path = os.path.join(output_dir, f"{base_name}{suffix}{ext}")
298
+
299
+ # Segment the image
300
+ result_path = self.segment_image(
301
+ input_path, output_path, text_prompt, threshold, smoothing_sigma
302
+ )
303
+ output_paths.append(result_path)
304
+
305
+ return output_paths