geoai-py 0.3.5__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.3.5"
5
+ __version__ = "0.4.0"
6
6
 
7
7
 
8
8
  import os
geoai/download.py CHANGED
@@ -1,18 +1,19 @@
1
1
  """This module provides functions to download data, including NAIP imagery and building data from Overture Maps."""
2
2
 
3
+ import logging
3
4
  import os
4
- from typing import List, Tuple, Optional, Dict, Any
5
- import rioxarray
6
- import numpy as np
5
+ import subprocess
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import geopandas as gpd
7
9
  import matplotlib.pyplot as plt
8
- from pystac_client import Client
10
+ import numpy as np
9
11
  import planetary_computer as pc
10
- import geopandas as gpd
12
+ import requests
13
+ import rioxarray
14
+ from pystac_client import Client
11
15
  from shapely.geometry import box
12
16
  from tqdm import tqdm
13
- import requests
14
- import subprocess
15
- import logging
16
17
 
17
18
  # Configure logging
18
19
  logging.basicConfig(
geoai/extract.py CHANGED
@@ -1,19 +1,26 @@
1
+ """This module provides a dataset class for object extraction from raster data"""
2
+
3
+ # Standard Library
1
4
  import os
5
+
6
+ # Third-Party Libraries
7
+ import cv2
8
+ import geopandas as gpd
9
+ import matplotlib.pyplot as plt
2
10
  import numpy as np
11
+ import rasterio
12
+ import scipy.ndimage as ndimage
3
13
  import torch
4
- import matplotlib.pyplot as plt
14
+ from huggingface_hub import hf_hub_download
15
+ from rasterio.windows import Window
5
16
  from shapely.geometry import Polygon, box
6
- import geopandas as gpd
17
+ from torchvision.models.detection import (
18
+ fasterrcnn_resnet50_fpn_v2,
19
+ maskrcnn_resnet50_fpn,
20
+ )
7
21
  from tqdm import tqdm
8
22
 
9
- import cv2
10
- from torchvision.models.detection import maskrcnn_resnet50_fpn
11
- import torchvision.transforms as T
12
- import rasterio
13
- from rasterio.windows import Window
14
- from rasterio.features import shapes
15
- from huggingface_hub import hf_hub_download
16
- import scipy.ndimage as ndimage
23
+ # Local Imports
17
24
  from .utils import get_raster_stats
18
25
 
19
26
  try:
@@ -60,6 +67,7 @@ class CustomDataset(NonGeoDataset):
60
67
  chip_size=(512, 512),
61
68
  overlap=0.5,
62
69
  transforms=None,
70
+ band_indexes=None,
63
71
  verbose=False,
64
72
  ):
65
73
  """
@@ -70,6 +78,7 @@ class CustomDataset(NonGeoDataset):
70
78
  chip_size: Size of image chips to extract (height, width). Default is (512, 512).
71
79
  overlap: Amount of overlap between adjacent tiles (0.0-1.0). Default is 0.5 (50%).
72
80
  transforms: Transforms to apply to the image. Default is None.
81
+ band_indexes: List of band indexes to use. Default is None (use all bands).
73
82
  verbose: Whether to print detailed processing information. Default is False.
74
83
 
75
84
  Raises:
@@ -82,6 +91,7 @@ class CustomDataset(NonGeoDataset):
82
91
  self.chip_size = chip_size
83
92
  self.overlap = overlap
84
93
  self.transforms = transforms
94
+ self.band_indexes = band_indexes
85
95
  self.verbose = verbose
86
96
  self.warned_about_bands = False
87
97
 
@@ -191,7 +201,10 @@ class CustomDataset(NonGeoDataset):
191
201
  if not self.warned_about_bands and self.verbose:
192
202
  print(f"Image has {image.shape[0]} bands, using first 3 bands only")
193
203
  self.warned_about_bands = True
194
- image = image[:3]
204
+ if self.band_indexes is not None:
205
+ image = image[self.band_indexes]
206
+ else:
207
+ image = image[:3]
195
208
  elif image.shape[0] < 3:
196
209
  # If image has fewer than 3 bands, duplicate the last band to make 3
197
210
  if not self.warned_about_bands and self.verbose:
@@ -256,7 +269,9 @@ class ObjectDetector:
256
269
  Object extraction using Mask R-CNN with TorchGeo.
257
270
  """
258
271
 
259
- def __init__(self, model_path=None, repo_id=None, model=None, device=None):
272
+ def __init__(
273
+ self, model_path=None, repo_id=None, model=None, num_classes=2, device=None
274
+ ):
260
275
  """
261
276
  Initialize the object extractor.
262
277
 
@@ -264,6 +279,7 @@ class ObjectDetector:
264
279
  model_path: Path to the .pth model file.
265
280
  repo_id: Hugging Face repository ID for model download.
266
281
  model: Pre-initialized model object (optional).
282
+ num_classes: Number of classes for detection (default: 2).
267
283
  device: Device to use for inference ('cuda:0', 'cpu', etc.).
268
284
  """
269
285
  # Set device
@@ -283,7 +299,7 @@ class ObjectDetector:
283
299
  self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
284
300
 
285
301
  # Initialize model
286
- self.model = self.initialize_model(model)
302
+ self.model = self.initialize_model(model, num_classes=num_classes)
287
303
 
288
304
  # Download model if needed
289
305
  if model_path is None or (not os.path.exists(model_path)):
@@ -328,11 +344,12 @@ class ObjectDetector:
328
344
  print("Please specify a local model path or ensure internet connectivity.")
329
345
  raise
330
346
 
331
- def initialize_model(self, model):
347
+ def initialize_model(self, model, num_classes=2):
332
348
  """Initialize a deep learning model for object detection.
333
349
 
334
350
  Args:
335
351
  model (torch.nn.Module): A pre-initialized model object.
352
+ num_classes (int): Number of classes for detection.
336
353
 
337
354
  Returns:
338
355
  torch.nn.Module: A deep learning model for object detection.
@@ -347,7 +364,7 @@ class ObjectDetector:
347
364
  model = maskrcnn_resnet50_fpn(
348
365
  weights=None,
349
366
  progress=False,
350
- num_classes=2, # Background + object
367
+ num_classes=num_classes, # Background + object
351
368
  weights_backbone=None,
352
369
  # These parameters ensure consistent normalization
353
370
  image_mean=image_mean,
@@ -594,7 +611,7 @@ class ObjectDetector:
594
611
 
595
612
  Args:
596
613
  mask_path: Path to the object masks GeoTIFF
597
- output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
614
+ output_path: Path to save the output GeoJSON or Parquet file (default: mask_path with .geojson extension)
598
615
  simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
599
616
  mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
600
617
  min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
@@ -779,7 +796,10 @@ class ObjectDetector:
779
796
 
780
797
  # Save to file
781
798
  if output_path:
782
- gdf.to_file(output_path)
799
+ if output_path.endswith(".parquet"):
800
+ gdf.to_parquet(output_path)
801
+ else:
802
+ gdf.to_file(output_path)
783
803
  print(f"Saved {len(gdf)} objects to {output_path}")
784
804
 
785
805
  return gdf
@@ -792,6 +812,7 @@ class ObjectDetector:
792
812
  batch_size=4,
793
813
  filter_edges=True,
794
814
  edge_buffer=20,
815
+ band_indexes=None,
795
816
  **kwargs,
796
817
  ):
797
818
  """
@@ -799,10 +820,11 @@ class ObjectDetector:
799
820
 
800
821
  Args:
801
822
  raster_path: Path to input raster file
802
- output_path: Path to output GeoJSON file (optional)
823
+ output_path: Path to output GeoJSON or Parquet file (optional)
803
824
  batch_size: Batch size for processing
804
825
  filter_edges: Whether to filter out objects at the edges of the image
805
826
  edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
827
+ band_indexes: List of band indexes to use (if None, use all bands)
806
828
  **kwargs: Additional parameters:
807
829
  confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
808
830
  overlap: Overlap between adjacent tiles (0.0-1.0)
@@ -843,7 +865,10 @@ class ObjectDetector:
843
865
 
844
866
  # Create dataset
845
867
  dataset = CustomDataset(
846
- raster_path=raster_path, chip_size=chip_size, overlap=overlap
868
+ raster_path=raster_path,
869
+ chip_size=chip_size,
870
+ overlap=overlap,
871
+ band_indexes=band_indexes,
847
872
  )
848
873
  self.raster_stats = dataset.raster_stats
849
874
 
@@ -1021,7 +1046,10 @@ class ObjectDetector:
1021
1046
 
1022
1047
  # Save to file if requested
1023
1048
  if output_path:
1024
- gdf.to_file(output_path, driver="GeoJSON")
1049
+ if output_path.endswith(".parquet"):
1050
+ gdf.to_parquet(output_path)
1051
+ else:
1052
+ gdf.to_file(output_path, driver="GeoJSON")
1025
1053
  print(f"Saved {len(gdf)} objects to {output_path}")
1026
1054
 
1027
1055
  return gdf
@@ -1281,13 +1309,14 @@ class ObjectDetector:
1281
1309
  Returns:
1282
1310
  GeoDataFrame with regularized objects
1283
1311
  """
1312
+ import math
1313
+
1314
+ import cv2
1315
+ import geopandas as gpd
1284
1316
  import numpy as np
1285
- from shapely.geometry import Polygon, MultiPolygon, box
1286
1317
  from shapely.affinity import rotate, translate
1287
- import geopandas as gpd
1288
- import math
1318
+ from shapely.geometry import MultiPolygon, Polygon, box
1289
1319
  from tqdm import tqdm
1290
- import cv2
1291
1320
 
1292
1321
  def get_angle(p1, p2, p3):
1293
1322
  """Calculate angle between three points in degrees (0-180)"""
@@ -1887,21 +1916,30 @@ class ObjectDetector:
1887
1916
  output_path=None,
1888
1917
  confidence_threshold=None,
1889
1918
  mask_threshold=None,
1919
+ min_object_area=10,
1920
+ max_object_area=float("inf"),
1890
1921
  overlap=0.25,
1891
1922
  batch_size=4,
1923
+ band_indexes=None,
1892
1924
  verbose=False,
1893
1925
  **kwargs,
1894
1926
  ):
1895
1927
  """
1896
1928
  Save masks with confidence values as a multi-band GeoTIFF.
1897
1929
 
1930
+ Objects with area smaller than min_object_area or larger than max_object_area
1931
+ will be filtered out.
1932
+
1898
1933
  Args:
1899
1934
  raster_path: Path to input raster
1900
1935
  output_path: Path for output GeoTIFF
1901
1936
  confidence_threshold: Minimum confidence score (0.0-1.0)
1902
1937
  mask_threshold: Threshold for mask binarization (0.0-1.0)
1938
+ min_object_area: Minimum area (in pixels) for an object to be included
1939
+ max_object_area: Maximum area (in pixels) for an object to be included
1903
1940
  overlap: Overlap between tiles (0.0-1.0)
1904
1941
  batch_size: Batch size for processing
1942
+ band_indexes: List of band indexes to use (default: all bands)
1905
1943
  verbose: Whether to print detailed processing information
1906
1944
 
1907
1945
  Returns:
@@ -1926,6 +1964,7 @@ class ObjectDetector:
1926
1964
  raster_path=raster_path,
1927
1965
  chip_size=chip_size,
1928
1966
  overlap=overlap,
1967
+ band_indexes=band_indexes,
1929
1968
  verbose=verbose,
1930
1969
  )
1931
1970
 
@@ -2012,6 +2051,21 @@ class ObjectDetector:
2012
2051
  for mask_idx, mask in enumerate(masks):
2013
2052
  # Convert to binary mask
2014
2053
  binary_mask = (mask[0] > mask_threshold).astype(np.uint8) * 255
2054
+
2055
+ # Check object area - calculate number of pixels in the mask
2056
+ object_area = np.sum(binary_mask > 0)
2057
+
2058
+ # Skip objects that don't meet area criteria
2059
+ if (
2060
+ object_area < min_object_area
2061
+ or object_area > max_object_area
2062
+ ):
2063
+ if verbose:
2064
+ print(
2065
+ f"Filtering out object with area {object_area} pixels"
2066
+ )
2067
+ continue
2068
+
2015
2069
  conf_value = int(scores[mask_idx] * 255) # Scale to 0-255
2016
2070
 
2017
2071
  # Update the mask and confidence arrays
@@ -2062,7 +2116,7 @@ class ObjectDetector:
2062
2116
  output_path=None,
2063
2117
  confidence_threshold=0.5,
2064
2118
  min_object_area=100,
2065
- max_object_size=None,
2119
+ max_object_area=None,
2066
2120
  **kwargs,
2067
2121
  ):
2068
2122
  """
@@ -2073,7 +2127,7 @@ class ObjectDetector:
2073
2127
  output_path: Path for output GeoJSON.
2074
2128
  confidence_threshold: Minimum confidence score (0.0-1.0). Default: 0.5
2075
2129
  min_object_area: Minimum area in pixels to keep an object. Default: 100
2076
- max_object_size: Maximum area in pixels to keep an object. Default: None
2130
+ max_object_area: Maximum area in pixels to keep an object. Default: None
2077
2131
  **kwargs: Additional parameters
2078
2132
 
2079
2133
  Returns:
@@ -2097,8 +2151,9 @@ class ObjectDetector:
2097
2151
  print(f"Found {num_features} connected components")
2098
2152
 
2099
2153
  # Process each component
2100
- car_polygons = []
2101
- car_confidences = []
2154
+ polygons = []
2155
+ confidences = []
2156
+ pixels = []
2102
2157
 
2103
2158
  # Add progress bar
2104
2159
  for label in tqdm(range(1, num_features + 1), desc="Processing components"):
@@ -2129,8 +2184,8 @@ class ObjectDetector:
2129
2184
  if area < min_object_area:
2130
2185
  continue
2131
2186
 
2132
- if max_object_size is not None:
2133
- if area > max_object_size:
2187
+ if max_object_area is not None:
2188
+ if area > max_object_area:
2134
2189
  continue
2135
2190
 
2136
2191
  # Get minimum area rectangle
@@ -2147,16 +2202,18 @@ class ObjectDetector:
2147
2202
  poly = Polygon(geo_points)
2148
2203
 
2149
2204
  # Add to lists
2150
- car_polygons.append(poly)
2151
- car_confidences.append(confidence)
2205
+ polygons.append(poly)
2206
+ confidences.append(confidence)
2207
+ pixels.append(area)
2152
2208
 
2153
2209
  # Create GeoDataFrame
2154
- if car_polygons:
2210
+ if polygons:
2155
2211
  gdf = gpd.GeoDataFrame(
2156
2212
  {
2157
- "geometry": car_polygons,
2158
- "confidence": car_confidences,
2159
- "class": [1] * len(car_polygons),
2213
+ "geometry": polygons,
2214
+ "confidence": confidences,
2215
+ "class": [1] * len(polygons),
2216
+ "pixels": pixels,
2160
2217
  },
2161
2218
  crs=crs,
2162
2219
  )
@@ -2164,11 +2221,11 @@ class ObjectDetector:
2164
2221
  # Save to file if requested
2165
2222
  if output_path:
2166
2223
  gdf.to_file(output_path, driver="GeoJSON")
2167
- print(f"Saved {len(gdf)} cars with confidence to {output_path}")
2224
+ print(f"Saved {len(gdf)} objects with confidence to {output_path}")
2168
2225
 
2169
2226
  return gdf
2170
2227
  else:
2171
- print("No valid car polygons found")
2228
+ print("No valid polygons found")
2172
2229
  return None
2173
2230
 
2174
2231
 
@@ -2277,3 +2334,66 @@ class ShipDetector(ObjectDetector):
2277
2334
  super().__init__(
2278
2335
  model_path=model_path, repo_id=repo_id, model=model, device=device
2279
2336
  )
2337
+
2338
+
2339
+ class SolarPanelDetector(ObjectDetector):
2340
+ """
2341
+ Solar panel detection using a pre-trained Mask R-CNN model.
2342
+
2343
+ This class extends the
2344
+ `ObjectDetector` class with additional methods for solar panel detection."
2345
+ """
2346
+
2347
+ def __init__(
2348
+ self,
2349
+ model_path="solar_panel_detection.pth",
2350
+ repo_id=None,
2351
+ model=None,
2352
+ device=None,
2353
+ ):
2354
+ """
2355
+ Initialize the object extractor.
2356
+
2357
+ Args:
2358
+ model_path: Path to the .pth model file.
2359
+ repo_id: Repo ID for loading models from the Hub.
2360
+ model: Custom model to use for inference.
2361
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2362
+ """
2363
+ super().__init__(
2364
+ model_path=model_path, repo_id=repo_id, model=model, device=device
2365
+ )
2366
+
2367
+
2368
+ class ParkingSplotDetector(ObjectDetector):
2369
+ """
2370
+ Car detection using a pre-trained Mask R-CNN model.
2371
+
2372
+ This class extends the `ObjectDetector` class with additional methods for car detection.
2373
+ """
2374
+
2375
+ def __init__(
2376
+ self,
2377
+ model_path="parking_spot_detection.pth",
2378
+ repo_id=None,
2379
+ model=None,
2380
+ num_classes=3,
2381
+ device=None,
2382
+ ):
2383
+ """
2384
+ Initialize the object extractor.
2385
+
2386
+ Args:
2387
+ model_path: Path to the .pth model file.
2388
+ repo_id: Repo ID for loading models from the Hub.
2389
+ model: Custom model to use for inference.
2390
+ num_classes: Number of classes for the model. Default: 3
2391
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2392
+ """
2393
+ super().__init__(
2394
+ model_path=model_path,
2395
+ repo_id=repo_id,
2396
+ model=model,
2397
+ num_classes=num_classes,
2398
+ device=device,
2399
+ )
geoai/geoai.py CHANGED
@@ -1,4 +1,7 @@
1
1
  """Main module."""
2
2
 
3
- from .utils import *
4
3
  from .extract import *
4
+ from .hf import *
5
+ from .segment import *
6
+ from .utils import *
7
+ from .train import train_MaskRCNN_model, object_detection