geoai-py 0.3.1__py2.py3-none-any.whl → 0.3.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/extract.py CHANGED
@@ -13,7 +13,7 @@ import rasterio
13
13
  from rasterio.windows import Window
14
14
  from rasterio.features import shapes
15
15
  from huggingface_hub import hf_hub_download
16
- from .preprocess import get_raster_stats
16
+ from .utils import get_raster_stats
17
17
 
18
18
  try:
19
19
  from torchgeo.datasets import NonGeoDataset
@@ -23,9 +23,9 @@ except ImportError as e:
23
23
  )
24
24
 
25
25
 
26
- class BuildingFootprintDataset(NonGeoDataset):
26
+ class CustomDataset(NonGeoDataset):
27
27
  """
28
- A TorchGeo dataset for building footprint extraction.
28
+ A TorchGeo dataset for object extraction.
29
29
  Using NonGeoDataset to avoid spatial indexing issues.
30
30
  """
31
31
 
@@ -170,18 +170,20 @@ class BuildingFootprintDataset(NonGeoDataset):
170
170
  return self.rows * self.cols
171
171
 
172
172
 
173
- class BuildingFootprintExtractor:
173
+ class ObjectDetector:
174
174
  """
175
- Building footprint extraction using Mask R-CNN with TorchGeo.
175
+ Object extraction using Mask R-CNN with TorchGeo.
176
176
  """
177
177
 
178
- def __init__(self, model_path=None, device=None):
178
+ def __init__(self, model_path=None, repo_id=None, model=None, device=None):
179
179
  """
180
- Initialize the building footprint extractor.
180
+ Initialize the object extractor.
181
181
 
182
182
  Args:
183
- model_path: Path to the .pth model file
184
- device: Device to use for inference ('cuda:0', 'cpu', etc.)
183
+ model_path: Path to the .pth model file.
184
+ repo_id: Hugging Face repository ID for model download.
185
+ model: Pre-initialized model object (optional).
186
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
185
187
  """
186
188
  # Set device
187
189
  if device is None:
@@ -189,31 +191,36 @@ class BuildingFootprintExtractor:
189
191
  else:
190
192
  self.device = torch.device(device)
191
193
 
192
- # Default parameters for building detection - these can be overridden in process_raster
194
+ # Default parameters for object detection - these can be overridden in process_raster
193
195
  self.chip_size = (512, 512) # Size of image chips for processing
194
196
  self.overlap = 0.25 # Default overlap between tiles
195
197
  self.confidence_threshold = 0.5 # Default confidence threshold
196
198
  self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
197
- self.small_building_area = 100 # Minimum area in pixels to keep a building
199
+ self.min_object_area = 100 # Minimum area in pixels to keep an object
200
+ self.max_object_area = None # Maximum area in pixels to keep an object
198
201
  self.mask_threshold = 0.5 # Threshold for mask binarization
199
202
  self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
200
203
 
201
204
  # Initialize model
202
- self.model = self._initialize_model()
205
+ self.model = self.initialize_model(model)
203
206
 
204
207
  # Download model if needed
205
- if model_path is None:
206
- model_path = self._download_model_from_hf()
208
+ if model_path is None or (not os.path.exists(model_path)):
209
+ model_path = self.download_model_from_hf(model_path, repo_id)
207
210
 
208
211
  # Load model weights
209
- self._load_weights(model_path)
212
+ self.load_weights(model_path)
210
213
 
211
214
  # Set model to evaluation mode
212
215
  self.model.eval()
213
216
 
214
- def _download_model_from_hf(self):
217
+ def download_model_from_hf(self, model_path=None, repo_id=None):
215
218
  """
216
- Download the USA building footprints model from Hugging Face.
219
+ Download the object detection model from Hugging Face.
220
+
221
+ Args:
222
+ model_path: Path to the model file.
223
+ repo_id: Hugging Face repository ID.
217
224
 
218
225
  Returns:
219
226
  Path to the downloaded model file
@@ -223,17 +230,14 @@ class BuildingFootprintExtractor:
223
230
  print("Model path not specified, downloading from Hugging Face...")
224
231
 
225
232
  # Define the repository ID and model filename
226
- repo_id = "giswqs/geoai" # Update with your actual username/repo
227
- filename = "building_footprints_usa.pth"
233
+ if repo_id is None:
234
+ repo_id = "giswqs/geoai"
228
235
 
229
- # Ensure cache directory exists
230
- # cache_dir = os.path.join(
231
- # os.path.expanduser("~"), ".cache", "building_footprints"
232
- # )
233
- # os.makedirs(cache_dir, exist_ok=True)
236
+ if model_path is None:
237
+ model_path = "building_footprints_usa.pth"
234
238
 
235
239
  # Download the model
236
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
240
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
237
241
  print(f"Model downloaded to: {model_path}")
238
242
 
239
243
  return model_path
@@ -243,28 +247,36 @@ class BuildingFootprintExtractor:
243
247
  print("Please specify a local model path or ensure internet connectivity.")
244
248
  raise
245
249
 
246
- def _initialize_model(self):
247
- """Initialize Mask R-CNN model with ResNet50 backbone."""
248
- # Standard image mean and std for pre-trained models
249
- # Note: This would normally come from your config file
250
- image_mean = [0.485, 0.456, 0.406]
251
- image_std = [0.229, 0.224, 0.225]
252
-
253
- # Create model with explicit normalization parameters
254
- model = maskrcnn_resnet50_fpn(
255
- weights=None,
256
- progress=False,
257
- num_classes=2, # Background + building
258
- weights_backbone=None,
259
- # These parameters ensure consistent normalization
260
- image_mean=image_mean,
261
- image_std=image_std,
262
- )
250
+ def initialize_model(self, model):
251
+ """Initialize a deep learning model for object detection.
252
+
253
+ Args:
254
+ model (torch.nn.Module): A pre-initialized model object.
255
+
256
+ Returns:
257
+ torch.nn.Module: A deep learning model for object detection.
258
+ """
259
+
260
+ if model is None: # Initialize Mask R-CNN model with ResNet50 backbone.
261
+ # Standard image mean and std for pre-trained models
262
+ image_mean = [0.485, 0.456, 0.406]
263
+ image_std = [0.229, 0.224, 0.225]
264
+
265
+ # Create model with explicit normalization parameters
266
+ model = maskrcnn_resnet50_fpn(
267
+ weights=None,
268
+ progress=False,
269
+ num_classes=2, # Background + object
270
+ weights_backbone=None,
271
+ # These parameters ensure consistent normalization
272
+ image_mean=image_mean,
273
+ image_std=image_std,
274
+ )
263
275
 
264
276
  model.to(self.device)
265
277
  return model
266
278
 
267
- def _load_weights(self, model_path):
279
+ def load_weights(self, model_path):
268
280
  """
269
281
  Load weights from file with error handling for different formats.
270
282
 
@@ -306,7 +318,7 @@ class BuildingFootprintExtractor:
306
318
  except Exception as e:
307
319
  raise RuntimeError(f"Failed to load model: {e}")
308
320
 
309
- def _mask_to_polygons(self, mask, **kwargs):
321
+ def mask_to_polygons(self, mask, **kwargs):
310
322
  """
311
323
  Convert binary mask to polygon contours using OpenCV.
312
324
 
@@ -315,7 +327,8 @@ class BuildingFootprintExtractor:
315
327
  **kwargs: Optional parameters:
316
328
  simplify_tolerance: Tolerance for polygon simplification
317
329
  mask_threshold: Threshold for mask binarization
318
- small_building_area: Minimum area in pixels to keep a building
330
+ min_object_area: Minimum area in pixels to keep an object
331
+ max_object_area: Maximum area in pixels to keep an object
319
332
 
320
333
  Returns:
321
334
  List of polygons as lists of (x, y) coordinates
@@ -324,9 +337,8 @@ class BuildingFootprintExtractor:
324
337
  # Get parameters from kwargs or use instance defaults
325
338
  simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
326
339
  mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
327
- small_building_area = kwargs.get(
328
- "small_building_area", self.small_building_area
329
- )
340
+ min_object_area = kwargs.get("min_object_area", self.min_object_area)
341
+ max_object_area = kwargs.get("max_object_area", self.max_object_area)
330
342
 
331
343
  # Ensure binary mask
332
344
  mask = (mask > mask_threshold).astype(np.uint8)
@@ -342,7 +354,14 @@ class BuildingFootprintExtractor:
342
354
  polygons = []
343
355
  for contour in contours:
344
356
  # Filter out too small contours
345
- if contour.shape[0] < 3 or cv2.contourArea(contour) < small_building_area:
357
+ if contour.shape[0] < 3 or cv2.contourArea(contour) < min_object_area:
358
+ continue
359
+
360
+ # Filter out too large contours
361
+ if (
362
+ max_object_area is not None
363
+ and cv2.contourArea(contour) > max_object_area
364
+ ):
346
365
  continue
347
366
 
348
367
  # Simplify contour if it has many points
@@ -356,7 +375,7 @@ class BuildingFootprintExtractor:
356
375
 
357
376
  return polygons
358
377
 
359
- def _filter_overlapping_polygons(self, gdf, **kwargs):
378
+ def filter_overlapping_polygons(self, gdf, **kwargs):
360
379
  """
361
380
  Filter overlapping polygons using non-maximum suppression.
362
381
 
@@ -413,26 +432,26 @@ class BuildingFootprintExtractor:
413
432
 
414
433
  return gdf.iloc[keep_indices]
415
434
 
416
- def filter_edge_buildings(self, gdf, raster_path, edge_buffer=10):
435
+ def filter_edge_objects(self, gdf, raster_path, edge_buffer=10):
417
436
  """
418
- Filter out building detections that fall in padding/edge areas of the image.
437
+ Filter out object detections that fall in padding/edge areas of the image.
419
438
 
420
439
  Args:
421
- gdf: GeoDataFrame with building footprint detections
440
+ gdf: GeoDataFrame with object detections
422
441
  raster_path: Path to the original raster file
423
442
  edge_buffer: Buffer in pixels to consider as edge region
424
443
 
425
444
  Returns:
426
- GeoDataFrame with filtered building footprints
445
+ GeoDataFrame with filtered objects
427
446
  """
428
447
  import rasterio
429
448
  from shapely.geometry import box
430
449
 
431
- # If no buildings detected, return empty GeoDataFrame
450
+ # If no objects detected, return empty GeoDataFrame
432
451
  if gdf is None or len(gdf) == 0:
433
452
  return gdf
434
453
 
435
- print(f"Buildings before filtering: {len(gdf)}")
454
+ print(f"Objects before filtering: {len(gdf)}")
436
455
 
437
456
  with rasterio.open(raster_path) as src:
438
457
  # Get raster bounds
@@ -461,18 +480,18 @@ class BuildingFootprintExtractor:
461
480
  else:
462
481
  inner_box = box(*inner_bounds)
463
482
 
464
- # Filter out buildings that intersect with the edge of the image
483
+ # Filter out objects that intersect with the edge of the image
465
484
  filtered_gdf = gdf[gdf.intersects(inner_box)]
466
485
 
467
- # Additional check for buildings that have >50% of their area outside the valid region
468
- valid_buildings = []
486
+ # Additional check for objects that have >50% of their area outside the valid region
487
+ valid_objects = []
469
488
  for idx, row in filtered_gdf.iterrows():
470
489
  if row.geometry.intersection(inner_box).area >= 0.5 * row.geometry.area:
471
- valid_buildings.append(idx)
490
+ valid_objects.append(idx)
472
491
 
473
- filtered_gdf = filtered_gdf.loc[valid_buildings]
492
+ filtered_gdf = filtered_gdf.loc[valid_objects]
474
493
 
475
- print(f"Buildings after filtering: {len(filtered_gdf)}")
494
+ print(f"Objects after filtering: {len(filtered_gdf)}")
476
495
 
477
496
  return filtered_gdf
478
497
 
@@ -482,28 +501,30 @@ class BuildingFootprintExtractor:
482
501
  output_path=None,
483
502
  simplify_tolerance=None,
484
503
  mask_threshold=None,
485
- small_building_area=None,
504
+ min_object_area=None,
505
+ max_object_area=None,
486
506
  nms_iou_threshold=None,
487
507
  regularize=True,
488
508
  angle_threshold=15,
489
509
  rectangularity_threshold=0.7,
490
510
  ):
491
511
  """
492
- Convert a building mask GeoTIFF to vector polygons and save as GeoJSON.
512
+ Convert an object mask GeoTIFF to vector polygons and save as GeoJSON.
493
513
 
494
514
  Args:
495
- mask_path: Path to the building masks GeoTIFF
515
+ mask_path: Path to the object masks GeoTIFF
496
516
  output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
497
517
  simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
498
518
  mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
499
- small_building_area: Minimum area in pixels to keep a building (default: self.small_building_area)
519
+ min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
520
+ max_object_area: Minimum area in pixels to keep an object (default: self.max_object_area)
500
521
  nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
501
- regularize: Whether to regularize buildings to right angles (default: True)
522
+ regularize: Whether to regularize objects to right angles (default: True)
502
523
  angle_threshold: Maximum deviation from 90 degrees for regularization (default: 15)
503
524
  rectangularity_threshold: Threshold for rectangle simplification (default: 0.7)
504
525
 
505
526
  Returns:
506
- GeoDataFrame with building footprints
527
+ GeoDataFrame with objects
507
528
  """
508
529
  # Use class defaults if parameters not provided
509
530
  simplify_tolerance = (
@@ -514,10 +535,11 @@ class BuildingFootprintExtractor:
514
535
  mask_threshold = (
515
536
  mask_threshold if mask_threshold is not None else self.mask_threshold
516
537
  )
517
- small_building_area = (
518
- small_building_area
519
- if small_building_area is not None
520
- else self.small_building_area
538
+ min_object_area = (
539
+ min_object_area if min_object_area is not None else self.min_object_area
540
+ )
541
+ max_object_area = (
542
+ max_object_area if max_object_area is not None else self.max_object_area
521
543
  )
522
544
  nms_iou_threshold = (
523
545
  nms_iou_threshold
@@ -531,10 +553,11 @@ class BuildingFootprintExtractor:
531
553
 
532
554
  print(f"Converting mask to GeoJSON with parameters:")
533
555
  print(f"- Mask threshold: {mask_threshold}")
534
- print(f"- Min building area: {small_building_area}")
556
+ print(f"- Min object area: {min_object_area}")
557
+ print(f"- Max object area: {max_object_area}")
535
558
  print(f"- Simplify tolerance: {simplify_tolerance}")
536
559
  print(f"- NMS IoU threshold: {nms_iou_threshold}")
537
- print(f"- Regularize buildings: {regularize}")
560
+ print(f"- Regularize objects: {regularize}")
538
561
  if regularize:
539
562
  print(f"- Angle threshold: {angle_threshold}° from 90°")
540
563
  print(f"- Rectangularity threshold: {rectangularity_threshold*100}%")
@@ -564,7 +587,7 @@ class BuildingFootprintExtractor:
564
587
  )
565
588
 
566
589
  print(
567
- f"Found {num_labels-1} potential buildings"
590
+ f"Found {num_labels-1} potential objects"
568
591
  ) # Subtract 1 for background
569
592
 
570
593
  # Create list to store polygons and confidence values
@@ -573,19 +596,23 @@ class BuildingFootprintExtractor:
573
596
 
574
597
  # Process each component (skip the first one which is background)
575
598
  for i in tqdm(range(1, num_labels)):
576
- # Extract this building
599
+ # Extract this object
577
600
  area = stats[i, cv2.CC_STAT_AREA]
578
601
 
579
602
  # Skip if too small
580
- if area < small_building_area:
603
+ if area < min_object_area:
604
+ continue
605
+
606
+ # Skip if too large
607
+ if max_object_area is not None and area > max_object_area:
581
608
  continue
582
609
 
583
- # Create a mask for this building
584
- building_mask = (labels == i).astype(np.uint8)
610
+ # Create a mask for this object
611
+ object_mask = (labels == i).astype(np.uint8)
585
612
 
586
613
  # Find contours
587
614
  contours, _ = cv2.findContours(
588
- building_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
615
+ object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
589
616
  )
590
617
 
591
618
  # Process each contour
@@ -633,17 +660,17 @@ class BuildingFootprintExtractor:
633
660
  {
634
661
  "geometry": all_polygons,
635
662
  "confidence": all_confidences,
636
- "class": 1, # Building class
663
+ "class": 1, # Object class
637
664
  },
638
665
  crs=crs,
639
666
  )
640
667
 
641
668
  # Apply non-maximum suppression to remove overlapping polygons
642
- gdf = self._filter_overlapping_polygons(
669
+ gdf = self.filter_overlapping_polygons(
643
670
  gdf, nms_iou_threshold=nms_iou_threshold
644
671
  )
645
672
 
646
- print(f"Building count after NMS filtering: {len(gdf)}")
673
+ print(f"Object count after NMS filtering: {len(gdf)}")
647
674
 
648
675
  # Apply regularization if requested
649
676
  if regularize and len(gdf) > 0:
@@ -661,8 +688,8 @@ class BuildingFootprintExtractor:
661
688
  # Use 10 pixels as minimum area in geographic units
662
689
  min_geo_area = 10 * avg_pixel_area
663
690
 
664
- # Regularize buildings
665
- gdf = self.regularize_buildings(
691
+ # Regularize objects
692
+ gdf = self.regularize_objects(
666
693
  gdf,
667
694
  min_area=min_geo_area,
668
695
  angle_threshold=angle_threshold,
@@ -672,7 +699,7 @@ class BuildingFootprintExtractor:
672
699
  # Save to file
673
700
  if output_path:
674
701
  gdf.to_file(output_path)
675
- print(f"Saved {len(gdf)} building footprints to {output_path}")
702
+ print(f"Saved {len(gdf)} objects to {output_path}")
676
703
 
677
704
  return gdf
678
705
 
@@ -687,25 +714,25 @@ class BuildingFootprintExtractor:
687
714
  **kwargs,
688
715
  ):
689
716
  """
690
- Process a raster file to extract building footprints with customizable parameters.
717
+ Process a raster file to extract objects with customizable parameters.
691
718
 
692
719
  Args:
693
720
  raster_path: Path to input raster file
694
721
  output_path: Path to output GeoJSON file (optional)
695
722
  batch_size: Batch size for processing
696
- filter_edges: Whether to filter out buildings at the edges of the image
697
- edge_buffer: Size of edge buffer in pixels to filter out buildings (if filter_edges=True)
723
+ filter_edges: Whether to filter out objects at the edges of the image
724
+ edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
698
725
  **kwargs: Additional parameters:
699
726
  confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
700
727
  overlap: Overlap between adjacent tiles (0.0-1.0)
701
728
  chip_size: Size of image chips for processing (height, width)
702
729
  nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
703
730
  mask_threshold: Threshold for mask binarization (0.0-1.0)
704
- small_building_area: Minimum area in pixels to keep a building
731
+ min_object_area: Minimum area in pixels to keep an object
705
732
  simplify_tolerance: Tolerance for polygon simplification
706
733
 
707
734
  Returns:
708
- GeoDataFrame with building footprints
735
+ GeoDataFrame with objects
709
736
  """
710
737
  # Get parameters from kwargs or use instance defaults
711
738
  confidence_threshold = kwargs.get(
@@ -715,9 +742,8 @@ class BuildingFootprintExtractor:
715
742
  chip_size = kwargs.get("chip_size", self.chip_size)
716
743
  nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
717
744
  mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
718
- small_building_area = kwargs.get(
719
- "small_building_area", self.small_building_area
720
- )
745
+ min_object_area = kwargs.get("min_object_area", self.min_object_area)
746
+ max_object_area = kwargs.get("max_object_area", self.max_object_area)
721
747
  simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
722
748
 
723
749
  # Print parameters being used
@@ -727,14 +753,15 @@ class BuildingFootprintExtractor:
727
753
  print(f"- Chip size: {chip_size}")
728
754
  print(f"- NMS IoU threshold: {nms_iou_threshold}")
729
755
  print(f"- Mask threshold: {mask_threshold}")
730
- print(f"- Min building area: {small_building_area}")
756
+ print(f"- Min object area: {min_object_area}")
757
+ print(f"- Max object area: {max_object_area}")
731
758
  print(f"- Simplify tolerance: {simplify_tolerance}")
732
- print(f"- Filter edge buildings: {filter_edges}")
759
+ print(f"- Filter edge objects: {filter_edges}")
733
760
  if filter_edges:
734
761
  print(f"- Edge buffer size: {edge_buffer} pixels")
735
762
 
736
763
  # Create dataset
737
- dataset = BuildingFootprintDataset(raster_path=raster_path, chip_size=chip_size)
764
+ dataset = CustomDataset(raster_path=raster_path, chip_size=chip_size)
738
765
  self.raster_stats = dataset.raster_stats
739
766
 
740
767
  # Custom collate function to handle Shapely objects
@@ -854,11 +881,12 @@ class BuildingFootprintExtractor:
854
881
  binary_mask = mask[0] # Get binary mask
855
882
 
856
883
  # Convert mask to polygon with custom parameters
857
- contours = self._mask_to_polygons(
884
+ contours = self.mask_to_polygons(
858
885
  binary_mask,
859
886
  simplify_tolerance=simplify_tolerance,
860
887
  mask_threshold=mask_threshold,
861
- small_building_area=small_building_area,
888
+ min_object_area=min_object_area,
889
+ max_object_area=max_object_area,
862
890
  )
863
891
 
864
892
  # Skip if no valid polygons
@@ -896,24 +924,22 @@ class BuildingFootprintExtractor:
896
924
  {
897
925
  "geometry": all_polygons,
898
926
  "confidence": all_scores,
899
- "class": 1, # Building class
927
+ "class": 1, # Object class
900
928
  },
901
929
  crs=dataset.crs,
902
930
  )
903
931
 
904
932
  # Remove overlapping polygons with custom threshold
905
- gdf = self._filter_overlapping_polygons(
906
- gdf, nms_iou_threshold=nms_iou_threshold
907
- )
933
+ gdf = self.filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
908
934
 
909
- # Filter edge buildings if requested
935
+ # Filter edge objects if requested
910
936
  if filter_edges:
911
- gdf = self.filter_edge_buildings(gdf, raster_path, edge_buffer=edge_buffer)
937
+ gdf = self.filter_edge_objects(gdf, raster_path, edge_buffer=edge_buffer)
912
938
 
913
939
  # Save to file if requested
914
940
  if output_path:
915
941
  gdf.to_file(output_path, driver="GeoJSON")
916
- print(f"Saved {len(gdf)} building footprints to {output_path}")
942
+ print(f"Saved {len(gdf)} objects to {output_path}")
917
943
 
918
944
  return gdf
919
945
 
@@ -921,7 +947,7 @@ class BuildingFootprintExtractor:
921
947
  self, raster_path, output_path=None, batch_size=4, verbose=False, **kwargs
922
948
  ):
923
949
  """
924
- Process a raster file to extract building footprint masks and save as GeoTIFF.
950
+ Process a raster file to extract object masks and save as GeoTIFF.
925
951
 
926
952
  Args:
927
953
  raster_path: Path to input raster file
@@ -955,7 +981,7 @@ class BuildingFootprintExtractor:
955
981
  print(f"- Mask threshold: {mask_threshold}")
956
982
 
957
983
  # Create dataset
958
- dataset = BuildingFootprintDataset(
984
+ dataset = CustomDataset(
959
985
  raster_path=raster_path, chip_size=chip_size, verbose=verbose
960
986
  )
961
987
 
@@ -972,7 +998,7 @@ class BuildingFootprintExtractor:
972
998
  output_profile = src.profile.copy()
973
999
  output_profile.update(
974
1000
  dtype=rasterio.uint8,
975
- count=1, # Single band for building mask
1001
+ count=1, # Single band for object mask
976
1002
  compress="lzw",
977
1003
  nodata=0,
978
1004
  )
@@ -1144,10 +1170,10 @@ class BuildingFootprintExtractor:
1144
1170
  # Write the final mask to the output file
1145
1171
  dst.write(mask_array, 1)
1146
1172
 
1147
- print(f"Building masks saved to {output_path}")
1173
+ print(f"Object masks saved to {output_path}")
1148
1174
  return output_path
1149
1175
 
1150
- def regularize_buildings(
1176
+ def regularize_objects(
1151
1177
  self,
1152
1178
  gdf,
1153
1179
  min_area=10,
@@ -1156,17 +1182,17 @@ class BuildingFootprintExtractor:
1156
1182
  rectangularity_threshold=0.7,
1157
1183
  ):
1158
1184
  """
1159
- Regularize building footprints to enforce right angles and rectangular shapes.
1185
+ Regularize objects to enforce right angles and rectangular shapes.
1160
1186
 
1161
1187
  Args:
1162
- gdf: GeoDataFrame with building footprints
1163
- min_area: Minimum area in square units to keep a building
1188
+ gdf: GeoDataFrame with objects
1189
+ min_area: Minimum area in square units to keep an object
1164
1190
  angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
1165
- orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
1166
- rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
1191
+ orthogonality_threshold: Percentage of angles that must be orthogonal for an object to be regularized
1192
+ rectangularity_threshold: Minimum area ratio to Object's oriented bounding box for rectangular simplification
1167
1193
 
1168
1194
  Returns:
1169
- GeoDataFrame with regularized building footprints
1195
+ GeoDataFrame with regularized objects
1170
1196
  """
1171
1197
  import numpy as np
1172
1198
  from shapely.geometry import Polygon, MultiPolygon, box
@@ -1281,10 +1307,10 @@ class BuildingFootprintExtractor:
1281
1307
  return rect
1282
1308
 
1283
1309
  if gdf is None or len(gdf) == 0:
1284
- print("No buildings to regularize")
1310
+ print("No Objects to regularize")
1285
1311
  return gdf
1286
1312
 
1287
- print(f"Regularizing {len(gdf)} building footprints...")
1313
+ print(f"Regularizing {len(gdf)} objects...")
1288
1314
  print(f"- Angle threshold: {angle_threshold}° from 90°")
1289
1315
  print(f"- Min orthogonality: {orthogonality_threshold*100}% of angles")
1290
1316
  print(
@@ -1295,11 +1321,11 @@ class BuildingFootprintExtractor:
1295
1321
  result_gdf = gdf.copy()
1296
1322
 
1297
1323
  # Track statistics
1298
- total_buildings = len(gdf)
1324
+ total_objects = len(gdf)
1299
1325
  regularized_count = 0
1300
1326
  rectangularized_count = 0
1301
1327
 
1302
- # Process each building
1328
+ # Process each Object
1303
1329
  for idx, row in tqdm(gdf.iterrows(), total=len(gdf)):
1304
1330
  geom = row.geometry
1305
1331
 
@@ -1314,7 +1340,7 @@ class BuildingFootprintExtractor:
1314
1340
  continue
1315
1341
  geom = list(geom.geoms)[np.argmax(areas)]
1316
1342
 
1317
- # Filter out tiny buildings
1343
+ # Filter out tiny Objects
1318
1344
  if geom.area < min_area:
1319
1345
  continue
1320
1346
 
@@ -1331,33 +1357,33 @@ class BuildingFootprintExtractor:
1331
1357
 
1332
1358
  # Decide how to regularize
1333
1359
  if rectangularity >= rectangularity_threshold:
1334
- # Building is already quite rectangular, simplify to a rectangle
1360
+ # Object is already quite rectangular, simplify to a rectangle
1335
1361
  result_gdf.at[idx, "geometry"] = oriented_box
1336
1362
  result_gdf.at[idx, "regularized"] = "rectangle"
1337
1363
  rectangularized_count += 1
1338
1364
  elif orthogonality >= orthogonality_threshold:
1339
- # Building has many orthogonal angles but isn't rectangular
1365
+ # Object has many orthogonal angles but isn't rectangular
1340
1366
  # Could implement more sophisticated regularization here
1341
1367
  # For now, we'll still use the oriented rectangle
1342
1368
  result_gdf.at[idx, "geometry"] = oriented_box
1343
1369
  result_gdf.at[idx, "regularized"] = "orthogonal"
1344
1370
  regularized_count += 1
1345
1371
  else:
1346
- # Building doesn't have clear orthogonal structure
1372
+ # Object doesn't have clear orthogonal structure
1347
1373
  # Keep original but flag as unmodified
1348
1374
  result_gdf.at[idx, "regularized"] = "original"
1349
1375
 
1350
1376
  # Report statistics
1351
1377
  print(f"Regularization completed:")
1352
- print(f"- Total buildings: {total_buildings}")
1378
+ print(f"- Total objects: {total_objects}")
1353
1379
  print(
1354
- f"- Rectangular buildings: {rectangularized_count} ({rectangularized_count/total_buildings*100:.1f}%)"
1380
+ f"- Rectangular objects: {rectangularized_count} ({rectangularized_count/total_objects*100:.1f}%)"
1355
1381
  )
1356
1382
  print(
1357
- f"- Other regularized buildings: {regularized_count} ({regularized_count/total_buildings*100:.1f}%)"
1383
+ f"- Other regularized objects: {regularized_count} ({regularized_count/total_objects*100:.1f}%)"
1358
1384
  )
1359
1385
  print(
1360
- f"- Unmodified buildings: {total_buildings-rectangularized_count-regularized_count} ({(total_buildings-rectangularized_count-regularized_count)/total_buildings*100:.1f}%)"
1386
+ f"- Unmodified objects: {total_objects-rectangularized_count-regularized_count} ({(total_objects-rectangularized_count-regularized_count)/total_objects*100:.1f}%)"
1361
1387
  )
1362
1388
 
1363
1389
  return result_gdf
@@ -1366,14 +1392,14 @@ class BuildingFootprintExtractor:
1366
1392
  self, raster_path, gdf=None, output_path=None, figsize=(12, 12)
1367
1393
  ):
1368
1394
  """
1369
- Visualize building detection results with proper coordinate transformation.
1395
+ Visualize object detection results with proper coordinate transformation.
1370
1396
 
1371
- This function displays building footprints on top of the raster image,
1397
+ This function displays objects on top of the raster image,
1372
1398
  ensuring proper alignment between the GeoDataFrame polygons and the image.
1373
1399
 
1374
1400
  Args:
1375
1401
  raster_path: Path to input raster
1376
- gdf: GeoDataFrame with building polygons (optional)
1402
+ gdf: GeoDataFrame with object polygons (optional)
1377
1403
  output_path: Path to save visualization (optional)
1378
1404
  figsize: Figure size (width, height) in inches
1379
1405
 
@@ -1390,7 +1416,7 @@ class BuildingFootprintExtractor:
1390
1416
  gdf = self.process_raster(raster_path)
1391
1417
 
1392
1418
  if gdf is None or len(gdf) == 0:
1393
- print("No buildings to visualize")
1419
+ print("No objects to visualize")
1394
1420
  return False
1395
1421
 
1396
1422
  # Check if confidence column exists in the GeoDataFrame
@@ -1531,7 +1557,7 @@ class BuildingFootprintExtractor:
1531
1557
  print(f"Unsupported geometry type: {geometry.geom_type}")
1532
1558
  return None
1533
1559
 
1534
- # Plot each building footprint
1560
+ # Plot each object
1535
1561
  for idx, row in gdf.iterrows():
1536
1562
  try:
1537
1563
  # Convert polygon to pixel coordinates
@@ -1593,7 +1619,7 @@ class BuildingFootprintExtractor:
1593
1619
  # Remove axes
1594
1620
  ax.set_xticks([])
1595
1621
  ax.set_yticks([])
1596
- ax.set_title(f"Building Footprints (Found: {len(gdf)})")
1622
+ ax.set_title(f"objects (Found: {len(gdf)})")
1597
1623
 
1598
1624
  # Save if requested
1599
1625
  if output_path:
@@ -1603,21 +1629,21 @@ class BuildingFootprintExtractor:
1603
1629
 
1604
1630
  plt.close()
1605
1631
 
1606
- # Create a simpler visualization focused just on a subset of buildings
1632
+ # Create a simpler visualization focused just on a subset of objects
1607
1633
  if len(gdf) > 0:
1608
1634
  plt.figure(figsize=figsize)
1609
1635
  ax = plt.gca()
1610
1636
 
1611
1637
  # Choose a subset of the image to show
1612
1638
  with rasterio.open(raster_path) as src:
1613
- # Get centroid of first building
1639
+ # Get centroid of first object
1614
1640
  sample_geom = gdf.iloc[0].geometry
1615
1641
  centroid = sample_geom.centroid
1616
1642
 
1617
1643
  # Convert to pixel coordinates
1618
1644
  center_x, center_y = ~src.transform * (centroid.x, centroid.y)
1619
1645
 
1620
- # Define a window around this building
1646
+ # Define a window around this object
1621
1647
  window_size = 500 # pixels
1622
1648
  window = rasterio.windows.Window(
1623
1649
  max(0, int(center_x - window_size / 2)),
@@ -1654,7 +1680,7 @@ class BuildingFootprintExtractor:
1654
1680
  window_bounds = rasterio.windows.bounds(window, src.transform)
1655
1681
  window_box = box(*window_bounds)
1656
1682
 
1657
- # Filter buildings that intersect with this window
1683
+ # Filter objects that intersect with this window
1658
1684
  visible_gdf = gdf[gdf.intersects(window_box)]
1659
1685
 
1660
1686
  # Set up colors for sample view if confidence data exists
@@ -1676,7 +1702,7 @@ class BuildingFootprintExtractor:
1676
1702
  except Exception as e:
1677
1703
  print(f"Error setting up sample confidence visualization: {e}")
1678
1704
 
1679
- # Plot building footprints in sample view
1705
+ # Plot objects in sample view
1680
1706
  for idx, row in visible_gdf.iterrows():
1681
1707
  try:
1682
1708
  # Get window-relative pixel coordinates
@@ -1751,9 +1777,7 @@ class BuildingFootprintExtractor:
1751
1777
  print(f"Error plotting polygon in sample view: {e}")
1752
1778
 
1753
1779
  # Set title
1754
- ax.set_title(
1755
- f"Sample Area - Building Footprints (Showing: {len(visible_gdf)})"
1756
- )
1780
+ ax.set_title(f"Sample Area - objects (Showing: {len(visible_gdf)})")
1757
1781
 
1758
1782
  # Remove axes
1759
1783
  ax.set_xticks([])
@@ -1769,3 +1793,111 @@ class BuildingFootprintExtractor:
1769
1793
  plt.tight_layout()
1770
1794
  plt.savefig(sample_output, dpi=300, bbox_inches="tight")
1771
1795
  print(f"Sample visualization saved to {sample_output}")
1796
+
1797
+
1798
+ class BuildingFootprintExtractor(ObjectDetector):
1799
+ """
1800
+ Building footprint extraction using a pre-trained Mask R-CNN model.
1801
+
1802
+ This class extends the
1803
+ `ObjectDetector` class with additional methods for building footprint extraction."
1804
+ """
1805
+
1806
+ def __init__(
1807
+ self,
1808
+ model_path="building_footprints_usa.pth",
1809
+ repo_id=None,
1810
+ model=None,
1811
+ device=None,
1812
+ ):
1813
+ """
1814
+ Initialize the object extractor.
1815
+
1816
+ Args:
1817
+ model_path: Path to the .pth model file.
1818
+ repo_id: Repo ID for loading models from the Hub.
1819
+ model: Custom model to use for inference.
1820
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
1821
+ """
1822
+ super().__init__(
1823
+ model_path=model_path, repo_id=repo_id, model=model, device=device
1824
+ )
1825
+
1826
+ def regularize_buildings(
1827
+ self,
1828
+ gdf,
1829
+ min_area=10,
1830
+ angle_threshold=15,
1831
+ orthogonality_threshold=0.3,
1832
+ rectangularity_threshold=0.7,
1833
+ ):
1834
+ """
1835
+ Regularize building footprints to enforce right angles and rectangular shapes.
1836
+
1837
+ Args:
1838
+ gdf: GeoDataFrame with building footprints
1839
+ min_area: Minimum area in square units to keep a building
1840
+ angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
1841
+ orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
1842
+ rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
1843
+
1844
+ Returns:
1845
+ GeoDataFrame with regularized building footprints
1846
+ """
1847
+ return self.regularize_objects(
1848
+ gdf,
1849
+ min_area=min_area,
1850
+ angle_threshold=angle_threshold,
1851
+ orthogonality_threshold=orthogonality_threshold,
1852
+ rectangularity_threshold=rectangularity_threshold,
1853
+ )
1854
+
1855
+
1856
+ class CarDetector(ObjectDetector):
1857
+ """
1858
+ Car detection using a pre-trained Mask R-CNN model.
1859
+
1860
+ This class extends the
1861
+ `ObjectDetector` class with additional methods for car detection."
1862
+ """
1863
+
1864
+ def __init__(
1865
+ self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
1866
+ ):
1867
+ """
1868
+ Initialize the object extractor.
1869
+
1870
+ Args:
1871
+ model_path: Path to the .pth model file.
1872
+ repo_id: Repo ID for loading models from the Hub.
1873
+ model: Custom model to use for inference.
1874
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
1875
+ """
1876
+ super().__init__(
1877
+ model_path=model_path, repo_id=repo_id, model=model, device=device
1878
+ )
1879
+
1880
+
1881
+ class ShipDetector(ObjectDetector):
1882
+ """
1883
+ Ship detection using a pre-trained Mask R-CNN model.
1884
+
1885
+ This class extends the
1886
+ `ObjectDetector` class with additional methods for ship detection."
1887
+ """
1888
+
1889
+ def __init__(
1890
+ self, model_path="ship_detection.pth", repo_id=None, model=None, device=None
1891
+ ):
1892
+ """
1893
+ Initialize the object extractor.
1894
+
1895
+ Args:
1896
+ model_path: Path to the .pth model file.
1897
+ repo_id: Repo ID for loading models from the Hub.
1898
+ model: Custom model to use for inference.
1899
+ device: Device to use for inference ('cuda:0', 'cpu', etc.).
1900
+ """
1901
+ super().__init__(
1902
+ model_path=model_path, repo_id=repo_id, model=model, device=device
1903
+ )