geoai-py 0.3.4__py2.py3-none-any.whl → 0.3.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.3.4"
5
+ __version__ = "0.3.6"
6
6
 
7
7
 
8
8
  import os
geoai/extract.py CHANGED
@@ -1,21 +1,29 @@
1
+ """This module provides a dataset class for object extraction from raster data"""
2
+
3
+ # Standard Library
1
4
  import os
5
+
6
+ # Third-Party Libraries
7
+ import cv2
8
+ import geopandas as gpd
9
+ import matplotlib.pyplot as plt
2
10
  import numpy as np
11
+ import rasterio
12
+ import scipy.ndimage as ndimage
3
13
  import torch
4
- import matplotlib.pyplot as plt
14
+ from huggingface_hub import hf_hub_download
15
+ from rasterio.windows import Window
5
16
  from shapely.geometry import Polygon, box
6
- import geopandas as gpd
7
17
  from tqdm import tqdm
18
+ from torchvision.models.detection import (
19
+ maskrcnn_resnet50_fpn,
20
+ fasterrcnn_resnet50_fpn_v2,
21
+ )
8
22
 
9
- import cv2
10
- from torchvision.models.detection import maskrcnn_resnet50_fpn
11
- import torchvision.transforms as T
12
- import rasterio
13
- from rasterio.windows import Window
14
- from rasterio.features import shapes
15
- from huggingface_hub import hf_hub_download
16
- import scipy.ndimage as ndimage
23
+ # Local Imports
17
24
  from .utils import get_raster_stats
18
25
 
26
+
19
27
  try:
20
28
  from torchgeo.datasets import NonGeoDataset
21
29
  except ImportError as e:
@@ -60,6 +68,7 @@ class CustomDataset(NonGeoDataset):
60
68
  chip_size=(512, 512),
61
69
  overlap=0.5,
62
70
  transforms=None,
71
+ band_indexes=None,
63
72
  verbose=False,
64
73
  ):
65
74
  """
@@ -70,6 +79,7 @@ class CustomDataset(NonGeoDataset):
70
79
  chip_size: Size of image chips to extract (height, width). Default is (512, 512).
71
80
  overlap: Amount of overlap between adjacent tiles (0.0-1.0). Default is 0.5 (50%).
72
81
  transforms: Transforms to apply to the image. Default is None.
82
+ band_indexes: List of band indexes to use. Default is None (use all bands).
73
83
  verbose: Whether to print detailed processing information. Default is False.
74
84
 
75
85
  Raises:
@@ -82,6 +92,7 @@ class CustomDataset(NonGeoDataset):
82
92
  self.chip_size = chip_size
83
93
  self.overlap = overlap
84
94
  self.transforms = transforms
95
+ self.band_indexes = band_indexes
85
96
  self.verbose = verbose
86
97
  self.warned_about_bands = False
87
98
 
@@ -191,7 +202,10 @@ class CustomDataset(NonGeoDataset):
191
202
  if not self.warned_about_bands and self.verbose:
192
203
  print(f"Image has {image.shape[0]} bands, using first 3 bands only")
193
204
  self.warned_about_bands = True
194
- image = image[:3]
205
+ if self.band_indexes is not None:
206
+ image = image[self.band_indexes]
207
+ else:
208
+ image = image[:3]
195
209
  elif image.shape[0] < 3:
196
210
  # If image has fewer than 3 bands, duplicate the last band to make 3
197
211
  if not self.warned_about_bands and self.verbose:
@@ -594,7 +608,7 @@ class ObjectDetector:
594
608
 
595
609
  Args:
596
610
  mask_path: Path to the object masks GeoTIFF
597
- output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
611
+ output_path: Path to save the output GeoJSON or Parquet file (default: mask_path with .geojson extension)
598
612
  simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
599
613
  mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
600
614
  min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
@@ -779,7 +793,10 @@ class ObjectDetector:
779
793
 
780
794
  # Save to file
781
795
  if output_path:
782
- gdf.to_file(output_path)
796
+ if output_path.endswith(".parquet"):
797
+ gdf.to_parquet(output_path)
798
+ else:
799
+ gdf.to_file(output_path)
783
800
  print(f"Saved {len(gdf)} objects to {output_path}")
784
801
 
785
802
  return gdf
@@ -792,6 +809,7 @@ class ObjectDetector:
792
809
  batch_size=4,
793
810
  filter_edges=True,
794
811
  edge_buffer=20,
812
+ band_indexes=None,
795
813
  **kwargs,
796
814
  ):
797
815
  """
@@ -799,10 +817,11 @@ class ObjectDetector:
799
817
 
800
818
  Args:
801
819
  raster_path: Path to input raster file
802
- output_path: Path to output GeoJSON file (optional)
820
+ output_path: Path to output GeoJSON or Parquet file (optional)
803
821
  batch_size: Batch size for processing
804
822
  filter_edges: Whether to filter out objects at the edges of the image
805
823
  edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
824
+ band_indexes: List of band indexes to use (if None, use all bands)
806
825
  **kwargs: Additional parameters:
807
826
  confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
808
827
  overlap: Overlap between adjacent tiles (0.0-1.0)
@@ -843,7 +862,10 @@ class ObjectDetector:
843
862
 
844
863
  # Create dataset
845
864
  dataset = CustomDataset(
846
- raster_path=raster_path, chip_size=chip_size, overlap=overlap
865
+ raster_path=raster_path,
866
+ chip_size=chip_size,
867
+ overlap=overlap,
868
+ band_indexes=band_indexes,
847
869
  )
848
870
  self.raster_stats = dataset.raster_stats
849
871
 
@@ -1021,7 +1043,10 @@ class ObjectDetector:
1021
1043
 
1022
1044
  # Save to file if requested
1023
1045
  if output_path:
1024
- gdf.to_file(output_path, driver="GeoJSON")
1046
+ if output_path.endswith(".parquet"):
1047
+ gdf.to_parquet(output_path)
1048
+ else:
1049
+ gdf.to_file(output_path, driver="GeoJSON")
1025
1050
  print(f"Saved {len(gdf)} objects to {output_path}")
1026
1051
 
1027
1052
  return gdf
@@ -1881,108 +1906,36 @@ class ObjectDetector:
1881
1906
  plt.savefig(sample_output, dpi=300, bbox_inches="tight")
1882
1907
  print(f"Sample visualization saved to {sample_output}")
1883
1908
 
1884
-
1885
- class BuildingFootprintExtractor(ObjectDetector):
1886
- """
1887
- Building footprint extraction using a pre-trained Mask R-CNN model.
1888
-
1889
- This class extends the
1890
- `ObjectDetector` class with additional methods for building footprint extraction."
1891
- """
1892
-
1893
- def __init__(
1894
- self,
1895
- model_path="building_footprints_usa.pth",
1896
- repo_id=None,
1897
- model=None,
1898
- device=None,
1899
- ):
1900
- """
1901
- Initialize the object extractor.
1902
-
1903
- Args:
1904
- model_path: Path to the .pth model file.
1905
- repo_id: Repo ID for loading models from the Hub.
1906
- model: Custom model to use for inference.
1907
- device: Device to use for inference ('cuda:0', 'cpu', etc.).
1908
- """
1909
- super().__init__(
1910
- model_path=model_path, repo_id=repo_id, model=model, device=device
1911
- )
1912
-
1913
- def regularize_buildings(
1914
- self,
1915
- gdf,
1916
- min_area=10,
1917
- angle_threshold=15,
1918
- orthogonality_threshold=0.3,
1919
- rectangularity_threshold=0.7,
1920
- ):
1921
- """
1922
- Regularize building footprints to enforce right angles and rectangular shapes.
1923
-
1924
- Args:
1925
- gdf: GeoDataFrame with building footprints
1926
- min_area: Minimum area in square units to keep a building
1927
- angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
1928
- orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
1929
- rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
1930
-
1931
- Returns:
1932
- GeoDataFrame with regularized building footprints
1933
- """
1934
- return self.regularize_objects(
1935
- gdf,
1936
- min_area=min_area,
1937
- angle_threshold=angle_threshold,
1938
- orthogonality_threshold=orthogonality_threshold,
1939
- rectangularity_threshold=rectangularity_threshold,
1940
- )
1941
-
1942
-
1943
- class CarDetector(ObjectDetector):
1944
- """
1945
- Car detection using a pre-trained Mask R-CNN model.
1946
-
1947
- This class extends the `ObjectDetector` class with additional methods for car detection.
1948
- """
1949
-
1950
- def __init__(
1951
- self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
1952
- ):
1953
- """
1954
- Initialize the object extractor.
1955
-
1956
- Args:
1957
- model_path: Path to the .pth model file.
1958
- repo_id: Repo ID for loading models from the Hub.
1959
- model: Custom model to use for inference.
1960
- device: Device to use for inference ('cuda:0', 'cpu', etc.).
1961
- """
1962
- super().__init__(
1963
- model_path=model_path, repo_id=repo_id, model=model, device=device
1964
- )
1965
-
1966
1909
  def generate_masks(
1967
1910
  self,
1968
1911
  raster_path,
1969
1912
  output_path=None,
1970
1913
  confidence_threshold=None,
1971
1914
  mask_threshold=None,
1915
+ min_object_area=10,
1916
+ max_object_area=float("inf"),
1972
1917
  overlap=0.25,
1973
1918
  batch_size=4,
1919
+ band_indexes=None,
1974
1920
  verbose=False,
1921
+ **kwargs,
1975
1922
  ):
1976
1923
  """
1977
1924
  Save masks with confidence values as a multi-band GeoTIFF.
1978
1925
 
1926
+ Objects with area smaller than min_object_area or larger than max_object_area
1927
+ will be filtered out.
1928
+
1979
1929
  Args:
1980
1930
  raster_path: Path to input raster
1981
1931
  output_path: Path for output GeoTIFF
1982
1932
  confidence_threshold: Minimum confidence score (0.0-1.0)
1983
1933
  mask_threshold: Threshold for mask binarization (0.0-1.0)
1934
+ min_object_area: Minimum area (in pixels) for an object to be included
1935
+ max_object_area: Maximum area (in pixels) for an object to be included
1984
1936
  overlap: Overlap between tiles (0.0-1.0)
1985
1937
  batch_size: Batch size for processing
1938
+ band_indexes: List of band indexes to use (default: all bands)
1986
1939
  verbose: Whether to print detailed processing information
1987
1940
 
1988
1941
  Returns:
@@ -1994,6 +1947,8 @@ class CarDetector(ObjectDetector):
1994
1947
  if mask_threshold is None:
1995
1948
  mask_threshold = self.mask_threshold
1996
1949
 
1950
+ chip_size = kwargs.get("chip_size", self.chip_size)
1951
+
1997
1952
  # Default output path
1998
1953
  if output_path is None:
1999
1954
  output_path = os.path.splitext(raster_path)[0] + "_masks_conf.tif"
@@ -2003,8 +1958,9 @@ class CarDetector(ObjectDetector):
2003
1958
  # Create dataset with the specified overlap
2004
1959
  dataset = CustomDataset(
2005
1960
  raster_path=raster_path,
2006
- chip_size=self.chip_size,
1961
+ chip_size=chip_size,
2007
1962
  overlap=overlap,
1963
+ band_indexes=band_indexes,
2008
1964
  verbose=verbose,
2009
1965
  )
2010
1966
 
@@ -2091,6 +2047,21 @@ class CarDetector(ObjectDetector):
2091
2047
  for mask_idx, mask in enumerate(masks):
2092
2048
  # Convert to binary mask
2093
2049
  binary_mask = (mask[0] > mask_threshold).astype(np.uint8) * 255
2050
+
2051
+ # Check object area - calculate number of pixels in the mask
2052
+ object_area = np.sum(binary_mask > 0)
2053
+
2054
+ # Skip objects that don't meet area criteria
2055
+ if (
2056
+ object_area < min_object_area
2057
+ or object_area > max_object_area
2058
+ ):
2059
+ if verbose:
2060
+ print(
2061
+ f"Filtering out object with area {object_area} pixels"
2062
+ )
2063
+ continue
2064
+
2094
2065
  conf_value = int(scores[mask_idx] * 255) # Scale to 0-255
2095
2066
 
2096
2067
  # Update the mask and confidence arrays
@@ -2135,13 +2106,24 @@ class CarDetector(ObjectDetector):
2135
2106
  print(f"Masks with confidence values saved to {output_path}")
2136
2107
  return output_path
2137
2108
 
2138
- def vectorize_masks(self, masks_path, output_path=None, **kwargs):
2109
+ def vectorize_masks(
2110
+ self,
2111
+ masks_path,
2112
+ output_path=None,
2113
+ confidence_threshold=0.5,
2114
+ min_object_area=100,
2115
+ max_object_size=None,
2116
+ **kwargs,
2117
+ ):
2139
2118
  """
2140
2119
  Convert masks with confidence to vector polygons.
2141
2120
 
2142
2121
  Args:
2143
- masks_path: Path to masks GeoTIFF with confidence band
2144
- output_path: Path for output GeoJSON
2122
+ masks_path: Path to masks GeoTIFF with confidence band.
2123
+ output_path: Path for output GeoJSON.
2124
+ confidence_threshold: Minimum confidence score (0.0-1.0). Default: 0.5
2125
+ min_object_area: Minimum area in pixels to keep an object. Default: 100
2126
+ max_object_size: Maximum area in pixels to keep an object. Default: None
2145
2127
  **kwargs: Additional parameters
2146
2128
 
2147
2129
  Returns:
@@ -2182,6 +2164,10 @@ class CarDetector(ObjectDetector):
2182
2164
  else:
2183
2165
  confidence = 0.0
2184
2166
 
2167
+ # Skip if confidence is below threshold
2168
+ if confidence < confidence_threshold:
2169
+ continue
2170
+
2185
2171
  # Find contours
2186
2172
  contours, _ = cv2.findContours(
2187
2173
  component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
@@ -2190,9 +2176,13 @@ class CarDetector(ObjectDetector):
2190
2176
  for contour in contours:
2191
2177
  # Filter by size
2192
2178
  area = cv2.contourArea(contour)
2193
- if area < kwargs.get("min_object_area", 100):
2179
+ if area < min_object_area:
2194
2180
  continue
2195
2181
 
2182
+ if max_object_size is not None:
2183
+ if area > max_object_size:
2184
+ continue
2185
+
2196
2186
  # Get minimum area rectangle
2197
2187
  rect = cv2.minAreaRect(contour)
2198
2188
  box_points = cv2.boxPoints(rect)
@@ -2224,7 +2214,7 @@ class CarDetector(ObjectDetector):
2224
2214
  # Save to file if requested
2225
2215
  if output_path:
2226
2216
  gdf.to_file(output_path, driver="GeoJSON")
2227
- print(f"Saved {len(gdf)} cars with confidence to {output_path}")
2217
+ print(f"Saved {len(gdf)} objects with confidence to {output_path}")
2228
2218
 
2229
2219
  return gdf
2230
2220
  else:
@@ -2232,6 +2222,88 @@ class CarDetector(ObjectDetector):
2232
2222
  return None
2233
2223
 
2234
2224
 
2225
+ class BuildingFootprintExtractor(ObjectDetector):
2226
+ """
2227
+ Building footprint extraction using a pre-trained Mask R-CNN model.
2228
+
2229
+ This class extends the
2230
+ `ObjectDetector` class with additional methods for building footprint extraction."
2231
+ """
2232
+
2233
+ def __init__(
2234
+ self,
2235
+ model_path="building_footprints_usa.pth",
2236
+ repo_id=None,
2237
+ model=None,
2238
+ device=None,
2239
+ ):
2240
+ """
2241
+ Initialize the object extractor.
2242
+
2243
+ Args:
2244
+ model_path: Path to the .pth model file.
2245
+ repo_id: Repo ID for loading models from the Hub.
2246
+ model: Custom model to use for inference.
2247
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2248
+ """
2249
+ super().__init__(
2250
+ model_path=model_path, repo_id=repo_id, model=model, device=device
2251
+ )
2252
+
2253
+ def regularize_buildings(
2254
+ self,
2255
+ gdf,
2256
+ min_area=10,
2257
+ angle_threshold=15,
2258
+ orthogonality_threshold=0.3,
2259
+ rectangularity_threshold=0.7,
2260
+ ):
2261
+ """
2262
+ Regularize building footprints to enforce right angles and rectangular shapes.
2263
+
2264
+ Args:
2265
+ gdf: GeoDataFrame with building footprints
2266
+ min_area: Minimum area in square units to keep a building
2267
+ angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
2268
+ orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
2269
+ rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
2270
+
2271
+ Returns:
2272
+ GeoDataFrame with regularized building footprints
2273
+ """
2274
+ return self.regularize_objects(
2275
+ gdf,
2276
+ min_area=min_area,
2277
+ angle_threshold=angle_threshold,
2278
+ orthogonality_threshold=orthogonality_threshold,
2279
+ rectangularity_threshold=rectangularity_threshold,
2280
+ )
2281
+
2282
+
2283
+ class CarDetector(ObjectDetector):
2284
+ """
2285
+ Car detection using a pre-trained Mask R-CNN model.
2286
+
2287
+ This class extends the `ObjectDetector` class with additional methods for car detection.
2288
+ """
2289
+
2290
+ def __init__(
2291
+ self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
2292
+ ):
2293
+ """
2294
+ Initialize the object extractor.
2295
+
2296
+ Args:
2297
+ model_path: Path to the .pth model file.
2298
+ repo_id: Repo ID for loading models from the Hub.
2299
+ model: Custom model to use for inference.
2300
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2301
+ """
2302
+ super().__init__(
2303
+ model_path=model_path, repo_id=repo_id, model=model, device=device
2304
+ )
2305
+
2306
+
2235
2307
  class ShipDetector(ObjectDetector):
2236
2308
  """
2237
2309
  Ship detection using a pre-trained Mask R-CNN model.
@@ -2255,3 +2327,32 @@ class ShipDetector(ObjectDetector):
2255
2327
  super().__init__(
2256
2328
  model_path=model_path, repo_id=repo_id, model=model, device=device
2257
2329
  )
2330
+
2331
+
2332
+ class SolarPanelDetector(ObjectDetector):
2333
+ """
2334
+ Solar panel detection using a pre-trained Mask R-CNN model.
2335
+
2336
+ This class extends the
2337
+ `ObjectDetector` class with additional methods for solar panel detection."
2338
+ """
2339
+
2340
+ def __init__(
2341
+ self,
2342
+ model_path="solar_panel_detection.pth",
2343
+ repo_id=None,
2344
+ model=None,
2345
+ device=None,
2346
+ ):
2347
+ """
2348
+ Initialize the object extractor.
2349
+
2350
+ Args:
2351
+ model_path: Path to the .pth model file.
2352
+ repo_id: Repo ID for loading models from the Hub.
2353
+ model: Custom model to use for inference.
2354
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
2355
+ """
2356
+ super().__init__(
2357
+ model_path=model_path, repo_id=repo_id, model=model, device=device
2358
+ )
geoai/geoai.py CHANGED
@@ -2,3 +2,4 @@
2
2
 
3
3
  from .utils import *
4
4
  from .extract import *
5
+ from .segment import *