geoai-py 0.3.1__py2.py3-none-any.whl → 0.3.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.3.1"
5
+ __version__ = "0.3.2"
6
6
 
7
7
 
8
8
  import os
geoai/extract.py CHANGED
@@ -23,9 +23,9 @@ except ImportError as e:
23
23
  )
24
24
 
25
25
 
26
- class BuildingFootprintDataset(NonGeoDataset):
26
+ class CustomDataset(NonGeoDataset):
27
27
  """
28
- A TorchGeo dataset for building footprint extraction.
28
+ A TorchGeo dataset for object extraction.
29
29
  Using NonGeoDataset to avoid spatial indexing issues.
30
30
  """
31
31
 
@@ -170,18 +170,20 @@ class BuildingFootprintDataset(NonGeoDataset):
170
170
  return self.rows * self.cols
171
171
 
172
172
 
173
- class BuildingFootprintExtractor:
173
+ class ObjectDetector:
174
174
  """
175
- Building footprint extraction using Mask R-CNN with TorchGeo.
175
+ Object extraction using Mask R-CNN with TorchGeo.
176
176
  """
177
177
 
178
- def __init__(self, model_path=None, device=None):
178
+ def __init__(self, model_path=None, repo_id=None, model=None, device=None):
179
179
  """
180
- Initialize the building footprint extractor.
180
+ Initialize the object extractor.
181
181
 
182
182
  Args:
183
- model_path: Path to the .pth model file
184
- device: Device to use for inference ('cuda:0', 'cpu', etc.)
183
+ model_path: Path to the .pth model file.
184
+ repo_id: Hugging Face repository ID for model download.
185
+ model: Pre-initialized model object (optional).
186
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
185
187
  """
186
188
  # Set device
187
189
  if device is None:
@@ -189,31 +191,35 @@ class BuildingFootprintExtractor:
189
191
  else:
190
192
  self.device = torch.device(device)
191
193
 
192
- # Default parameters for building detection - these can be overridden in process_raster
194
+ # Default parameters for object detection - these can be overridden in process_raster
193
195
  self.chip_size = (512, 512) # Size of image chips for processing
194
196
  self.overlap = 0.25 # Default overlap between tiles
195
197
  self.confidence_threshold = 0.5 # Default confidence threshold
196
198
  self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
197
- self.small_building_area = 100 # Minimum area in pixels to keep a building
199
+ self.small_object_area = 100 # Minimum area in pixels to keep an object
198
200
  self.mask_threshold = 0.5 # Threshold for mask binarization
199
201
  self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
200
202
 
201
203
  # Initialize model
202
- self.model = self._initialize_model()
204
+ self.model = self.initialize_model(model)
203
205
 
204
206
  # Download model if needed
205
- if model_path is None:
206
- model_path = self._download_model_from_hf()
207
+ if model_path is None or (not os.path.exists(model_path)):
208
+ model_path = self.download_model_from_hf(model_path, repo_id)
207
209
 
208
210
  # Load model weights
209
- self._load_weights(model_path)
211
+ self.load_weights(model_path)
210
212
 
211
213
  # Set model to evaluation mode
212
214
  self.model.eval()
213
215
 
214
- def _download_model_from_hf(self):
216
+ def download_model_from_hf(self, model_path=None, repo_id=None):
215
217
  """
216
- Download the USA building footprints model from Hugging Face.
218
+ Download the object detection model from Hugging Face.
219
+
220
+ Args:
221
+ model_path: Path to the model file.
222
+ repo_id: Hugging Face repository ID.
217
223
 
218
224
  Returns:
219
225
  Path to the downloaded model file
@@ -223,17 +229,14 @@ class BuildingFootprintExtractor:
223
229
  print("Model path not specified, downloading from Hugging Face...")
224
230
 
225
231
  # Define the repository ID and model filename
226
- repo_id = "giswqs/geoai" # Update with your actual username/repo
227
- filename = "building_footprints_usa.pth"
232
+ if repo_id is None:
233
+ repo_id = "giswqs/geoai"
228
234
 
229
- # Ensure cache directory exists
230
- # cache_dir = os.path.join(
231
- # os.path.expanduser("~"), ".cache", "building_footprints"
232
- # )
233
- # os.makedirs(cache_dir, exist_ok=True)
235
+ if model_path is None:
236
+ model_path = "building_footprints_usa.pth"
234
237
 
235
238
  # Download the model
236
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
239
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
237
240
  print(f"Model downloaded to: {model_path}")
238
241
 
239
242
  return model_path
@@ -243,28 +246,36 @@ class BuildingFootprintExtractor:
243
246
  print("Please specify a local model path or ensure internet connectivity.")
244
247
  raise
245
248
 
246
- def _initialize_model(self):
247
- """Initialize Mask R-CNN model with ResNet50 backbone."""
248
- # Standard image mean and std for pre-trained models
249
- # Note: This would normally come from your config file
250
- image_mean = [0.485, 0.456, 0.406]
251
- image_std = [0.229, 0.224, 0.225]
252
-
253
- # Create model with explicit normalization parameters
254
- model = maskrcnn_resnet50_fpn(
255
- weights=None,
256
- progress=False,
257
- num_classes=2, # Background + building
258
- weights_backbone=None,
259
- # These parameters ensure consistent normalization
260
- image_mean=image_mean,
261
- image_std=image_std,
262
- )
249
+ def initialize_model(self, model):
250
+ """Initialize a deep learning model for object detection.
251
+
252
+ Args:
253
+ model (torch.nn.Module): A pre-initialized model object.
254
+
255
+ Returns:
256
+ torch.nn.Module: A deep learning model for object detection.
257
+ """
258
+
259
+ if model is None: # Initialize Mask R-CNN model with ResNet50 backbone.
260
+ # Standard image mean and std for pre-trained models
261
+ image_mean = [0.485, 0.456, 0.406]
262
+ image_std = [0.229, 0.224, 0.225]
263
+
264
+ # Create model with explicit normalization parameters
265
+ model = maskrcnn_resnet50_fpn(
266
+ weights=None,
267
+ progress=False,
268
+ num_classes=2, # Background + object
269
+ weights_backbone=None,
270
+ # These parameters ensure consistent normalization
271
+ image_mean=image_mean,
272
+ image_std=image_std,
273
+ )
263
274
 
264
275
  model.to(self.device)
265
276
  return model
266
277
 
267
- def _load_weights(self, model_path):
278
+ def load_weights(self, model_path):
268
279
  """
269
280
  Load weights from file with error handling for different formats.
270
281
 
@@ -306,7 +317,7 @@ class BuildingFootprintExtractor:
306
317
  except Exception as e:
307
318
  raise RuntimeError(f"Failed to load model: {e}")
308
319
 
309
- def _mask_to_polygons(self, mask, **kwargs):
320
+ def mask_to_polygons(self, mask, **kwargs):
310
321
  """
311
322
  Convert binary mask to polygon contours using OpenCV.
312
323
 
@@ -315,7 +326,7 @@ class BuildingFootprintExtractor:
315
326
  **kwargs: Optional parameters:
316
327
  simplify_tolerance: Tolerance for polygon simplification
317
328
  mask_threshold: Threshold for mask binarization
318
- small_building_area: Minimum area in pixels to keep a building
329
+ small_object_area: Minimum area in pixels to keep an object
319
330
 
320
331
  Returns:
321
332
  List of polygons as lists of (x, y) coordinates
@@ -324,9 +335,7 @@ class BuildingFootprintExtractor:
324
335
  # Get parameters from kwargs or use instance defaults
325
336
  simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
326
337
  mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
327
- small_building_area = kwargs.get(
328
- "small_building_area", self.small_building_area
329
- )
338
+ small_object_area = kwargs.get("small_object_area", self.small_object_area)
330
339
 
331
340
  # Ensure binary mask
332
341
  mask = (mask > mask_threshold).astype(np.uint8)
@@ -342,7 +351,7 @@ class BuildingFootprintExtractor:
342
351
  polygons = []
343
352
  for contour in contours:
344
353
  # Filter out too small contours
345
- if contour.shape[0] < 3 or cv2.contourArea(contour) < small_building_area:
354
+ if contour.shape[0] < 3 or cv2.contourArea(contour) < small_object_area:
346
355
  continue
347
356
 
348
357
  # Simplify contour if it has many points
@@ -356,7 +365,7 @@ class BuildingFootprintExtractor:
356
365
 
357
366
  return polygons
358
367
 
359
- def _filter_overlapping_polygons(self, gdf, **kwargs):
368
+ def filter_overlapping_polygons(self, gdf, **kwargs):
360
369
  """
361
370
  Filter overlapping polygons using non-maximum suppression.
362
371
 
@@ -413,26 +422,26 @@ class BuildingFootprintExtractor:
413
422
 
414
423
  return gdf.iloc[keep_indices]
415
424
 
416
- def filter_edge_buildings(self, gdf, raster_path, edge_buffer=10):
425
+ def filter_edge_objects(self, gdf, raster_path, edge_buffer=10):
417
426
  """
418
- Filter out building detections that fall in padding/edge areas of the image.
427
+ Filter out object detections that fall in padding/edge areas of the image.
419
428
 
420
429
  Args:
421
- gdf: GeoDataFrame with building footprint detections
430
+ gdf: GeoDataFrame with object detections
422
431
  raster_path: Path to the original raster file
423
432
  edge_buffer: Buffer in pixels to consider as edge region
424
433
 
425
434
  Returns:
426
- GeoDataFrame with filtered building footprints
435
+ GeoDataFrame with filtered objects
427
436
  """
428
437
  import rasterio
429
438
  from shapely.geometry import box
430
439
 
431
- # If no buildings detected, return empty GeoDataFrame
440
+ # If no objects detected, return empty GeoDataFrame
432
441
  if gdf is None or len(gdf) == 0:
433
442
  return gdf
434
443
 
435
- print(f"Buildings before filtering: {len(gdf)}")
444
+ print(f"Objects before filtering: {len(gdf)}")
436
445
 
437
446
  with rasterio.open(raster_path) as src:
438
447
  # Get raster bounds
@@ -461,18 +470,18 @@ class BuildingFootprintExtractor:
461
470
  else:
462
471
  inner_box = box(*inner_bounds)
463
472
 
464
- # Filter out buildings that intersect with the edge of the image
473
+ # Filter out objects that intersect with the edge of the image
465
474
  filtered_gdf = gdf[gdf.intersects(inner_box)]
466
475
 
467
- # Additional check for buildings that have >50% of their area outside the valid region
468
- valid_buildings = []
476
+ # Additional check for objects that have >50% of their area outside the valid region
477
+ valid_objects = []
469
478
  for idx, row in filtered_gdf.iterrows():
470
479
  if row.geometry.intersection(inner_box).area >= 0.5 * row.geometry.area:
471
- valid_buildings.append(idx)
480
+ valid_objects.append(idx)
472
481
 
473
- filtered_gdf = filtered_gdf.loc[valid_buildings]
482
+ filtered_gdf = filtered_gdf.loc[valid_objects]
474
483
 
475
- print(f"Buildings after filtering: {len(filtered_gdf)}")
484
+ print(f"Objects after filtering: {len(filtered_gdf)}")
476
485
 
477
486
  return filtered_gdf
478
487
 
@@ -482,28 +491,28 @@ class BuildingFootprintExtractor:
482
491
  output_path=None,
483
492
  simplify_tolerance=None,
484
493
  mask_threshold=None,
485
- small_building_area=None,
494
+ small_object_area=None,
486
495
  nms_iou_threshold=None,
487
496
  regularize=True,
488
497
  angle_threshold=15,
489
498
  rectangularity_threshold=0.7,
490
499
  ):
491
500
  """
492
- Convert a building mask GeoTIFF to vector polygons and save as GeoJSON.
501
+ Convert an object mask GeoTIFF to vector polygons and save as GeoJSON.
493
502
 
494
503
  Args:
495
- mask_path: Path to the building masks GeoTIFF
504
+ mask_path: Path to the object masks GeoTIFF
496
505
  output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
497
506
  simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
498
507
  mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
499
- small_building_area: Minimum area in pixels to keep a building (default: self.small_building_area)
508
+ small_object_area: Minimum area in pixels to keep an object (default: self.small_object_area)
500
509
  nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
501
- regularize: Whether to regularize buildings to right angles (default: True)
510
+ regularize: Whether to regularize objects to right angles (default: True)
502
511
  angle_threshold: Maximum deviation from 90 degrees for regularization (default: 15)
503
512
  rectangularity_threshold: Threshold for rectangle simplification (default: 0.7)
504
513
 
505
514
  Returns:
506
- GeoDataFrame with building footprints
515
+ GeoDataFrame with objects
507
516
  """
508
517
  # Use class defaults if parameters not provided
509
518
  simplify_tolerance = (
@@ -514,10 +523,10 @@ class BuildingFootprintExtractor:
514
523
  mask_threshold = (
515
524
  mask_threshold if mask_threshold is not None else self.mask_threshold
516
525
  )
517
- small_building_area = (
518
- small_building_area
519
- if small_building_area is not None
520
- else self.small_building_area
526
+ small_object_area = (
527
+ small_object_area
528
+ if small_object_area is not None
529
+ else self.small_object_area
521
530
  )
522
531
  nms_iou_threshold = (
523
532
  nms_iou_threshold
@@ -531,10 +540,10 @@ class BuildingFootprintExtractor:
531
540
 
532
541
  print(f"Converting mask to GeoJSON with parameters:")
533
542
  print(f"- Mask threshold: {mask_threshold}")
534
- print(f"- Min building area: {small_building_area}")
543
+ print(f"- Min object area: {small_object_area}")
535
544
  print(f"- Simplify tolerance: {simplify_tolerance}")
536
545
  print(f"- NMS IoU threshold: {nms_iou_threshold}")
537
- print(f"- Regularize buildings: {regularize}")
546
+ print(f"- Regularize objects: {regularize}")
538
547
  if regularize:
539
548
  print(f"- Angle threshold: {angle_threshold}° from 90°")
540
549
  print(f"- Rectangularity threshold: {rectangularity_threshold*100}%")
@@ -564,7 +573,7 @@ class BuildingFootprintExtractor:
564
573
  )
565
574
 
566
575
  print(
567
- f"Found {num_labels-1} potential buildings"
576
+ f"Found {num_labels-1} potential objects"
568
577
  ) # Subtract 1 for background
569
578
 
570
579
  # Create list to store polygons and confidence values
@@ -573,19 +582,19 @@ class BuildingFootprintExtractor:
573
582
 
574
583
  # Process each component (skip the first one which is background)
575
584
  for i in tqdm(range(1, num_labels)):
576
- # Extract this building
585
+ # Extract this object
577
586
  area = stats[i, cv2.CC_STAT_AREA]
578
587
 
579
588
  # Skip if too small
580
- if area < small_building_area:
589
+ if area < small_object_area:
581
590
  continue
582
591
 
583
- # Create a mask for this building
584
- building_mask = (labels == i).astype(np.uint8)
592
+ # Create a mask for this object
593
+ object_mask = (labels == i).astype(np.uint8)
585
594
 
586
595
  # Find contours
587
596
  contours, _ = cv2.findContours(
588
- building_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
597
+ object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
589
598
  )
590
599
 
591
600
  # Process each contour
@@ -633,17 +642,17 @@ class BuildingFootprintExtractor:
633
642
  {
634
643
  "geometry": all_polygons,
635
644
  "confidence": all_confidences,
636
- "class": 1, # Building class
645
+ "class": 1, # Object class
637
646
  },
638
647
  crs=crs,
639
648
  )
640
649
 
641
650
  # Apply non-maximum suppression to remove overlapping polygons
642
- gdf = self._filter_overlapping_polygons(
651
+ gdf = self.filter_overlapping_polygons(
643
652
  gdf, nms_iou_threshold=nms_iou_threshold
644
653
  )
645
654
 
646
- print(f"Building count after NMS filtering: {len(gdf)}")
655
+ print(f"Object count after NMS filtering: {len(gdf)}")
647
656
 
648
657
  # Apply regularization if requested
649
658
  if regularize and len(gdf) > 0:
@@ -661,8 +670,8 @@ class BuildingFootprintExtractor:
661
670
  # Use 10 pixels as minimum area in geographic units
662
671
  min_geo_area = 10 * avg_pixel_area
663
672
 
664
- # Regularize buildings
665
- gdf = self.regularize_buildings(
673
+ # Regularize objects
674
+ gdf = self.regularize_objects(
666
675
  gdf,
667
676
  min_area=min_geo_area,
668
677
  angle_threshold=angle_threshold,
@@ -672,7 +681,7 @@ class BuildingFootprintExtractor:
672
681
  # Save to file
673
682
  if output_path:
674
683
  gdf.to_file(output_path)
675
- print(f"Saved {len(gdf)} building footprints to {output_path}")
684
+ print(f"Saved {len(gdf)} objects to {output_path}")
676
685
 
677
686
  return gdf
678
687
 
@@ -687,25 +696,25 @@ class BuildingFootprintExtractor:
687
696
  **kwargs,
688
697
  ):
689
698
  """
690
- Process a raster file to extract building footprints with customizable parameters.
699
+ Process a raster file to extract objects with customizable parameters.
691
700
 
692
701
  Args:
693
702
  raster_path: Path to input raster file
694
703
  output_path: Path to output GeoJSON file (optional)
695
704
  batch_size: Batch size for processing
696
- filter_edges: Whether to filter out buildings at the edges of the image
697
- edge_buffer: Size of edge buffer in pixels to filter out buildings (if filter_edges=True)
705
+ filter_edges: Whether to filter out objects at the edges of the image
706
+ edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
698
707
  **kwargs: Additional parameters:
699
708
  confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
700
709
  overlap: Overlap between adjacent tiles (0.0-1.0)
701
710
  chip_size: Size of image chips for processing (height, width)
702
711
  nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
703
712
  mask_threshold: Threshold for mask binarization (0.0-1.0)
704
- small_building_area: Minimum area in pixels to keep a building
713
+ small_object_area: Minimum area in pixels to keep an object
705
714
  simplify_tolerance: Tolerance for polygon simplification
706
715
 
707
716
  Returns:
708
- GeoDataFrame with building footprints
717
+ GeoDataFrame with objects
709
718
  """
710
719
  # Get parameters from kwargs or use instance defaults
711
720
  confidence_threshold = kwargs.get(
@@ -715,9 +724,7 @@ class BuildingFootprintExtractor:
715
724
  chip_size = kwargs.get("chip_size", self.chip_size)
716
725
  nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
717
726
  mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
718
- small_building_area = kwargs.get(
719
- "small_building_area", self.small_building_area
720
- )
727
+ small_object_area = kwargs.get("small_object_area", self.small_object_area)
721
728
  simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
722
729
 
723
730
  # Print parameters being used
@@ -727,14 +734,14 @@ class BuildingFootprintExtractor:
727
734
  print(f"- Chip size: {chip_size}")
728
735
  print(f"- NMS IoU threshold: {nms_iou_threshold}")
729
736
  print(f"- Mask threshold: {mask_threshold}")
730
- print(f"- Min building area: {small_building_area}")
737
+ print(f"- Min object area: {small_object_area}")
731
738
  print(f"- Simplify tolerance: {simplify_tolerance}")
732
- print(f"- Filter edge buildings: {filter_edges}")
739
+ print(f"- Filter edge objects: {filter_edges}")
733
740
  if filter_edges:
734
741
  print(f"- Edge buffer size: {edge_buffer} pixels")
735
742
 
736
743
  # Create dataset
737
- dataset = BuildingFootprintDataset(raster_path=raster_path, chip_size=chip_size)
744
+ dataset = CustomDataset(raster_path=raster_path, chip_size=chip_size)
738
745
  self.raster_stats = dataset.raster_stats
739
746
 
740
747
  # Custom collate function to handle Shapely objects
@@ -854,11 +861,11 @@ class BuildingFootprintExtractor:
854
861
  binary_mask = mask[0] # Get binary mask
855
862
 
856
863
  # Convert mask to polygon with custom parameters
857
- contours = self._mask_to_polygons(
864
+ contours = self.mask_to_polygons(
858
865
  binary_mask,
859
866
  simplify_tolerance=simplify_tolerance,
860
867
  mask_threshold=mask_threshold,
861
- small_building_area=small_building_area,
868
+ small_object_area=small_object_area,
862
869
  )
863
870
 
864
871
  # Skip if no valid polygons
@@ -896,24 +903,22 @@ class BuildingFootprintExtractor:
896
903
  {
897
904
  "geometry": all_polygons,
898
905
  "confidence": all_scores,
899
- "class": 1, # Building class
906
+ "class": 1, # Object class
900
907
  },
901
908
  crs=dataset.crs,
902
909
  )
903
910
 
904
911
  # Remove overlapping polygons with custom threshold
905
- gdf = self._filter_overlapping_polygons(
906
- gdf, nms_iou_threshold=nms_iou_threshold
907
- )
912
+ gdf = self.filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
908
913
 
909
- # Filter edge buildings if requested
914
+ # Filter edge objects if requested
910
915
  if filter_edges:
911
- gdf = self.filter_edge_buildings(gdf, raster_path, edge_buffer=edge_buffer)
916
+ gdf = self.filter_edge_objects(gdf, raster_path, edge_buffer=edge_buffer)
912
917
 
913
918
  # Save to file if requested
914
919
  if output_path:
915
920
  gdf.to_file(output_path, driver="GeoJSON")
916
- print(f"Saved {len(gdf)} building footprints to {output_path}")
921
+ print(f"Saved {len(gdf)} objects to {output_path}")
917
922
 
918
923
  return gdf
919
924
 
@@ -921,7 +926,7 @@ class BuildingFootprintExtractor:
921
926
  self, raster_path, output_path=None, batch_size=4, verbose=False, **kwargs
922
927
  ):
923
928
  """
924
- Process a raster file to extract building footprint masks and save as GeoTIFF.
929
+ Process a raster file to extract object masks and save as GeoTIFF.
925
930
 
926
931
  Args:
927
932
  raster_path: Path to input raster file
@@ -955,7 +960,7 @@ class BuildingFootprintExtractor:
955
960
  print(f"- Mask threshold: {mask_threshold}")
956
961
 
957
962
  # Create dataset
958
- dataset = BuildingFootprintDataset(
963
+ dataset = CustomDataset(
959
964
  raster_path=raster_path, chip_size=chip_size, verbose=verbose
960
965
  )
961
966
 
@@ -972,7 +977,7 @@ class BuildingFootprintExtractor:
972
977
  output_profile = src.profile.copy()
973
978
  output_profile.update(
974
979
  dtype=rasterio.uint8,
975
- count=1, # Single band for building mask
980
+ count=1, # Single band for object mask
976
981
  compress="lzw",
977
982
  nodata=0,
978
983
  )
@@ -1144,10 +1149,10 @@ class BuildingFootprintExtractor:
1144
1149
  # Write the final mask to the output file
1145
1150
  dst.write(mask_array, 1)
1146
1151
 
1147
- print(f"Building masks saved to {output_path}")
1152
+ print(f"Object masks saved to {output_path}")
1148
1153
  return output_path
1149
1154
 
1150
- def regularize_buildings(
1155
+ def regularize_objects(
1151
1156
  self,
1152
1157
  gdf,
1153
1158
  min_area=10,
@@ -1156,17 +1161,17 @@ class BuildingFootprintExtractor:
1156
1161
  rectangularity_threshold=0.7,
1157
1162
  ):
1158
1163
  """
1159
- Regularize building footprints to enforce right angles and rectangular shapes.
1164
+ Regularize objects to enforce right angles and rectangular shapes.
1160
1165
 
1161
1166
  Args:
1162
- gdf: GeoDataFrame with building footprints
1163
- min_area: Minimum area in square units to keep a building
1167
+ gdf: GeoDataFrame with objects
1168
+ min_area: Minimum area in square units to keep an object
1164
1169
  angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
1165
- orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
1166
- rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
1170
+ orthogonality_threshold: Percentage of angles that must be orthogonal for an object to be regularized
1171
+ rectangularity_threshold: Minimum area ratio to Object's oriented bounding box for rectangular simplification
1167
1172
 
1168
1173
  Returns:
1169
- GeoDataFrame with regularized building footprints
1174
+ GeoDataFrame with regularized objects
1170
1175
  """
1171
1176
  import numpy as np
1172
1177
  from shapely.geometry import Polygon, MultiPolygon, box
@@ -1281,10 +1286,10 @@ class BuildingFootprintExtractor:
1281
1286
  return rect
1282
1287
 
1283
1288
  if gdf is None or len(gdf) == 0:
1284
- print("No buildings to regularize")
1289
+ print("No Objects to regularize")
1285
1290
  return gdf
1286
1291
 
1287
- print(f"Regularizing {len(gdf)} building footprints...")
1292
+ print(f"Regularizing {len(gdf)} objects...")
1288
1293
  print(f"- Angle threshold: {angle_threshold}° from 90°")
1289
1294
  print(f"- Min orthogonality: {orthogonality_threshold*100}% of angles")
1290
1295
  print(
@@ -1295,11 +1300,11 @@ class BuildingFootprintExtractor:
1295
1300
  result_gdf = gdf.copy()
1296
1301
 
1297
1302
  # Track statistics
1298
- total_buildings = len(gdf)
1303
+ total_objects = len(gdf)
1299
1304
  regularized_count = 0
1300
1305
  rectangularized_count = 0
1301
1306
 
1302
- # Process each building
1307
+ # Process each Object
1303
1308
  for idx, row in tqdm(gdf.iterrows(), total=len(gdf)):
1304
1309
  geom = row.geometry
1305
1310
 
@@ -1314,7 +1319,7 @@ class BuildingFootprintExtractor:
1314
1319
  continue
1315
1320
  geom = list(geom.geoms)[np.argmax(areas)]
1316
1321
 
1317
- # Filter out tiny buildings
1322
+ # Filter out tiny Objects
1318
1323
  if geom.area < min_area:
1319
1324
  continue
1320
1325
 
@@ -1331,33 +1336,33 @@ class BuildingFootprintExtractor:
1331
1336
 
1332
1337
  # Decide how to regularize
1333
1338
  if rectangularity >= rectangularity_threshold:
1334
- # Building is already quite rectangular, simplify to a rectangle
1339
+ # Object is already quite rectangular, simplify to a rectangle
1335
1340
  result_gdf.at[idx, "geometry"] = oriented_box
1336
1341
  result_gdf.at[idx, "regularized"] = "rectangle"
1337
1342
  rectangularized_count += 1
1338
1343
  elif orthogonality >= orthogonality_threshold:
1339
- # Building has many orthogonal angles but isn't rectangular
1344
+ # Object has many orthogonal angles but isn't rectangular
1340
1345
  # Could implement more sophisticated regularization here
1341
1346
  # For now, we'll still use the oriented rectangle
1342
1347
  result_gdf.at[idx, "geometry"] = oriented_box
1343
1348
  result_gdf.at[idx, "regularized"] = "orthogonal"
1344
1349
  regularized_count += 1
1345
1350
  else:
1346
- # Building doesn't have clear orthogonal structure
1351
+ # Object doesn't have clear orthogonal structure
1347
1352
  # Keep original but flag as unmodified
1348
1353
  result_gdf.at[idx, "regularized"] = "original"
1349
1354
 
1350
1355
  # Report statistics
1351
1356
  print(f"Regularization completed:")
1352
- print(f"- Total buildings: {total_buildings}")
1357
+ print(f"- Total objects: {total_objects}")
1353
1358
  print(
1354
- f"- Rectangular buildings: {rectangularized_count} ({rectangularized_count/total_buildings*100:.1f}%)"
1359
+ f"- Rectangular objects: {rectangularized_count} ({rectangularized_count/total_objects*100:.1f}%)"
1355
1360
  )
1356
1361
  print(
1357
- f"- Other regularized buildings: {regularized_count} ({regularized_count/total_buildings*100:.1f}%)"
1362
+ f"- Other regularized objects: {regularized_count} ({regularized_count/total_objects*100:.1f}%)"
1358
1363
  )
1359
1364
  print(
1360
- f"- Unmodified buildings: {total_buildings-rectangularized_count-regularized_count} ({(total_buildings-rectangularized_count-regularized_count)/total_buildings*100:.1f}%)"
1365
+ f"- Unmodified objects: {total_objects-rectangularized_count-regularized_count} ({(total_objects-rectangularized_count-regularized_count)/total_objects*100:.1f}%)"
1361
1366
  )
1362
1367
 
1363
1368
  return result_gdf
@@ -1366,14 +1371,14 @@ class BuildingFootprintExtractor:
1366
1371
  self, raster_path, gdf=None, output_path=None, figsize=(12, 12)
1367
1372
  ):
1368
1373
  """
1369
- Visualize building detection results with proper coordinate transformation.
1374
+ Visualize object detection results with proper coordinate transformation.
1370
1375
 
1371
- This function displays building footprints on top of the raster image,
1376
+ This function displays objects on top of the raster image,
1372
1377
  ensuring proper alignment between the GeoDataFrame polygons and the image.
1373
1378
 
1374
1379
  Args:
1375
1380
  raster_path: Path to input raster
1376
- gdf: GeoDataFrame with building polygons (optional)
1381
+ gdf: GeoDataFrame with object polygons (optional)
1377
1382
  output_path: Path to save visualization (optional)
1378
1383
  figsize: Figure size (width, height) in inches
1379
1384
 
@@ -1390,7 +1395,7 @@ class BuildingFootprintExtractor:
1390
1395
  gdf = self.process_raster(raster_path)
1391
1396
 
1392
1397
  if gdf is None or len(gdf) == 0:
1393
- print("No buildings to visualize")
1398
+ print("No objects to visualize")
1394
1399
  return False
1395
1400
 
1396
1401
  # Check if confidence column exists in the GeoDataFrame
@@ -1531,7 +1536,7 @@ class BuildingFootprintExtractor:
1531
1536
  print(f"Unsupported geometry type: {geometry.geom_type}")
1532
1537
  return None
1533
1538
 
1534
- # Plot each building footprint
1539
+ # Plot each object
1535
1540
  for idx, row in gdf.iterrows():
1536
1541
  try:
1537
1542
  # Convert polygon to pixel coordinates
@@ -1593,7 +1598,7 @@ class BuildingFootprintExtractor:
1593
1598
  # Remove axes
1594
1599
  ax.set_xticks([])
1595
1600
  ax.set_yticks([])
1596
- ax.set_title(f"Building Footprints (Found: {len(gdf)})")
1601
+ ax.set_title(f"objects (Found: {len(gdf)})")
1597
1602
 
1598
1603
  # Save if requested
1599
1604
  if output_path:
@@ -1603,21 +1608,21 @@ class BuildingFootprintExtractor:
1603
1608
 
1604
1609
  plt.close()
1605
1610
 
1606
- # Create a simpler visualization focused just on a subset of buildings
1611
+ # Create a simpler visualization focused just on a subset of objects
1607
1612
  if len(gdf) > 0:
1608
1613
  plt.figure(figsize=figsize)
1609
1614
  ax = plt.gca()
1610
1615
 
1611
1616
  # Choose a subset of the image to show
1612
1617
  with rasterio.open(raster_path) as src:
1613
- # Get centroid of first building
1618
+ # Get centroid of first object
1614
1619
  sample_geom = gdf.iloc[0].geometry
1615
1620
  centroid = sample_geom.centroid
1616
1621
 
1617
1622
  # Convert to pixel coordinates
1618
1623
  center_x, center_y = ~src.transform * (centroid.x, centroid.y)
1619
1624
 
1620
- # Define a window around this building
1625
+ # Define a window around this object
1621
1626
  window_size = 500 # pixels
1622
1627
  window = rasterio.windows.Window(
1623
1628
  max(0, int(center_x - window_size / 2)),
@@ -1654,7 +1659,7 @@ class BuildingFootprintExtractor:
1654
1659
  window_bounds = rasterio.windows.bounds(window, src.transform)
1655
1660
  window_box = box(*window_bounds)
1656
1661
 
1657
- # Filter buildings that intersect with this window
1662
+ # Filter objects that intersect with this window
1658
1663
  visible_gdf = gdf[gdf.intersects(window_box)]
1659
1664
 
1660
1665
  # Set up colors for sample view if confidence data exists
@@ -1676,7 +1681,7 @@ class BuildingFootprintExtractor:
1676
1681
  except Exception as e:
1677
1682
  print(f"Error setting up sample confidence visualization: {e}")
1678
1683
 
1679
- # Plot building footprints in sample view
1684
+ # Plot objects in sample view
1680
1685
  for idx, row in visible_gdf.iterrows():
1681
1686
  try:
1682
1687
  # Get window-relative pixel coordinates
@@ -1751,9 +1756,7 @@ class BuildingFootprintExtractor:
1751
1756
  print(f"Error plotting polygon in sample view: {e}")
1752
1757
 
1753
1758
  # Set title
1754
- ax.set_title(
1755
- f"Sample Area - Building Footprints (Showing: {len(visible_gdf)})"
1756
- )
1759
+ ax.set_title(f"Sample Area - objects (Showing: {len(visible_gdf)})")
1757
1760
 
1758
1761
  # Remove axes
1759
1762
  ax.set_xticks([])
@@ -1769,3 +1772,111 @@ class BuildingFootprintExtractor:
1769
1772
  plt.tight_layout()
1770
1773
  plt.savefig(sample_output, dpi=300, bbox_inches="tight")
1771
1774
  print(f"Sample visualization saved to {sample_output}")
1775
+
1776
+
1777
+ class BuildingFootprintExtractor(ObjectDetector):
1778
+ """
1779
+ Building footprint extraction using a pre-trained Mask R-CNN model.
1780
+
1781
+ This class extends the
1782
+ `ObjectDetector` class with additional methods for building footprint extraction."
1783
+ """
1784
+
1785
+ def __init__(
1786
+ self,
1787
+ model_path="building_footprints_usa.pth",
1788
+ repo_id=None,
1789
+ model=None,
1790
+ device=None,
1791
+ ):
1792
+ """
1793
+ Initialize the object extractor.
1794
+
1795
+ Args:
1796
+ model_path: Path to the .pth model file.
1797
+ repo_id: Repo ID for loading models from the Hub.
1798
+ model: Custom model to use for inference.
1799
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
1800
+ """
1801
+ super().__init__(
1802
+ model_path=model_path, repo_id=repo_id, model=model, device=device
1803
+ )
1804
+
1805
+ def regularize_buildings(
1806
+ self,
1807
+ gdf,
1808
+ min_area=10,
1809
+ angle_threshold=15,
1810
+ orthogonality_threshold=0.3,
1811
+ rectangularity_threshold=0.7,
1812
+ ):
1813
+ """
1814
+ Regularize building footprints to enforce right angles and rectangular shapes.
1815
+
1816
+ Args:
1817
+ gdf: GeoDataFrame with building footprints
1818
+ min_area: Minimum area in square units to keep a building
1819
+ angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
1820
+ orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
1821
+ rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
1822
+
1823
+ Returns:
1824
+ GeoDataFrame with regularized building footprints
1825
+ """
1826
+ return self.regularize_objects(
1827
+ gdf,
1828
+ min_area=min_area,
1829
+ angle_threshold=angle_threshold,
1830
+ orthogonality_threshold=orthogonality_threshold,
1831
+ rectangularity_threshold=rectangularity_threshold,
1832
+ )
1833
+
1834
+
1835
+ class CarDetector(ObjectDetector):
1836
+ """
1837
+ Car detection using a pre-trained Mask R-CNN model.
1838
+
1839
+ This class extends the
1840
+ `ObjectDetector` class with additional methods for car detection."
1841
+ """
1842
+
1843
+ def __init__(
1844
+ self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
1845
+ ):
1846
+ """
1847
+ Initialize the object extractor.
1848
+
1849
+ Args:
1850
+ model_path: Path to the .pth model file.
1851
+ repo_id: Repo ID for loading models from the Hub.
1852
+ model: Custom model to use for inference.
1853
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
1854
+ """
1855
+ super().__init__(
1856
+ model_path=model_path, repo_id=repo_id, model=model, device=device
1857
+ )
1858
+
1859
+
1860
+ class ShipDetector(ObjectDetector):
1861
+ """
1862
+ Ship detection using a pre-trained Mask R-CNN model.
1863
+
1864
+ This class extends the
1865
+ `ObjectDetector` class with additional methods for ship detection."
1866
+ """
1867
+
1868
+ def __init__(
1869
+ self, model_path="ship_detection.pth", repo_id=None, model=None, device=None
1870
+ ):
1871
+ """
1872
+ Initialize the object extractor.
1873
+
1874
+ Args:
1875
+ model_path: Path to the .pth model file.
1876
+ repo_id: Repo ID for loading models from the Hub.
1877
+ model: Custom model to use for inference.
1878
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
1879
+ """
1880
+ super().__init__(
1881
+ model_path=model_path, repo_id=repo_id, model=model, device=device
1882
+ )
geoai/utils.py CHANGED
@@ -1080,3 +1080,97 @@ def install_package(package):
1080
1080
 
1081
1081
  # Wait for process to complete
1082
1082
  process.wait()
1083
+
1084
+
1085
+ def create_split_map(
1086
+ left_layer: Optional[str] = "TERRAIN",
1087
+ right_layer: Optional[str] = "OpenTopoMap",
1088
+ left_args: Optional[dict] = None,
1089
+ right_args: Optional[dict] = None,
1090
+ left_array_args: Optional[dict] = None,
1091
+ right_array_args: Optional[dict] = None,
1092
+ zoom_control: Optional[bool] = True,
1093
+ fullscreen_control: Optional[bool] = True,
1094
+ layer_control: Optional[bool] = True,
1095
+ add_close_button: Optional[bool] = False,
1096
+ left_label: Optional[str] = None,
1097
+ right_label: Optional[str] = None,
1098
+ left_position: Optional[str] = "bottomleft",
1099
+ right_position: Optional[str] = "bottomright",
1100
+ widget_layout: Optional[dict] = None,
1101
+ draggable: Optional[bool] = True,
1102
+ center: Optional[List[float]] = [20, 0],
1103
+ zoom: Optional[int] = 2,
1104
+ height: Optional[int] = "600px",
1105
+ basemap: Optional[str] = None,
1106
+ basemap_args: Optional[dict] = None,
1107
+ m=None,
1108
+ **kwargs,
1109
+ ) -> None:
1110
+ """Adds split map.
1111
+
1112
+ Args:
1113
+ left_layer (str, optional): The left tile layer. Can be a local file path, HTTP URL, or a basemap name. Defaults to 'TERRAIN'.
1114
+ right_layer (str, optional): The right tile layer. Can be a local file path, HTTP URL, or a basemap name. Defaults to 'OpenTopoMap'.
1115
+ left_args (dict, optional): The arguments for the left tile layer. Defaults to {}.
1116
+ right_args (dict, optional): The arguments for the right tile layer. Defaults to {}.
1117
+ left_array_args (dict, optional): The arguments for array_to_image for the left layer. Defaults to {}.
1118
+ right_array_args (dict, optional): The arguments for array_to_image for the right layer. Defaults to {}.
1119
+ zoom_control (bool, optional): Whether to add zoom control. Defaults to True.
1120
+ fullscreen_control (bool, optional): Whether to add fullscreen control. Defaults to True.
1121
+ layer_control (bool, optional): Whether to add layer control. Defaults to True.
1122
+ add_close_button (bool, optional): Whether to add a close button. Defaults to False.
1123
+ left_label (str, optional): The label for the left layer. Defaults to None.
1124
+ right_label (str, optional): The label for the right layer. Defaults to None.
1125
+ left_position (str, optional): The position for the left label. Defaults to "bottomleft".
1126
+ right_position (str, optional): The position for the right label. Defaults to "bottomright".
1127
+ widget_layout (dict, optional): The layout for the widget. Defaults to None.
1128
+ draggable (bool, optional): Whether the split map is draggable. Defaults to True.
1129
+ """
1130
+
1131
+ if left_args is None:
1132
+ left_args = {}
1133
+
1134
+ if right_args is None:
1135
+ right_args = {}
1136
+
1137
+ if left_array_args is None:
1138
+ left_array_args = {}
1139
+
1140
+ if right_array_args is None:
1141
+ right_array_args = {}
1142
+
1143
+ if basemap_args is None:
1144
+ basemap_args = {}
1145
+
1146
+ if m is None:
1147
+ m = leafmap.Map(center=center, zoom=zoom, height=height, **kwargs)
1148
+ m.clear_layers()
1149
+ if isinstance(basemap, str):
1150
+ if basemap.endswith(".tif"):
1151
+ if basemap.startswith("http"):
1152
+ m.add_cog_layer(basemap, name="Basemap", **basemap_args)
1153
+ else:
1154
+ m.add_raster(basemap, name="Basemap", **basemap_args)
1155
+ else:
1156
+ m.add_basemap(basemap)
1157
+ m.split_map(
1158
+ left_layer=left_layer,
1159
+ right_layer=right_layer,
1160
+ left_args=left_args,
1161
+ right_args=right_args,
1162
+ left_array_args=left_array_args,
1163
+ right_array_args=right_array_args,
1164
+ zoom_control=zoom_control,
1165
+ fullscreen_control=fullscreen_control,
1166
+ layer_control=layer_control,
1167
+ add_close_button=add_close_button,
1168
+ left_label=left_label,
1169
+ right_label=right_label,
1170
+ left_position=left_position,
1171
+ right_position=right_position,
1172
+ widget_layout=widget_layout,
1173
+ draggable=draggable,
1174
+ )
1175
+
1176
+ return m
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: geoai-py
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -31,7 +31,6 @@ Requires-Dist: pystac-client
31
31
  Requires-Dist: rasterio
32
32
  Requires-Dist: rioxarray
33
33
  Requires-Dist: scikit-learn
34
- Requires-Dist: segment-geospatial
35
34
  Requires-Dist: torch
36
35
  Requires-Dist: torchgeo
37
36
  Requires-Dist: tqdm
@@ -0,0 +1,13 @@
1
+ geoai/__init__.py,sha256=XvH0-HD7pv6gUBC8VJzToLqq7uCcWp2Orhp9sTR7Dd4,923
2
+ geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
3
+ geoai/extract.py,sha256=DQEvZXXmSxwoAUhJsXActM1IbFOg1ljg5fArqM5JdSk,75240
4
+ geoai/geoai.py,sha256=wNwKIqwOT10tU4uiWTcNp5Gd598rRFMANIfJsGdOWKM,90
5
+ geoai/preprocess.py,sha256=zQynNxQ_nxDkCEQU-h4G1SrgqxV1c5EREMV3JeS0cC0,118701
6
+ geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
7
+ geoai/utils.py,sha256=DCHfL2G5jH6RWDJUyJYORMa5uTaQp56Kj8sQXlG-0ck,42678
8
+ geoai_py-0.3.2.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
9
+ geoai_py-0.3.2.dist-info/METADATA,sha256=-j8ngaO_hQnWueYEoZsLbbIDN2cSj-uWpbBmD7m-2mU,5720
10
+ geoai_py-0.3.2.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
11
+ geoai_py-0.3.2.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
12
+ geoai_py-0.3.2.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
13
+ geoai_py-0.3.2.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- geoai/__init__.py,sha256=D1BgGoNkd6ZiimfM11EIeaTBVauMLz_XUnzp73IAJ80,923
2
- geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
3
- geoai/extract.py,sha256=2MdfLwlxbZ4YZIgEQhPcGpMTbNDTJ5-TbdJZnPfZ4Vw,71886
4
- geoai/geoai.py,sha256=wNwKIqwOT10tU4uiWTcNp5Gd598rRFMANIfJsGdOWKM,90
5
- geoai/preprocess.py,sha256=zQynNxQ_nxDkCEQU-h4G1SrgqxV1c5EREMV3JeS0cC0,118701
6
- geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
7
- geoai/utils.py,sha256=uEZJLLnk2qOhsJLJgfJY6Fj_P0fP3FOZLJys0RwEkTs,38766
8
- geoai_py-0.3.1.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
9
- geoai_py-0.3.1.dist-info/METADATA,sha256=uBxf6OpFzAd1DlTqhrb6ge6N0-YYv4K53c86I8rTjk8,5754
10
- geoai_py-0.3.1.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
11
- geoai_py-0.3.1.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
12
- geoai_py-0.3.1.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
13
- geoai_py-0.3.1.dist-info/RECORD,,