geoai-py 0.3.1__py2.py3-none-any.whl → 0.3.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/extract.py +275 -143
- geoai/geoai.py +0 -1
- geoai/preprocess.py +20 -6
- geoai/utils.py +3882 -13
- {geoai_py-0.3.1.dist-info → geoai_py-0.3.3.dist-info}/METADATA +11 -5
- geoai_py-0.3.3.dist-info/RECORD +13 -0
- geoai_py-0.3.1.dist-info/RECORD +0 -13
- {geoai_py-0.3.1.dist-info → geoai_py-0.3.3.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.1.dist-info → geoai_py-0.3.3.dist-info}/WHEEL +0 -0
- {geoai_py-0.3.1.dist-info → geoai_py-0.3.3.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.1.dist-info → geoai_py-0.3.3.dist-info}/top_level.txt +0 -0
geoai/extract.py
CHANGED
|
@@ -13,7 +13,7 @@ import rasterio
|
|
|
13
13
|
from rasterio.windows import Window
|
|
14
14
|
from rasterio.features import shapes
|
|
15
15
|
from huggingface_hub import hf_hub_download
|
|
16
|
-
from .
|
|
16
|
+
from .utils import get_raster_stats
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
19
|
from torchgeo.datasets import NonGeoDataset
|
|
@@ -23,9 +23,9 @@ except ImportError as e:
|
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
26
|
+
class CustomDataset(NonGeoDataset):
|
|
27
27
|
"""
|
|
28
|
-
A TorchGeo dataset for
|
|
28
|
+
A TorchGeo dataset for object extraction.
|
|
29
29
|
Using NonGeoDataset to avoid spatial indexing issues.
|
|
30
30
|
"""
|
|
31
31
|
|
|
@@ -170,18 +170,20 @@ class BuildingFootprintDataset(NonGeoDataset):
|
|
|
170
170
|
return self.rows * self.cols
|
|
171
171
|
|
|
172
172
|
|
|
173
|
-
class
|
|
173
|
+
class ObjectDetector:
|
|
174
174
|
"""
|
|
175
|
-
|
|
175
|
+
Object extraction using Mask R-CNN with TorchGeo.
|
|
176
176
|
"""
|
|
177
177
|
|
|
178
|
-
def __init__(self, model_path=None, device=None):
|
|
178
|
+
def __init__(self, model_path=None, repo_id=None, model=None, device=None):
|
|
179
179
|
"""
|
|
180
|
-
Initialize the
|
|
180
|
+
Initialize the object extractor.
|
|
181
181
|
|
|
182
182
|
Args:
|
|
183
|
-
model_path: Path to the .pth model file
|
|
184
|
-
|
|
183
|
+
model_path: Path to the .pth model file.
|
|
184
|
+
repo_id: Hugging Face repository ID for model download.
|
|
185
|
+
model: Pre-initialized model object (optional).
|
|
186
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
185
187
|
"""
|
|
186
188
|
# Set device
|
|
187
189
|
if device is None:
|
|
@@ -189,31 +191,36 @@ class BuildingFootprintExtractor:
|
|
|
189
191
|
else:
|
|
190
192
|
self.device = torch.device(device)
|
|
191
193
|
|
|
192
|
-
# Default parameters for
|
|
194
|
+
# Default parameters for object detection - these can be overridden in process_raster
|
|
193
195
|
self.chip_size = (512, 512) # Size of image chips for processing
|
|
194
196
|
self.overlap = 0.25 # Default overlap between tiles
|
|
195
197
|
self.confidence_threshold = 0.5 # Default confidence threshold
|
|
196
198
|
self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
|
|
197
|
-
self.
|
|
199
|
+
self.min_object_area = 100 # Minimum area in pixels to keep an object
|
|
200
|
+
self.max_object_area = None # Maximum area in pixels to keep an object
|
|
198
201
|
self.mask_threshold = 0.5 # Threshold for mask binarization
|
|
199
202
|
self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
|
|
200
203
|
|
|
201
204
|
# Initialize model
|
|
202
|
-
self.model = self.
|
|
205
|
+
self.model = self.initialize_model(model)
|
|
203
206
|
|
|
204
207
|
# Download model if needed
|
|
205
|
-
if model_path is None:
|
|
206
|
-
model_path = self.
|
|
208
|
+
if model_path is None or (not os.path.exists(model_path)):
|
|
209
|
+
model_path = self.download_model_from_hf(model_path, repo_id)
|
|
207
210
|
|
|
208
211
|
# Load model weights
|
|
209
|
-
self.
|
|
212
|
+
self.load_weights(model_path)
|
|
210
213
|
|
|
211
214
|
# Set model to evaluation mode
|
|
212
215
|
self.model.eval()
|
|
213
216
|
|
|
214
|
-
def
|
|
217
|
+
def download_model_from_hf(self, model_path=None, repo_id=None):
|
|
215
218
|
"""
|
|
216
|
-
Download the
|
|
219
|
+
Download the object detection model from Hugging Face.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
model_path: Path to the model file.
|
|
223
|
+
repo_id: Hugging Face repository ID.
|
|
217
224
|
|
|
218
225
|
Returns:
|
|
219
226
|
Path to the downloaded model file
|
|
@@ -223,17 +230,14 @@ class BuildingFootprintExtractor:
|
|
|
223
230
|
print("Model path not specified, downloading from Hugging Face...")
|
|
224
231
|
|
|
225
232
|
# Define the repository ID and model filename
|
|
226
|
-
repo_id
|
|
227
|
-
|
|
233
|
+
if repo_id is None:
|
|
234
|
+
repo_id = "giswqs/geoai"
|
|
228
235
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
# os.path.expanduser("~"), ".cache", "building_footprints"
|
|
232
|
-
# )
|
|
233
|
-
# os.makedirs(cache_dir, exist_ok=True)
|
|
236
|
+
if model_path is None:
|
|
237
|
+
model_path = "building_footprints_usa.pth"
|
|
234
238
|
|
|
235
239
|
# Download the model
|
|
236
|
-
model_path = hf_hub_download(repo_id=repo_id, filename=
|
|
240
|
+
model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
|
|
237
241
|
print(f"Model downloaded to: {model_path}")
|
|
238
242
|
|
|
239
243
|
return model_path
|
|
@@ -243,28 +247,36 @@ class BuildingFootprintExtractor:
|
|
|
243
247
|
print("Please specify a local model path or ensure internet connectivity.")
|
|
244
248
|
raise
|
|
245
249
|
|
|
246
|
-
def
|
|
247
|
-
"""Initialize
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
250
|
+
def initialize_model(self, model):
|
|
251
|
+
"""Initialize a deep learning model for object detection.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
model (torch.nn.Module): A pre-initialized model object.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
torch.nn.Module: A deep learning model for object detection.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
if model is None: # Initialize Mask R-CNN model with ResNet50 backbone.
|
|
261
|
+
# Standard image mean and std for pre-trained models
|
|
262
|
+
image_mean = [0.485, 0.456, 0.406]
|
|
263
|
+
image_std = [0.229, 0.224, 0.225]
|
|
264
|
+
|
|
265
|
+
# Create model with explicit normalization parameters
|
|
266
|
+
model = maskrcnn_resnet50_fpn(
|
|
267
|
+
weights=None,
|
|
268
|
+
progress=False,
|
|
269
|
+
num_classes=2, # Background + object
|
|
270
|
+
weights_backbone=None,
|
|
271
|
+
# These parameters ensure consistent normalization
|
|
272
|
+
image_mean=image_mean,
|
|
273
|
+
image_std=image_std,
|
|
274
|
+
)
|
|
263
275
|
|
|
264
276
|
model.to(self.device)
|
|
265
277
|
return model
|
|
266
278
|
|
|
267
|
-
def
|
|
279
|
+
def load_weights(self, model_path):
|
|
268
280
|
"""
|
|
269
281
|
Load weights from file with error handling for different formats.
|
|
270
282
|
|
|
@@ -306,7 +318,7 @@ class BuildingFootprintExtractor:
|
|
|
306
318
|
except Exception as e:
|
|
307
319
|
raise RuntimeError(f"Failed to load model: {e}")
|
|
308
320
|
|
|
309
|
-
def
|
|
321
|
+
def mask_to_polygons(self, mask, **kwargs):
|
|
310
322
|
"""
|
|
311
323
|
Convert binary mask to polygon contours using OpenCV.
|
|
312
324
|
|
|
@@ -315,7 +327,8 @@ class BuildingFootprintExtractor:
|
|
|
315
327
|
**kwargs: Optional parameters:
|
|
316
328
|
simplify_tolerance: Tolerance for polygon simplification
|
|
317
329
|
mask_threshold: Threshold for mask binarization
|
|
318
|
-
|
|
330
|
+
min_object_area: Minimum area in pixels to keep an object
|
|
331
|
+
max_object_area: Maximum area in pixels to keep an object
|
|
319
332
|
|
|
320
333
|
Returns:
|
|
321
334
|
List of polygons as lists of (x, y) coordinates
|
|
@@ -324,9 +337,8 @@ class BuildingFootprintExtractor:
|
|
|
324
337
|
# Get parameters from kwargs or use instance defaults
|
|
325
338
|
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
326
339
|
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
)
|
|
340
|
+
min_object_area = kwargs.get("min_object_area", self.min_object_area)
|
|
341
|
+
max_object_area = kwargs.get("max_object_area", self.max_object_area)
|
|
330
342
|
|
|
331
343
|
# Ensure binary mask
|
|
332
344
|
mask = (mask > mask_threshold).astype(np.uint8)
|
|
@@ -342,7 +354,14 @@ class BuildingFootprintExtractor:
|
|
|
342
354
|
polygons = []
|
|
343
355
|
for contour in contours:
|
|
344
356
|
# Filter out too small contours
|
|
345
|
-
if contour.shape[0] < 3 or cv2.contourArea(contour) <
|
|
357
|
+
if contour.shape[0] < 3 or cv2.contourArea(contour) < min_object_area:
|
|
358
|
+
continue
|
|
359
|
+
|
|
360
|
+
# Filter out too large contours
|
|
361
|
+
if (
|
|
362
|
+
max_object_area is not None
|
|
363
|
+
and cv2.contourArea(contour) > max_object_area
|
|
364
|
+
):
|
|
346
365
|
continue
|
|
347
366
|
|
|
348
367
|
# Simplify contour if it has many points
|
|
@@ -356,7 +375,7 @@ class BuildingFootprintExtractor:
|
|
|
356
375
|
|
|
357
376
|
return polygons
|
|
358
377
|
|
|
359
|
-
def
|
|
378
|
+
def filter_overlapping_polygons(self, gdf, **kwargs):
|
|
360
379
|
"""
|
|
361
380
|
Filter overlapping polygons using non-maximum suppression.
|
|
362
381
|
|
|
@@ -413,26 +432,26 @@ class BuildingFootprintExtractor:
|
|
|
413
432
|
|
|
414
433
|
return gdf.iloc[keep_indices]
|
|
415
434
|
|
|
416
|
-
def
|
|
435
|
+
def filter_edge_objects(self, gdf, raster_path, edge_buffer=10):
|
|
417
436
|
"""
|
|
418
|
-
Filter out
|
|
437
|
+
Filter out object detections that fall in padding/edge areas of the image.
|
|
419
438
|
|
|
420
439
|
Args:
|
|
421
|
-
gdf: GeoDataFrame with
|
|
440
|
+
gdf: GeoDataFrame with object detections
|
|
422
441
|
raster_path: Path to the original raster file
|
|
423
442
|
edge_buffer: Buffer in pixels to consider as edge region
|
|
424
443
|
|
|
425
444
|
Returns:
|
|
426
|
-
GeoDataFrame with filtered
|
|
445
|
+
GeoDataFrame with filtered objects
|
|
427
446
|
"""
|
|
428
447
|
import rasterio
|
|
429
448
|
from shapely.geometry import box
|
|
430
449
|
|
|
431
|
-
# If no
|
|
450
|
+
# If no objects detected, return empty GeoDataFrame
|
|
432
451
|
if gdf is None or len(gdf) == 0:
|
|
433
452
|
return gdf
|
|
434
453
|
|
|
435
|
-
print(f"
|
|
454
|
+
print(f"Objects before filtering: {len(gdf)}")
|
|
436
455
|
|
|
437
456
|
with rasterio.open(raster_path) as src:
|
|
438
457
|
# Get raster bounds
|
|
@@ -461,18 +480,18 @@ class BuildingFootprintExtractor:
|
|
|
461
480
|
else:
|
|
462
481
|
inner_box = box(*inner_bounds)
|
|
463
482
|
|
|
464
|
-
# Filter out
|
|
483
|
+
# Filter out objects that intersect with the edge of the image
|
|
465
484
|
filtered_gdf = gdf[gdf.intersects(inner_box)]
|
|
466
485
|
|
|
467
|
-
# Additional check for
|
|
468
|
-
|
|
486
|
+
# Additional check for objects that have >50% of their area outside the valid region
|
|
487
|
+
valid_objects = []
|
|
469
488
|
for idx, row in filtered_gdf.iterrows():
|
|
470
489
|
if row.geometry.intersection(inner_box).area >= 0.5 * row.geometry.area:
|
|
471
|
-
|
|
490
|
+
valid_objects.append(idx)
|
|
472
491
|
|
|
473
|
-
filtered_gdf = filtered_gdf.loc[
|
|
492
|
+
filtered_gdf = filtered_gdf.loc[valid_objects]
|
|
474
493
|
|
|
475
|
-
print(f"
|
|
494
|
+
print(f"Objects after filtering: {len(filtered_gdf)}")
|
|
476
495
|
|
|
477
496
|
return filtered_gdf
|
|
478
497
|
|
|
@@ -482,28 +501,30 @@ class BuildingFootprintExtractor:
|
|
|
482
501
|
output_path=None,
|
|
483
502
|
simplify_tolerance=None,
|
|
484
503
|
mask_threshold=None,
|
|
485
|
-
|
|
504
|
+
min_object_area=None,
|
|
505
|
+
max_object_area=None,
|
|
486
506
|
nms_iou_threshold=None,
|
|
487
507
|
regularize=True,
|
|
488
508
|
angle_threshold=15,
|
|
489
509
|
rectangularity_threshold=0.7,
|
|
490
510
|
):
|
|
491
511
|
"""
|
|
492
|
-
Convert
|
|
512
|
+
Convert an object mask GeoTIFF to vector polygons and save as GeoJSON.
|
|
493
513
|
|
|
494
514
|
Args:
|
|
495
|
-
mask_path: Path to the
|
|
515
|
+
mask_path: Path to the object masks GeoTIFF
|
|
496
516
|
output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
|
|
497
517
|
simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
|
|
498
518
|
mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
|
|
499
|
-
|
|
519
|
+
min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
|
|
520
|
+
max_object_area: Minimum area in pixels to keep an object (default: self.max_object_area)
|
|
500
521
|
nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
|
|
501
|
-
regularize: Whether to regularize
|
|
522
|
+
regularize: Whether to regularize objects to right angles (default: True)
|
|
502
523
|
angle_threshold: Maximum deviation from 90 degrees for regularization (default: 15)
|
|
503
524
|
rectangularity_threshold: Threshold for rectangle simplification (default: 0.7)
|
|
504
525
|
|
|
505
526
|
Returns:
|
|
506
|
-
GeoDataFrame with
|
|
527
|
+
GeoDataFrame with objects
|
|
507
528
|
"""
|
|
508
529
|
# Use class defaults if parameters not provided
|
|
509
530
|
simplify_tolerance = (
|
|
@@ -514,10 +535,11 @@ class BuildingFootprintExtractor:
|
|
|
514
535
|
mask_threshold = (
|
|
515
536
|
mask_threshold if mask_threshold is not None else self.mask_threshold
|
|
516
537
|
)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
538
|
+
min_object_area = (
|
|
539
|
+
min_object_area if min_object_area is not None else self.min_object_area
|
|
540
|
+
)
|
|
541
|
+
max_object_area = (
|
|
542
|
+
max_object_area if max_object_area is not None else self.max_object_area
|
|
521
543
|
)
|
|
522
544
|
nms_iou_threshold = (
|
|
523
545
|
nms_iou_threshold
|
|
@@ -531,10 +553,11 @@ class BuildingFootprintExtractor:
|
|
|
531
553
|
|
|
532
554
|
print(f"Converting mask to GeoJSON with parameters:")
|
|
533
555
|
print(f"- Mask threshold: {mask_threshold}")
|
|
534
|
-
print(f"- Min
|
|
556
|
+
print(f"- Min object area: {min_object_area}")
|
|
557
|
+
print(f"- Max object area: {max_object_area}")
|
|
535
558
|
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
536
559
|
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
537
|
-
print(f"- Regularize
|
|
560
|
+
print(f"- Regularize objects: {regularize}")
|
|
538
561
|
if regularize:
|
|
539
562
|
print(f"- Angle threshold: {angle_threshold}° from 90°")
|
|
540
563
|
print(f"- Rectangularity threshold: {rectangularity_threshold*100}%")
|
|
@@ -564,7 +587,7 @@ class BuildingFootprintExtractor:
|
|
|
564
587
|
)
|
|
565
588
|
|
|
566
589
|
print(
|
|
567
|
-
f"Found {num_labels-1} potential
|
|
590
|
+
f"Found {num_labels-1} potential objects"
|
|
568
591
|
) # Subtract 1 for background
|
|
569
592
|
|
|
570
593
|
# Create list to store polygons and confidence values
|
|
@@ -573,19 +596,23 @@ class BuildingFootprintExtractor:
|
|
|
573
596
|
|
|
574
597
|
# Process each component (skip the first one which is background)
|
|
575
598
|
for i in tqdm(range(1, num_labels)):
|
|
576
|
-
# Extract this
|
|
599
|
+
# Extract this object
|
|
577
600
|
area = stats[i, cv2.CC_STAT_AREA]
|
|
578
601
|
|
|
579
602
|
# Skip if too small
|
|
580
|
-
if area <
|
|
603
|
+
if area < min_object_area:
|
|
604
|
+
continue
|
|
605
|
+
|
|
606
|
+
# Skip if too large
|
|
607
|
+
if max_object_area is not None and area > max_object_area:
|
|
581
608
|
continue
|
|
582
609
|
|
|
583
|
-
# Create a mask for this
|
|
584
|
-
|
|
610
|
+
# Create a mask for this object
|
|
611
|
+
object_mask = (labels == i).astype(np.uint8)
|
|
585
612
|
|
|
586
613
|
# Find contours
|
|
587
614
|
contours, _ = cv2.findContours(
|
|
588
|
-
|
|
615
|
+
object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
589
616
|
)
|
|
590
617
|
|
|
591
618
|
# Process each contour
|
|
@@ -633,17 +660,17 @@ class BuildingFootprintExtractor:
|
|
|
633
660
|
{
|
|
634
661
|
"geometry": all_polygons,
|
|
635
662
|
"confidence": all_confidences,
|
|
636
|
-
"class": 1, #
|
|
663
|
+
"class": 1, # Object class
|
|
637
664
|
},
|
|
638
665
|
crs=crs,
|
|
639
666
|
)
|
|
640
667
|
|
|
641
668
|
# Apply non-maximum suppression to remove overlapping polygons
|
|
642
|
-
gdf = self.
|
|
669
|
+
gdf = self.filter_overlapping_polygons(
|
|
643
670
|
gdf, nms_iou_threshold=nms_iou_threshold
|
|
644
671
|
)
|
|
645
672
|
|
|
646
|
-
print(f"
|
|
673
|
+
print(f"Object count after NMS filtering: {len(gdf)}")
|
|
647
674
|
|
|
648
675
|
# Apply regularization if requested
|
|
649
676
|
if regularize and len(gdf) > 0:
|
|
@@ -661,8 +688,8 @@ class BuildingFootprintExtractor:
|
|
|
661
688
|
# Use 10 pixels as minimum area in geographic units
|
|
662
689
|
min_geo_area = 10 * avg_pixel_area
|
|
663
690
|
|
|
664
|
-
# Regularize
|
|
665
|
-
gdf = self.
|
|
691
|
+
# Regularize objects
|
|
692
|
+
gdf = self.regularize_objects(
|
|
666
693
|
gdf,
|
|
667
694
|
min_area=min_geo_area,
|
|
668
695
|
angle_threshold=angle_threshold,
|
|
@@ -672,7 +699,7 @@ class BuildingFootprintExtractor:
|
|
|
672
699
|
# Save to file
|
|
673
700
|
if output_path:
|
|
674
701
|
gdf.to_file(output_path)
|
|
675
|
-
print(f"Saved {len(gdf)}
|
|
702
|
+
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
676
703
|
|
|
677
704
|
return gdf
|
|
678
705
|
|
|
@@ -687,25 +714,25 @@ class BuildingFootprintExtractor:
|
|
|
687
714
|
**kwargs,
|
|
688
715
|
):
|
|
689
716
|
"""
|
|
690
|
-
Process a raster file to extract
|
|
717
|
+
Process a raster file to extract objects with customizable parameters.
|
|
691
718
|
|
|
692
719
|
Args:
|
|
693
720
|
raster_path: Path to input raster file
|
|
694
721
|
output_path: Path to output GeoJSON file (optional)
|
|
695
722
|
batch_size: Batch size for processing
|
|
696
|
-
filter_edges: Whether to filter out
|
|
697
|
-
edge_buffer: Size of edge buffer in pixels to filter out
|
|
723
|
+
filter_edges: Whether to filter out objects at the edges of the image
|
|
724
|
+
edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
|
|
698
725
|
**kwargs: Additional parameters:
|
|
699
726
|
confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
|
|
700
727
|
overlap: Overlap between adjacent tiles (0.0-1.0)
|
|
701
728
|
chip_size: Size of image chips for processing (height, width)
|
|
702
729
|
nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
|
|
703
730
|
mask_threshold: Threshold for mask binarization (0.0-1.0)
|
|
704
|
-
|
|
731
|
+
min_object_area: Minimum area in pixels to keep an object
|
|
705
732
|
simplify_tolerance: Tolerance for polygon simplification
|
|
706
733
|
|
|
707
734
|
Returns:
|
|
708
|
-
GeoDataFrame with
|
|
735
|
+
GeoDataFrame with objects
|
|
709
736
|
"""
|
|
710
737
|
# Get parameters from kwargs or use instance defaults
|
|
711
738
|
confidence_threshold = kwargs.get(
|
|
@@ -715,9 +742,8 @@ class BuildingFootprintExtractor:
|
|
|
715
742
|
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
716
743
|
nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
|
|
717
744
|
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
)
|
|
745
|
+
min_object_area = kwargs.get("min_object_area", self.min_object_area)
|
|
746
|
+
max_object_area = kwargs.get("max_object_area", self.max_object_area)
|
|
721
747
|
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
722
748
|
|
|
723
749
|
# Print parameters being used
|
|
@@ -727,14 +753,15 @@ class BuildingFootprintExtractor:
|
|
|
727
753
|
print(f"- Chip size: {chip_size}")
|
|
728
754
|
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
729
755
|
print(f"- Mask threshold: {mask_threshold}")
|
|
730
|
-
print(f"- Min
|
|
756
|
+
print(f"- Min object area: {min_object_area}")
|
|
757
|
+
print(f"- Max object area: {max_object_area}")
|
|
731
758
|
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
732
|
-
print(f"- Filter edge
|
|
759
|
+
print(f"- Filter edge objects: {filter_edges}")
|
|
733
760
|
if filter_edges:
|
|
734
761
|
print(f"- Edge buffer size: {edge_buffer} pixels")
|
|
735
762
|
|
|
736
763
|
# Create dataset
|
|
737
|
-
dataset =
|
|
764
|
+
dataset = CustomDataset(raster_path=raster_path, chip_size=chip_size)
|
|
738
765
|
self.raster_stats = dataset.raster_stats
|
|
739
766
|
|
|
740
767
|
# Custom collate function to handle Shapely objects
|
|
@@ -854,11 +881,12 @@ class BuildingFootprintExtractor:
|
|
|
854
881
|
binary_mask = mask[0] # Get binary mask
|
|
855
882
|
|
|
856
883
|
# Convert mask to polygon with custom parameters
|
|
857
|
-
contours = self.
|
|
884
|
+
contours = self.mask_to_polygons(
|
|
858
885
|
binary_mask,
|
|
859
886
|
simplify_tolerance=simplify_tolerance,
|
|
860
887
|
mask_threshold=mask_threshold,
|
|
861
|
-
|
|
888
|
+
min_object_area=min_object_area,
|
|
889
|
+
max_object_area=max_object_area,
|
|
862
890
|
)
|
|
863
891
|
|
|
864
892
|
# Skip if no valid polygons
|
|
@@ -896,24 +924,22 @@ class BuildingFootprintExtractor:
|
|
|
896
924
|
{
|
|
897
925
|
"geometry": all_polygons,
|
|
898
926
|
"confidence": all_scores,
|
|
899
|
-
"class": 1, #
|
|
927
|
+
"class": 1, # Object class
|
|
900
928
|
},
|
|
901
929
|
crs=dataset.crs,
|
|
902
930
|
)
|
|
903
931
|
|
|
904
932
|
# Remove overlapping polygons with custom threshold
|
|
905
|
-
gdf = self.
|
|
906
|
-
gdf, nms_iou_threshold=nms_iou_threshold
|
|
907
|
-
)
|
|
933
|
+
gdf = self.filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
|
|
908
934
|
|
|
909
|
-
# Filter edge
|
|
935
|
+
# Filter edge objects if requested
|
|
910
936
|
if filter_edges:
|
|
911
|
-
gdf = self.
|
|
937
|
+
gdf = self.filter_edge_objects(gdf, raster_path, edge_buffer=edge_buffer)
|
|
912
938
|
|
|
913
939
|
# Save to file if requested
|
|
914
940
|
if output_path:
|
|
915
941
|
gdf.to_file(output_path, driver="GeoJSON")
|
|
916
|
-
print(f"Saved {len(gdf)}
|
|
942
|
+
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
917
943
|
|
|
918
944
|
return gdf
|
|
919
945
|
|
|
@@ -921,7 +947,7 @@ class BuildingFootprintExtractor:
|
|
|
921
947
|
self, raster_path, output_path=None, batch_size=4, verbose=False, **kwargs
|
|
922
948
|
):
|
|
923
949
|
"""
|
|
924
|
-
Process a raster file to extract
|
|
950
|
+
Process a raster file to extract object masks and save as GeoTIFF.
|
|
925
951
|
|
|
926
952
|
Args:
|
|
927
953
|
raster_path: Path to input raster file
|
|
@@ -955,7 +981,7 @@ class BuildingFootprintExtractor:
|
|
|
955
981
|
print(f"- Mask threshold: {mask_threshold}")
|
|
956
982
|
|
|
957
983
|
# Create dataset
|
|
958
|
-
dataset =
|
|
984
|
+
dataset = CustomDataset(
|
|
959
985
|
raster_path=raster_path, chip_size=chip_size, verbose=verbose
|
|
960
986
|
)
|
|
961
987
|
|
|
@@ -972,7 +998,7 @@ class BuildingFootprintExtractor:
|
|
|
972
998
|
output_profile = src.profile.copy()
|
|
973
999
|
output_profile.update(
|
|
974
1000
|
dtype=rasterio.uint8,
|
|
975
|
-
count=1, # Single band for
|
|
1001
|
+
count=1, # Single band for object mask
|
|
976
1002
|
compress="lzw",
|
|
977
1003
|
nodata=0,
|
|
978
1004
|
)
|
|
@@ -1144,10 +1170,10 @@ class BuildingFootprintExtractor:
|
|
|
1144
1170
|
# Write the final mask to the output file
|
|
1145
1171
|
dst.write(mask_array, 1)
|
|
1146
1172
|
|
|
1147
|
-
print(f"
|
|
1173
|
+
print(f"Object masks saved to {output_path}")
|
|
1148
1174
|
return output_path
|
|
1149
1175
|
|
|
1150
|
-
def
|
|
1176
|
+
def regularize_objects(
|
|
1151
1177
|
self,
|
|
1152
1178
|
gdf,
|
|
1153
1179
|
min_area=10,
|
|
@@ -1156,17 +1182,17 @@ class BuildingFootprintExtractor:
|
|
|
1156
1182
|
rectangularity_threshold=0.7,
|
|
1157
1183
|
):
|
|
1158
1184
|
"""
|
|
1159
|
-
Regularize
|
|
1185
|
+
Regularize objects to enforce right angles and rectangular shapes.
|
|
1160
1186
|
|
|
1161
1187
|
Args:
|
|
1162
|
-
gdf: GeoDataFrame with
|
|
1163
|
-
min_area: Minimum area in square units to keep
|
|
1188
|
+
gdf: GeoDataFrame with objects
|
|
1189
|
+
min_area: Minimum area in square units to keep an object
|
|
1164
1190
|
angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
|
|
1165
|
-
orthogonality_threshold: Percentage of angles that must be orthogonal for
|
|
1166
|
-
rectangularity_threshold: Minimum area ratio to
|
|
1191
|
+
orthogonality_threshold: Percentage of angles that must be orthogonal for an object to be regularized
|
|
1192
|
+
rectangularity_threshold: Minimum area ratio to Object's oriented bounding box for rectangular simplification
|
|
1167
1193
|
|
|
1168
1194
|
Returns:
|
|
1169
|
-
GeoDataFrame with regularized
|
|
1195
|
+
GeoDataFrame with regularized objects
|
|
1170
1196
|
"""
|
|
1171
1197
|
import numpy as np
|
|
1172
1198
|
from shapely.geometry import Polygon, MultiPolygon, box
|
|
@@ -1281,10 +1307,10 @@ class BuildingFootprintExtractor:
|
|
|
1281
1307
|
return rect
|
|
1282
1308
|
|
|
1283
1309
|
if gdf is None or len(gdf) == 0:
|
|
1284
|
-
print("No
|
|
1310
|
+
print("No Objects to regularize")
|
|
1285
1311
|
return gdf
|
|
1286
1312
|
|
|
1287
|
-
print(f"Regularizing {len(gdf)}
|
|
1313
|
+
print(f"Regularizing {len(gdf)} objects...")
|
|
1288
1314
|
print(f"- Angle threshold: {angle_threshold}° from 90°")
|
|
1289
1315
|
print(f"- Min orthogonality: {orthogonality_threshold*100}% of angles")
|
|
1290
1316
|
print(
|
|
@@ -1295,11 +1321,11 @@ class BuildingFootprintExtractor:
|
|
|
1295
1321
|
result_gdf = gdf.copy()
|
|
1296
1322
|
|
|
1297
1323
|
# Track statistics
|
|
1298
|
-
|
|
1324
|
+
total_objects = len(gdf)
|
|
1299
1325
|
regularized_count = 0
|
|
1300
1326
|
rectangularized_count = 0
|
|
1301
1327
|
|
|
1302
|
-
# Process each
|
|
1328
|
+
# Process each Object
|
|
1303
1329
|
for idx, row in tqdm(gdf.iterrows(), total=len(gdf)):
|
|
1304
1330
|
geom = row.geometry
|
|
1305
1331
|
|
|
@@ -1314,7 +1340,7 @@ class BuildingFootprintExtractor:
|
|
|
1314
1340
|
continue
|
|
1315
1341
|
geom = list(geom.geoms)[np.argmax(areas)]
|
|
1316
1342
|
|
|
1317
|
-
# Filter out tiny
|
|
1343
|
+
# Filter out tiny Objects
|
|
1318
1344
|
if geom.area < min_area:
|
|
1319
1345
|
continue
|
|
1320
1346
|
|
|
@@ -1331,33 +1357,33 @@ class BuildingFootprintExtractor:
|
|
|
1331
1357
|
|
|
1332
1358
|
# Decide how to regularize
|
|
1333
1359
|
if rectangularity >= rectangularity_threshold:
|
|
1334
|
-
#
|
|
1360
|
+
# Object is already quite rectangular, simplify to a rectangle
|
|
1335
1361
|
result_gdf.at[idx, "geometry"] = oriented_box
|
|
1336
1362
|
result_gdf.at[idx, "regularized"] = "rectangle"
|
|
1337
1363
|
rectangularized_count += 1
|
|
1338
1364
|
elif orthogonality >= orthogonality_threshold:
|
|
1339
|
-
#
|
|
1365
|
+
# Object has many orthogonal angles but isn't rectangular
|
|
1340
1366
|
# Could implement more sophisticated regularization here
|
|
1341
1367
|
# For now, we'll still use the oriented rectangle
|
|
1342
1368
|
result_gdf.at[idx, "geometry"] = oriented_box
|
|
1343
1369
|
result_gdf.at[idx, "regularized"] = "orthogonal"
|
|
1344
1370
|
regularized_count += 1
|
|
1345
1371
|
else:
|
|
1346
|
-
#
|
|
1372
|
+
# Object doesn't have clear orthogonal structure
|
|
1347
1373
|
# Keep original but flag as unmodified
|
|
1348
1374
|
result_gdf.at[idx, "regularized"] = "original"
|
|
1349
1375
|
|
|
1350
1376
|
# Report statistics
|
|
1351
1377
|
print(f"Regularization completed:")
|
|
1352
|
-
print(f"- Total
|
|
1378
|
+
print(f"- Total objects: {total_objects}")
|
|
1353
1379
|
print(
|
|
1354
|
-
f"- Rectangular
|
|
1380
|
+
f"- Rectangular objects: {rectangularized_count} ({rectangularized_count/total_objects*100:.1f}%)"
|
|
1355
1381
|
)
|
|
1356
1382
|
print(
|
|
1357
|
-
f"- Other regularized
|
|
1383
|
+
f"- Other regularized objects: {regularized_count} ({regularized_count/total_objects*100:.1f}%)"
|
|
1358
1384
|
)
|
|
1359
1385
|
print(
|
|
1360
|
-
f"- Unmodified
|
|
1386
|
+
f"- Unmodified objects: {total_objects-rectangularized_count-regularized_count} ({(total_objects-rectangularized_count-regularized_count)/total_objects*100:.1f}%)"
|
|
1361
1387
|
)
|
|
1362
1388
|
|
|
1363
1389
|
return result_gdf
|
|
@@ -1366,14 +1392,14 @@ class BuildingFootprintExtractor:
|
|
|
1366
1392
|
self, raster_path, gdf=None, output_path=None, figsize=(12, 12)
|
|
1367
1393
|
):
|
|
1368
1394
|
"""
|
|
1369
|
-
Visualize
|
|
1395
|
+
Visualize object detection results with proper coordinate transformation.
|
|
1370
1396
|
|
|
1371
|
-
This function displays
|
|
1397
|
+
This function displays objects on top of the raster image,
|
|
1372
1398
|
ensuring proper alignment between the GeoDataFrame polygons and the image.
|
|
1373
1399
|
|
|
1374
1400
|
Args:
|
|
1375
1401
|
raster_path: Path to input raster
|
|
1376
|
-
gdf: GeoDataFrame with
|
|
1402
|
+
gdf: GeoDataFrame with object polygons (optional)
|
|
1377
1403
|
output_path: Path to save visualization (optional)
|
|
1378
1404
|
figsize: Figure size (width, height) in inches
|
|
1379
1405
|
|
|
@@ -1390,7 +1416,7 @@ class BuildingFootprintExtractor:
|
|
|
1390
1416
|
gdf = self.process_raster(raster_path)
|
|
1391
1417
|
|
|
1392
1418
|
if gdf is None or len(gdf) == 0:
|
|
1393
|
-
print("No
|
|
1419
|
+
print("No objects to visualize")
|
|
1394
1420
|
return False
|
|
1395
1421
|
|
|
1396
1422
|
# Check if confidence column exists in the GeoDataFrame
|
|
@@ -1531,7 +1557,7 @@ class BuildingFootprintExtractor:
|
|
|
1531
1557
|
print(f"Unsupported geometry type: {geometry.geom_type}")
|
|
1532
1558
|
return None
|
|
1533
1559
|
|
|
1534
|
-
# Plot each
|
|
1560
|
+
# Plot each object
|
|
1535
1561
|
for idx, row in gdf.iterrows():
|
|
1536
1562
|
try:
|
|
1537
1563
|
# Convert polygon to pixel coordinates
|
|
@@ -1593,7 +1619,7 @@ class BuildingFootprintExtractor:
|
|
|
1593
1619
|
# Remove axes
|
|
1594
1620
|
ax.set_xticks([])
|
|
1595
1621
|
ax.set_yticks([])
|
|
1596
|
-
ax.set_title(f"
|
|
1622
|
+
ax.set_title(f"objects (Found: {len(gdf)})")
|
|
1597
1623
|
|
|
1598
1624
|
# Save if requested
|
|
1599
1625
|
if output_path:
|
|
@@ -1603,21 +1629,21 @@ class BuildingFootprintExtractor:
|
|
|
1603
1629
|
|
|
1604
1630
|
plt.close()
|
|
1605
1631
|
|
|
1606
|
-
# Create a simpler visualization focused just on a subset of
|
|
1632
|
+
# Create a simpler visualization focused just on a subset of objects
|
|
1607
1633
|
if len(gdf) > 0:
|
|
1608
1634
|
plt.figure(figsize=figsize)
|
|
1609
1635
|
ax = plt.gca()
|
|
1610
1636
|
|
|
1611
1637
|
# Choose a subset of the image to show
|
|
1612
1638
|
with rasterio.open(raster_path) as src:
|
|
1613
|
-
# Get centroid of first
|
|
1639
|
+
# Get centroid of first object
|
|
1614
1640
|
sample_geom = gdf.iloc[0].geometry
|
|
1615
1641
|
centroid = sample_geom.centroid
|
|
1616
1642
|
|
|
1617
1643
|
# Convert to pixel coordinates
|
|
1618
1644
|
center_x, center_y = ~src.transform * (centroid.x, centroid.y)
|
|
1619
1645
|
|
|
1620
|
-
# Define a window around this
|
|
1646
|
+
# Define a window around this object
|
|
1621
1647
|
window_size = 500 # pixels
|
|
1622
1648
|
window = rasterio.windows.Window(
|
|
1623
1649
|
max(0, int(center_x - window_size / 2)),
|
|
@@ -1654,7 +1680,7 @@ class BuildingFootprintExtractor:
|
|
|
1654
1680
|
window_bounds = rasterio.windows.bounds(window, src.transform)
|
|
1655
1681
|
window_box = box(*window_bounds)
|
|
1656
1682
|
|
|
1657
|
-
# Filter
|
|
1683
|
+
# Filter objects that intersect with this window
|
|
1658
1684
|
visible_gdf = gdf[gdf.intersects(window_box)]
|
|
1659
1685
|
|
|
1660
1686
|
# Set up colors for sample view if confidence data exists
|
|
@@ -1676,7 +1702,7 @@ class BuildingFootprintExtractor:
|
|
|
1676
1702
|
except Exception as e:
|
|
1677
1703
|
print(f"Error setting up sample confidence visualization: {e}")
|
|
1678
1704
|
|
|
1679
|
-
# Plot
|
|
1705
|
+
# Plot objects in sample view
|
|
1680
1706
|
for idx, row in visible_gdf.iterrows():
|
|
1681
1707
|
try:
|
|
1682
1708
|
# Get window-relative pixel coordinates
|
|
@@ -1751,9 +1777,7 @@ class BuildingFootprintExtractor:
|
|
|
1751
1777
|
print(f"Error plotting polygon in sample view: {e}")
|
|
1752
1778
|
|
|
1753
1779
|
# Set title
|
|
1754
|
-
ax.set_title(
|
|
1755
|
-
f"Sample Area - Building Footprints (Showing: {len(visible_gdf)})"
|
|
1756
|
-
)
|
|
1780
|
+
ax.set_title(f"Sample Area - objects (Showing: {len(visible_gdf)})")
|
|
1757
1781
|
|
|
1758
1782
|
# Remove axes
|
|
1759
1783
|
ax.set_xticks([])
|
|
@@ -1769,3 +1793,111 @@ class BuildingFootprintExtractor:
|
|
|
1769
1793
|
plt.tight_layout()
|
|
1770
1794
|
plt.savefig(sample_output, dpi=300, bbox_inches="tight")
|
|
1771
1795
|
print(f"Sample visualization saved to {sample_output}")
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
class BuildingFootprintExtractor(ObjectDetector):
|
|
1799
|
+
"""
|
|
1800
|
+
Building footprint extraction using a pre-trained Mask R-CNN model.
|
|
1801
|
+
|
|
1802
|
+
This class extends the
|
|
1803
|
+
`ObjectDetector` class with additional methods for building footprint extraction."
|
|
1804
|
+
"""
|
|
1805
|
+
|
|
1806
|
+
def __init__(
|
|
1807
|
+
self,
|
|
1808
|
+
model_path="building_footprints_usa.pth",
|
|
1809
|
+
repo_id=None,
|
|
1810
|
+
model=None,
|
|
1811
|
+
device=None,
|
|
1812
|
+
):
|
|
1813
|
+
"""
|
|
1814
|
+
Initialize the object extractor.
|
|
1815
|
+
|
|
1816
|
+
Args:
|
|
1817
|
+
model_path: Path to the .pth model file.
|
|
1818
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
1819
|
+
model: Custom model to use for inference.
|
|
1820
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1821
|
+
"""
|
|
1822
|
+
super().__init__(
|
|
1823
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1824
|
+
)
|
|
1825
|
+
|
|
1826
|
+
def regularize_buildings(
|
|
1827
|
+
self,
|
|
1828
|
+
gdf,
|
|
1829
|
+
min_area=10,
|
|
1830
|
+
angle_threshold=15,
|
|
1831
|
+
orthogonality_threshold=0.3,
|
|
1832
|
+
rectangularity_threshold=0.7,
|
|
1833
|
+
):
|
|
1834
|
+
"""
|
|
1835
|
+
Regularize building footprints to enforce right angles and rectangular shapes.
|
|
1836
|
+
|
|
1837
|
+
Args:
|
|
1838
|
+
gdf: GeoDataFrame with building footprints
|
|
1839
|
+
min_area: Minimum area in square units to keep a building
|
|
1840
|
+
angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
|
|
1841
|
+
orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
|
|
1842
|
+
rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
|
|
1843
|
+
|
|
1844
|
+
Returns:
|
|
1845
|
+
GeoDataFrame with regularized building footprints
|
|
1846
|
+
"""
|
|
1847
|
+
return self.regularize_objects(
|
|
1848
|
+
gdf,
|
|
1849
|
+
min_area=min_area,
|
|
1850
|
+
angle_threshold=angle_threshold,
|
|
1851
|
+
orthogonality_threshold=orthogonality_threshold,
|
|
1852
|
+
rectangularity_threshold=rectangularity_threshold,
|
|
1853
|
+
)
|
|
1854
|
+
|
|
1855
|
+
|
|
1856
|
+
class CarDetector(ObjectDetector):
|
|
1857
|
+
"""
|
|
1858
|
+
Car detection using a pre-trained Mask R-CNN model.
|
|
1859
|
+
|
|
1860
|
+
This class extends the
|
|
1861
|
+
`ObjectDetector` class with additional methods for car detection."
|
|
1862
|
+
"""
|
|
1863
|
+
|
|
1864
|
+
def __init__(
|
|
1865
|
+
self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
|
|
1866
|
+
):
|
|
1867
|
+
"""
|
|
1868
|
+
Initialize the object extractor.
|
|
1869
|
+
|
|
1870
|
+
Args:
|
|
1871
|
+
model_path: Path to the .pth model file.
|
|
1872
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
1873
|
+
model: Custom model to use for inference.
|
|
1874
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1875
|
+
"""
|
|
1876
|
+
super().__init__(
|
|
1877
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1878
|
+
)
|
|
1879
|
+
|
|
1880
|
+
|
|
1881
|
+
class ShipDetector(ObjectDetector):
|
|
1882
|
+
"""
|
|
1883
|
+
Ship detection using a pre-trained Mask R-CNN model.
|
|
1884
|
+
|
|
1885
|
+
This class extends the
|
|
1886
|
+
`ObjectDetector` class with additional methods for ship detection."
|
|
1887
|
+
"""
|
|
1888
|
+
|
|
1889
|
+
def __init__(
|
|
1890
|
+
self, model_path="ship_detection.pth", repo_id=None, model=None, device=None
|
|
1891
|
+
):
|
|
1892
|
+
"""
|
|
1893
|
+
Initialize the object extractor.
|
|
1894
|
+
|
|
1895
|
+
Args:
|
|
1896
|
+
model_path: Path to the .pth model file.
|
|
1897
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
1898
|
+
model: Custom model to use for inference.
|
|
1899
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1900
|
+
"""
|
|
1901
|
+
super().__init__(
|
|
1902
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1903
|
+
)
|